In [1]:
from tensorflow import keras
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, f1_score

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout, GRU, Bidirectional, Flatten
from tensorflow.keras.optimizers import SGD
from tensorflow.random import set_seed

set_seed(2024)
np.random.seed(2024)


import csv
import librosa
import librosa.display
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

train_dir = os.path.abspath('../data/train/train')
test_dir = os.path.abspath('../data/test/test')

classes = os.listdir(train_dir + '/audio')
# classes.remove("_background_noise_")

X_train = np.load(train_dir + "/X_train.npy")
y_train = np.load(train_dir + "/y_train.npy")

X_val = np.load(train_dir + "/X_val.npy")
y_val = np.load(train_dir + "/y_val.npy")

X_train = X_train.reshape((-1, X_train.shape[1], X_train.shape[2]))
X_val = X_val.reshape((-1, X_val.shape[1], X_val.shape[2]))

X_test = np.load(test_dir + '/X_test.npy')
X_files = np.loadtxt(test_dir + '/X_files.txt', delimiter=" ", dtype='str')

def plot_loss(history_df, name, idx):
    plt.figure()
    plt.plot(history_df['loss'])
    plt.plot(history_df['val_loss'])
    plt.title(f'{name}: loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    path = f'train_history/{name}/{idx}_loss.png'
    plt.savefig(path)
    print(f'Loss plot is saved to: {path}')
    plt.close()


# LSTM

In [4]:
input_shape = (122, 85)
epochs = 100
batch_size = 32


model_lstm = Sequential()
model_lstm.add(LSTM(units=125, activation="tanh", input_shape=input_shape))
model_lstm.add(Dense(units=len(classes)))
# Compiling the model
model_lstm.compile(optimizer="RMSprop", loss="mse")

print(model_lstm.summary())

# Model training
history = model_lstm.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, y_val))

history_df = pd.DataFrame(history.history) 
hist_csv_file = 'train_history/' + 'LSTM' + '/' + str(0) + '_history.csv'
with open(hist_csv_file, mode='w') as f:
    history_df.to_csv(f)

plot_loss(history_df, 'LSTM', str(0))

# Check validation accuracy and f1
y_pred = np.argmax(model_lstm.predict(X_val), axis=-1)
accuracy = accuracy_score(y_val, y_pred)

f1 = f1_score(y_val, y_pred, average='macro')
metrics_df = pd.DataFrame({'accuracy': [accuracy], 'f1': [f1]}) 
metrics_csv_file = 'train_history/' + 'LSTM' + '/' + str(0) + '_metrics.csv'
with open(metrics_csv_file, mode='w') as f:
    metrics_df.to_csv(f)
print(metrics_df)

# Confusion Matrix
cm = confusion_matrix(y_val, y_pred)
accuracy = accuracy_score(y_val, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(0, 30), yticklabels=range(0, 30))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'LSTM Confusion Matrix. Accuracy: {accuracy * 100:.2f}%')
path = f'train_history/LSTM/0_confusion_matrix.png'
plt.savefig(path)
print(f'Confusion matrix is saved to: {path}')
plt.close()

# Prepare submission
submission_df = pd.DataFrame({'fname': [], 'label': []}) 
for ind, test_file in enumerate(X_files):
    y_pred = np.argmax(model_lstm.predict(np.array([X_test[ind]])), axis=-1)
    submission_df.loc[len(submission_df.index)] = [X_files[ind], classes[y_pred[0]]]
submission_csv_file = 'train_history/' + 'LSTM' + '/' + str(0) + '_submission.csv'
with open(submission_csv_file, mode='w') as f:
    submission_df.to_csv(f, index=False, lineterminator='\n')
print('Submission saved.')

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm_1 (LSTM)               (None, 125)               105500    
                                                                 
 dense_1 (Dense)             (None, 30)                3780      
                                                                 
Total params: 109,280
Trainable params: 109,280
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 

MemoryError: Unable to allocate 1.63 MiB for an array with shape (2, 106512) and data type object

# GRU

In [2]:
input_shape = (122, 85)
epochs = 50
batch_size = 32

model_gru = Sequential()
model_gru.add(Dropout(0.2, input_shape=input_shape))
model_gru.add(GRU(units=125, return_sequences=True))
model_gru.add(Dropout(0.2))
model_gru.add(GRU(units=125, return_sequences=True))
model_gru.add(Dropout(0.2))
model_gru.add(GRU(units=125, return_sequences=True))
model_gru.add(Dropout(0.2))
model_gru.add(Flatten()),
model_gru.add(Dense(units=125, activation="sigmoid"))
model_gru.add(Dropout(0.2))
model_gru.add(Dense(units=len(classes), activation="softmax"))
# Compiling the model
model_gru.compile(optimizer="adam", loss="mse")

print(model_gru.summary())

# Model training
history = model_gru.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, y_val))

history_df = pd.DataFrame(history.history) 
hist_csv_file = 'train_history/' + 'GRU' + '/' + str(1) + '_history.csv'
with open(hist_csv_file, mode='w') as f:
    history_df.to_csv(f)

plot_loss(history_df, 'GRU', str(1))

# Check validation accuracy and f1
y_pred = np.argmax(model_gru.predict(X_val), axis=-1)
accuracy = accuracy_score(y_val, y_pred)

f1 = f1_score(y_val, y_pred, average='macro')
metrics_df = pd.DataFrame({'accuracy': [accuracy], 'f1': [f1]}) 
metrics_csv_file = 'train_history/' + 'GRU' + '/' + str(1) + '_metrics.csv'
with open(metrics_csv_file, mode='w') as f:
    metrics_df.to_csv(f)
print(metrics_df)

# Confusion Matrix
cm = confusion_matrix(y_val, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(0, 30), yticklabels=range(0, 30))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'GRU Confusion Matrix. Accuracy: {accuracy * 100:.2f}%')
path = f'train_history/GRU/1_confusion_matrix.png'
plt.savefig(path)
print(f'Confusion matrix is saved to: {path}')
plt.close()

# Prepare submission
# submission_df = pd.DataFrame({'fname': [], 'label': []})
# # y_pred = np.argmax(model_gru.predict(X_test), axis=-1)
# for ind, test_file in enumerate(X_files):
#     if ind%2000 == 0:
#         print("{} done!".format(ind))
#     y_pred = np.argmax(model_gru.predict(np.array([X_test[ind]])), axis=-1)
#     submission_df.loc[len(submission_df.index)] = [X_files[ind], classes[y_pred[0]]]
# submission_csv_file = 'train_history/' + 'GRU' + '/' + str(0) + '_submission.csv'
# with open(submission_csv_file, mode='w') as f:
#     submission_df.to_csv(f, index=False, lineterminator='\n')
# print('Submission saved.')

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dropout (Dropout)           (None, 122, 85)           0         
                                                                 
 gru (GRU)                   (None, 122, 125)          79500     
                                                                 
 dropout_1 (Dropout)         (None, 122, 125)          0         
                                                                 
 gru_1 (GRU)                 (None, 122, 125)          94500     
                                                                 
 dropout_2 (Dropout)         (None, 122, 125)          0         
                                                                 
 gru_2 (GRU)                 (None, 122, 125)          94500     
                                                                 
 dropout_3 (Dropout)         (None, 122, 125)          0

# Transformer

# Prepare silence

In [None]:
from pydub import AudioSegment

def split_and_save_silence(original, name):
    for i in range(int(len(original) / 1000) - 1):
        t1 = i * 1000
        t2 = (i+1) * 1000
        newAudio = original[t1:t2]
        newAudio.export(f'../data/train/train/audio/silence/{name}_{i}.wav', format="wav")

split_and_save_silence(AudioSegment.from_wav(os.path.abspath('../data/train/train/_background_noise_/doing_the_dishes.wav')), 'doing_the_dishes')
split_and_save_silence(AudioSegment.from_wav(os.path.abspath('../data/train/train/_background_noise_/dude_miaowing.wav')), 'dude_miaowing')
split_and_save_silence(AudioSegment.from_wav(os.path.abspath('../data/train/train/_background_noise_/white_noise.wav')), 'white_noise')
split_and_save_silence(AudioSegment.from_wav(os.path.abspath('../data/train/train/_background_noise_/exercise_bike.wav')), 'exercise_bike')
split_and_save_silence(AudioSegment.from_wav(os.path.abspath('../data/train/train/_background_noise_/pink_noise.wav')), 'pink_noise')
split_and_save_silence(AudioSegment.from_wav(os.path.abspath('../data/train/train/_background_noise_/running_tap.wav')), 'running_tap')

In [2]:
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, TrainingArguments, Trainer, pipeline
import evaluate
from datasets import load_dataset, Audio, ClassLabel, concatenate_datasets
import os

train_dir = os.path.abspath('../data/train/train')
dataset = load_dataset("audiofolder", data_dir="../data/train/train/audio")
labels = dataset["train"].features["label"].names
# print(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

known_classes = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'silence', 'unknown']

def rename_unknown_labels(sample):
    if id2label[str(sample["label"])] not in known_classes:
        sample["new_label"] = known_classes.index('unknown')
    else:
        sample["new_label"] = known_classes.index(id2label[str(sample["label"])])
    return sample

dataset['train'] = dataset['train'].map(rename_unknown_labels, remove_columns='label')
dataset['train'] = dataset['train'].cast_column("new_label", ClassLabel(num_classes=len(known_classes), names=known_classes))
dataset['train'] = dataset['train'].rename_column("new_label", "label")

# Reduce number of unknown
dataset_unknown = dataset['train'].filter(lambda example: example["label"] == known_classes.index('unknown'))
dataset['train'] = dataset['train'].filter(lambda example: example["label"] != known_classes.index('unknown'))
dataset_unknown_split = dataset_unknown.train_test_split(test_size=0.05)
dataset['train'] = concatenate_datasets([dataset['train'], dataset_unknown_split['test']])
dataset = dataset['train'].train_test_split(test_size=0.2)

# Identify new label mappings
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
id2label


  from .autonotebook import tqdm as notebook_tqdm


{'0': 'yes',
 '1': 'no',
 '2': 'up',
 '3': 'down',
 '4': 'left',
 '5': 'right',
 '6': 'on',
 '7': 'off',
 '8': 'stop',
 '9': 'go',
 '10': 'silence',
 '11': 'unknown'}

## facebook/wav2vec2-base

