In [1]:
from tensorflow import keras
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, f1_score

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout, GRU, Bidirectional, Flatten
from tensorflow.keras.optimizers import SGD
from tensorflow.random import set_seed

set_seed(2024)
np.random.seed(2024)


import csv
import librosa
import librosa.display
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

train_dir = os.path.abspath('../data/train/train')
test_dir = os.path.abspath('../data/test/test')

classes = os.listdir(train_dir + '/audio')
# classes.remove("_background_noise_")

X_train = np.load(train_dir + "/X_train.npy")
y_train = np.load(train_dir + "/y_train.npy")

X_val = np.load(train_dir + "/X_val.npy")
y_val = np.load(train_dir + "/y_val.npy")

X_train = X_train.reshape((-1, X_train.shape[1], X_train.shape[2]))
X_val = X_val.reshape((-1, X_val.shape[1], X_val.shape[2]))

X_test = np.load(test_dir + '/X_test.npy')
X_files = np.loadtxt(test_dir + '/X_files.txt', delimiter=" ", dtype='str')

def plot_loss(history_df, name, idx):
    plt.figure()
    plt.plot(history_df['loss'])
    plt.plot(history_df['val_loss'])
    plt.title(f'{name}: loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    path = f'train_history/{name}/{idx}_loss.png'
    plt.savefig(path)
    print(f'Loss plot is saved to: {path}')
    plt.close()


In [4]:
input_shape = (122, 85)
epochs = 100
batch_size = 32


model_lstm = Sequential()
model_lstm.add(LSTM(units=125, activation="tanh", input_shape=input_shape))
model_lstm.add(Dense(units=len(classes)))
# Compiling the model
model_lstm.compile(optimizer="RMSprop", loss="mse")

print(model_lstm.summary())

# Model training
history = model_lstm.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, y_val))

history_df = pd.DataFrame(history.history) 
hist_csv_file = 'train_history/' + 'LSTM' + '/' + str(0) + '_history.csv'
with open(hist_csv_file, mode='w') as f:
    history_df.to_csv(f)

plot_loss(history_df, 'LSTM', str(0))

# Check validation accuracy and f1
y_pred = np.argmax(model_lstm.predict(X_val), axis=-1)
accuracy = accuracy_score(y_val, y_pred)

f1 = f1_score(y_val, y_pred, average='macro')
metrics_df = pd.DataFrame({'accuracy': [accuracy], 'f1': [f1]}) 
metrics_csv_file = 'train_history/' + 'LSTM' + '/' + str(0) + '_metrics.csv'
with open(metrics_csv_file, mode='w') as f:
    metrics_df.to_csv(f)
print(metrics_df)

# Confusion Matrix
cm = confusion_matrix(y_val, y_pred)
accuracy = accuracy_score(y_val, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(0, 30), yticklabels=range(0, 30))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'LSTM Confusion Matrix. Accuracy: {accuracy * 100:.2f}%')
path = f'train_history/LSTM/0_confusion_matrix.png'
plt.savefig(path)
print(f'Confusion matrix is saved to: {path}')
plt.close()

# Prepare submission
submission_df = pd.DataFrame({'fname': [], 'label': []}) 
for ind, test_file in enumerate(X_files):
    y_pred = np.argmax(model_lstm.predict(np.array([X_test[ind]])), axis=-1)
    submission_df.loc[len(submission_df.index)] = [X_files[ind], classes[y_pred[0]]]
submission_csv_file = 'train_history/' + 'LSTM' + '/' + str(0) + '_submission.csv'
with open(submission_csv_file, mode='w') as f:
    submission_df.to_csv(f, index=False, lineterminator='\n')
print('Submission saved.')

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm_1 (LSTM)               (None, 125)               105500    
                                                                 
 dense_1 (Dense)             (None, 30)                3780      
                                                                 
Total params: 109,280
Trainable params: 109,280
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 

MemoryError: Unable to allocate 1.63 MiB for an array with shape (2, 106512) and data type object

In [2]:
input_shape = (122, 85)
epochs = 50
batch_size = 32

model_gru = Sequential()
model_gru.add(Dropout(0.2, input_shape=input_shape))
model_gru.add(GRU(units=125, return_sequences=True))
model_gru.add(Dropout(0.2))
model_gru.add(GRU(units=125, return_sequences=True))
model_gru.add(Dropout(0.2))
model_gru.add(GRU(units=125, return_sequences=True))
model_gru.add(Dropout(0.2))
model_gru.add(Flatten()),
model_gru.add(Dense(units=125, activation="sigmoid"))
model_gru.add(Dropout(0.2))
model_gru.add(Dense(units=len(classes), activation="softmax"))
# Compiling the model
model_gru.compile(optimizer="adam", loss="mse")

print(model_gru.summary())

# Model training
history = model_gru.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, y_val))

history_df = pd.DataFrame(history.history) 
hist_csv_file = 'train_history/' + 'GRU' + '/' + str(1) + '_history.csv'
with open(hist_csv_file, mode='w') as f:
    history_df.to_csv(f)

plot_loss(history_df, 'GRU', str(1))

# Check validation accuracy and f1
y_pred = np.argmax(model_gru.predict(X_val), axis=-1)
accuracy = accuracy_score(y_val, y_pred)

f1 = f1_score(y_val, y_pred, average='macro')
metrics_df = pd.DataFrame({'accuracy': [accuracy], 'f1': [f1]}) 
metrics_csv_file = 'train_history/' + 'GRU' + '/' + str(1) + '_metrics.csv'
with open(metrics_csv_file, mode='w') as f:
    metrics_df.to_csv(f)
print(metrics_df)

# Confusion Matrix
cm = confusion_matrix(y_val, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(0, 30), yticklabels=range(0, 30))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'GRU Confusion Matrix. Accuracy: {accuracy * 100:.2f}%')
path = f'train_history/GRU/1_confusion_matrix.png'
plt.savefig(path)
print(f'Confusion matrix is saved to: {path}')
plt.close()

# Prepare submission
# submission_df = pd.DataFrame({'fname': [], 'label': []})
# # y_pred = np.argmax(model_gru.predict(X_test), axis=-1)
# for ind, test_file in enumerate(X_files):
#     if ind%2000 == 0:
#         print("{} done!".format(ind))
#     y_pred = np.argmax(model_gru.predict(np.array([X_test[ind]])), axis=-1)
#     submission_df.loc[len(submission_df.index)] = [X_files[ind], classes[y_pred[0]]]
# submission_csv_file = 'train_history/' + 'GRU' + '/' + str(0) + '_submission.csv'
# with open(submission_csv_file, mode='w') as f:
#     submission_df.to_csv(f, index=False, lineterminator='\n')
# print('Submission saved.')

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dropout (Dropout)           (None, 122, 85)           0         
                                                                 
 gru (GRU)                   (None, 122, 125)          79500     
                                                                 
 dropout_1 (Dropout)         (None, 122, 125)          0         
                                                                 
 gru_1 (GRU)                 (None, 122, 125)          94500     
                                                                 
 dropout_2 (Dropout)         (None, 122, 125)          0         
                                                                 
 gru_2 (GRU)                 (None, 122, 125)          94500     
                                                                 
 dropout_3 (Dropout)         (None, 122, 125)          0

# Transformer

In [4]:
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, TrainingArguments, Trainer, pipeline
import evaluate
from datasets import load_dataset, Audio

train_dir = os.path.abspath('../data/train/train')
dataset = load_dataset("audiofolder", data_dir="../data/train/train/audio")
dataset = dataset['train'].train_test_split(test_size=0.2)
labels = dataset["train"].features["label"].names
print(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label


['bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'four', 'go', 'happy', 'house', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'wow', 'yes', 'zero']


In [20]:
feature_extractor = AutoFeatureExtractor.from_pretrained(
    "facebook/wav2vec2-base")

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
      audio_arrays,
      sampling_rate=feature_extractor.sampling_rate,
      max_length=16000, truncation=True
    )
    return inputs
 
 
encoded_dataset = dataset.map(
    preprocess_function, remove_columns="audio", batched=True)
# encoded_dataset = encoded_dataset.rename_column("intent_class", "label")

encoded_dataset['train']


Map: 100%|██████████| 51776/51776 [05:21<00:00, 161.19 examples/s]
Map: 100%|██████████| 12945/12945 [01:25<00:00, 150.65 examples/s]


Dataset({
    features: ['label', 'input_values'],
    num_rows: 51776
})

In [22]:


accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)


In [23]:
import numpy as np

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-base", num_labels=num_labels, 
  label2id=label2id, id2label=id2label
)
 
training_args = TrainingArguments(
    output_dir="GFG_commands_audio_classification_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    # increase this if you have better machine resources
    learning_rate=4e-4,
    # increase this if you have better machine resources
    per_device_train_batch_size=32,
    gradient_accumulation_steps=4,
    # increase this if you have better machine resources
    per_device_eval_batch_size=32,
    # increase this if you have better machine resources
    num_train_epochs=8,
    warmup_ratio=0.1,
    # increase this if you have better machine resources
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,                             # set to False if you don't have GPU
)
 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"].with_format("torch"),
    eval_dataset=encoded_dataset["test"].with_format("torch"),
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)
 
trainer.train()


  return self.fget.__get__(instance, owner)()
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 10/3232 [00:07<34:06,  1.57it/s] 

{'loss': 3.3981, 'grad_norm': 0.5974945425987244, 'learning_rate': 9.876543209876543e-06, 'epoch': 0.02}


  1%|          | 20/3232 [00:12<31:08,  1.72it/s]

{'loss': 3.3927, 'grad_norm': 0.5135430097579956, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.05}


  1%|          | 30/3232 [00:18<30:46,  1.73it/s]

{'loss': 3.3725, 'grad_norm': nan, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.07}


  1%|          | 40/3232 [00:24<32:28,  1.64it/s]

{'loss': 3.3348, 'grad_norm': 0.9705635905265808, 'learning_rate': 4.567901234567901e-05, 'epoch': 0.1}


  2%|▏         | 50/3232 [00:30<31:37,  1.68it/s]

{'loss': 3.1878, 'grad_norm': 2.001976251602173, 'learning_rate': 5.679012345679012e-05, 'epoch': 0.12}


  2%|▏         | 60/3232 [00:36<29:57,  1.76it/s]

{'loss': 2.949, 'grad_norm': 3.482431650161743, 'learning_rate': 6.91358024691358e-05, 'epoch': 0.15}


  2%|▏         | 70/3232 [00:41<29:36,  1.78it/s]

{'loss': 2.6708, 'grad_norm': 6.4644927978515625, 'learning_rate': 8.148148148148148e-05, 'epoch': 0.17}


  2%|▏         | 80/3232 [00:47<29:36,  1.77it/s]

{'loss': 2.3857, 'grad_norm': 4.359455108642578, 'learning_rate': 9.382716049382717e-05, 'epoch': 0.2}


  3%|▎         | 90/3232 [00:53<29:26,  1.78it/s]

{'loss': 2.1176, 'grad_norm': 6.847593784332275, 'learning_rate': 0.00010617283950617284, 'epoch': 0.22}


  3%|▎         | 100/3232 [00:58<29:22,  1.78it/s]

