In [1]:
from torch.utils.data import DataLoader
from datasets import load_dataset, Audio
from transformers import AutoFeatureExtractor, ASTForAudioClassification, AutoModelForAudioClassification
from torch.optim import AdamW

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:6'
sampling_rate = 16000 # 这个数据集本来就是16000
ds = load_dataset("MLCommons/ml_spoken_words", languages=["zh-CN"])
ds = ds.remove_columns(["file", 'is_valid',"language", "speaker_id", 'gender']) # 已检查，全部valid
train_set = ds['train']
val_set = ds['validation']
test_set = ds['test']
labels = []
# 检查数据集合法性
worddict = {}
for item in train_set:
    worddict[item['keyword']] = 1
for item in val_set:
    assert worddict[item['keyword']] == 1
for item in test_set:
    assert worddict[item['keyword']] == 1
for item in worddict:
    labels.append(item)
    
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label
print(label2id)
print(id2label)
num_labels = len(id2label)

Found cached dataset ml_spoken_words (/remote-home/pjli/.cache/huggingface/datasets/MLCommons___ml_spoken_words/zh-CN_wav-6e7818d76df98faf/1.0.0/321ea853cf0a05abb7a2d7efea900692a3d8622af65a2f3ce98adb7800a5d57b)
100%|██████████| 3/3 [00:00<00:00, 516.58it/s]


{'一些': 0, '三十四': 1, '三尖杉': 2, '三峡': 3, '三年': 4, '三维': 5, '上述': 6, '下表': 7, '不丹': 8, '不明': 9, '世代': 10, '世宗': 11, '东京': 12, '东侧': 13, '东南': 14, '东正教': 15, '东部': 16, '严谨': 17, '中举': 18, '为期': 19, '为首': 20, '主任委员': 21, '主体': 22, '主办': 23, '主线': 24, '主题曲': 25, '之中': 26, '乌头属': 27, '乔治亚州': 28, '乘客': 29, '九龙': 30, '争议': 31, '事务所': 32, '二十一': 33, '二氧化碳': 34, '于是': 35, '交会': 36, '交界': 37, '产于': 38, '人类': 39, '人类学': 40, '介面': 41, '任意': 42, '任职': 43, '众多': 44, '优良': 45, '传奇': 46, '传记': 47, '伴随': 48, '例如': 49, '供应': 50, '俄国': 51, '保加利亚': 52, '信用': 53, '俱乐部': 54, '偏好': 55, '停止': 56, '元素': 57, '光滑': 58, '兖州': 59, '入口': 60, '入境': 61, '入球': 62, '全市': 63, '八旗': 64, '公元': 65, '公司': 66, '公学': 67, '公家': 68, '兰德': 69, '关押': 70, '其实': 71, '具有': 72, '内尔': 73, '内战': 74, '内蒙古': 75, '军政': 76, '农业部': 77, '农场': 78, '凯尔': 79, '出使': 80, '出没': 81, '击败': 82, '分别': 83, '分子': 84, '分开': 85, '刑事': 86, '列表': 87, '初中': 88, '删除': 89, '剧场': 90, '副作用': 91, '努力': 92, '北海道': 93, '医师': 94, '十五': 95, '十八年': 96, '十字': 97, '十字花科':

In [3]:
feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=num_labels, label2id=label2id, id2label=id2label,
    ignore_mismatched_sizes=True
)
model.to(device)

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([498, 768]) in the model instantiated
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([498]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ASTLayer(
          (attention): ASTAttention(
            (attention): ASTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
            (de

In [4]:
import torch
def mlc_collate_fn(data):
    batch = {}
    # batch['keyword'] = [d['keyword'] for d in data]
    batch['labels'] = torch.tensor([label2id[d['keyword']] for d in data]).to(device)
    batch['input_values'] = torch.cat(
            [feature_extractor(d['audio']['array'], sampling_rate=sampling_rate, return_tensors="pt").input_values for d in data]).to(device) # a tensor of arrays return batch
    return batch

In [12]:
train_loader = DataLoader(dataset=train_set, batch_size=12, shuffle=True, collate_fn = mlc_collate_fn)
optimizer = AdamW(model.parameters(),lr = 1e-5)
epochs = 10

In [None]:
from tqdm import tqdm
for epoch in range(epochs):
    cnt = 0
    pos = 0
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        output = model(**batch)
        pred = output.logits.argmax(-1)
        labels = batch['labels']
        pos += (torch.sum(pred == labels)).item()
        cnt += pred.shape[0]
        loss = output.loss
        if batch_idx % 100 == 0: 
            with torch.no_grad():
                acc = torch.sum(pred == labels)/pred.shape[0]
                print('loss:',loss.item(),'acc:',acc.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print('epoch:',epoch,'acc:',pos/cnt)

In [14]:
# 测试
test_loader = DataLoader(dataset=test_set, batch_size=8, shuffle=True, collate_fn = mlc_collate_fn)
cnt = 0
pos = 0
model.eval()
with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(test_loader)):
        output = model(**batch)
        pred = output.logits.argmax(-1)
        labels = batch['labels']
        pos += (torch.sum(pred == labels)).item()
        cnt += pred.shape[0]
        if batch_idx % 100 == 0:
            print('pos:',pos,'cnt:',cnt)
        # loss = output.loss
        # if batch_idx % 100 == 0: 
        #     with torch.no_grad():
        #         acc = torch.sum(pred == labels)/pred.shape[0]
        #         print('loss:',loss.item(),'acc:',acc.item())
        # loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()
print('acc on test set:',pos/cnt)

2it [00:00,  5.22it/s]

pos: 5 cnt: 8


102it [00:19,  5.28it/s]

pos: 437 cnt: 808


154it [00:29,  5.28it/s]

acc on test set: 0.5219155844155844





In [9]:
import os
pid = os.getpid()
print(pid)
!kill -9 $pid

21770


: 

: 