In [59]:
feature_extractor = AutoFeatureExtractor.from_pretrained(
    "facebook/wav2vec2-base")

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
      audio_arrays,
      sampling_rate=feature_extractor.sampling_rate,
      max_length=16000, truncation=True
    )
    return inputs
 
 
encoded_dataset = dataset.map(
    preprocess_function, remove_columns="audio", batched=True)
# encoded_dataset = encoded_dataset.rename_column("intent_class", "label")

encoded_dataset['train']


Map: 100%|██████████| 20900/20900 [00:11<00:00, 1757.69 examples/s]
Map: 100%|██████████| 5226/5226 [00:03<00:00, 1574.70 examples/s]


Dataset({
    features: ['label', 'input_values'],
    num_rows: 20900
})

In [60]:
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)


In [61]:
import numpy as np

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-base", num_labels=num_labels, 
  label2id=label2id, id2label=id2label
)
 
training_args = TrainingArguments(
    output_dir="GFG_commands_reduced_unknown",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    # increase this if you have better machine resources
    learning_rate=4e-4,
    # increase this if you have better machine resources
    per_device_train_batch_size=32,
    gradient_accumulation_steps=4,
    # increase this if you have better machine resources
    per_device_eval_batch_size=32,
    # increase this if you have better machine resources
    num_train_epochs=8,
    warmup_ratio=0.1,
    # increase this if you have better machine resources
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,                             # set to False if you don't have GPU
)
 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"].with_format("torch"),
    eval_dataset=encoded_dataset["test"].with_format("torch"),
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)
 
trainer.train()


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 10/1304 [00:06<12:47,  1.69it/s]

{'loss': 2.4824, 'grad_norm': 0.6286147236824036, 'learning_rate': 2.4427480916030535e-05, 'epoch': 0.06}


  2%|▏         | 20/1304 [00:12<12:49,  1.67it/s]

{'loss': 2.4414, 'grad_norm': 1.3282233476638794, 'learning_rate': 5.1908396946564884e-05, 'epoch': 0.12}


  2%|▏         | 30/1304 [00:18<12:43,  1.67it/s]

{'loss': 2.2983, 'grad_norm': nan, 'learning_rate': 7.938931297709924e-05, 'epoch': 0.18}


  3%|▎         | 40/1304 [00:24<12:37,  1.67it/s]

{'loss': 1.9507, 'grad_norm': 8.602686882019043, 'learning_rate': 0.00010687022900763359, 'epoch': 0.24}


  4%|▍         | 50/1304 [00:30<12:38,  1.65it/s]

{'loss': 1.5431, 'grad_norm': 9.872970581054688, 'learning_rate': 0.00013740458015267177, 'epoch': 0.31}


  5%|▍         | 60/1304 [00:36<12:40,  1.64it/s]

{'loss': 1.2198, 'grad_norm': 7.798912048339844, 'learning_rate': 0.00016793893129770992, 'epoch': 0.37}


  5%|▌         | 70/1304 [00:42<12:18,  1.67it/s]

{'loss': 1.0167, 'grad_norm': 9.402361869812012, 'learning_rate': 0.0001984732824427481, 'epoch': 0.43}


  6%|▌         | 80/1304 [00:48<11:39,  1.75it/s]

{'loss': 0.8613, 'grad_norm': 9.996301651000977, 'learning_rate': 0.00022900763358778625, 'epoch': 0.49}


  7%|▋         | 90/1304 [00:53<11:26,  1.77it/s]

{'loss': 0.8676, 'grad_norm': 15.601285934448242, 'learning_rate': 0.00025954198473282443, 'epoch': 0.55}


  8%|▊         | 100/1304 [00:59<11:18,  1.77it/s]

{'loss': 0.7732, 'grad_norm': 27.552324295043945, 'learning_rate': 0.0002900763358778626, 'epoch': 0.61}


  8%|▊         | 110/1304 [01:05<11:13,  1.77it/s]

{'loss': 0.9186, 'grad_norm': 17.670345306396484, 'learning_rate': 0.00031755725190839697, 'epoch': 0.67}


  9%|▉         | 120/1304 [01:10<11:12,  1.76it/s]

{'loss': 0.8096, 'grad_norm': 13.415016174316406, 'learning_rate': 0.0003480916030534351, 'epoch': 0.73}


 10%|▉         | 130/1304 [01:16<11:06,  1.76it/s]

{'loss': 0.9841, 'grad_norm': 14.81900405883789, 'learning_rate': 0.00037862595419847333, 'epoch': 0.8}


 11%|█         | 140/1304 [01:22<11:01,  1.76it/s]

{'loss': 0.8302, 'grad_norm': 21.113359451293945, 'learning_rate': 0.0003989769820971867, 'epoch': 0.86}


 12%|█▏        | 150/1304 [01:27<10:55,  1.76it/s]

{'loss': 0.9165, 'grad_norm': 12.502235412597656, 'learning_rate': 0.00039556692242114237, 'epoch': 0.92}


 12%|█▏        | 160/1304 [01:33<10:48,  1.76it/s]

{'loss': 0.9394, 'grad_norm': 9.118658065795898, 'learning_rate': 0.000392156862745098, 'epoch': 0.98}


 12%|█▎        | 163/1304 [01:35<10:49,  1.76it/s]
 12%|█▎        | 163/1304 [01:42<10:49,  1.76it/s]

{'eval_loss': 0.6720324158668518, 'eval_accuracy': 0.8538078836586299, 'eval_runtime': 6.7428, 'eval_samples_per_second': 775.049, 'eval_steps_per_second': 24.322, 'epoch': 1.0}


 13%|█▎        | 170/1304 [01:46<15:50,  1.19it/s]

{'loss': 0.9971, 'grad_norm': 12.292678833007812, 'learning_rate': 0.0003890878090366582, 'epoch': 1.04}


 14%|█▍        | 180/1304 [01:52<10:45,  1.74it/s]

{'loss': 0.9689, 'grad_norm': 50.75411605834961, 'learning_rate': 0.00038567774936061383, 'epoch': 1.1}


 15%|█▍        | 190/1304 [01:58<10:33,  1.76it/s]

{'loss': 0.9321, 'grad_norm': 11.759584426879883, 'learning_rate': 0.00038226768968456954, 'epoch': 1.16}


 15%|█▌        | 200/1304 [02:03<10:29,  1.75it/s]

{'loss': 1.0133, 'grad_norm': 10.136816024780273, 'learning_rate': 0.0003788576300085252, 'epoch': 1.22}


 16%|█▌        | 210/1304 [02:09<10:21,  1.76it/s]

{'loss': 0.7582, 'grad_norm': 5.858729362487793, 'learning_rate': 0.00037544757033248084, 'epoch': 1.28}


 17%|█▋        | 220/1304 [02:15<10:21,  1.74it/s]

{'loss': 0.7689, 'grad_norm': 12.35438346862793, 'learning_rate': 0.00037203751065643655, 'epoch': 1.35}


 18%|█▊        | 230/1304 [02:21<10:18,  1.74it/s]

{'loss': 0.8938, 'grad_norm': 5.80318021774292, 'learning_rate': 0.00036862745098039214, 'epoch': 1.41}


 18%|█▊        | 240/1304 [02:26<10:14,  1.73it/s]

{'loss': 0.8308, 'grad_norm': 4.2485456466674805, 'learning_rate': 0.00036521739130434785, 'epoch': 1.47}


 19%|█▉        | 250/1304 [02:32<10:11,  1.72it/s]

{'loss': 0.8325, 'grad_norm': 14.981266975402832, 'learning_rate': 0.0003618073316283035, 'epoch': 1.53}


 20%|█▉        | 260/1304 [02:38<10:07,  1.72it/s]

{'loss': 0.8647, 'grad_norm': 7.030647277832031, 'learning_rate': 0.00035839727195225915, 'epoch': 1.59}


 21%|██        | 270/1304 [02:44<10:09,  1.70it/s]

{'loss': 0.7134, 'grad_norm': 13.524137496948242, 'learning_rate': 0.00035498721227621486, 'epoch': 1.65}


 21%|██▏       | 280/1304 [02:50<09:57,  1.71it/s]

{'loss': 0.7986, 'grad_norm': 13.521166801452637, 'learning_rate': 0.0003515771526001705, 'epoch': 1.71}


 22%|██▏       | 290/1304 [02:56<09:51,  1.72it/s]

{'loss': 0.8043, 'grad_norm': 6.396311283111572, 'learning_rate': 0.00034816709292412616, 'epoch': 1.77}


 23%|██▎       | 300/1304 [03:02<10:01,  1.67it/s]

{'loss': 0.7208, 'grad_norm': 4.8131537437438965, 'learning_rate': 0.00034475703324808187, 'epoch': 1.83}


 24%|██▍       | 310/1304 [03:08<09:52,  1.68it/s]

{'loss': 0.8435, 'grad_norm': 7.527023792266846, 'learning_rate': 0.0003413469735720375, 'epoch': 1.9}


 25%|██▍       | 320/1304 [03:14<09:39,  1.70it/s]

{'loss': 0.7271, 'grad_norm': 7.094088554382324, 'learning_rate': 0.00033793691389599317, 'epoch': 1.96}


 25%|██▌       | 327/1304 [03:18<09:09,  1.78it/s]
 25%|██▌       | 327/1304 [03:25<09:09,  1.78it/s]

{'eval_loss': 0.4453696012496948, 'eval_accuracy': 0.8991580558744738, 'eval_runtime': 6.9221, 'eval_samples_per_second': 754.969, 'eval_steps_per_second': 23.692, 'epoch': 2.0}


 25%|██▌       | 330/1304 [03:27<28:15,  1.74s/it]

{'loss': 0.742, 'grad_norm': 7.009048938751221, 'learning_rate': 0.0003345268542199489, 'epoch': 2.02}


 26%|██▌       | 340/1304 [03:33<10:06,  1.59it/s]

{'loss': 0.7589, 'grad_norm': 9.418395042419434, 'learning_rate': 0.0003311167945439045, 'epoch': 2.08}


 27%|██▋       | 350/1304 [03:39<09:29,  1.68it/s]

{'loss': 0.655, 'grad_norm': 8.03874683380127, 'learning_rate': 0.0003277067348678602, 'epoch': 2.14}


 28%|██▊       | 360/1304 [03:45<09:24,  1.67it/s]

{'loss': 0.5918, 'grad_norm': 6.436952590942383, 'learning_rate': 0.0003242966751918159, 'epoch': 2.2}


 28%|██▊       | 370/1304 [03:51<09:17,  1.68it/s]