{'loss': 1.8958, 'grad_norm': 7.126283168792725, 'learning_rate': 0.00011851851851851852, 'epoch': 0.25}


  3%|▎         | 110/3232 [01:04<29:15,  1.78it/s]

{'loss': 1.7193, 'grad_norm': 14.05808162689209, 'learning_rate': 0.0001308641975308642, 'epoch': 0.27}


  4%|▎         | 120/3232 [01:10<29:13,  1.77it/s]

{'loss': 1.6074, 'grad_norm': 13.156152725219727, 'learning_rate': 0.00014197530864197534, 'epoch': 0.3}


  4%|▍         | 130/3232 [01:15<29:06,  1.78it/s]

{'loss': 1.4696, 'grad_norm': 7.656411647796631, 'learning_rate': 0.00015432098765432098, 'epoch': 0.32}


  4%|▍         | 140/3232 [01:21<29:40,  1.74it/s]

{'loss': 1.3265, 'grad_norm': 9.719106674194336, 'learning_rate': 0.0001666666666666667, 'epoch': 0.35}


  5%|▍         | 150/3232 [01:27<29:22,  1.75it/s]

{'loss': 1.2989, 'grad_norm': 14.675065040588379, 'learning_rate': 0.00017901234567901234, 'epoch': 0.37}


  5%|▍         | 160/3232 [01:32<29:50,  1.72it/s]

{'loss': 1.2164, 'grad_norm': 10.397281646728516, 'learning_rate': 0.00019135802469135804, 'epoch': 0.4}


  5%|▌         | 170/3232 [01:38<29:27,  1.73it/s]

{'loss': 1.1313, 'grad_norm': 9.059977531433105, 'learning_rate': 0.00020370370370370372, 'epoch': 0.42}


  6%|▌         | 180/3232 [01:44<30:29,  1.67it/s]

{'loss': 1.0487, 'grad_norm': 19.16676902770996, 'learning_rate': 0.00021604938271604937, 'epoch': 0.44}


  6%|▌         | 190/3232 [01:50<29:22,  1.73it/s]

{'loss': 1.055, 'grad_norm': 12.754321098327637, 'learning_rate': 0.00022839506172839507, 'epoch': 0.47}


  6%|▌         | 200/3232 [01:56<29:15,  1.73it/s]

{'loss': 1.03, 'grad_norm': 19.3332462310791, 'learning_rate': 0.00024074074074074075, 'epoch': 0.49}


  6%|▋         | 210/3232 [02:02<29:11,  1.73it/s]

{'loss': 1.1729, 'grad_norm': 25.820953369140625, 'learning_rate': 0.0002530864197530864, 'epoch': 0.52}


  7%|▋         | 220/3232 [02:07<29:29,  1.70it/s]

{'loss': 1.0571, 'grad_norm': 8.217148780822754, 'learning_rate': 0.0002641975308641976, 'epoch': 0.54}


  7%|▋         | 230/3232 [02:13<29:06,  1.72it/s]

{'loss': 1.034, 'grad_norm': 10.695181846618652, 'learning_rate': 0.0002765432098765432, 'epoch': 0.57}


  7%|▋         | 240/3232 [02:19<28:52,  1.73it/s]

{'loss': 1.0716, 'grad_norm': 20.687252044677734, 'learning_rate': 0.0002888888888888889, 'epoch': 0.59}


  8%|▊         | 250/3232 [02:25<28:46,  1.73it/s]

{'loss': 0.9481, 'grad_norm': 8.926913261413574, 'learning_rate': 0.00030000000000000003, 'epoch': 0.62}


  8%|▊         | 260/3232 [02:31<28:50,  1.72it/s]

{'loss': 0.9354, 'grad_norm': 10.112759590148926, 'learning_rate': 0.0003123456790123457, 'epoch': 0.64}


  8%|▊         | 270/3232 [02:36<28:57,  1.70it/s]

{'loss': 0.9729, 'grad_norm': 10.21896743774414, 'learning_rate': 0.0003246913580246914, 'epoch': 0.67}


  9%|▊         | 280/3232 [02:42<29:00,  1.70it/s]

{'loss': 0.8974, 'grad_norm': 71.07129669189453, 'learning_rate': 0.00033703703703703706, 'epoch': 0.69}


  9%|▉         | 290/3232 [02:48<28:41,  1.71it/s]

{'loss': 1.1328, 'grad_norm': 11.965250968933105, 'learning_rate': 0.00034938271604938273, 'epoch': 0.72}


  9%|▉         | 300/3232 [02:54<28:22,  1.72it/s]

{'loss': 0.9245, 'grad_norm': 89.32243347167969, 'learning_rate': 0.0003617283950617284, 'epoch': 0.74}


 10%|▉         | 310/3232 [03:00<28:59,  1.68it/s]

{'loss': 0.9307, 'grad_norm': 9.151155471801758, 'learning_rate': 0.0003740740740740741, 'epoch': 0.77}


 10%|▉         | 320/3232 [03:06<28:28,  1.70it/s]

{'loss': 1.1359, 'grad_norm': 15.506536483764648, 'learning_rate': 0.00038641975308641976, 'epoch': 0.79}


 10%|█         | 330/3232 [03:12<28:14,  1.71it/s]

{'loss': 1.0861, 'grad_norm': 10.072122573852539, 'learning_rate': 0.00039876543209876544, 'epoch': 0.82}


 11%|█         | 340/3232 [03:18<28:14,  1.71it/s]

{'loss': 1.006, 'grad_norm': 9.98062515258789, 'learning_rate': 0.00039876203576341126, 'epoch': 0.84}


 11%|█         | 350/3232 [03:23<28:12,  1.70it/s]

{'loss': 1.1914, 'grad_norm': 13.83738899230957, 'learning_rate': 0.0003973865199449794, 'epoch': 0.87}


 11%|█         | 360/3232 [03:29<29:06,  1.64it/s]

{'loss': 1.0098, 'grad_norm': 9.632941246032715, 'learning_rate': 0.00039601100412654745, 'epoch': 0.89}


 11%|█▏        | 370/3232 [03:35<28:42,  1.66it/s]

{'loss': 0.9551, 'grad_norm': 31.466896057128906, 'learning_rate': 0.0003946354883081155, 'epoch': 0.91}


 12%|█▏        | 380/3232 [03:42<29:09,  1.63it/s]

{'loss': 0.9839, 'grad_norm': 19.316991806030273, 'learning_rate': 0.0003932599724896837, 'epoch': 0.94}


 12%|█▏        | 390/3232 [03:48<28:33,  1.66it/s]

{'loss': 1.0631, 'grad_norm': 13.221471786499023, 'learning_rate': 0.00039188445667125176, 'epoch': 0.96}


 12%|█▏        | 400/3232 [03:54<28:05,  1.68it/s]

{'loss': 0.8905, 'grad_norm': 9.887857437133789, 'learning_rate': 0.00039050894085281983, 'epoch': 0.99}


                                                  
 12%|█▎        | 404/3232 [04:13<27:48,  1.69it/s]

{'eval_loss': 0.7223916053771973, 'eval_accuracy': 0.8159907300115875, 'eval_runtime': 16.9046, 'eval_samples_per_second': 765.766, 'eval_steps_per_second': 23.958, 'epoch': 1.0}


 13%|█▎        | 410/3232 [04:17<1:09:34,  1.48s/it]

{'loss': 1.0884, 'grad_norm': 9.147896766662598, 'learning_rate': 0.0003891334250343879, 'epoch': 1.01}


 13%|█▎        | 420/3232 [04:23<28:02,  1.67it/s]  

{'loss': 1.0692, 'grad_norm': 9.761871337890625, 'learning_rate': 0.000387757909215956, 'epoch': 1.04}


 13%|█▎        | 430/3232 [04:29<26:43,  1.75it/s]

{'loss': 0.8541, 'grad_norm': 8.569286346435547, 'learning_rate': 0.0003863823933975241, 'epoch': 1.06}


 14%|█▎        | 440/3232 [04:35<26:41,  1.74it/s]

{'loss': 0.9016, 'grad_norm': 9.989483833312988, 'learning_rate': 0.00038500687757909215, 'epoch': 1.09}


 14%|█▍        | 450/3232 [04:40<26:44,  1.73it/s]

{'loss': 0.9457, 'grad_norm': 13.375882148742676, 'learning_rate': 0.0003836313617606603, 'epoch': 1.11}


 14%|█▍        | 460/3232 [04:46<26:36,  1.74it/s]

{'loss': 0.995, 'grad_norm': 7.6293816566467285, 'learning_rate': 0.00038225584594222834, 'epoch': 1.14}


 15%|█▍        | 470/3232 [04:52<26:29,  1.74it/s]

{'loss': 1.0295, 'grad_norm': 11.752564430236816, 'learning_rate': 0.0003808803301237964, 'epoch': 1.16}


 15%|█▍        | 480/3232 [04:58<26:26,  1.73it/s]

{'loss': 0.8052, 'grad_norm': 6.213232517242432, 'learning_rate': 0.00037964236588720776, 'epoch': 1.19}


 15%|█▌        | 490/3232 [05:03<26:27,  1.73it/s]

{'loss': 0.8636, 'grad_norm': 10.599929809570312, 'learning_rate': 0.00037826685006877583, 'epoch': 1.21}


 15%|█▌        | 500/3232 [05:09<26:11,  1.74it/s]

{'loss': 0.7823, 'grad_norm': 5.848618507385254, 'learning_rate': 0.0003768913342503439, 'epoch': 1.24}


 16%|█▌        | 510/3232 [05:15<26:24,  1.72it/s]

{'loss': 0.7473, 'grad_norm': 12.626099586486816, 'learning_rate': 0.000375515818431912, 'epoch': 1.26}


 16%|█▌        | 520/3232 [05:21<26:08,  1.73it/s]

{'loss': 0.9934, 'grad_norm': 6.224715709686279, 'learning_rate': 0.0003741403026134801, 'epoch': 1.29}


 16%|█▋        | 530/3232 [05:27<26:10,  1.72it/s]

{'loss': 0.8305, 'grad_norm': 7.35163688659668, 'learning_rate': 0.00037276478679504815, 'epoch': 1.31}


 17%|█▋        | 540/3232 [05:33<27:19,  1.64it/s]

{'loss': 0.9013, 'grad_norm': 5.136435508728027, 'learning_rate': 0.0003713892709766162, 'epoch': 1.33}


 17%|█▋        | 550/3232 [05:39<27:10,  1.65it/s]

{'loss': 0.7845, 'grad_norm': 11.255746841430664, 'learning_rate': 0.00037001375515818434, 'epoch': 1.36}


 17%|█▋        | 560/3232 [05:45<25:49,  1.72it/s]

{'loss': 0.8139, 'grad_norm': 6.624602794647217, 'learning_rate': 0.0003686382393397524, 'epoch': 1.38}


 18%|█▊        | 570/3232 [05:50<25:42,  1.73it/s]

{'loss': 0.7719, 'grad_norm': 6.633724212646484, 'learning_rate': 0.0003672627235213205, 'epoch': 1.41}


 18%|█▊        | 580/3232 [05:56<25:31,  1.73it/s]