{'loss': 0.599, 'grad_norm': 6.9878668785095215, 'learning_rate': 0.00032088661551577153, 'epoch': 2.26}


 29%|██▉       | 380/1304 [03:57<09:06,  1.69it/s]

{'loss': 0.643, 'grad_norm': 6.2331390380859375, 'learning_rate': 0.0003174765558397272, 'epoch': 2.32}


 30%|██▉       | 390/1304 [04:03<09:02,  1.68it/s]

{'loss': 0.5734, 'grad_norm': 9.526603698730469, 'learning_rate': 0.0003140664961636829, 'epoch': 2.39}


 31%|███       | 400/1304 [04:09<09:04,  1.66it/s]

{'loss': 0.6094, 'grad_norm': 5.081620693206787, 'learning_rate': 0.00031065643648763854, 'epoch': 2.45}


 31%|███▏      | 410/1304 [04:15<08:55,  1.67it/s]

{'loss': 0.619, 'grad_norm': 12.497641563415527, 'learning_rate': 0.0003072463768115942, 'epoch': 2.51}


 32%|███▏      | 420/1304 [04:21<08:51,  1.66it/s]

{'loss': 0.6109, 'grad_norm': 8.949642181396484, 'learning_rate': 0.0003038363171355499, 'epoch': 2.57}


 33%|███▎      | 430/1304 [04:27<08:46,  1.66it/s]

{'loss': 0.5977, 'grad_norm': 4.375180721282959, 'learning_rate': 0.00030042625745950555, 'epoch': 2.63}


 34%|███▎      | 440/1304 [04:33<08:41,  1.66it/s]

{'loss': 0.6985, 'grad_norm': 4.527022838592529, 'learning_rate': 0.0002970161977834612, 'epoch': 2.69}


 35%|███▍      | 450/1304 [04:39<08:38,  1.65it/s]

{'loss': 0.5842, 'grad_norm': 3.775916814804077, 'learning_rate': 0.0002936061381074169, 'epoch': 2.75}


 35%|███▌      | 460/1304 [04:45<08:30,  1.65it/s]

{'loss': 0.5464, 'grad_norm': 4.207491397857666, 'learning_rate': 0.00029019607843137256, 'epoch': 2.81}


 36%|███▌      | 470/1304 [04:51<08:22,  1.66it/s]

{'loss': 0.5952, 'grad_norm': 8.31806468963623, 'learning_rate': 0.0002867860187553282, 'epoch': 2.87}


 37%|███▋      | 480/1304 [04:57<08:19,  1.65it/s]

{'loss': 0.5492, 'grad_norm': 4.159152507781982, 'learning_rate': 0.0002833759590792839, 'epoch': 2.94}


 38%|███▊      | 490/1304 [05:03<08:12,  1.65it/s]

{'loss': 0.5265, 'grad_norm': 6.54163932800293, 'learning_rate': 0.00027996589940323957, 'epoch': 3.0}



 38%|███▊      | 490/1304 [05:11<08:12,  1.65it/s]

{'eval_loss': 0.2622239887714386, 'eval_accuracy': 0.9246077305778798, 'eval_runtime': 7.0288, 'eval_samples_per_second': 743.513, 'eval_steps_per_second': 23.333, 'epoch': 3.0}


 38%|███▊      | 500/1304 [05:18<09:31,  1.41it/s]

{'loss': 0.4559, 'grad_norm': 7.447869300842285, 'learning_rate': 0.0002765558397271952, 'epoch': 3.06}


 39%|███▉      | 510/1304 [05:24<08:06,  1.63it/s]

{'loss': 0.5409, 'grad_norm': 5.04181432723999, 'learning_rate': 0.0002731457800511509, 'epoch': 3.12}


 40%|███▉      | 520/1304 [05:30<08:31,  1.53it/s]

{'loss': 0.4563, 'grad_norm': 8.103859901428223, 'learning_rate': 0.0002697357203751066, 'epoch': 3.18}


 41%|████      | 530/1304 [05:37<08:24,  1.53it/s]

{'loss': 0.4876, 'grad_norm': 5.504897594451904, 'learning_rate': 0.00026632566069906223, 'epoch': 3.24}


 41%|████▏     | 540/1304 [05:43<08:15,  1.54it/s]

{'loss': 0.5193, 'grad_norm': 6.96021842956543, 'learning_rate': 0.00026291560102301793, 'epoch': 3.3}


 42%|████▏     | 550/1304 [05:50<08:05,  1.55it/s]

{'loss': 0.4701, 'grad_norm': 4.0810394287109375, 'learning_rate': 0.0002595055413469736, 'epoch': 3.36}


 43%|████▎     | 560/1304 [05:56<08:00,  1.55it/s]

{'loss': 0.5241, 'grad_norm': 4.537802219390869, 'learning_rate': 0.00025609548167092924, 'epoch': 3.43}


 44%|████▎     | 570/1304 [06:03<08:01,  1.52it/s]

{'loss': 0.4625, 'grad_norm': 5.111947536468506, 'learning_rate': 0.00025268542199488494, 'epoch': 3.49}


 44%|████▍     | 580/1304 [06:09<08:04,  1.49it/s]

{'loss': 0.4444, 'grad_norm': 5.047516345977783, 'learning_rate': 0.0002492753623188406, 'epoch': 3.55}


 45%|████▌     | 590/1304 [06:16<08:03,  1.48it/s]

{'loss': 0.4827, 'grad_norm': 4.791135311126709, 'learning_rate': 0.00024586530264279625, 'epoch': 3.61}


 46%|████▌     | 600/1304 [06:23<07:45,  1.51it/s]

{'loss': 0.517, 'grad_norm': 3.3689920902252197, 'learning_rate': 0.00024245524296675192, 'epoch': 3.67}


 47%|████▋     | 610/1304 [06:29<07:30,  1.54it/s]

{'loss': 0.5039, 'grad_norm': 6.654399394989014, 'learning_rate': 0.0002390451832907076, 'epoch': 3.73}


 48%|████▊     | 620/1304 [06:36<07:11,  1.59it/s]

{'loss': 0.3863, 'grad_norm': 2.971837043762207, 'learning_rate': 0.00023563512361466325, 'epoch': 3.79}


 48%|████▊     | 630/1304 [06:42<06:59,  1.61it/s]

{'loss': 0.4084, 'grad_norm': 7.112678527832031, 'learning_rate': 0.00023222506393861893, 'epoch': 3.85}


 49%|████▉     | 640/1304 [06:48<06:54,  1.60it/s]

{'loss': 0.4409, 'grad_norm': 3.527311086654663, 'learning_rate': 0.0002288150042625746, 'epoch': 3.91}


 50%|████▉     | 650/1304 [06:55<06:47,  1.60it/s]

{'loss': 0.4301, 'grad_norm': 3.5551846027374268, 'learning_rate': 0.00022540494458653026, 'epoch': 3.98}


 50%|█████     | 654/1304 [06:57<06:30,  1.67it/s]
 50%|█████     | 654/1304 [07:04<06:30,  1.67it/s]

{'eval_loss': 0.1800004243850708, 'eval_accuracy': 0.9531190202831994, 'eval_runtime': 7.2139, 'eval_samples_per_second': 724.44, 'eval_steps_per_second': 22.734, 'epoch': 4.0}


 51%|█████     | 660/1304 [07:09<11:11,  1.04s/it]

{'loss': 0.3861, 'grad_norm': 3.8510050773620605, 'learning_rate': 0.00022199488491048594, 'epoch': 4.04}


 51%|█████▏    | 670/1304 [07:15<06:47,  1.56it/s]

{'loss': 0.3834, 'grad_norm': 4.594685077667236, 'learning_rate': 0.00021858482523444162, 'epoch': 4.1}


 52%|█████▏    | 680/1304 [07:22<06:31,  1.59it/s]

{'loss': 0.384, 'grad_norm': 3.959357738494873, 'learning_rate': 0.00021517476555839727, 'epoch': 4.16}


 53%|█████▎    | 690/1304 [07:28<06:20,  1.61it/s]

{'loss': 0.3997, 'grad_norm': 3.674698829650879, 'learning_rate': 0.00021176470588235295, 'epoch': 4.22}


 54%|█████▎    | 700/1304 [07:34<06:13,  1.62it/s]

{'loss': 0.3363, 'grad_norm': 3.096623420715332, 'learning_rate': 0.00020835464620630863, 'epoch': 4.28}


 54%|█████▍    | 710/1304 [07:40<06:08,  1.61it/s]

{'loss': 0.3722, 'grad_norm': 3.223283052444458, 'learning_rate': 0.00020494458653026428, 'epoch': 4.34}


 55%|█████▌    | 720/1304 [07:46<06:02,  1.61it/s]

{'loss': 0.3686, 'grad_norm': 4.242850303649902, 'learning_rate': 0.00020153452685421996, 'epoch': 4.4}


 56%|█████▌    | 730/1304 [07:53<05:57,  1.61it/s]

{'loss': 0.3413, 'grad_norm': 5.012390613555908, 'learning_rate': 0.00019812446717817564, 'epoch': 4.46}


 57%|█████▋    | 740/1304 [07:59<05:52,  1.60it/s]

{'loss': 0.3188, 'grad_norm': 6.473104953765869, 'learning_rate': 0.00019471440750213132, 'epoch': 4.53}


 58%|█████▊    | 750/1304 [08:05<05:49,  1.58it/s]

{'loss': 0.411, 'grad_norm': 4.673171043395996, 'learning_rate': 0.00019130434782608697, 'epoch': 4.59}


 58%|█████▊    | 760/1304 [08:11<05:44,  1.58it/s]

{'loss': 0.3728, 'grad_norm': 6.339261531829834, 'learning_rate': 0.00018789428815004265, 'epoch': 4.65}


 59%|█████▉    | 770/1304 [08:18<05:38,  1.58it/s]

{'loss': 0.3223, 'grad_norm': 4.0824666023254395, 'learning_rate': 0.00018448422847399832, 'epoch': 4.71}


 60%|█████▉    | 780/1304 [08:24<05:32,  1.58it/s]

{'loss': 0.3861, 'grad_norm': 4.412510871887207, 'learning_rate': 0.00018107416879795398, 'epoch': 4.77}


 61%|██████    | 790/1304 [08:31<05:26,  1.58it/s]

{'loss': 0.3459, 'grad_norm': 2.6922767162323, 'learning_rate': 0.00017766410912190965, 'epoch': 4.83}


 61%|██████▏   | 800/1304 [08:37<05:19,  1.58it/s]