{'loss': 0.6731, 'grad_norm': 8.048134803771973, 'learning_rate': 0.0003658872077028886, 'epoch': 1.43}


 18%|█▊        | 590/3232 [06:02<25:22,  1.74it/s]

{'loss': 0.7616, 'grad_norm': 7.670865535736084, 'learning_rate': 0.0003645116918844567, 'epoch': 1.46}


 19%|█▊        | 600/3232 [06:08<25:20,  1.73it/s]

{'loss': 0.8793, 'grad_norm': 17.933467864990234, 'learning_rate': 0.0003631361760660248, 'epoch': 1.48}


 19%|█▉        | 610/3232 [06:14<25:06,  1.74it/s]

{'loss': 0.7094, 'grad_norm': 9.235109329223633, 'learning_rate': 0.00036176066024759285, 'epoch': 1.51}


 19%|█▉        | 620/3232 [06:19<25:06,  1.73it/s]

{'loss': 0.7413, 'grad_norm': 9.457657814025879, 'learning_rate': 0.00036038514442916097, 'epoch': 1.53}


 19%|█▉        | 630/3232 [06:25<25:32,  1.70it/s]

{'loss': 0.7503, 'grad_norm': 8.781124114990234, 'learning_rate': 0.00035900962861072904, 'epoch': 1.56}


 20%|█▉        | 640/3232 [06:31<25:07,  1.72it/s]

{'loss': 0.592, 'grad_norm': 7.627665996551514, 'learning_rate': 0.0003576341127922971, 'epoch': 1.58}


 20%|██        | 650/3232 [06:37<24:56,  1.73it/s]

{'loss': 0.7885, 'grad_norm': 6.117629528045654, 'learning_rate': 0.00035625859697386523, 'epoch': 1.61}


 20%|██        | 660/3232 [06:43<24:49,  1.73it/s]

{'loss': 0.6792, 'grad_norm': 6.476618766784668, 'learning_rate': 0.0003548830811554333, 'epoch': 1.63}


 21%|██        | 670/3232 [06:48<24:45,  1.73it/s]

{'loss': 0.7925, 'grad_norm': 5.640066623687744, 'learning_rate': 0.00035350756533700136, 'epoch': 1.66}


 21%|██        | 680/3232 [06:54<25:17,  1.68it/s]

{'loss': 0.778, 'grad_norm': 7.942684650421143, 'learning_rate': 0.0003521320495185695, 'epoch': 1.68}


 21%|██▏       | 690/3232 [07:00<24:52,  1.70it/s]

{'loss': 0.69, 'grad_norm': 7.34427547454834, 'learning_rate': 0.00035075653370013755, 'epoch': 1.71}


 22%|██▏       | 700/3232 [07:06<24:46,  1.70it/s]

{'loss': 0.7528, 'grad_norm': 11.511927604675293, 'learning_rate': 0.00034938101788170567, 'epoch': 1.73}


 22%|██▏       | 710/3232 [07:12<24:22,  1.72it/s]

{'loss': 0.5785, 'grad_norm': 7.3523173332214355, 'learning_rate': 0.00034800550206327374, 'epoch': 1.76}


 22%|██▏       | 720/3232 [07:18<24:06,  1.74it/s]

{'loss': 0.6743, 'grad_norm': 4.841892719268799, 'learning_rate': 0.00034662998624484186, 'epoch': 1.78}


 23%|██▎       | 730/3232 [07:24<24:08,  1.73it/s]

{'loss': 0.6332, 'grad_norm': 11.82168960571289, 'learning_rate': 0.00034525447042640993, 'epoch': 1.8}


 23%|██▎       | 740/3232 [07:29<24:09,  1.72it/s]

{'loss': 0.6128, 'grad_norm': 4.7661356925964355, 'learning_rate': 0.000343878954607978, 'epoch': 1.83}


 23%|██▎       | 750/3232 [07:35<24:00,  1.72it/s]

{'loss': 0.5847, 'grad_norm': 5.897588729858398, 'learning_rate': 0.0003425034387895461, 'epoch': 1.85}


 24%|██▎       | 760/3232 [07:41<24:28,  1.68it/s]

{'loss': 0.7623, 'grad_norm': 8.407944679260254, 'learning_rate': 0.0003411279229711142, 'epoch': 1.88}


 24%|██▍       | 770/3232 [07:47<24:05,  1.70it/s]

{'loss': 0.7163, 'grad_norm': 12.995570182800293, 'learning_rate': 0.00033975240715268225, 'epoch': 1.9}


 24%|██▍       | 780/3232 [07:53<24:54,  1.64it/s]

{'loss': 0.6063, 'grad_norm': 11.89809799194336, 'learning_rate': 0.0003383768913342503, 'epoch': 1.93}


 24%|██▍       | 790/3232 [07:59<24:46,  1.64it/s]

{'loss': 0.5976, 'grad_norm': 7.1000776290893555, 'learning_rate': 0.00033700137551581844, 'epoch': 1.95}


 25%|██▍       | 800/3232 [08:05<23:41,  1.71it/s]

{'loss': 0.6646, 'grad_norm': 4.561517715454102, 'learning_rate': 0.00033562585969738656, 'epoch': 1.98}


                                                  
 25%|██▌       | 809/3232 [08:28<23:36,  1.71it/s]

{'eval_loss': 0.2674466371536255, 'eval_accuracy': 0.9235225955967555, 'eval_runtime': 17.0152, 'eval_samples_per_second': 760.788, 'eval_steps_per_second': 23.802, 'epoch': 2.0}


 25%|██▌       | 810/3232 [08:29<4:01:37,  5.99s/it]

{'loss': 0.583, 'grad_norm': 4.901156902313232, 'learning_rate': 0.00033425034387895463, 'epoch': 2.0}


 25%|██▌       | 820/3232 [08:35<29:32,  1.36it/s]  

{'loss': 0.5868, 'grad_norm': 12.466076850891113, 'learning_rate': 0.00033287482806052275, 'epoch': 2.03}


 26%|██▌       | 830/3232 [08:41<23:24,  1.71it/s]

{'loss': 0.7233, 'grad_norm': 7.9050445556640625, 'learning_rate': 0.0003314993122420908, 'epoch': 2.05}


 26%|██▌       | 840/3232 [08:47<23:44,  1.68it/s]

{'loss': 0.6148, 'grad_norm': 5.188591957092285, 'learning_rate': 0.0003301237964236589, 'epoch': 2.08}


 26%|██▋       | 850/3232 [08:53<23:36,  1.68it/s]

{'loss': 0.5321, 'grad_norm': 10.80848503112793, 'learning_rate': 0.00032874828060522695, 'epoch': 2.1}


 27%|██▋       | 860/3232 [08:59<23:14,  1.70it/s]

{'loss': 0.5345, 'grad_norm': 6.9038896560668945, 'learning_rate': 0.0003273727647867951, 'epoch': 2.13}


 27%|██▋       | 870/3232 [09:04<23:01,  1.71it/s]

{'loss': 0.5682, 'grad_norm': 8.394136428833008, 'learning_rate': 0.00032599724896836314, 'epoch': 2.15}


 27%|██▋       | 880/3232 [09:10<23:00,  1.70it/s]

{'loss': 0.4537, 'grad_norm': 10.20146656036377, 'learning_rate': 0.0003246217331499312, 'epoch': 2.18}


 28%|██▊       | 890/3232 [09:16<23:07,  1.69it/s]

{'loss': 0.5816, 'grad_norm': 5.988681316375732, 'learning_rate': 0.00032324621733149933, 'epoch': 2.2}


 28%|██▊       | 900/3232 [09:22<23:05,  1.68it/s]

{'loss': 0.5809, 'grad_norm': 3.4349489212036133, 'learning_rate': 0.0003218707015130674, 'epoch': 2.22}


 28%|██▊       | 910/3232 [09:28<22:37,  1.71it/s]

{'loss': 0.5515, 'grad_norm': 6.637514591217041, 'learning_rate': 0.0003204951856946355, 'epoch': 2.25}


 28%|██▊       | 920/3232 [09:34<22:35,  1.71it/s]

{'loss': 0.5411, 'grad_norm': 8.168432235717773, 'learning_rate': 0.0003191196698762036, 'epoch': 2.27}


 29%|██▉       | 930/3232 [09:40<22:32,  1.70it/s]

{'loss': 0.5347, 'grad_norm': 3.8581748008728027, 'learning_rate': 0.0003177441540577717, 'epoch': 2.3}


 29%|██▉       | 940/3232 [09:46<22:57,  1.66it/s]

{'loss': 0.5339, 'grad_norm': 5.095576763153076, 'learning_rate': 0.0003163686382393398, 'epoch': 2.32}


 29%|██▉       | 950/3232 [09:52<22:34,  1.68it/s]

{'loss': 0.5183, 'grad_norm': 10.683093070983887, 'learning_rate': 0.00031499312242090784, 'epoch': 2.35}


 30%|██▉       | 960/3232 [09:58<22:26,  1.69it/s]

{'loss': 0.5387, 'grad_norm': 7.7227959632873535, 'learning_rate': 0.00031361760660247596, 'epoch': 2.37}


 30%|███       | 970/3232 [10:04<22:33,  1.67it/s]

{'loss': 0.5126, 'grad_norm': 3.9734959602355957, 'learning_rate': 0.00031224209078404403, 'epoch': 2.4}


 30%|███       | 980/3232 [10:10<22:36,  1.66it/s]

{'loss': 0.5096, 'grad_norm': 6.569104194641113, 'learning_rate': 0.0003108665749656121, 'epoch': 2.42}


 31%|███       | 990/3232 [10:16<22:23,  1.67it/s]

{'loss': 0.4637, 'grad_norm': 6.02008581161499, 'learning_rate': 0.0003094910591471802, 'epoch': 2.45}


 31%|███       | 1000/3232 [10:22<22:13,  1.67it/s]

{'loss': 0.5497, 'grad_norm': 6.9329423904418945, 'learning_rate': 0.0003081155433287483, 'epoch': 2.47}


 31%|███▏      | 1010/3232 [10:28<22:11,  1.67it/s]

{'loss': 0.5547, 'grad_norm': 5.013092994689941, 'learning_rate': 0.00030674002751031635, 'epoch': 2.5}


 32%|███▏      | 1020/3232 [10:34<22:04,  1.67it/s]

{'loss': 0.5748, 'grad_norm': 5.011048793792725, 'learning_rate': 0.0003053645116918845, 'epoch': 2.52}


 32%|███▏      | 1030/3232 [10:40<22:30,  1.63it/s]

{'loss': 0.5735, 'grad_norm': 6.902886867523193, 'learning_rate': 0.0003039889958734526, 'epoch': 2.55}


 32%|███▏      | 1040/3232 [10:46<22:52,  1.60it/s]

{'loss': 0.6071, 'grad_norm': 7.596375942230225, 'learning_rate': 0.00030261348005502066, 'epoch': 2.57}


 32%|███▏      | 1050/3232 [10:52<22:03,  1.65it/s]

{'loss': 0.5273, 'grad_norm': 3.529195785522461, 'learning_rate': 0.00030123796423658873, 'epoch': 2.6}


 33%|███▎      | 1060/3232 [10:58<21:47,  1.66it/s]

{'loss': 0.5274, 'grad_norm': 5.672167778015137, 'learning_rate': 0.00029986244841815685, 'epoch': 2.62}


 33%|███▎      | 1070/3232 [11:04<21:34,  1.67it/s]

{'loss': 0.4825, 'grad_norm': 7.5248284339904785, 'learning_rate': 0.0002984869325997249, 'epoch': 2.65}


 33%|███▎      | 1080/3232 [11:10<21:28,  1.67it/s]

{'loss': 0.5114, 'grad_norm': 4.701301097869873, 'learning_rate': 0.000297111416781293, 'epoch': 2.67}


 34%|███▎      | 1090/3232 [11:16<21:23,  1.67it/s]

{'loss': 0.4659, 'grad_norm': 4.314423084259033, 'learning_rate': 0.00029573590096286105, 'epoch': 2.69}


 34%|███▍      | 1100/3232 [11:22<21:16,  1.67it/s]

{'loss': 0.5495, 'grad_norm': 5.602716445922852, 'learning_rate': 0.0002943603851444292, 'epoch': 2.72}


 34%|███▍      | 1110/3232 [11:28<21:13,  1.67it/s]

{'loss': 0.5066, 'grad_norm': 5.22689962387085, 'learning_rate': 0.00029298486932599724, 'epoch': 2.74}


 35%|███▍      | 1120/3232 [11:34<21:07,  1.67it/s]

{'loss': 0.5192, 'grad_norm': 4.3070292472839355, 'learning_rate': 0.0002916093535075653, 'epoch': 2.77}


 35%|███▍      | 1130/3232 [11:40<21:05,  1.66it/s]

{'loss': 0.4818, 'grad_norm': 4.634099006652832, 'learning_rate': 0.0002902338376891335, 'epoch': 2.79}


 35%|███▌      | 1140/3232 [11:46<20:52,  1.67it/s]

{'loss': 0.5386, 'grad_norm': 4.366332054138184, 'learning_rate': 0.00028885832187070155, 'epoch': 2.82}


 36%|███▌      | 1150/3232 [11:52<20:50,  1.67it/s]

{'loss': 0.4746, 'grad_norm': 3.8969123363494873, 'learning_rate': 0.0002874828060522696, 'epoch': 2.84}


 36%|███▌      | 1160/3232 [11:58<20:45,  1.66it/s]

{'loss': 0.472, 'grad_norm': 5.107517242431641, 'learning_rate': 0.0002861072902338377, 'epoch': 2.87}


 36%|███▌      | 1170/3232 [12:04<20:38,  1.66it/s]

{'loss': 0.5359, 'grad_norm': 9.464388847351074, 'learning_rate': 0.0002847317744154058, 'epoch': 2.89}


 37%|███▋      | 1180/3232 [12:10<20:24,  1.68it/s]

{'loss': 0.4811, 'grad_norm': 3.985809803009033, 'learning_rate': 0.0002833562585969739, 'epoch': 2.92}


 37%|███▋      | 1190/3232 [12:16<20:23,  1.67it/s]

{'loss': 0.4408, 'grad_norm': 3.783421516418457, 'learning_rate': 0.00028198074277854194, 'epoch': 2.94}


 37%|███▋      | 1200/3232 [12:22<20:15,  1.67it/s]

{'loss': 0.4971, 'grad_norm': 5.632336139678955, 'learning_rate': 0.00028060522696011006, 'epoch': 2.97}


 37%|███▋      | 1210/3232 [12:28<20:09,  1.67it/s]

{'loss': 0.4869, 'grad_norm': 4.598926067352295, 'learning_rate': 0.00027922971114167813, 'epoch': 2.99}


                                                   
 38%|███▊      | 1213/3232 [12:48<20:13,  1.66it/s]

{'eval_loss': 0.16818635165691376, 'eval_accuracy': 0.9537273078408652, 'eval_runtime': 17.1008, 'eval_samples_per_second': 756.981, 'eval_steps_per_second': 23.683, 'epoch': 3.0}


 38%|███▊      | 1220/3232 [12:52<41:24,  1.24s/it]  

{'loss': 0.416, 'grad_norm': 10.37142562866211, 'learning_rate': 0.0002778541953232462, 'epoch': 3.02}


 38%|███▊      | 1230/3232 [12:58<20:32,  1.62it/s]

{'loss': 0.4935, 'grad_norm': 5.0555596351623535, 'learning_rate': 0.0002764786795048143, 'epoch': 3.04}


 38%|███▊      | 1240/3232 [13:04<19:54,  1.67it/s]

{'loss': 0.4739, 'grad_norm': 4.25599479675293, 'learning_rate': 0.00027510316368638244, 'epoch': 3.07}


 39%|███▊      | 1250/3232 [13:10<19:26,  1.70it/s]

{'loss': 0.4872, 'grad_norm': 4.245752811431885, 'learning_rate': 0.0002737276478679505, 'epoch': 3.09}


 39%|███▉      | 1260/3232 [13:16<19:21,  1.70it/s]

{'loss': 0.4317, 'grad_norm': 3.7415523529052734, 'learning_rate': 0.0002723521320495186, 'epoch': 3.11}


 39%|███▉      | 1270/3232 [13:22<19:18,  1.69it/s]

{'loss': 0.4311, 'grad_norm': 4.422198295593262, 'learning_rate': 0.0002709766162310867, 'epoch': 3.14}


 40%|███▉      | 1280/3232 [13:28<19:15,  1.69it/s]

{'loss': 0.5054, 'grad_norm': 4.804646968841553, 'learning_rate': 0.00026960110041265476, 'epoch': 3.16}


 40%|███▉      | 1290/3232 [13:34<19:13,  1.68it/s]

{'loss': 0.4831, 'grad_norm': 4.000810623168945, 'learning_rate': 0.00026822558459422283, 'epoch': 3.19}


 40%|████      | 1300/3232 [13:40<19:00,  1.69it/s]

{'loss': 0.42, 'grad_norm': 3.658514976501465, 'learning_rate': 0.00026685006877579095, 'epoch': 3.21}


 41%|████      | 1310/3232 [13:46<19:20,  1.66it/s]

{'loss': 0.462, 'grad_norm': 4.942215442657471, 'learning_rate': 0.000265474552957359, 'epoch': 3.24}


 41%|████      | 1320/3232 [13:52<19:03,  1.67it/s]

{'loss': 0.4114, 'grad_norm': 3.0649726390838623, 'learning_rate': 0.0002640990371389271, 'epoch': 3.26}


 41%|████      | 1330/3232 [13:58<18:56,  1.67it/s]

{'loss': 0.439, 'grad_norm': 4.520589351654053, 'learning_rate': 0.0002627235213204952, 'epoch': 3.29}


 41%|████▏     | 1340/3232 [14:04<18:46,  1.68it/s]

{'loss': 0.4153, 'grad_norm': 3.6143782138824463, 'learning_rate': 0.0002613480055020633, 'epoch': 3.31}


 42%|████▏     | 1350/3232 [14:10<18:45,  1.67it/s]

{'loss': 0.4453, 'grad_norm': 5.714651107788086, 'learning_rate': 0.0002599724896836314, 'epoch': 3.34}


 42%|████▏     | 1360/3232 [14:16<18:36,  1.68it/s]

{'loss': 0.4077, 'grad_norm': 5.494105339050293, 'learning_rate': 0.00025859697386519946, 'epoch': 3.36}


 42%|████▏     | 1370/3232 [14:22<18:33,  1.67it/s]

{'loss': 0.4094, 'grad_norm': 8.837018013000488, 'learning_rate': 0.0002572214580467676, 'epoch': 3.39}


 43%|████▎     | 1380/3232 [14:28<18:24,  1.68it/s]

{'loss': 0.391, 'grad_norm': 3.669700860977173, 'learning_rate': 0.00025584594222833565, 'epoch': 3.41}


 43%|████▎     | 1390/3232 [14:34<18:16,  1.68it/s]

{'loss': 0.4276, 'grad_norm': 4.931869983673096, 'learning_rate': 0.0002544704264099037, 'epoch': 3.44}


 43%|████▎     | 1400/3232 [14:40<18:12,  1.68it/s]

{'loss': 0.447, 'grad_norm': 4.233998775482178, 'learning_rate': 0.0002530949105914718, 'epoch': 3.46}


 44%|████▎     | 1410/3232 [14:46<18:02,  1.68it/s]

{'loss': 0.3877, 'grad_norm': 4.47764778137207, 'learning_rate': 0.0002517193947730399, 'epoch': 3.49}


 44%|████▍     | 1420/3232 [14:52<18:00,  1.68it/s]

{'loss': 0.4393, 'grad_norm': 3.8606133460998535, 'learning_rate': 0.000250343878954608, 'epoch': 3.51}


 44%|████▍     | 1430/3232 [14:58<17:58,  1.67it/s]

{'loss': 0.4643, 'grad_norm': 4.885666370391846, 'learning_rate': 0.00024896836313617604, 'epoch': 3.54}


 45%|████▍     | 1440/3232 [15:04<17:48,  1.68it/s]

{'loss': 0.4583, 'grad_norm': 3.351395845413208, 'learning_rate': 0.00024759284731774416, 'epoch': 3.56}


 45%|████▍     | 1450/3232 [15:10<17:43,  1.68it/s]

{'loss': 0.4538, 'grad_norm': 4.091648578643799, 'learning_rate': 0.0002462173314993123, 'epoch': 3.58}


 45%|████▌     | 1460/3232 [15:16<17:34,  1.68it/s]

{'loss': 0.3723, 'grad_norm': 4.727935791015625, 'learning_rate': 0.00024484181568088035, 'epoch': 3.61}


 45%|████▌     | 1470/3232 [15:22<17:31,  1.68it/s]

{'loss': 0.4396, 'grad_norm': 4.885952472686768, 'learning_rate': 0.00024346629986244845, 'epoch': 3.63}


 46%|████▌     | 1480/3232 [15:28<17:18,  1.69it/s]

{'loss': 0.4065, 'grad_norm': 6.024846076965332, 'learning_rate': 0.00024209078404401654, 'epoch': 3.66}


 46%|████▌     | 1490/3232 [15:34<17:16,  1.68it/s]

{'loss': 0.4432, 'grad_norm': 7.5125732421875, 'learning_rate': 0.0002407152682255846, 'epoch': 3.68}


 46%|████▋     | 1500/3232 [15:39<17:12,  1.68it/s]

{'loss': 0.4081, 'grad_norm': 2.3537752628326416, 'learning_rate': 0.0002393397524071527, 'epoch': 3.71}


 47%|████▋     | 1510/3232 [15:45<17:08,  1.67it/s]

{'loss': 0.3974, 'grad_norm': 3.336303949356079, 'learning_rate': 0.00023796423658872077, 'epoch': 3.73}


 47%|████▋     | 1520/3232 [15:51<17:03,  1.67it/s]

{'loss': 0.3653, 'grad_norm': 5.751186370849609, 'learning_rate': 0.00023658872077028886, 'epoch': 3.76}


 47%|████▋     | 1530/3232 [15:57<16:54,  1.68it/s]