{'loss': 0.3021, 'grad_norm': 4.5593976974487305, 'learning_rate': 0.00017425404944586533, 'epoch': 4.89}


 62%|██████▏   | 810/1304 [08:43<05:09,  1.60it/s]

{'loss': 0.3962, 'grad_norm': 4.951617240905762, 'learning_rate': 0.00017084398976982098, 'epoch': 4.95}


 63%|██████▎   | 817/1304 [08:48<05:03,  1.61it/s]
 63%|██████▎   | 817/1304 [08:55<05:03,  1.61it/s]

{'eval_loss': 0.22894826531410217, 'eval_accuracy': 0.9456563337160352, 'eval_runtime': 7.1156, 'eval_samples_per_second': 734.443, 'eval_steps_per_second': 23.048, 'epoch': 5.0}


 63%|██████▎   | 820/1304 [08:57<14:29,  1.80s/it]

{'loss': 0.4309, 'grad_norm': 4.848880767822266, 'learning_rate': 0.00016743393009377666, 'epoch': 5.02}


 64%|██████▎   | 830/1304 [09:04<05:14,  1.51it/s]

{'loss': 0.3454, 'grad_norm': 4.777740001678467, 'learning_rate': 0.00016402387041773234, 'epoch': 5.08}


 64%|██████▍   | 840/1304 [09:10<04:53,  1.58it/s]

{'loss': 0.3222, 'grad_norm': 5.1117424964904785, 'learning_rate': 0.000160613810741688, 'epoch': 5.14}


 65%|██████▌   | 850/1304 [09:16<04:46,  1.58it/s]

{'loss': 0.3531, 'grad_norm': 4.876353740692139, 'learning_rate': 0.00015720375106564367, 'epoch': 5.2}


 66%|██████▌   | 860/1304 [09:23<04:39,  1.59it/s]

{'loss': 0.3401, 'grad_norm': 3.1518046855926514, 'learning_rate': 0.00015379369138959932, 'epoch': 5.26}


 67%|██████▋   | 870/1304 [09:29<04:33,  1.59it/s]

{'loss': 0.3016, 'grad_norm': 5.360238075256348, 'learning_rate': 0.00015038363171355497, 'epoch': 5.32}


 67%|██████▋   | 880/1304 [09:35<04:26,  1.59it/s]

{'loss': 0.2919, 'grad_norm': 5.208247184753418, 'learning_rate': 0.00014697357203751065, 'epoch': 5.38}


 68%|██████▊   | 890/1304 [09:41<04:20,  1.59it/s]

{'loss': 0.3373, 'grad_norm': 2.979475498199463, 'learning_rate': 0.00014356351236146633, 'epoch': 5.44}


 69%|██████▉   | 900/1304 [09:48<04:15,  1.58it/s]

{'loss': 0.2793, 'grad_norm': 3.417940616607666, 'learning_rate': 0.00014015345268542198, 'epoch': 5.5}


 70%|██████▉   | 910/1304 [09:54<04:07,  1.59it/s]

{'loss': 0.262, 'grad_norm': 2.5962023735046387, 'learning_rate': 0.00013674339300937766, 'epoch': 5.57}


 71%|███████   | 920/1304 [10:00<04:03,  1.58it/s]

{'loss': 0.2956, 'grad_norm': 5.7351555824279785, 'learning_rate': 0.00013333333333333334, 'epoch': 5.63}


 71%|███████▏  | 930/1304 [10:07<03:55,  1.59it/s]

{'loss': 0.3171, 'grad_norm': 14.81772232055664, 'learning_rate': 0.000129923273657289, 'epoch': 5.69}


 72%|███████▏  | 940/1304 [10:13<03:48,  1.59it/s]

{'loss': 0.3062, 'grad_norm': 5.2995452880859375, 'learning_rate': 0.00012651321398124467, 'epoch': 5.75}


 73%|███████▎  | 950/1304 [10:19<03:42,  1.59it/s]

{'loss': 0.2871, 'grad_norm': 4.5258941650390625, 'learning_rate': 0.00012310315430520035, 'epoch': 5.81}


 74%|███████▎  | 960/1304 [10:26<03:36,  1.59it/s]

{'loss': 0.2566, 'grad_norm': 5.312527179718018, 'learning_rate': 0.00011969309462915601, 'epoch': 5.87}


 74%|███████▍  | 970/1304 [10:32<03:30,  1.59it/s]

{'loss': 0.2761, 'grad_norm': 4.954793930053711, 'learning_rate': 0.00011628303495311168, 'epoch': 5.93}


 75%|███████▌  | 980/1304 [10:38<03:24,  1.58it/s]

{'loss': 0.2729, 'grad_norm': 4.755916118621826, 'learning_rate': 0.00011287297527706734, 'epoch': 5.99}


 75%|███████▌  | 981/1304 [10:39<03:14,  1.66it/s]
 75%|███████▌  | 981/1304 [10:46<03:14,  1.66it/s]

{'eval_loss': 0.12978434562683105, 'eval_accuracy': 0.9651741293532339, 'eval_runtime': 7.2913, 'eval_samples_per_second': 716.746, 'eval_steps_per_second': 22.493, 'epoch': 6.0}


 76%|███████▌  | 990/1304 [10:53<04:00,  1.30it/s]

{'loss': 0.2475, 'grad_norm': 2.343234062194824, 'learning_rate': 0.00010946291560102302, 'epoch': 6.06}


 77%|███████▋  | 1000/1304 [10:59<03:12,  1.58it/s]

{'loss': 0.268, 'grad_norm': 3.781723976135254, 'learning_rate': 0.00010605285592497869, 'epoch': 6.12}


 77%|███████▋  | 1010/1304 [11:05<03:11,  1.54it/s]

{'loss': 0.233, 'grad_norm': 2.315352439880371, 'learning_rate': 0.00010264279624893435, 'epoch': 6.18}


 78%|███████▊  | 1020/1304 [11:12<03:01,  1.56it/s]

{'loss': 0.217, 'grad_norm': 12.176037788391113, 'learning_rate': 9.923273657289003e-05, 'epoch': 6.24}


 79%|███████▉  | 1030/1304 [11:18<02:55,  1.56it/s]

{'loss': 0.2492, 'grad_norm': 5.365221977233887, 'learning_rate': 9.58226768968457e-05, 'epoch': 6.3}


 80%|███████▉  | 1040/1304 [11:25<02:48,  1.56it/s]

{'loss': 0.2256, 'grad_norm': 1.8933578729629517, 'learning_rate': 9.241261722080137e-05, 'epoch': 6.36}


 81%|████████  | 1050/1304 [11:31<02:39,  1.59it/s]

{'loss': 0.2896, 'grad_norm': 3.9985015392303467, 'learning_rate': 8.900255754475704e-05, 'epoch': 6.42}


 81%|████████▏ | 1060/1304 [11:37<02:33,  1.59it/s]

{'loss': 0.2597, 'grad_norm': 2.509798765182495, 'learning_rate': 8.55924978687127e-05, 'epoch': 6.48}


 82%|████████▏ | 1070/1304 [11:43<02:26,  1.60it/s]

{'loss': 0.2281, 'grad_norm': 3.320580005645752, 'learning_rate': 8.218243819266838e-05, 'epoch': 6.54}


 83%|████████▎ | 1080/1304 [11:50<02:20,  1.60it/s]

{'loss': 0.2382, 'grad_norm': 2.7813825607299805, 'learning_rate': 7.877237851662405e-05, 'epoch': 6.61}


 84%|████████▎ | 1090/1304 [11:56<02:13,  1.60it/s]

{'loss': 0.2061, 'grad_norm': 3.837019681930542, 'learning_rate': 7.536231884057971e-05, 'epoch': 6.67}


 84%|████████▍ | 1100/1304 [12:02<02:11,  1.55it/s]

{'loss': 0.2162, 'grad_norm': 2.6795654296875, 'learning_rate': 7.195225916453539e-05, 'epoch': 6.73}


 85%|████████▌ | 1110/1304 [12:09<02:04,  1.56it/s]

{'loss': 0.2547, 'grad_norm': 6.043508529663086, 'learning_rate': 6.854219948849106e-05, 'epoch': 6.79}


 86%|████████▌ | 1120/1304 [12:15<01:57,  1.57it/s]

{'loss': 0.2508, 'grad_norm': 2.753730535507202, 'learning_rate': 6.513213981244672e-05, 'epoch': 6.85}


 87%|████████▋ | 1130/1304 [12:21<01:49,  1.59it/s]

{'loss': 0.198, 'grad_norm': 2.4908828735351562, 'learning_rate': 6.17220801364024e-05, 'epoch': 6.91}


 87%|████████▋ | 1140/1304 [12:28<01:43,  1.59it/s]

{'loss': 0.2279, 'grad_norm': 4.796147346496582, 'learning_rate': 5.8312020460358065e-05, 'epoch': 6.97}


 88%|████████▊ | 1144/1304 [12:30<01:40,  1.59it/s]
 88%|████████▊ | 1144/1304 [12:38<01:40,  1.59it/s]

{'eval_loss': 0.12121764570474625, 'eval_accuracy': 0.9693838499808649, 'eval_runtime': 7.3635, 'eval_samples_per_second': 709.715, 'eval_steps_per_second': 22.272, 'epoch': 7.0}


 88%|████████▊ | 1150/1304 [12:42<02:42,  1.05s/it]

{'loss': 0.2086, 'grad_norm': 3.883356809616089, 'learning_rate': 5.490196078431373e-05, 'epoch': 7.03}


 89%|████████▉ | 1160/1304 [12:49<01:32,  1.56it/s]

{'loss': 0.1947, 'grad_norm': 2.3230090141296387, 'learning_rate': 5.14919011082694e-05, 'epoch': 7.09}


 90%|████████▉ | 1170/1304 [12:55<01:24,  1.58it/s]

{'loss': 0.2425, 'grad_norm': 4.033791542053223, 'learning_rate': 4.8081841432225067e-05, 'epoch': 7.16}


 90%|█████████ | 1180/1304 [13:01<01:18,  1.57it/s]

{'loss': 0.2287, 'grad_norm': 3.8072361946105957, 'learning_rate': 4.467178175618074e-05, 'epoch': 7.22}


 91%|█████████▏| 1190/1304 [13:08<01:12,  1.58it/s]