{'loss': 0.4022, 'grad_norm': 3.378048896789551, 'learning_rate': 0.00023521320495185696, 'epoch': 3.78}


 48%|████▊     | 1540/3232 [16:03<16:44,  1.68it/s]

{'loss': 0.372, 'grad_norm': 4.358017444610596, 'learning_rate': 0.00023383768913342503, 'epoch': 3.81}


 48%|████▊     | 1550/3232 [16:09<16:41,  1.68it/s]

{'loss': 0.4079, 'grad_norm': 4.330291271209717, 'learning_rate': 0.00023246217331499312, 'epoch': 3.83}


 48%|████▊     | 1560/3232 [16:15<16:35,  1.68it/s]

{'loss': 0.3688, 'grad_norm': 2.770416498184204, 'learning_rate': 0.00023108665749656124, 'epoch': 3.86}


 49%|████▊     | 1570/3232 [16:21<16:27,  1.68it/s]

{'loss': 0.3416, 'grad_norm': 4.183420658111572, 'learning_rate': 0.00022971114167812934, 'epoch': 3.88}


 49%|████▉     | 1580/3232 [16:27<16:20,  1.68it/s]

{'loss': 0.428, 'grad_norm': 3.5428566932678223, 'learning_rate': 0.0002283356258596974, 'epoch': 3.91}


 49%|████▉     | 1590/3232 [16:33<16:15,  1.68it/s]

{'loss': 0.3531, 'grad_norm': 3.6145105361938477, 'learning_rate': 0.0002269601100412655, 'epoch': 3.93}


 50%|████▉     | 1600/3232 [16:39<16:09,  1.68it/s]

{'loss': 0.4173, 'grad_norm': 3.9028639793395996, 'learning_rate': 0.0002255845942228336, 'epoch': 3.96}


 50%|████▉     | 1610/3232 [16:45<16:00,  1.69it/s]

{'loss': 0.3761, 'grad_norm': 3.1079185009002686, 'learning_rate': 0.00022420907840440166, 'epoch': 3.98}


                                                   
 50%|█████     | 1618/3232 [17:07<15:58,  1.68it/s]

{'eval_loss': 0.1465759426355362, 'eval_accuracy': 0.9617612977983777, 'eval_runtime': 17.2877, 'eval_samples_per_second': 748.8, 'eval_steps_per_second': 23.427, 'epoch': 4.0}


 50%|█████     | 1620/3232 [17:09<1:58:48,  4.42s/it]

{'loss': 0.3587, 'grad_norm': 5.623076915740967, 'learning_rate': 0.00022283356258596975, 'epoch': 4.0}


 50%|█████     | 1630/3232 [17:15<18:40,  1.43it/s]  

{'loss': 0.3867, 'grad_norm': 6.865736484527588, 'learning_rate': 0.00022145804676753782, 'epoch': 4.03}


 51%|█████     | 1640/3232 [17:21<15:42,  1.69it/s]

{'loss': 0.379, 'grad_norm': 3.134510040283203, 'learning_rate': 0.00022008253094910591, 'epoch': 4.05}


 51%|█████     | 1650/3232 [17:27<15:30,  1.70it/s]

{'loss': 0.3629, 'grad_norm': 5.120375633239746, 'learning_rate': 0.000218707015130674, 'epoch': 4.08}


 51%|█████▏    | 1660/3232 [17:33<15:26,  1.70it/s]

{'loss': 0.3336, 'grad_norm': 3.4141664505004883, 'learning_rate': 0.00021733149931224208, 'epoch': 4.1}


 52%|█████▏    | 1670/3232 [17:39<15:35,  1.67it/s]

{'loss': 0.4424, 'grad_norm': 6.418513774871826, 'learning_rate': 0.0002159559834938102, 'epoch': 4.13}


 52%|█████▏    | 1680/3232 [17:45<15:27,  1.67it/s]

{'loss': 0.335, 'grad_norm': 2.933443546295166, 'learning_rate': 0.0002145804676753783, 'epoch': 4.15}


 52%|█████▏    | 1690/3232 [17:51<15:22,  1.67it/s]

{'loss': 0.3641, 'grad_norm': 3.9296605587005615, 'learning_rate': 0.00021320495185694639, 'epoch': 4.18}


 53%|█████▎    | 1700/3232 [17:57<15:14,  1.67it/s]

{'loss': 0.3497, 'grad_norm': 3.900831460952759, 'learning_rate': 0.00021182943603851445, 'epoch': 4.2}


 53%|█████▎    | 1710/3232 [18:03<15:08,  1.67it/s]

{'loss': 0.3494, 'grad_norm': 3.7329468727111816, 'learning_rate': 0.00021045392022008255, 'epoch': 4.23}


 53%|█████▎    | 1720/3232 [18:09<15:03,  1.67it/s]

{'loss': 0.3152, 'grad_norm': 4.55244779586792, 'learning_rate': 0.00020907840440165064, 'epoch': 4.25}


 54%|█████▎    | 1730/3232 [18:15<14:57,  1.67it/s]

{'loss': 0.3585, 'grad_norm': 3.863698720932007, 'learning_rate': 0.0002077028885832187, 'epoch': 4.28}


 54%|█████▍    | 1740/3232 [18:21<14:51,  1.67it/s]

{'loss': 0.3456, 'grad_norm': 5.10590124130249, 'learning_rate': 0.0002063273727647868, 'epoch': 4.3}


 54%|█████▍    | 1750/3232 [18:27<14:45,  1.67it/s]

{'loss': 0.3581, 'grad_norm': 4.267453193664551, 'learning_rate': 0.00020495185694635487, 'epoch': 4.33}


 54%|█████▍    | 1760/3232 [18:33<14:32,  1.69it/s]

{'loss': 0.3868, 'grad_norm': 5.717972278594971, 'learning_rate': 0.00020357634112792296, 'epoch': 4.35}


 55%|█████▍    | 1770/3232 [18:39<14:20,  1.70it/s]

{'loss': 0.3734, 'grad_norm': 5.171428203582764, 'learning_rate': 0.00020220082530949106, 'epoch': 4.38}


 55%|█████▌    | 1780/3232 [18:45<14:13,  1.70it/s]

{'loss': 0.4511, 'grad_norm': 5.725854873657227, 'learning_rate': 0.00020082530949105918, 'epoch': 4.4}


 55%|█████▌    | 1790/3232 [18:50<14:10,  1.70it/s]

{'loss': 0.3732, 'grad_norm': 2.6687159538269043, 'learning_rate': 0.00019944979367262725, 'epoch': 4.43}


 56%|█████▌    | 1800/3232 [18:56<14:19,  1.67it/s]

{'loss': 0.348, 'grad_norm': 5.693115234375, 'learning_rate': 0.00019807427785419531, 'epoch': 4.45}


 56%|█████▌    | 1810/3232 [19:02<14:20,  1.65it/s]

{'loss': 0.3457, 'grad_norm': 6.134521961212158, 'learning_rate': 0.00019669876203576344, 'epoch': 4.47}


 56%|█████▋    | 1820/3232 [19:08<14:00,  1.68it/s]

{'loss': 0.3787, 'grad_norm': 2.948979377746582, 'learning_rate': 0.0001953232462173315, 'epoch': 4.5}


 57%|█████▋    | 1830/3232 [19:14<13:52,  1.68it/s]

{'loss': 0.3903, 'grad_norm': 4.403275489807129, 'learning_rate': 0.0001939477303988996, 'epoch': 4.52}


 57%|█████▋    | 1840/3232 [19:20<13:47,  1.68it/s]

{'loss': 0.3992, 'grad_norm': 7.705042362213135, 'learning_rate': 0.0001925722145804677, 'epoch': 4.55}


 57%|█████▋    | 1850/3232 [19:26<13:41,  1.68it/s]

{'loss': 0.3534, 'grad_norm': 3.432574510574341, 'learning_rate': 0.00019119669876203576, 'epoch': 4.57}


 58%|█████▊    | 1860/3232 [19:32<13:38,  1.68it/s]

{'loss': 0.3769, 'grad_norm': 4.17689847946167, 'learning_rate': 0.00018982118294360388, 'epoch': 4.6}


 58%|█████▊    | 1870/3232 [19:38<13:32,  1.68it/s]

{'loss': 0.4061, 'grad_norm': 6.927032947540283, 'learning_rate': 0.00018844566712517195, 'epoch': 4.62}


 58%|█████▊    | 1880/3232 [19:44<13:25,  1.68it/s]

{'loss': 0.3264, 'grad_norm': 3.0488123893737793, 'learning_rate': 0.00018707015130674004, 'epoch': 4.65}


 58%|█████▊    | 1890/3232 [19:50<13:18,  1.68it/s]

{'loss': 0.3042, 'grad_norm': 3.4706907272338867, 'learning_rate': 0.0001856946354883081, 'epoch': 4.67}


 59%|█████▉    | 1900/3232 [19:56<13:13,  1.68it/s]

{'loss': 0.3221, 'grad_norm': 4.746427059173584, 'learning_rate': 0.0001843191196698762, 'epoch': 4.7}


 59%|█████▉    | 1910/3232 [20:02<13:06,  1.68it/s]

{'loss': 0.3079, 'grad_norm': 3.597108840942383, 'learning_rate': 0.0001829436038514443, 'epoch': 4.72}


 59%|█████▉    | 1920/3232 [20:08<13:00,  1.68it/s]

{'loss': 0.3388, 'grad_norm': 5.811931610107422, 'learning_rate': 0.0001815680880330124, 'epoch': 4.75}


 60%|█████▉    | 1930/3232 [20:14<12:52,  1.68it/s]

{'loss': 0.3497, 'grad_norm': 5.233875751495361, 'learning_rate': 0.00018019257221458049, 'epoch': 4.77}


 60%|██████    | 1940/3232 [20:20<12:47,  1.68it/s]

{'loss': 0.3717, 'grad_norm': 4.033114433288574, 'learning_rate': 0.00017881705639614855, 'epoch': 4.8}


 60%|██████    | 1950/3232 [20:26<12:41,  1.68it/s]

{'loss': 0.3024, 'grad_norm': 3.159924268722534, 'learning_rate': 0.00017744154057771665, 'epoch': 4.82}


 61%|██████    | 1960/3232 [20:32<12:35,  1.68it/s]

{'loss': 0.3154, 'grad_norm': 3.4622793197631836, 'learning_rate': 0.00017606602475928474, 'epoch': 4.85}


 61%|██████    | 1970/3232 [20:38<12:26,  1.69it/s]

{'loss': 0.3382, 'grad_norm': 7.271684646606445, 'learning_rate': 0.00017469050894085284, 'epoch': 4.87}


 61%|██████▏   | 1980/3232 [20:44<12:23,  1.68it/s]

{'loss': 0.3135, 'grad_norm': 4.026495933532715, 'learning_rate': 0.00017331499312242093, 'epoch': 4.89}


 62%|██████▏   | 1990/3232 [20:50<12:15,  1.69it/s]

{'loss': 0.2828, 'grad_norm': 2.9327008724212646, 'learning_rate': 0.000171939477303989, 'epoch': 4.92}


 62%|██████▏   | 2000/3232 [20:56<12:12,  1.68it/s]