{'loss': 0.2126, 'grad_norm': 4.042374610900879, 'learning_rate': 4.12617220801364e-05, 'epoch': 7.28}


 92%|█████████▏| 1200/1304 [13:14<01:05,  1.58it/s]

{'loss': 0.1755, 'grad_norm': 2.1140122413635254, 'learning_rate': 3.7851662404092075e-05, 'epoch': 7.34}


 93%|█████████▎| 1210/1304 [13:20<00:59,  1.58it/s]

{'loss': 0.1845, 'grad_norm': 3.680279493331909, 'learning_rate': 3.444160272804775e-05, 'epoch': 7.4}


 94%|█████████▎| 1220/1304 [13:27<00:53,  1.58it/s]

{'loss': 0.1996, 'grad_norm': 4.736300945281982, 'learning_rate': 3.103154305200341e-05, 'epoch': 7.46}


 94%|█████████▍| 1230/1304 [13:33<00:46,  1.58it/s]

{'loss': 0.1739, 'grad_norm': 2.5087039470672607, 'learning_rate': 2.7621483375959077e-05, 'epoch': 7.52}


 95%|█████████▌| 1240/1304 [13:39<00:40,  1.59it/s]

{'loss': 0.2113, 'grad_norm': 3.657776117324829, 'learning_rate': 2.4211423699914752e-05, 'epoch': 7.58}


 96%|█████████▌| 1250/1304 [13:46<00:33,  1.59it/s]

{'loss': 0.1943, 'grad_norm': 6.767022132873535, 'learning_rate': 2.0801364023870417e-05, 'epoch': 7.65}


 97%|█████████▋| 1260/1304 [13:52<00:27,  1.59it/s]

{'loss': 0.2174, 'grad_norm': 3.0719258785247803, 'learning_rate': 1.739130434782609e-05, 'epoch': 7.71}


 97%|█████████▋| 1270/1304 [13:58<00:21,  1.58it/s]

{'loss': 0.1843, 'grad_norm': 5.079818248748779, 'learning_rate': 1.3981244671781757e-05, 'epoch': 7.77}


 98%|█████████▊| 1280/1304 [14:04<00:15,  1.58it/s]

{'loss': 0.2195, 'grad_norm': 2.30505633354187, 'learning_rate': 1.0571184995737427e-05, 'epoch': 7.83}


 99%|█████████▉| 1290/1304 [14:11<00:08,  1.59it/s]

{'loss': 0.1685, 'grad_norm': 4.53591251373291, 'learning_rate': 7.161125319693095e-06, 'epoch': 7.89}


100%|█████████▉| 1300/1304 [14:17<00:02,  1.58it/s]

{'loss': 0.1993, 'grad_norm': 4.31991720199585, 'learning_rate': 3.751065643648764e-06, 'epoch': 7.95}


100%|██████████| 1304/1304 [14:20<00:00,  1.58it/s]
100%|██████████| 1304/1304 [14:27<00:00,  1.58it/s]

{'eval_loss': 0.09926186501979828, 'eval_accuracy': 0.9734022196708764, 'eval_runtime': 7.5186, 'eval_samples_per_second': 695.08, 'eval_steps_per_second': 21.813, 'epoch': 7.98}


100%|██████████| 1304/1304 [14:28<00:00,  1.50it/s]

{'train_runtime': 868.8244, 'train_samples_per_second': 192.444, 'train_steps_per_second': 1.501, 'train_loss': 0.5409899705026778, 'epoch': 7.98}





TrainOutput(global_step=1304, training_loss=0.5409899705026778, metrics={'train_runtime': 868.8244, 'train_samples_per_second': 192.444, 'train_steps_per_second': 1.501, 'total_flos': 1.51359445138176e+18, 'train_loss': 0.5409899705026778, 'epoch': 7.9755351681957185})

## MIT/ast-finetuned-speech-commands-v2

In [5]:
feature_extractor = AutoFeatureExtractor.from_pretrained(
    "MIT/ast-finetuned-speech-commands-v2")

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
      audio_arrays,
      sampling_rate=feature_extractor.sampling_rate,
      max_length=16000, truncation=True
    )
    return inputs
 
 
encoded_dataset = dataset.map(
    preprocess_function, remove_columns="audio", batched=True)
# encoded_dataset = encoded_dataset.rename_column("intent_class", "label")

encoded_dataset['train']

accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

import numpy as np

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-speech-commands-v2", num_labels=num_labels, 
  label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True
)
 
training_args = TrainingArguments(
    output_dir="MIT_commands",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    # increase this if you have better machine resources
    learning_rate=4e-4,
    # increase this if you have better machine resources
    per_device_train_batch_size=32,
    gradient_accumulation_steps=4,
    # increase this if you have better machine resources
    per_device_eval_batch_size=32,
    # increase this if you have better machine resources
    num_train_epochs=8,
    warmup_ratio=0.1,
    # increase this if you have better machine resources
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,                             # set to False if you don't have GPU
)
 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"].with_format("torch"),
    eval_dataset=encoded_dataset["test"].with_format("torch"),
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)
 
trainer.train()

Map: 100%|██████████| 20900/20900 [01:41<00:00, 206.50 examples/s]
Map: 100%|██████████| 5226/5226 [00:09<00:00, 523.09 examples/s]
Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-speech-commands-v2 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([35]) in the checkpoint and torch.Size([12]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([35, 768]) in the checkpoint and torch.Size([12, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 10/1304 [00:08<16:22,  1.32it/s]

{'loss': 2.4709, 'grad_norm': 6.026375770568848, 'learning_rate': 3.053435114503817e-05, 'epoch': 0.06}


  2%|▏         | 20/1304 [00:15<16:06,  1.33it/s]

{'loss': 1.2982, 'grad_norm': 3.5650546550750732, 'learning_rate': 6.106870229007635e-05, 'epoch': 0.12}


  2%|▏         | 30/1304 [00:23<16:02,  1.32it/s]

{'loss': 0.2363, 'grad_norm': 0.918498694896698, 'learning_rate': 9.160305343511451e-05, 'epoch': 0.18}


  3%|▎         | 40/1304 [00:30<15:52,  1.33it/s]

{'loss': 0.069, 'grad_norm': 2.382890462875366, 'learning_rate': 0.0001221374045801527, 'epoch': 0.24}


  4%|▍         | 50/1304 [00:38<15:41,  1.33it/s]

{'loss': 0.0929, 'grad_norm': 1.6487853527069092, 'learning_rate': 0.00015267175572519084, 'epoch': 0.31}


  5%|▍         | 60/1304 [00:45<15:37,  1.33it/s]

{'loss': 0.0968, 'grad_norm': 1.6675211191177368, 'learning_rate': 0.00018320610687022902, 'epoch': 0.37}


  5%|▌         | 70/1304 [00:53<15:31,  1.32it/s]

{'loss': 0.0729, 'grad_norm': 1.7442679405212402, 'learning_rate': 0.00021374045801526718, 'epoch': 0.43}


  6%|▌         | 80/1304 [01:00<15:26,  1.32it/s]

{'loss': 0.0939, 'grad_norm': 2.2054309844970703, 'learning_rate': 0.0002442748091603054, 'epoch': 0.49}


  7%|▋         | 90/1304 [01:08<15:18,  1.32it/s]

{'loss': 0.1028, 'grad_norm': 0.684219241142273, 'learning_rate': 0.00027480916030534353, 'epoch': 0.55}


  8%|▊         | 100/1304 [01:16<15:13,  1.32it/s]

{'loss': 0.0788, 'grad_norm': 2.0869476795196533, 'learning_rate': 0.0003053435114503817, 'epoch': 0.61}


  8%|▊         | 110/1304 [01:23<15:02,  1.32it/s]

{'loss': 0.1233, 'grad_norm': 0.823319673538208, 'learning_rate': 0.00033587786259541984, 'epoch': 0.67}


  9%|▉         | 120/1304 [01:31<14:57,  1.32it/s]

{'loss': 0.0944, 'grad_norm': 2.0469186305999756, 'learning_rate': 0.00036641221374045805, 'epoch': 0.73}


 10%|▉         | 130/1304 [01:38<14:52,  1.31it/s]

{'loss': 0.1527, 'grad_norm': 2.013601541519165, 'learning_rate': 0.0003969465648854962, 'epoch': 0.8}


 11%|█         | 140/1304 [01:46<14:51,  1.31it/s]

{'loss': 0.1243, 'grad_norm': 2.2881245613098145, 'learning_rate': 0.0003969309462915601, 'epoch': 0.86}


 12%|█▏        | 150/1304 [01:54<14:39,  1.31it/s]

{'loss': 0.1561, 'grad_norm': 1.92942476272583, 'learning_rate': 0.0003935208866155158, 'epoch': 0.92}


 12%|█▏        | 160/1304 [02:01<14:35,  1.31it/s]

{'loss': 0.1271, 'grad_norm': 1.8339825868606567, 'learning_rate': 0.0003901108269394715, 'epoch': 0.98}


                                                  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.16473330557346344, 'eval_accuracy': 0.9510141599693839, 'eval_runtime': 15.299, 'eval_samples_per_second': 341.591, 'eval_steps_per_second': 10.72, 'epoch': 1.0}


 13%|█▎        | 170/1304 [02:25<25:15,  1.34s/it]  

{'loss': 0.1315, 'grad_norm': 2.3398289680480957, 'learning_rate': 0.00038670076726342713, 'epoch': 1.04}


 14%|█▍        | 180/1304 [02:33<14:37,  1.28it/s]

{'loss': 0.081, 'grad_norm': 1.4013766050338745, 'learning_rate': 0.0003832907075873828, 'epoch': 1.1}


 15%|█▍        | 190/1304 [02:41<14:24,  1.29it/s]

{'loss': 0.1379, 'grad_norm': 1.0443801879882812, 'learning_rate': 0.0003798806479113385, 'epoch': 1.16}


 15%|█▌        | 200/1304 [02:48<14:22,  1.28it/s]

{'loss': 0.1282, 'grad_norm': 1.9418104887008667, 'learning_rate': 0.00037647058823529414, 'epoch': 1.22}


 16%|█▌        | 210/1304 [02:56<14:14,  1.28it/s]

{'loss': 0.1553, 'grad_norm': 1.4076571464538574, 'learning_rate': 0.0003730605285592498, 'epoch': 1.28}


 17%|█▋        | 220/1304 [03:04<14:12,  1.27it/s]

{'loss': 0.0871, 'grad_norm': 0.5111271739006042, 'learning_rate': 0.0003696504688832055, 'epoch': 1.35}


 18%|█▊        | 230/1304 [03:12<14:02,  1.28it/s]

{'loss': 0.1106, 'grad_norm': 1.1947598457336426, 'learning_rate': 0.00036624040920716115, 'epoch': 1.41}


 18%|█▊        | 240/1304 [03:20<14:00,  1.27it/s]

{'loss': 0.0979, 'grad_norm': 0.9589504599571228, 'learning_rate': 0.0003628303495311168, 'epoch': 1.47}


 19%|█▉        | 250/1304 [03:28<13:44,  1.28it/s]

{'loss': 0.123, 'grad_norm': 0.7867357134819031, 'learning_rate': 0.0003594202898550725, 'epoch': 1.53}


 20%|█▉        | 260/1304 [03:36<13:46,  1.26it/s]

{'loss': 0.1029, 'grad_norm': 1.5050129890441895, 'learning_rate': 0.00035601023017902816, 'epoch': 1.59}


 21%|██        | 270/1304 [03:44<13:40,  1.26it/s]

{'loss': 0.1043, 'grad_norm': 1.186891794204712, 'learning_rate': 0.0003526001705029838, 'epoch': 1.65}


 21%|██▏       | 280/1304 [03:52<13:37,  1.25it/s]

{'loss': 0.1419, 'grad_norm': 1.4849789142608643, 'learning_rate': 0.0003491901108269395, 'epoch': 1.71}


 22%|██▏       | 290/1304 [03:59<13:24,  1.26it/s]

{'loss': 0.1551, 'grad_norm': 1.437150478363037, 'learning_rate': 0.00034578005115089516, 'epoch': 1.77}


 23%|██▎       | 300/1304 [04:07<13:07,  1.27it/s]

{'loss': 0.0905, 'grad_norm': 0.5041532516479492, 'learning_rate': 0.0003423699914748508, 'epoch': 1.83}


 24%|██▍       | 310/1304 [04:15<13:02,  1.27it/s]

{'loss': 0.1108, 'grad_norm': 0.7751504182815552, 'learning_rate': 0.0003389599317988065, 'epoch': 1.9}


 25%|██▍       | 320/1304 [04:23<13:05,  1.25it/s]

{'loss': 0.0774, 'grad_norm': 1.200866460800171, 'learning_rate': 0.0003355498721227622, 'epoch': 1.96}


                                                  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.15112832188606262, 'eval_accuracy': 0.9542671259089169, 'eval_runtime': 15.6088, 'eval_samples_per_second': 334.812, 'eval_steps_per_second': 10.507, 'epoch': 2.0}


 25%|██▌       | 330/1304 [04:48<52:06,  3.21s/it]  

{'loss': 0.0976, 'grad_norm': 1.1237363815307617, 'learning_rate': 0.0003321398124467178, 'epoch': 2.02}


 26%|██▌       | 340/1304 [04:56<13:53,  1.16it/s]

{'loss': 0.084, 'grad_norm': 0.3835541009902954, 'learning_rate': 0.00032872975277067353, 'epoch': 2.08}


 27%|██▋       | 350/1304 [05:03<12:39,  1.26it/s]

{'loss': 0.0815, 'grad_norm': 1.2849149703979492, 'learning_rate': 0.0003253196930946292, 'epoch': 2.14}


 28%|██▊       | 360/1304 [05:12<12:41,  1.24it/s]

{'loss': 0.0585, 'grad_norm': 1.0134010314941406, 'learning_rate': 0.00032190963341858483, 'epoch': 2.2}


 28%|██▊       | 370/1304 [05:20<13:07,  1.19it/s]

{'loss': 0.0802, 'grad_norm': 0.8658313155174255, 'learning_rate': 0.00031849957374254054, 'epoch': 2.26}


 29%|██▉       | 380/1304 [05:28<13:14,  1.16it/s]

{'loss': 0.0776, 'grad_norm': 0.8748392462730408, 'learning_rate': 0.0003150895140664962, 'epoch': 2.32}


 30%|██▉       | 390/1304 [05:36<12:27,  1.22it/s]

{'loss': 0.0831, 'grad_norm': 1.147264838218689, 'learning_rate': 0.00031167945439045184, 'epoch': 2.39}


 31%|███       | 400/1304 [05:45<13:38,  1.10it/s]

{'loss': 0.081, 'grad_norm': 0.5997412204742432, 'learning_rate': 0.00030826939471440755, 'epoch': 2.45}


 31%|███▏      | 410/1304 [05:54<12:22,  1.20it/s]

{'loss': 0.0611, 'grad_norm': 1.3958579301834106, 'learning_rate': 0.0003048593350383632, 'epoch': 2.51}


 32%|███▏      | 420/1304 [06:02<11:41,  1.26it/s]

{'loss': 0.0665, 'grad_norm': 0.8174690008163452, 'learning_rate': 0.00030144927536231885, 'epoch': 2.57}


 33%|███▎      | 430/1304 [06:10<11:32,  1.26it/s]

{'loss': 0.0672, 'grad_norm': 0.6136153340339661, 'learning_rate': 0.00029803921568627456, 'epoch': 2.63}


 34%|███▎      | 440/1304 [06:18<11:38,  1.24it/s]

{'loss': 0.0599, 'grad_norm': 0.949209988117218, 'learning_rate': 0.0002946291560102302, 'epoch': 2.69}


 35%|███▍      | 450/1304 [06:26<11:31,  1.24it/s]

{'loss': 0.0837, 'grad_norm': 0.6345248222351074, 'learning_rate': 0.00029121909633418586, 'epoch': 2.75}


 35%|███▌      | 460/1304 [06:34<11:17,  1.25it/s]

{'loss': 0.0523, 'grad_norm': 0.3561273217201233, 'learning_rate': 0.00028780903665814156, 'epoch': 2.81}


 36%|███▌      | 470/1304 [06:42<11:42,  1.19it/s]

{'loss': 0.0745, 'grad_norm': 1.4161969423294067, 'learning_rate': 0.0002843989769820972, 'epoch': 2.87}


 37%|███▋      | 480/1304 [06:51<11:38,  1.18it/s]

{'loss': 0.077, 'grad_norm': 0.3695119023323059, 'learning_rate': 0.00028098891730605287, 'epoch': 2.94}


 38%|███▊      | 490/1304 [06:59<11:06,  1.22it/s]

{'loss': 0.0814, 'grad_norm': 1.049297571182251, 'learning_rate': 0.0002775788576300086, 'epoch': 3.0}


                                                  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.0832609087228775, 'eval_accuracy': 0.972636815920398, 'eval_runtime': 15.9345, 'eval_samples_per_second': 327.967, 'eval_steps_per_second': 10.292, 'epoch': 3.0}


 38%|███▊      | 500/1304 [07:24<13:38,  1.02s/it]  

{'loss': 0.0615, 'grad_norm': 0.7107296586036682, 'learning_rate': 0.0002741687979539642, 'epoch': 3.06}


 39%|███▉      | 510/1304 [07:32<10:54,  1.21it/s]

{'loss': 0.0524, 'grad_norm': 0.936795711517334, 'learning_rate': 0.0002707587382779199, 'epoch': 3.12}


 40%|███▉      | 520/1304 [07:41<10:37,  1.23it/s]

{'loss': 0.0387, 'grad_norm': 0.4406301975250244, 'learning_rate': 0.0002673486786018756, 'epoch': 3.18}


 41%|████      | 530/1304 [07:49<10:29,  1.23it/s]

{'loss': 0.0679, 'grad_norm': 0.9288269877433777, 'learning_rate': 0.00026393861892583123, 'epoch': 3.24}


 41%|████▏     | 540/1304 [07:57<10:34,  1.20it/s]

{'loss': 0.0631, 'grad_norm': 0.5546150207519531, 'learning_rate': 0.0002605285592497869, 'epoch': 3.3}


 42%|████▏     | 550/1304 [08:05<10:19,  1.22it/s]

{'loss': 0.0446, 'grad_norm': 0.33928579092025757, 'learning_rate': 0.0002571184995737426, 'epoch': 3.36}


 43%|████▎     | 560/1304 [08:13<10:05,  1.23it/s]

{'loss': 0.0496, 'grad_norm': 0.6670635938644409, 'learning_rate': 0.00025370843989769824, 'epoch': 3.43}


 44%|████▎     | 570/1304 [08:21<09:54,  1.23it/s]

{'loss': 0.0573, 'grad_norm': 0.5972135663032532, 'learning_rate': 0.0002502983802216539, 'epoch': 3.49}


 44%|████▍     | 580/1304 [08:30<09:53,  1.22it/s]

{'loss': 0.0563, 'grad_norm': 1.2015042304992676, 'learning_rate': 0.0002468883205456096, 'epoch': 3.55}


 45%|████▌     | 590/1304 [08:38<09:44,  1.22it/s]

{'loss': 0.057, 'grad_norm': 0.7528542280197144, 'learning_rate': 0.00024347826086956525, 'epoch': 3.61}


 46%|████▌     | 600/1304 [08:46<09:36,  1.22it/s]

{'loss': 0.0412, 'grad_norm': 0.4962770342826843, 'learning_rate': 0.0002400682011935209, 'epoch': 3.67}


 47%|████▋     | 610/1304 [08:54<09:28,  1.22it/s]

{'loss': 0.0611, 'grad_norm': 0.5189045667648315, 'learning_rate': 0.00023665814151747658, 'epoch': 3.73}


 48%|████▊     | 620/1304 [09:02<09:21,  1.22it/s]

{'loss': 0.0451, 'grad_norm': 0.6187884211540222, 'learning_rate': 0.00023324808184143226, 'epoch': 3.79}


 48%|████▊     | 630/1304 [09:11<09:05,  1.24it/s]

{'loss': 0.0731, 'grad_norm': 0.4960116744041443, 'learning_rate': 0.0002298380221653879, 'epoch': 3.85}


 49%|████▉     | 640/1304 [09:19<09:05,  1.22it/s]

{'loss': 0.0541, 'grad_norm': 0.28665846586227417, 'learning_rate': 0.0002264279624893436, 'epoch': 3.91}


 50%|████▉     | 650/1304 [09:27<08:59,  1.21it/s]