{'loss': 0.3582, 'grad_norm': 4.033880710601807, 'learning_rate': 0.0001705639614855571, 'epoch': 4.94}


 62%|██████▏   | 2010/3232 [21:02<12:06,  1.68it/s]

{'loss': 0.3185, 'grad_norm': 3.6294476985931396, 'learning_rate': 0.00016918844566712516, 'epoch': 4.97}


 62%|██████▎   | 2020/3232 [21:08<12:00,  1.68it/s]

{'loss': 0.3161, 'grad_norm': 5.413112640380859, 'learning_rate': 0.00016781292984869328, 'epoch': 4.99}


                                                   
 63%|██████▎   | 2022/3232 [21:26<11:58,  1.68it/s]

{'eval_loss': 0.1182054802775383, 'eval_accuracy': 0.9684820393974507, 'eval_runtime': 17.2734, 'eval_samples_per_second': 749.419, 'eval_steps_per_second': 23.446, 'epoch': 5.0}


 63%|██████▎   | 2030/3232 [21:32<20:44,  1.04s/it]  

{'loss': 0.2925, 'grad_norm': 3.8681087493896484, 'learning_rate': 0.00016643741403026138, 'epoch': 5.02}


 63%|██████▎   | 2040/3232 [21:38<12:05,  1.64it/s]

{'loss': 0.2826, 'grad_norm': 2.823944330215454, 'learning_rate': 0.00016506189821182944, 'epoch': 5.04}


 63%|██████▎   | 2050/3232 [21:44<11:48,  1.67it/s]

{'loss': 0.3024, 'grad_norm': 5.565736293792725, 'learning_rate': 0.00016368638239339754, 'epoch': 5.07}


 64%|██████▎   | 2060/3232 [21:50<11:37,  1.68it/s]

{'loss': 0.3407, 'grad_norm': 3.541538953781128, 'learning_rate': 0.0001623108665749656, 'epoch': 5.09}


 64%|██████▍   | 2070/3232 [21:56<11:32,  1.68it/s]

{'loss': 0.2842, 'grad_norm': 2.7172749042510986, 'learning_rate': 0.0001609353507565337, 'epoch': 5.12}


 64%|██████▍   | 2080/3232 [22:02<11:24,  1.68it/s]

{'loss': 0.3475, 'grad_norm': 5.542634963989258, 'learning_rate': 0.0001595598349381018, 'epoch': 5.14}


 65%|██████▍   | 2090/3232 [22:07<11:18,  1.68it/s]

{'loss': 0.3453, 'grad_norm': 4.0136189460754395, 'learning_rate': 0.0001581843191196699, 'epoch': 5.17}


 65%|██████▍   | 2100/3232 [22:13<11:14,  1.68it/s]

{'loss': 0.2657, 'grad_norm': 6.445197582244873, 'learning_rate': 0.00015680880330123798, 'epoch': 5.19}


 65%|██████▌   | 2110/3232 [22:19<11:07,  1.68it/s]

{'loss': 0.2929, 'grad_norm': 4.4411773681640625, 'learning_rate': 0.00015543328748280605, 'epoch': 5.22}


 66%|██████▌   | 2120/3232 [22:25<11:00,  1.68it/s]

{'loss': 0.3242, 'grad_norm': 4.4857497215271, 'learning_rate': 0.00015405777166437414, 'epoch': 5.24}


 66%|██████▌   | 2130/3232 [22:31<10:55,  1.68it/s]

{'loss': 0.3133, 'grad_norm': 5.814061641693115, 'learning_rate': 0.00015268225584594224, 'epoch': 5.27}


 66%|██████▌   | 2140/3232 [22:37<10:48,  1.68it/s]

{'loss': 0.2637, 'grad_norm': 2.777866840362549, 'learning_rate': 0.00015130674002751033, 'epoch': 5.29}


 67%|██████▋   | 2150/3232 [22:43<10:41,  1.69it/s]

{'loss': 0.3237, 'grad_norm': 3.8688182830810547, 'learning_rate': 0.00014993122420907843, 'epoch': 5.32}


 67%|██████▋   | 2160/3232 [22:49<10:35,  1.69it/s]

{'loss': 0.3023, 'grad_norm': 4.033392429351807, 'learning_rate': 0.0001485557083906465, 'epoch': 5.34}


 67%|██████▋   | 2170/3232 [22:55<10:30,  1.68it/s]

{'loss': 0.3045, 'grad_norm': 2.552396535873413, 'learning_rate': 0.0001471801925722146, 'epoch': 5.36}


 67%|██████▋   | 2180/3232 [23:01<10:24,  1.69it/s]

{'loss': 0.2727, 'grad_norm': 5.039457321166992, 'learning_rate': 0.00014580467675378265, 'epoch': 5.39}


 68%|██████▊   | 2190/3232 [23:07<10:16,  1.69it/s]

{'loss': 0.3165, 'grad_norm': 4.369848251342773, 'learning_rate': 0.00014442916093535078, 'epoch': 5.41}


 68%|██████▊   | 2200/3232 [23:13<10:11,  1.69it/s]

{'loss': 0.3224, 'grad_norm': 3.644879102706909, 'learning_rate': 0.00014305364511691884, 'epoch': 5.44}


 68%|██████▊   | 2210/3232 [23:19<10:06,  1.69it/s]

{'loss': 0.2785, 'grad_norm': 3.3365724086761475, 'learning_rate': 0.00014167812929848694, 'epoch': 5.46}


 69%|██████▊   | 2220/3232 [23:25<09:59,  1.69it/s]

{'loss': 0.2794, 'grad_norm': 2.7659287452697754, 'learning_rate': 0.00014030261348005503, 'epoch': 5.49}


 69%|██████▉   | 2230/3232 [23:31<09:56,  1.68it/s]

{'loss': 0.2616, 'grad_norm': 4.156118869781494, 'learning_rate': 0.0001389270976616231, 'epoch': 5.51}


 69%|██████▉   | 2240/3232 [23:37<09:49,  1.68it/s]

{'loss': 0.28, 'grad_norm': 4.642432689666748, 'learning_rate': 0.00013755158184319122, 'epoch': 5.54}


 70%|██████▉   | 2250/3232 [23:43<09:42,  1.69it/s]

{'loss': 0.272, 'grad_norm': 3.3724241256713867, 'learning_rate': 0.0001361760660247593, 'epoch': 5.56}


 70%|██████▉   | 2260/3232 [23:49<09:36,  1.68it/s]

{'loss': 0.2787, 'grad_norm': 4.575003623962402, 'learning_rate': 0.00013480055020632738, 'epoch': 5.59}


 70%|███████   | 2270/3232 [23:55<09:32,  1.68it/s]

{'loss': 0.2698, 'grad_norm': 3.460165023803711, 'learning_rate': 0.00013342503438789548, 'epoch': 5.61}


 71%|███████   | 2280/3232 [24:01<09:25,  1.68it/s]

{'loss': 0.3069, 'grad_norm': 4.55268669128418, 'learning_rate': 0.00013204951856946354, 'epoch': 5.64}


 71%|███████   | 2290/3232 [24:06<09:19,  1.68it/s]

{'loss': 0.2389, 'grad_norm': 3.582916498184204, 'learning_rate': 0.00013067400275103164, 'epoch': 5.66}


 71%|███████   | 2300/3232 [24:12<09:14,  1.68it/s]

{'loss': 0.2987, 'grad_norm': 3.2755355834960938, 'learning_rate': 0.00012929848693259973, 'epoch': 5.69}


 71%|███████▏  | 2310/3232 [24:18<09:07,  1.68it/s]

{'loss': 0.274, 'grad_norm': 2.8251352310180664, 'learning_rate': 0.00012792297111416783, 'epoch': 5.71}


 72%|███████▏  | 2320/3232 [24:24<09:01,  1.68it/s]

{'loss': 0.2535, 'grad_norm': 3.8879854679107666, 'learning_rate': 0.0001265474552957359, 'epoch': 5.74}


 72%|███████▏  | 2330/3232 [24:30<08:55,  1.68it/s]

{'loss': 0.2794, 'grad_norm': 3.760988235473633, 'learning_rate': 0.000125171939477304, 'epoch': 5.76}


 72%|███████▏  | 2340/3232 [24:36<08:49,  1.68it/s]

{'loss': 0.2927, 'grad_norm': 4.290704250335693, 'learning_rate': 0.00012379642365887208, 'epoch': 5.78}


 73%|███████▎  | 2350/3232 [24:42<08:43,  1.68it/s]

{'loss': 0.2838, 'grad_norm': 4.360713481903076, 'learning_rate': 0.00012242090784044018, 'epoch': 5.81}


 73%|███████▎  | 2360/3232 [24:48<08:37,  1.69it/s]

{'loss': 0.3289, 'grad_norm': 3.3807387351989746, 'learning_rate': 0.00012104539202200827, 'epoch': 5.83}


 73%|███████▎  | 2370/3232 [24:54<08:30,  1.69it/s]

{'loss': 0.3025, 'grad_norm': 4.347283363342285, 'learning_rate': 0.00011966987620357635, 'epoch': 5.86}


 74%|███████▎  | 2380/3232 [25:00<08:27,  1.68it/s]

{'loss': 0.2657, 'grad_norm': 4.637937068939209, 'learning_rate': 0.00011829436038514443, 'epoch': 5.88}


 74%|███████▍  | 2390/3232 [25:06<08:19,  1.68it/s]

{'loss': 0.231, 'grad_norm': 2.7580292224884033, 'learning_rate': 0.00011691884456671251, 'epoch': 5.91}


 74%|███████▍  | 2400/3232 [25:12<08:14,  1.68it/s]

{'loss': 0.2955, 'grad_norm': 4.8379807472229, 'learning_rate': 0.00011554332874828062, 'epoch': 5.93}


 75%|███████▍  | 2410/3232 [25:18<08:08,  1.68it/s]

{'loss': 0.2446, 'grad_norm': 5.979431629180908, 'learning_rate': 0.0001141678129298487, 'epoch': 5.96}


 75%|███████▍  | 2420/3232 [25:24<08:02,  1.68it/s]

{'loss': 0.2903, 'grad_norm': 3.0931715965270996, 'learning_rate': 0.0001127922971114168, 'epoch': 5.98}


                                                   
 75%|███████▌  | 2427/3232 [25:46<07:57,  1.68it/s]

{'eval_loss': 0.09784573316574097, 'eval_accuracy': 0.9731942835071457, 'eval_runtime': 17.3614, 'eval_samples_per_second': 745.619, 'eval_steps_per_second': 23.328, 'epoch': 6.0}


 75%|███████▌  | 2430/3232 [25:48<43:58,  3.29s/it]  

{'loss': 0.2426, 'grad_norm': 3.4722445011138916, 'learning_rate': 0.00011141678129298488, 'epoch': 6.01}


 75%|███████▌  | 2440/3232 [25:54<08:45,  1.51it/s]

{'loss': 0.2776, 'grad_norm': 4.121274948120117, 'learning_rate': 0.00011004126547455296, 'epoch': 6.03}


 76%|███████▌  | 2450/3232 [26:00<07:44,  1.68it/s]