{'loss': 0.0466, 'grad_norm': 0.5251190662384033, 'learning_rate': 0.00022301790281329927, 'epoch': 3.98}


                                                  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.09427358210086823, 'eval_accuracy': 0.9705319556065825, 'eval_runtime': 15.9317, 'eval_samples_per_second': 328.025, 'eval_steps_per_second': 10.294, 'epoch': 4.0}


 51%|█████     | 660/1304 [09:52<17:35,  1.64s/it]  

{'loss': 0.0511, 'grad_norm': 0.5115259885787964, 'learning_rate': 0.00021960784313725492, 'epoch': 4.04}


 51%|█████▏    | 670/1304 [10:00<08:49,  1.20it/s]

{'loss': 0.0454, 'grad_norm': 0.5934010744094849, 'learning_rate': 0.0002161977834612106, 'epoch': 4.1}


 52%|█████▏    | 680/1304 [10:08<08:30,  1.22it/s]

{'loss': 0.0279, 'grad_norm': 0.2592496871948242, 'learning_rate': 0.00021278772378516628, 'epoch': 4.16}


 53%|█████▎    | 690/1304 [10:16<08:18,  1.23it/s]

{'loss': 0.0349, 'grad_norm': 0.8304488062858582, 'learning_rate': 0.00020937766410912193, 'epoch': 4.22}


 54%|█████▎    | 700/1304 [10:24<07:58,  1.26it/s]

{'loss': 0.0349, 'grad_norm': 0.3014427125453949, 'learning_rate': 0.0002059676044330776, 'epoch': 4.28}


 54%|█████▍    | 710/1304 [10:32<08:06,  1.22it/s]

{'loss': 0.0493, 'grad_norm': 0.4087807834148407, 'learning_rate': 0.00020255754475703328, 'epoch': 4.34}


 55%|█████▌    | 720/1304 [10:40<07:56,  1.23it/s]

{'loss': 0.0447, 'grad_norm': 0.5309171676635742, 'learning_rate': 0.00019914748508098894, 'epoch': 4.4}


 56%|█████▌    | 730/1304 [10:49<07:49,  1.22it/s]

{'loss': 0.0286, 'grad_norm': 0.008779837749898434, 'learning_rate': 0.0001957374254049446, 'epoch': 4.46}


 57%|█████▋    | 740/1304 [10:57<07:41,  1.22it/s]

{'loss': 0.0256, 'grad_norm': 0.02799318917095661, 'learning_rate': 0.00019232736572890027, 'epoch': 4.53}


 58%|█████▊    | 750/1304 [11:05<07:30,  1.23it/s]

{'loss': 0.039, 'grad_norm': 1.039990782737732, 'learning_rate': 0.00018891730605285594, 'epoch': 4.59}


 58%|█████▊    | 760/1304 [11:13<07:25,  1.22it/s]

{'loss': 0.0472, 'grad_norm': 0.7903344631195068, 'learning_rate': 0.0001855072463768116, 'epoch': 4.65}


 59%|█████▉    | 770/1304 [11:21<07:18,  1.22it/s]

{'loss': 0.032, 'grad_norm': 0.9351637363433838, 'learning_rate': 0.00018209718670076727, 'epoch': 4.71}


 60%|█████▉    | 780/1304 [11:30<07:09,  1.22it/s]

{'loss': 0.0432, 'grad_norm': 0.908443033695221, 'learning_rate': 0.00017868712702472295, 'epoch': 4.77}


 61%|██████    | 790/1304 [11:38<07:01,  1.22it/s]

{'loss': 0.0411, 'grad_norm': 0.0818815678358078, 'learning_rate': 0.0001752770673486786, 'epoch': 4.83}


 61%|██████▏   | 800/1304 [11:46<06:53,  1.22it/s]

{'loss': 0.0331, 'grad_norm': 0.35225236415863037, 'learning_rate': 0.00017186700767263428, 'epoch': 4.89}


 62%|██████▏   | 810/1304 [11:54<06:40,  1.23it/s]

{'loss': 0.0534, 'grad_norm': 0.34111687541007996, 'learning_rate': 0.00016845694799658996, 'epoch': 4.95}


                                                  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.09301015734672546, 'eval_accuracy': 0.9735935706084959, 'eval_runtime': 15.9686, 'eval_samples_per_second': 327.268, 'eval_steps_per_second': 10.27, 'epoch': 5.0}


 63%|██████▎   | 820/1304 [12:19<26:24,  3.27s/it]

{'loss': 0.0281, 'grad_norm': 0.5482175350189209, 'learning_rate': 0.0001650468883205456, 'epoch': 5.02}


 64%|██████▎   | 830/1304 [12:27<06:59,  1.13it/s]

{'loss': 0.0263, 'grad_norm': 0.1902717649936676, 'learning_rate': 0.0001616368286445013, 'epoch': 5.08}


 64%|██████▍   | 840/1304 [12:36<06:25,  1.20it/s]

{'loss': 0.0286, 'grad_norm': 0.6426795721054077, 'learning_rate': 0.00015822676896845697, 'epoch': 5.14}


 65%|██████▌   | 850/1304 [12:44<06:10,  1.23it/s]

{'loss': 0.0223, 'grad_norm': 0.5036693215370178, 'learning_rate': 0.00015481670929241262, 'epoch': 5.2}


 66%|██████▌   | 860/1304 [12:52<06:00,  1.23it/s]

{'loss': 0.0402, 'grad_norm': 0.5274649858474731, 'learning_rate': 0.0001514066496163683, 'epoch': 5.26}


 67%|██████▋   | 870/1304 [13:00<06:00,  1.20it/s]

{'loss': 0.0304, 'grad_norm': 0.1329706609249115, 'learning_rate': 0.00014799658994032398, 'epoch': 5.32}


 67%|██████▋   | 880/1304 [13:08<05:44,  1.23it/s]

{'loss': 0.0273, 'grad_norm': 0.6670853495597839, 'learning_rate': 0.00014458653026427963, 'epoch': 5.38}


 68%|██████▊   | 890/1304 [13:17<05:34,  1.24it/s]

{'loss': 0.043, 'grad_norm': 0.45168766379356384, 'learning_rate': 0.0001411764705882353, 'epoch': 5.44}


 69%|██████▉   | 900/1304 [13:25<05:29,  1.23it/s]

{'loss': 0.0209, 'grad_norm': 0.4627516269683838, 'learning_rate': 0.000137766410912191, 'epoch': 5.5}


 70%|██████▉   | 910/1304 [13:33<05:22,  1.22it/s]

{'loss': 0.0291, 'grad_norm': 0.1890954226255417, 'learning_rate': 0.00013435635123614664, 'epoch': 5.57}


 71%|███████   | 920/1304 [13:41<05:15,  1.22it/s]

{'loss': 0.0242, 'grad_norm': 0.2721240520477295, 'learning_rate': 0.00013094629156010232, 'epoch': 5.63}


 71%|███████▏  | 930/1304 [13:49<05:02,  1.24it/s]

{'loss': 0.0242, 'grad_norm': 0.37288185954093933, 'learning_rate': 0.00012753623188405797, 'epoch': 5.69}


 72%|███████▏  | 940/1304 [13:57<05:07,  1.18it/s]

{'loss': 0.0128, 'grad_norm': 0.2147819846868515, 'learning_rate': 0.00012412617220801365, 'epoch': 5.75}


 73%|███████▎  | 950/1304 [14:06<04:48,  1.23it/s]

{'loss': 0.0237, 'grad_norm': 0.2559977173805237, 'learning_rate': 0.0001207161125319693, 'epoch': 5.81}


 74%|███████▎  | 960/1304 [14:14<04:39,  1.23it/s]

{'loss': 0.013, 'grad_norm': 0.0037765144370496273, 'learning_rate': 0.00011730605285592498, 'epoch': 5.87}


 74%|███████▍  | 970/1304 [14:22<04:30,  1.23it/s]

{'loss': 0.0222, 'grad_norm': 0.004781922325491905, 'learning_rate': 0.00011389599317988064, 'epoch': 5.93}


 75%|███████▌  | 980/1304 [14:30<04:24,  1.23it/s]

{'loss': 0.0242, 'grad_norm': 1.3823648691177368, 'learning_rate': 0.00011048593350383631, 'epoch': 5.99}


                                                  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.09710565209388733, 'eval_accuracy': 0.9743589743589743, 'eval_runtime': 16.0648, 'eval_samples_per_second': 325.308, 'eval_steps_per_second': 10.209, 'epoch': 6.0}


 76%|███████▌  | 990/1304 [14:55<05:46,  1.10s/it]

{'loss': 0.0259, 'grad_norm': 0.3000414967536926, 'learning_rate': 0.00010707587382779199, 'epoch': 6.06}


 77%|███████▋  | 1000/1304 [15:03<04:12,  1.21it/s]

{'loss': 0.0318, 'grad_norm': 0.011007709428668022, 'learning_rate': 0.00010366581415174765, 'epoch': 6.12}


 77%|███████▋  | 1010/1304 [15:11<04:01,  1.22it/s]

{'loss': 0.0096, 'grad_norm': 0.21152661740779877, 'learning_rate': 0.00010025575447570333, 'epoch': 6.18}


 78%|███████▊  | 1020/1304 [15:20<03:55,  1.21it/s]

{'loss': 0.0071, 'grad_norm': 0.014505675062537193, 'learning_rate': 9.684569479965901e-05, 'epoch': 6.24}


 79%|███████▉  | 1030/1304 [15:28<03:47,  1.20it/s]

{'loss': 0.0128, 'grad_norm': 0.0098262969404459, 'learning_rate': 9.343563512361467e-05, 'epoch': 6.3}


 80%|███████▉  | 1040/1304 [15:36<03:36,  1.22it/s]

{'loss': 0.0123, 'grad_norm': 0.18752741813659668, 'learning_rate': 9.002557544757034e-05, 'epoch': 6.36}


 81%|████████  | 1050/1304 [15:44<03:28,  1.22it/s]

{'loss': 0.0142, 'grad_norm': 0.0031284287106245756, 'learning_rate': 8.6615515771526e-05, 'epoch': 6.42}


 81%|████████▏ | 1060/1304 [15:53<03:19,  1.22it/s]

{'loss': 0.0075, 'grad_norm': 0.1688692718744278, 'learning_rate': 8.320545609548167e-05, 'epoch': 6.48}


 82%|████████▏ | 1070/1304 [16:01<03:08,  1.24it/s]

{'loss': 0.0141, 'grad_norm': 0.09615769237279892, 'learning_rate': 7.979539641943735e-05, 'epoch': 6.54}


 83%|████████▎ | 1080/1304 [16:09<03:03,  1.22it/s]

{'loss': 0.012, 'grad_norm': 0.06789079308509827, 'learning_rate': 7.638533674339301e-05, 'epoch': 6.61}


 84%|████████▎ | 1090/1304 [16:17<02:59,  1.19it/s]

{'loss': 0.0144, 'grad_norm': 0.0036591547541320324, 'learning_rate': 7.297527706734868e-05, 'epoch': 6.67}


 84%|████████▍ | 1100/1304 [16:25<02:45,  1.23it/s]

{'loss': 0.0151, 'grad_norm': 0.2727125883102417, 'learning_rate': 6.956521739130436e-05, 'epoch': 6.73}


 85%|████████▌ | 1110/1304 [16:34<02:37,  1.23it/s]

{'loss': 0.0148, 'grad_norm': 0.004321764688938856, 'learning_rate': 6.615515771526002e-05, 'epoch': 6.79}


 86%|████████▌ | 1120/1304 [16:42<02:36,  1.18it/s]

{'loss': 0.0193, 'grad_norm': 0.3358572721481323, 'learning_rate': 6.274509803921569e-05, 'epoch': 6.85}


 87%|████████▋ | 1130/1304 [16:50<02:24,  1.20it/s]

{'loss': 0.0089, 'grad_norm': 0.37842443585395813, 'learning_rate': 5.933503836317136e-05, 'epoch': 6.91}


 87%|████████▋ | 1140/1304 [16:59<02:13,  1.22it/s]

{'loss': 0.0057, 'grad_norm': 0.050189122557640076, 'learning_rate': 5.592497868712703e-05, 'epoch': 6.97}


                                                   
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.09827478975057602, 'eval_accuracy': 0.9770378874856487, 'eval_runtime': 15.9578, 'eval_samples_per_second': 327.488, 'eval_steps_per_second': 10.277, 'epoch': 7.0}


 88%|████████▊ | 1150/1304 [17:24<04:14,  1.65s/it]

{'loss': 0.0086, 'grad_norm': 0.29653942584991455, 'learning_rate': 5.2514919011082694e-05, 'epoch': 7.03}


 89%|████████▉ | 1160/1304 [17:32<01:59,  1.21it/s]

{'loss': 0.0062, 'grad_norm': 0.0007053431472741067, 'learning_rate': 4.9104859335038366e-05, 'epoch': 7.09}


 90%|████████▉ | 1170/1304 [17:40<01:50,  1.21it/s]

{'loss': 0.0078, 'grad_norm': 0.2725358307361603, 'learning_rate': 4.569479965899404e-05, 'epoch': 7.16}


 90%|█████████ | 1180/1304 [17:48<01:42,  1.21it/s]

{'loss': 0.0078, 'grad_norm': 0.004369067959487438, 'learning_rate': 4.228473998294971e-05, 'epoch': 7.22}


 91%|█████████▏| 1190/1304 [17:56<01:34,  1.21it/s]

{'loss': 0.0136, 'grad_norm': 0.40845006704330444, 'learning_rate': 3.887468030690537e-05, 'epoch': 7.28}


 92%|█████████▏| 1200/1304 [18:04<01:25,  1.22it/s]

{'loss': 0.0016, 'grad_norm': 0.16378386318683624, 'learning_rate': 3.546462063086104e-05, 'epoch': 7.34}


 93%|█████████▎| 1210/1304 [18:13<01:17,  1.21it/s]

{'loss': 0.0031, 'grad_norm': 0.1372034251689911, 'learning_rate': 3.205456095481671e-05, 'epoch': 7.4}


 94%|█████████▎| 1220/1304 [18:21<01:09,  1.22it/s]

{'loss': 0.0077, 'grad_norm': 0.2382117360830307, 'learning_rate': 2.864450127877238e-05, 'epoch': 7.46}


 94%|█████████▍| 1230/1304 [18:29<01:00,  1.22it/s]

{'loss': 0.005, 'grad_norm': 0.407755583524704, 'learning_rate': 2.5234441602728048e-05, 'epoch': 7.52}


 95%|█████████▌| 1240/1304 [18:37<00:51,  1.24it/s]

{'loss': 0.0118, 'grad_norm': 0.16313298046588898, 'learning_rate': 2.182438192668372e-05, 'epoch': 7.58}


 96%|█████████▌| 1250/1304 [18:45<00:44,  1.22it/s]

{'loss': 0.0077, 'grad_norm': 0.38596516847610474, 'learning_rate': 1.8414322250639388e-05, 'epoch': 7.65}


 97%|█████████▋| 1260/1304 [18:54<00:36,  1.22it/s]

{'loss': 0.0061, 'grad_norm': 0.002409356413409114, 'learning_rate': 1.5004262574595056e-05, 'epoch': 7.71}


 97%|█████████▋| 1270/1304 [19:02<00:28,  1.21it/s]

{'loss': 0.0034, 'grad_norm': 0.015156571753323078, 'learning_rate': 1.1594202898550725e-05, 'epoch': 7.77}


 98%|█████████▊| 1280/1304 [19:10<00:19,  1.23it/s]

{'loss': 0.0121, 'grad_norm': 0.2210005223751068, 'learning_rate': 8.184143222506395e-06, 'epoch': 7.83}


 99%|█████████▉| 1290/1304 [19:18<00:11,  1.20it/s]

{'loss': 0.0057, 'grad_norm': 0.007655323948711157, 'learning_rate': 4.774083546462063e-06, 'epoch': 7.89}


100%|█████████▉| 1300/1304 [19:27<00:03,  1.22it/s]

{'loss': 0.0098, 'grad_norm': 0.004138079471886158, 'learning_rate': 1.3640238704177325e-06, 'epoch': 7.95}


                                                   
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.10480619221925735, 'eval_accuracy': 0.9770378874856487, 'eval_runtime': 16.462, 'eval_samples_per_second': 317.459, 'eval_steps_per_second': 9.962, 'epoch': 7.98}


100%|██████████| 1304/1304 [19:47<00:00,  1.10it/s]

{'train_runtime': 1187.9624, 'train_samples_per_second': 140.745, 'train_steps_per_second': 1.098, 'train_loss': 0.0826756409183716, 'epoch': 7.98}





TrainOutput(global_step=1304, training_loss=0.0826756409183716, metrics={'train_runtime': 1187.9624, 'train_samples_per_second': 140.745, 'train_steps_per_second': 1.098, 'total_flos': 1.3992439943984579e+18, 'train_loss': 0.0826756409183716, 'epoch': 7.9755351681957185})

# Submission preparation

In [8]:
# audio_file='../data/test/test/audio/clip_000a96d0a.wav'
import csv

classifier = pipeline("audio-classification", model="MIT_commands/checkpoint-1304")

def predict(audio_file, classifier):
    prediction = classifier(audio_file)
    if len(prediction) == 0:
        print(f'no prediction for {audio_file}')
        return 'unknown'
    return prediction[0]['label']


submission_df = pd.DataFrame({'fname': [], 'label': []})
classifier = pipeline("audio-classification", model="MIT_commands/checkpoint-1304")
submission_csv_file = 'MIT_commands/' + 'checkpoint-1304' + '/' + str(0) + '_submission.csv'

for ind, test_file in enumerate(X_files):
    if ind%2000 == 0:
        print("{} done!".format(ind))
    label_pred = predict(test_dir + '/audio/' + test_file, classifier)
    with open(submission_csv_file, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([X_files[ind], label_pred])

print('Submission saved.')

0 done!
2000 done!
4000 done!
6000 done!
8000 done!
10000 done!
12000 done!
14000 done!
16000 done!
18000 done!
20000 done!
22000 done!
24000 done!
26000 done!
28000 done!
30000 done!
32000 done!
34000 done!
36000 done!
38000 done!
40000 done!
42000 done!
44000 done!
46000 done!
48000 done!
50000 done!
52000 done!
54000 done!
56000 done!
58000 done!
60000 done!
62000 done!
64000 done!
66000 done!
68000 done!
70000 done!
72000 done!
74000 done!
76000 done!
78000 done!
80000 done!
82000 done!
84000 done!
86000 done!
88000 done!
90000 done!
92000 done!
94000 done!
96000 done!
98000 done!
100000 done!
102000 done!
104000 done!
106000 done!
108000 done!
110000 done!
112000 done!
114000 done!
116000 done!
118000 done!
120000 done!
122000 done!
124000 done!
126000 done!
128000 done!
130000 done!
132000 done!
134000 done!
136000 done!
138000 done!
140000 done!
142000 done!
144000 done!
146000 done!
148000 done!
150000 done!
152000 done!
154000 done!
156000 done!
158000 done!


Exception in thread Thread-169259 (_readerthread):
Traceback (most recent call last):
  File "c:\Users\megav\anaconda3\envs\iml-10\lib\threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "c:\Users\megav\anaconda3\envs\iml-10\lib\site-packages\ipykernel\ipkernel.py", line 761, in run_closure
    _threading_Thread_run(self)
  File "c:\Users\megav\anaconda3\envs\iml-10\lib\threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "c:\Users\megav\anaconda3\envs\iml-10\lib\subprocess.py", line 1515, in _readerthread
    buffer.append(fh.read())
MemoryError


# Accuracy and confusion matrix

In [6]:
classifier = pipeline("audio-classification", model="MIT_commands/checkpoint-1304")

y_true = []
y_pred = []

def predict(audio_file, classifier):
    prediction = classifier(audio_file)
    if len(prediction) == 0:
        print(f'no prediction for {audio_file}')
        return 'unknown'
    return prediction[0]['label']

for ind, file in enumerate(dataset['test']):
    if ind%2000 == 0:
        print("{} done!".format(ind))
    y_true.append(file['label'])
    predicted_label = predict(file['audio']['path'], classifier)
    y_pred.append(label2id[predicted_label])

y_pred = [int(x) for x in y_pred]
accuracy = accuracy_score(y_true, y_pred)

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(0, 12), yticklabels=range(0, 12))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Transformer Confusion Matrix. Accuracy: {accuracy * 100:.2f}%')
path = f'MIT_commands/checkpoint-1304/0_confusion_matrix.png'
plt.savefig(path)
print(f'Confusion matrix is saved to: {path}')
plt.close()

0 done!


  waveform = torch.from_numpy(waveform).unsqueeze(0)


2000 done!
4000 done!
Accuracy: 97.70%
Confusion matrix is saved to: MIT_commands/checkpoint-1304/0_confusion_matrix.png