{'loss': 0.2235, 'grad_norm': 5.908088684082031, 'learning_rate': 0.00010866574965612104, 'epoch': 6.06}


 76%|███████▌  | 2460/3232 [26:06<07:43,  1.67it/s]

{'loss': 0.2208, 'grad_norm': 2.6448311805725098, 'learning_rate': 0.00010729023383768915, 'epoch': 6.08}


 76%|███████▋  | 2470/3232 [26:12<07:37,  1.66it/s]

{'loss': 0.22, 'grad_norm': 5.1669392585754395, 'learning_rate': 0.00010591471801925723, 'epoch': 6.11}


 77%|███████▋  | 2480/3232 [26:18<07:25,  1.69it/s]

{'loss': 0.2227, 'grad_norm': 4.853193283081055, 'learning_rate': 0.00010453920220082532, 'epoch': 6.13}


 77%|███████▋  | 2490/3232 [26:24<07:19,  1.69it/s]

{'loss': 0.2212, 'grad_norm': 3.4840879440307617, 'learning_rate': 0.0001031636863823934, 'epoch': 6.16}


 77%|███████▋  | 2500/3232 [26:30<07:13,  1.69it/s]

{'loss': 0.2519, 'grad_norm': 4.210199356079102, 'learning_rate': 0.00010178817056396148, 'epoch': 6.18}


 78%|███████▊  | 2510/3232 [26:36<07:06,  1.69it/s]

{'loss': 0.2427, 'grad_norm': 2.1172494888305664, 'learning_rate': 0.00010041265474552959, 'epoch': 6.21}


 78%|███████▊  | 2520/3232 [26:42<07:01,  1.69it/s]

{'loss': 0.2369, 'grad_norm': 10.798616409301758, 'learning_rate': 9.903713892709766e-05, 'epoch': 6.23}


 78%|███████▊  | 2530/3232 [26:47<06:56,  1.69it/s]

{'loss': 0.2255, 'grad_norm': 2.7213258743286133, 'learning_rate': 9.766162310866575e-05, 'epoch': 6.25}


 79%|███████▊  | 2540/3232 [26:53<06:48,  1.69it/s]

{'loss': 0.2313, 'grad_norm': 3.482661008834839, 'learning_rate': 9.628610729023385e-05, 'epoch': 6.28}


 79%|███████▉  | 2550/3232 [26:59<06:43,  1.69it/s]

{'loss': 0.2422, 'grad_norm': 3.7436490058898926, 'learning_rate': 9.491059147180194e-05, 'epoch': 6.3}


 79%|███████▉  | 2560/3232 [27:05<06:37,  1.69it/s]

{'loss': 0.2174, 'grad_norm': 5.810586452484131, 'learning_rate': 9.353507565337002e-05, 'epoch': 6.33}


 80%|███████▉  | 2570/3232 [27:11<06:30,  1.69it/s]

{'loss': 0.208, 'grad_norm': 2.752751350402832, 'learning_rate': 9.21595598349381e-05, 'epoch': 6.35}


 80%|███████▉  | 2580/3232 [27:17<06:34,  1.65it/s]

{'loss': 0.2306, 'grad_norm': 2.549732208251953, 'learning_rate': 9.07840440165062e-05, 'epoch': 6.38}


 80%|████████  | 2590/3232 [27:23<06:26,  1.66it/s]

{'loss': 0.2742, 'grad_norm': 4.091672897338867, 'learning_rate': 8.940852819807428e-05, 'epoch': 6.4}


 80%|████████  | 2600/3232 [27:29<06:20,  1.66it/s]

{'loss': 0.2233, 'grad_norm': 2.628967523574829, 'learning_rate': 8.803301237964237e-05, 'epoch': 6.43}


 81%|████████  | 2610/3232 [27:35<06:14,  1.66it/s]

{'loss': 0.246, 'grad_norm': 3.877865791320801, 'learning_rate': 8.665749656121047e-05, 'epoch': 6.45}


 81%|████████  | 2620/3232 [27:41<06:02,  1.69it/s]

{'loss': 0.2473, 'grad_norm': 4.532379627227783, 'learning_rate': 8.528198074277855e-05, 'epoch': 6.48}


 81%|████████▏ | 2630/3232 [27:47<05:57,  1.68it/s]

{'loss': 0.208, 'grad_norm': 3.508511781692505, 'learning_rate': 8.390646492434664e-05, 'epoch': 6.5}


 82%|████████▏ | 2640/3232 [27:53<05:51,  1.69it/s]

{'loss': 0.2174, 'grad_norm': 5.761270046234131, 'learning_rate': 8.253094910591472e-05, 'epoch': 6.53}


 82%|████████▏ | 2650/3232 [27:59<05:45,  1.68it/s]

{'loss': 0.2284, 'grad_norm': 3.0120561122894287, 'learning_rate': 8.11554332874828e-05, 'epoch': 6.55}


 82%|████████▏ | 2660/3232 [28:05<05:38,  1.69it/s]

{'loss': 0.2303, 'grad_norm': 4.492600917816162, 'learning_rate': 7.97799174690509e-05, 'epoch': 6.58}


 83%|████████▎ | 2670/3232 [28:11<05:32,  1.69it/s]

{'loss': 0.1947, 'grad_norm': 3.6690263748168945, 'learning_rate': 7.840440165061899e-05, 'epoch': 6.6}


 83%|████████▎ | 2680/3232 [28:17<05:26,  1.69it/s]

{'loss': 0.2266, 'grad_norm': 3.327747106552124, 'learning_rate': 7.702888583218707e-05, 'epoch': 6.63}


 83%|████████▎ | 2690/3232 [28:23<05:19,  1.69it/s]

{'loss': 0.1906, 'grad_norm': 3.094400405883789, 'learning_rate': 7.565337001375517e-05, 'epoch': 6.65}


 84%|████████▎ | 2700/3232 [28:29<05:14,  1.69it/s]

{'loss': 0.2276, 'grad_norm': 5.2836503982543945, 'learning_rate': 7.427785419532325e-05, 'epoch': 6.67}


 84%|████████▍ | 2710/3232 [28:35<05:08,  1.69it/s]

{'loss': 0.2296, 'grad_norm': 3.34639310836792, 'learning_rate': 7.290233837689133e-05, 'epoch': 6.7}


 84%|████████▍ | 2720/3232 [28:40<05:02,  1.69it/s]

{'loss': 0.2072, 'grad_norm': 5.160595893859863, 'learning_rate': 7.152682255845942e-05, 'epoch': 6.72}


 84%|████████▍ | 2730/3232 [28:46<04:56,  1.69it/s]

{'loss': 0.2334, 'grad_norm': 1.8365404605865479, 'learning_rate': 7.015130674002752e-05, 'epoch': 6.75}


 85%|████████▍ | 2740/3232 [28:52<04:49,  1.70it/s]

{'loss': 0.2346, 'grad_norm': 2.1697983741760254, 'learning_rate': 6.877579092159561e-05, 'epoch': 6.77}


 85%|████████▌ | 2750/3232 [28:58<04:45,  1.69it/s]

{'loss': 0.1931, 'grad_norm': 3.6759355068206787, 'learning_rate': 6.740027510316369e-05, 'epoch': 6.8}


 85%|████████▌ | 2760/3232 [29:04<04:38,  1.70it/s]

{'loss': 0.2755, 'grad_norm': 2.9860363006591797, 'learning_rate': 6.602475928473177e-05, 'epoch': 6.82}


 86%|████████▌ | 2770/3232 [29:10<04:33,  1.69it/s]

{'loss': 0.2461, 'grad_norm': 2.547865629196167, 'learning_rate': 6.464924346629987e-05, 'epoch': 6.85}


 86%|████████▌ | 2780/3232 [29:16<04:27,  1.69it/s]

{'loss': 0.2061, 'grad_norm': 3.087374448776245, 'learning_rate': 6.327372764786795e-05, 'epoch': 6.87}


 86%|████████▋ | 2790/3232 [29:22<04:21,  1.69it/s]

{'loss': 0.2152, 'grad_norm': 2.9787306785583496, 'learning_rate': 6.189821182943604e-05, 'epoch': 6.9}


 87%|████████▋ | 2800/3232 [29:28<04:16,  1.68it/s]

{'loss': 0.2375, 'grad_norm': 3.3239495754241943, 'learning_rate': 6.0522696011004135e-05, 'epoch': 6.92}


 87%|████████▋ | 2810/3232 [29:34<04:10,  1.69it/s]

{'loss': 0.1984, 'grad_norm': 2.785007953643799, 'learning_rate': 5.9147180192572216e-05, 'epoch': 6.95}


 87%|████████▋ | 2820/3232 [29:40<04:03,  1.69it/s]

{'loss': 0.2347, 'grad_norm': 5.420063495635986, 'learning_rate': 5.777166437414031e-05, 'epoch': 6.97}


 88%|████████▊ | 2830/3232 [29:46<04:03,  1.65it/s]

{'loss': 0.2332, 'grad_norm': 3.4585187435150146, 'learning_rate': 5.63961485557084e-05, 'epoch': 7.0}


                                                   
 88%|████████▊ | 2831/3232 [30:04<04:03,  1.65it/s]

{'eval_loss': 0.08941604942083359, 'eval_accuracy': 0.9754345307068366, 'eval_runtime': 17.1183, 'eval_samples_per_second': 756.209, 'eval_steps_per_second': 23.659, 'epoch': 7.0}


 88%|████████▊ | 2840/3232 [30:10<05:54,  1.10it/s]

{'loss': 0.212, 'grad_norm': 2.464061975479126, 'learning_rate': 5.502063273727648e-05, 'epoch': 7.02}


 88%|████████▊ | 2850/3232 [30:16<03:49,  1.66it/s]

{'loss': 0.2118, 'grad_norm': 2.094383955001831, 'learning_rate': 5.364511691884457e-05, 'epoch': 7.05}


 88%|████████▊ | 2860/3232 [30:22<03:40,  1.69it/s]

{'loss': 0.1932, 'grad_norm': 2.1360201835632324, 'learning_rate': 5.226960110041266e-05, 'epoch': 7.07}


 89%|████████▉ | 2870/3232 [30:28<03:34,  1.68it/s]

{'loss': 0.2168, 'grad_norm': 4.077985763549805, 'learning_rate': 5.089408528198074e-05, 'epoch': 7.1}


 89%|████████▉ | 2880/3232 [30:34<03:28,  1.69it/s]

{'loss': 0.2228, 'grad_norm': 3.867352247238159, 'learning_rate': 4.951856946354883e-05, 'epoch': 7.12}


 89%|████████▉ | 2890/3232 [30:39<03:22,  1.69it/s]

{'loss': 0.2277, 'grad_norm': 2.490651845932007, 'learning_rate': 4.814305364511692e-05, 'epoch': 7.14}


 90%|████████▉ | 2900/3232 [30:45<03:16,  1.69it/s]

{'loss': 0.1747, 'grad_norm': 2.6401093006134033, 'learning_rate': 4.676753782668501e-05, 'epoch': 7.17}


 90%|█████████ | 2910/3232 [30:51<03:10,  1.69it/s]

{'loss': 0.2053, 'grad_norm': 3.8791134357452393, 'learning_rate': 4.53920220082531e-05, 'epoch': 7.19}


 90%|█████████ | 2920/3232 [30:57<03:04,  1.69it/s]

{'loss': 0.1762, 'grad_norm': 3.06819748878479, 'learning_rate': 4.4016506189821186e-05, 'epoch': 7.22}


 91%|█████████ | 2930/3232 [31:03<02:58,  1.69it/s]

{'loss': 0.1951, 'grad_norm': 3.5318076610565186, 'learning_rate': 4.264099037138927e-05, 'epoch': 7.24}


 91%|█████████ | 2940/3232 [31:09<02:52,  1.69it/s]

{'loss': 0.1884, 'grad_norm': 2.981812000274658, 'learning_rate': 4.126547455295736e-05, 'epoch': 7.27}


 91%|█████████▏| 2950/3232 [31:15<02:47,  1.69it/s]

{'loss': 0.1873, 'grad_norm': 4.401243209838867, 'learning_rate': 3.988995873452545e-05, 'epoch': 7.29}


 92%|█████████▏| 2960/3232 [31:21<02:40,  1.69it/s]

{'loss': 0.1706, 'grad_norm': 1.5826785564422607, 'learning_rate': 3.8514442916093536e-05, 'epoch': 7.32}


 92%|█████████▏| 2970/3232 [31:27<02:35,  1.69it/s]

{'loss': 0.2091, 'grad_norm': 8.006757736206055, 'learning_rate': 3.713892709766162e-05, 'epoch': 7.34}


 92%|█████████▏| 2980/3232 [31:33<02:30,  1.68it/s]

{'loss': 0.2156, 'grad_norm': 3.2715542316436768, 'learning_rate': 3.576341127922971e-05, 'epoch': 7.37}


 93%|█████████▎| 2990/3232 [31:39<02:24,  1.68it/s]

{'loss': 0.2007, 'grad_norm': 3.6225833892822266, 'learning_rate': 3.4387895460797805e-05, 'epoch': 7.39}


 93%|█████████▎| 3000/3232 [31:45<02:17,  1.69it/s]

{'loss': 0.2278, 'grad_norm': 2.9611923694610596, 'learning_rate': 3.3012379642365886e-05, 'epoch': 7.42}


 93%|█████████▎| 3010/3232 [31:51<02:12,  1.68it/s]

{'loss': 0.2034, 'grad_norm': 5.066888332366943, 'learning_rate': 3.163686382393397e-05, 'epoch': 7.44}


 93%|█████████▎| 3020/3232 [31:57<02:06,  1.68it/s]

{'loss': 0.2474, 'grad_norm': 3.5787014961242676, 'learning_rate': 3.0261348005502068e-05, 'epoch': 7.47}


 94%|█████████▍| 3030/3232 [32:03<01:59,  1.69it/s]

{'loss': 0.1799, 'grad_norm': 2.551879644393921, 'learning_rate': 2.8885832187070155e-05, 'epoch': 7.49}


 94%|█████████▍| 3040/3232 [32:09<01:55,  1.66it/s]

{'loss': 0.1902, 'grad_norm': 2.4168405532836914, 'learning_rate': 2.751031636863824e-05, 'epoch': 7.52}


 94%|█████████▍| 3050/3232 [32:15<01:49,  1.66it/s]

{'loss': 0.2092, 'grad_norm': 5.827054023742676, 'learning_rate': 2.613480055020633e-05, 'epoch': 7.54}


 95%|█████████▍| 3060/3232 [32:21<01:42,  1.68it/s]

{'loss': 0.2195, 'grad_norm': 1.6788429021835327, 'learning_rate': 2.4759284731774414e-05, 'epoch': 7.56}


 95%|█████████▍| 3070/3232 [32:27<01:36,  1.68it/s]

{'loss': 0.1949, 'grad_norm': 3.7941503524780273, 'learning_rate': 2.3383768913342505e-05, 'epoch': 7.59}


 95%|█████████▌| 3080/3232 [32:33<01:30,  1.69it/s]

{'loss': 0.1835, 'grad_norm': 1.877219557762146, 'learning_rate': 2.2008253094910593e-05, 'epoch': 7.61}


 96%|█████████▌| 3090/3232 [32:39<01:25,  1.67it/s]

{'loss': 0.1989, 'grad_norm': 3.492435932159424, 'learning_rate': 2.063273727647868e-05, 'epoch': 7.64}


 96%|█████████▌| 3100/3232 [32:45<01:18,  1.67it/s]

{'loss': 0.1691, 'grad_norm': 2.614943504333496, 'learning_rate': 1.9257221458046768e-05, 'epoch': 7.66}


 96%|█████████▌| 3110/3232 [32:50<01:11,  1.70it/s]

{'loss': 0.2227, 'grad_norm': 4.166014671325684, 'learning_rate': 1.7881705639614855e-05, 'epoch': 7.69}


 97%|█████████▋| 3120/3232 [32:56<01:05,  1.70it/s]

{'loss': 0.2017, 'grad_norm': 2.6938695907592773, 'learning_rate': 1.6506189821182943e-05, 'epoch': 7.71}


 97%|█████████▋| 3130/3232 [33:02<01:00,  1.70it/s]

{'loss': 0.2168, 'grad_norm': 2.635915517807007, 'learning_rate': 1.5130674002751034e-05, 'epoch': 7.74}


 97%|█████████▋| 3140/3232 [33:08<00:54,  1.70it/s]

{'loss': 0.1659, 'grad_norm': 4.148288726806641, 'learning_rate': 1.375515818431912e-05, 'epoch': 7.76}


 97%|█████████▋| 3150/3232 [33:14<00:48,  1.69it/s]

{'loss': 0.1808, 'grad_norm': 2.8882551193237305, 'learning_rate': 1.2379642365887207e-05, 'epoch': 7.79}


 98%|█████████▊| 3160/3232 [33:20<00:42,  1.70it/s]

{'loss': 0.1769, 'grad_norm': 3.2612431049346924, 'learning_rate': 1.1004126547455296e-05, 'epoch': 7.81}


 98%|█████████▊| 3170/3232 [33:26<00:36,  1.70it/s]

{'loss': 0.1855, 'grad_norm': 4.676534652709961, 'learning_rate': 9.628610729023384e-06, 'epoch': 7.84}


 98%|█████████▊| 3180/3232 [33:32<00:30,  1.70it/s]

{'loss': 0.1745, 'grad_norm': 1.97274649143219, 'learning_rate': 8.253094910591471e-06, 'epoch': 7.86}


 99%|█████████▊| 3190/3232 [33:38<00:24,  1.70it/s]

{'loss': 0.1733, 'grad_norm': 2.3577656745910645, 'learning_rate': 6.87757909215956e-06, 'epoch': 7.89}


 99%|█████████▉| 3200/3232 [33:44<00:18,  1.69it/s]

{'loss': 0.1715, 'grad_norm': 2.458008289337158, 'learning_rate': 5.502063273727648e-06, 'epoch': 7.91}


 99%|█████████▉| 3210/3232 [33:50<00:13,  1.67it/s]

{'loss': 0.1796, 'grad_norm': 4.103083610534668, 'learning_rate': 4.126547455295736e-06, 'epoch': 7.94}


100%|█████████▉| 3220/3232 [33:56<00:07,  1.67it/s]

{'loss': 0.1777, 'grad_norm': 2.305513620376587, 'learning_rate': 2.751031636863824e-06, 'epoch': 7.96}


100%|█████████▉| 3230/3232 [34:02<00:01,  1.67it/s]

{'loss': 0.1961, 'grad_norm': 1.7054755687713623, 'learning_rate': 1.375515818431912e-06, 'epoch': 7.99}


                                                   
100%|██████████| 3232/3232 [34:20<00:00,  1.67it/s]

{'eval_loss': 0.08352439105510712, 'eval_accuracy': 0.977211278485902, 'eval_runtime': 16.9383, 'eval_samples_per_second': 764.242, 'eval_steps_per_second': 23.91, 'epoch': 7.99}


100%|██████████| 3232/3232 [34:21<00:00,  1.57it/s]

{'train_runtime': 2061.452, 'train_samples_per_second': 200.93, 'train_steps_per_second': 1.568, 'train_loss': 0.5422990972379057, 'epoch': 7.99}





TrainOutput(global_step=3232, training_loss=0.5422990972379057, metrics={'train_runtime': 2061.452, 'train_samples_per_second': 200.93, 'train_steps_per_second': 1.568, 'total_flos': 3.756079796453376e+18, 'train_loss': 0.5422990972379057, 'epoch': 7.990111248454882})

In [3]:
# audio_file='../data/test/test/audio/clip_000a96d0a.wav'
import csv

classifier = pipeline("audio-classification", model="GFG_commands_audio_classification_finetuned/checkpoint-3232")

def predict(audio_file, classifier):
    prediction = classifier(audio_file)
    if len(prediction) == 0:
        print(f'no prediction for {audio_file}')
        return 'unknown'
    score = prediction[0]['score']
    if score > 0.5:
        return prediction[0]['label']
    elif score <= 0.5 and score > 0.25:
        return 'unknown'
    else:
        return 'silence'


submission_df = pd.DataFrame({'fname': [], 'label': []})
classifier = pipeline("audio-classification", model="GFG_commands_audio_classification_finetuned/checkpoint-3232")
submission_csv_file = 'GFG_commands_audio_classification_finetuned/' + 'checkpoint-3232' + '/' + str(0) + '_submission.csv'

for ind, test_file in enumerate(X_files[:4000]):
    if ind%2000 == 0:
        print("{} done!".format(ind))
    label_pred = predict(test_dir + '/audio/' + test_file, classifier)
    with open(submission_csv_file, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([X_files[ind], label_pred])

print('Submission saved.')

0 done!
2000 done!
Submission saved.


In [53]:
classifier = pipeline("audio-classification", model="GFG_commands_audio_classification_finetuned/checkpoint-3232")

y_true = []
y_pred = []

for ind, file in enumerate(dataset['test']):
    if ind%2000 == 0:
        print("{} done!".format(ind))
    y_true.append(file['label'])
    predicted_label = predict(file['audio']['path'], classifier)
    if predicted_label in label2id:
        y_pred.append(label2id[predicted_label])
    elif predicted_label == 'unknown':
        y_pred.append(len(label2id))
    else:
        y_pred.append(len(label2id) + 1)

y_pred = [int(x) for x in y_pred]
accuracy = accuracy_score(y_true, y_pred)

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(0, 32), yticklabels=range(0, 32))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Transformer Confusion Matrix. Accuracy: {accuracy * 100:.2f}%')
path = f'GFG_commands_audio_classification_finetuned/checkpoint-3232/0_confusion_matrix.png'
plt.savefig(path)
print(f'Confusion matrix is saved to: {path}')
plt.close()

0 done!
2000 done!
4000 done!
6000 done!
8000 done!
10000 done!
12000 done!
Accuracy: 97.95%
Confusion matrix is saved to: GFG_commands_audio_classification_finetuned/checkpoint-3232/0_confusion_matrix.png
