In [6]:
import numpy as np
import torch
from datasets import load_dataset, Value, Sequence
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
    DebertaV2Tokenizer
)
import pandas as pd
from tabulate import tabulate
from sklearn.metrics import f1_score

# ------------------------------
# CUDA Initialization Check
# ------------------------------
if torch.cuda.is_available():
    try:
        torch.cuda.init()  # Force CUDA initialization (optional)
    except Exception as e:
        print("Error during CUDA initialization:", e)
    print("CUDA is available. Running on GPU.")
    try:
        print("GPU Name:", torch.cuda.get_device_name(0))
    except Exception as e:
        print("Could not get GPU Name:", e)
    print("GPU Count:", torch.cuda.device_count())
else:
    print("CUDA is not available. Running on CPU.")

# ------------------------------
# 1. Load the GoEmotions Dataset
# ------------------------------
dataset = load_dataset("go_emotions")
print("\nSample from training set:")
print(dataset["train"][0])  # Show a sample to inspect

# Number of labels
num_labels = dataset["train"].features["labels"].feature.num_classes
print("Number of labels:", num_labels)

# ------------------------------
# 2. Data Preprocessing
# ------------------------------
def convert_labels(label_list):
    """Convert list of label indices to a multi-hot vector of floats."""
    return [1.0 if i in label_list else 0.0 for i in range(num_labels)]

model_checkpoint = "microsoft/deberta-v3-base"
# Use a stronger checkpoint such as microsoft/deberta-v3-base
tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-base")
max_length = 128  # You can experiment with longer sequences

def preprocess_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    tokenized["labels"] = [convert_labels(lbls) for lbls in examples["labels"]]
    return tokenized

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.cast_column("labels", Sequence(Value("float32")))
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# ------------------------------
# 3. Model Setup
# ------------------------------
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
)
model.config.problem_type = "multi_label_classification"

if torch.cuda.is_available():
    model = model.to("cuda")

# ------------------------------
# 4. Define Compute Metrics Function
# ------------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Binarize predictions using threshold 0.5
    predictions = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    # Compute subset accuracy (exact match)
    subset_accuracy = np.mean(np.all(predictions == labels, axis=1))
    # Compute micro and macro F1
    micro_f1 = f1_score(labels, predictions, average="micro", zero_division=0)
    macro_f1 = f1_score(labels, predictions, average="macro", zero_division=0)
    return {
        "subset_accuracy": subset_accuracy,
        "micro_f1": micro_f1,
        "macro_f1": macro_f1
    }

# ------------------------------
# 5. Training Arguments & Trainer Setup
# ------------------------------
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,  # Try 5-10 epochs for better performance
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=1,
    load_best_model_at_end=True,
)

def custom_data_collator(features):
    batch = DefaultDataCollator()(features)
    batch["labels"] = batch["labels"].float()
    return batch

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=custom_data_collator,
    compute_metrics=compute_metrics,
)

print("\nStarting training...")
trainer.train()

# ------------------------------
# 6. Evaluation
# ------------------------------
print("\nEvaluating on validation set...")
val_results = trainer.evaluate()
df_val_results = pd.DataFrame([val_results])
print("Validation Results:")
print(tabulate(df_val_results, headers="keys", tablefmt="fancy_grid"))

print("\nEvaluating on test set...")
test_results = trainer.evaluate(eval_dataset=encoded_dataset["test"])
df_test_results = pd.DataFrame([test_results])
print("Test Results:")
print(tabulate(df_test_results, headers="keys", tablefmt="fancy_grid"))

# ------------------------------
# 7. Inference Example
# ------------------------------
def predict_emotions(text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]
    probs = torch.sigmoid(logits).cpu().numpy()
    return probs

sample_text = "I am feeling really happy and excited about today!"
probs = predict_emotions(sample_text)
print("\nSample inference text:", sample_text)
print("Predicted probabilities for each emotion:")
print(probs)


CUDA is available. Running on GPU.
GPU Name: NVIDIA GeForce RTX 4070 Ti SUPER
GPU Count: 1

Sample from training set:
{'text': "My favourite food is anything I didn't have to cook myself.", 'labels': [27], 'id': 'eebbqej'}
Number of labels: 28


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Starting training...


  1%|          | 101/13570 [00:13<27:54,  8.04it/s]

{'loss': 0.3424, 'learning_rate': 1.9852616064848935e-05, 'epoch': 0.04}


  1%|▏         | 201/13570 [00:26<27:33,  8.09it/s]

{'loss': 0.1575, 'learning_rate': 1.9705232129697864e-05, 'epoch': 0.07}


  2%|▏         | 301/13570 [00:38<29:07,  7.59it/s]

{'loss': 0.1489, 'learning_rate': 1.9557848194546797e-05, 'epoch': 0.11}


  3%|▎         | 401/13570 [00:51<27:12,  8.07it/s]

{'loss': 0.1452, 'learning_rate': 1.9410464259395727e-05, 'epoch': 0.15}


  4%|▎         | 501/13570 [01:03<26:55,  8.09it/s]

{'loss': 0.1385, 'learning_rate': 1.926308032424466e-05, 'epoch': 0.18}


  4%|▍         | 601/13570 [01:16<26:35,  8.13it/s]

{'loss': 0.1353, 'learning_rate': 1.911569638909359e-05, 'epoch': 0.22}


  5%|▌         | 701/13570 [01:28<26:33,  8.08it/s]

{'loss': 0.1295, 'learning_rate': 1.8968312453942523e-05, 'epoch': 0.26}


  6%|▌         | 801/13570 [01:40<26:39,  7.98it/s]

{'loss': 0.1268, 'learning_rate': 1.8820928518791452e-05, 'epoch': 0.29}


  7%|▋         | 901/13570 [01:53<26:12,  8.06it/s]

{'loss': 0.1194, 'learning_rate': 1.8673544583640385e-05, 'epoch': 0.33}


  7%|▋         | 1001/13570 [02:05<27:22,  7.65it/s]

{'loss': 0.1162, 'learning_rate': 1.8526160648489315e-05, 'epoch': 0.37}


  8%|▊         | 1101/13570 [02:18<26:09,  7.94it/s]

{'loss': 0.1158, 'learning_rate': 1.8378776713338248e-05, 'epoch': 0.41}


  9%|▉         | 1201/13570 [02:31<26:42,  7.72it/s]

{'loss': 0.1114, 'learning_rate': 1.8231392778187178e-05, 'epoch': 0.44}


 10%|▉         | 1301/13570 [02:43<25:31,  8.01it/s]

{'loss': 0.1118, 'learning_rate': 1.808400884303611e-05, 'epoch': 0.48}


 10%|█         | 1401/13570 [02:56<25:15,  8.03it/s]

{'loss': 0.1043, 'learning_rate': 1.7936624907885044e-05, 'epoch': 0.52}


 11%|█         | 1501/13570 [03:08<25:01,  8.04it/s]

{'loss': 0.1081, 'learning_rate': 1.7789240972733973e-05, 'epoch': 0.55}


 12%|█▏        | 1601/13570 [03:21<24:56,  8.00it/s]

{'loss': 0.1054, 'learning_rate': 1.7641857037582906e-05, 'epoch': 0.59}


 13%|█▎        | 1701/13570 [03:33<24:31,  8.06it/s]

{'loss': 0.0995, 'learning_rate': 1.7494473102431836e-05, 'epoch': 0.63}


 13%|█▎        | 1801/13570 [03:45<24:15,  8.09it/s]

{'loss': 0.1042, 'learning_rate': 1.734708916728077e-05, 'epoch': 0.66}


 14%|█▍        | 1901/13570 [03:58<24:22,  7.98it/s]

{'loss': 0.1026, 'learning_rate': 1.71997052321297e-05, 'epoch': 0.7}


 15%|█▍        | 2001/13570 [04:10<24:22,  7.91it/s]

{'loss': 0.1, 'learning_rate': 1.7052321296978632e-05, 'epoch': 0.74}


 15%|█▌        | 2101/13570 [04:23<23:58,  7.97it/s]

{'loss': 0.1028, 'learning_rate': 1.690493736182756e-05, 'epoch': 0.77}


 16%|█▌        | 2201/13570 [04:35<23:36,  8.02it/s]

{'loss': 0.0994, 'learning_rate': 1.6757553426676494e-05, 'epoch': 0.81}


 17%|█▋        | 2301/13570 [04:48<23:07,  8.12it/s]

{'loss': 0.0992, 'learning_rate': 1.6610169491525424e-05, 'epoch': 0.85}


 18%|█▊        | 2401/13570 [05:00<23:12,  8.02it/s]

{'loss': 0.0978, 'learning_rate': 1.6462785556374357e-05, 'epoch': 0.88}


 18%|█▊        | 2501/13570 [05:13<23:09,  7.97it/s]

{'loss': 0.0951, 'learning_rate': 1.6315401621223287e-05, 'epoch': 0.92}


 19%|█▉        | 2601/13570 [05:25<22:36,  8.08it/s]

{'loss': 0.0942, 'learning_rate': 1.616801768607222e-05, 'epoch': 0.96}


 20%|█▉        | 2701/13570 [05:38<23:31,  7.70it/s]

{'loss': 0.097, 'learning_rate': 1.602063375092115e-05, 'epoch': 0.99}


                                                    
 20%|██        | 2714/13570 [05:51<22:51,  7.92it/s]

{'eval_loss': 0.09200803190469742, 'eval_subset_accuracy': 0.40600810910431256, 'eval_micro_f1': 0.528171106050104, 'eval_macro_f1': 0.29070297317790333, 'eval_runtime': 11.3419, 'eval_samples_per_second': 478.402, 'eval_steps_per_second': 29.977, 'epoch': 1.0}


 21%|██        | 2801/13570 [06:04<22:30,  7.98it/s]  

{'loss': 0.0945, 'learning_rate': 1.5873249815770082e-05, 'epoch': 1.03}


 21%|██▏       | 2901/13570 [06:16<22:20,  7.96it/s]

{'loss': 0.0908, 'learning_rate': 1.5725865880619012e-05, 'epoch': 1.07}


 22%|██▏       | 3001/13570 [06:29<22:03,  7.98it/s]

{'loss': 0.0935, 'learning_rate': 1.5578481945467945e-05, 'epoch': 1.11}


 23%|██▎       | 3101/13570 [06:41<22:03,  7.91it/s]

{'loss': 0.0882, 'learning_rate': 1.5431098010316875e-05, 'epoch': 1.14}


 24%|██▎       | 3201/13570 [06:54<21:23,  8.08it/s]

{'loss': 0.0894, 'learning_rate': 1.5283714075165808e-05, 'epoch': 1.18}


 24%|██▍       | 3301/13570 [07:06<21:18,  8.03it/s]

{'loss': 0.0853, 'learning_rate': 1.5136330140014739e-05, 'epoch': 1.22}


 25%|██▌       | 3401/13570 [07:19<20:59,  8.07it/s]

{'loss': 0.0878, 'learning_rate': 1.498894620486367e-05, 'epoch': 1.25}


 26%|██▌       | 3501/13570 [07:31<21:10,  7.92it/s]

{'loss': 0.0872, 'learning_rate': 1.4841562269712602e-05, 'epoch': 1.29}


 27%|██▋       | 3601/13570 [07:44<21:38,  7.68it/s]

{'loss': 0.0883, 'learning_rate': 1.4694178334561533e-05, 'epoch': 1.33}


 27%|██▋       | 3701/13570 [07:57<20:48,  7.90it/s]

{'loss': 0.0887, 'learning_rate': 1.4546794399410464e-05, 'epoch': 1.36}


 28%|██▊       | 3801/13570 [08:09<20:07,  8.09it/s]

{'loss': 0.0892, 'learning_rate': 1.4399410464259397e-05, 'epoch': 1.4}


 29%|██▊       | 3901/13570 [08:22<20:11,  7.98it/s]

{'loss': 0.0854, 'learning_rate': 1.4252026529108329e-05, 'epoch': 1.44}


 29%|██▉       | 4001/13570 [08:34<20:03,  7.95it/s]

{'loss': 0.0887, 'learning_rate': 1.410464259395726e-05, 'epoch': 1.47}


 30%|███       | 4101/13570 [08:46<19:31,  8.08it/s]

{'loss': 0.0901, 'learning_rate': 1.3957258658806191e-05, 'epoch': 1.51}


 31%|███       | 4201/13570 [08:59<19:40,  7.93it/s]

{'loss': 0.0845, 'learning_rate': 1.3809874723655123e-05, 'epoch': 1.55}


 32%|███▏      | 4301/13570 [09:11<19:05,  8.09it/s]

{'loss': 0.0902, 'learning_rate': 1.3662490788504054e-05, 'epoch': 1.58}


 32%|███▏      | 4401/13570 [09:24<19:10,  7.97it/s]

{'loss': 0.0886, 'learning_rate': 1.3515106853352985e-05, 'epoch': 1.62}


 33%|███▎      | 4501/13570 [09:37<20:15,  7.46it/s]

{'loss': 0.0884, 'learning_rate': 1.3367722918201917e-05, 'epoch': 1.66}


 34%|███▍      | 4601/13570 [09:50<19:57,  7.49it/s]

{'loss': 0.086, 'learning_rate': 1.3220338983050848e-05, 'epoch': 1.69}


 35%|███▍      | 4701/13570 [10:03<19:37,  7.53it/s]

{'loss': 0.0862, 'learning_rate': 1.307295504789978e-05, 'epoch': 1.73}


 35%|███▌      | 4801/13570 [10:17<20:04,  7.28it/s]

{'loss': 0.0844, 'learning_rate': 1.292557111274871e-05, 'epoch': 1.77}


 36%|███▌      | 4901/13570 [10:30<19:23,  7.45it/s]

{'loss': 0.0851, 'learning_rate': 1.2778187177597642e-05, 'epoch': 1.81}


 37%|███▋      | 5001/13570 [10:43<18:47,  7.60it/s]

{'loss': 0.085, 'learning_rate': 1.2630803242446574e-05, 'epoch': 1.84}


 38%|███▊      | 5101/13570 [10:57<18:50,  7.49it/s]

{'loss': 0.0854, 'learning_rate': 1.2483419307295505e-05, 'epoch': 1.88}


 38%|███▊      | 5201/13570 [11:10<18:27,  7.56it/s]

{'loss': 0.087, 'learning_rate': 1.2336035372144436e-05, 'epoch': 1.92}


 39%|███▉      | 5301/13570 [11:23<18:13,  7.57it/s]

{'loss': 0.0872, 'learning_rate': 1.218865143699337e-05, 'epoch': 1.95}


 40%|███▉      | 5401/13570 [11:36<18:06,  7.52it/s]

{'loss': 0.0822, 'learning_rate': 1.20412675018423e-05, 'epoch': 1.99}


                                                    
 40%|████      | 5428/13570 [11:51<17:25,  7.79it/s]

{'eval_loss': 0.08527296781539917, 'eval_subset_accuracy': 0.4566900110578695, 'eval_micro_f1': 0.5731989272172385, 'eval_macro_f1': 0.40966896291164445, 'eval_runtime': 11.5131, 'eval_samples_per_second': 471.29, 'eval_steps_per_second': 29.532, 'epoch': 2.0}


 41%|████      | 5501/13570 [12:03<18:00,  7.47it/s]  

{'loss': 0.0816, 'learning_rate': 1.1893883566691232e-05, 'epoch': 2.03}


 41%|████▏     | 5601/13570 [12:17<17:37,  7.53it/s]

{'loss': 0.0793, 'learning_rate': 1.1746499631540163e-05, 'epoch': 2.06}


 42%|████▏     | 5701/13570 [12:30<17:22,  7.55it/s]

{'loss': 0.0784, 'learning_rate': 1.1599115696389095e-05, 'epoch': 2.1}


 43%|████▎     | 5801/13570 [12:43<17:11,  7.53it/s]

{'loss': 0.0748, 'learning_rate': 1.1451731761238026e-05, 'epoch': 2.14}


 43%|████▎     | 5901/13570 [12:56<16:57,  7.53it/s]

{'loss': 0.0762, 'learning_rate': 1.1304347826086957e-05, 'epoch': 2.17}


 44%|████▍     | 6001/13570 [13:10<16:44,  7.54it/s]

{'loss': 0.0772, 'learning_rate': 1.1156963890935889e-05, 'epoch': 2.21}


 45%|████▍     | 6101/13570 [13:23<16:42,  7.45it/s]

{'loss': 0.0761, 'learning_rate': 1.100957995578482e-05, 'epoch': 2.25}


 46%|████▌     | 6201/13570 [13:36<16:00,  7.67it/s]

{'loss': 0.0785, 'learning_rate': 1.0862196020633753e-05, 'epoch': 2.28}


 46%|████▋     | 6301/13570 [13:49<16:03,  7.54it/s]

{'loss': 0.0777, 'learning_rate': 1.0714812085482684e-05, 'epoch': 2.32}


 47%|████▋     | 6401/13570 [14:03<16:44,  7.14it/s]

{'loss': 0.0777, 'learning_rate': 1.0567428150331616e-05, 'epoch': 2.36}


 48%|████▊     | 6501/13570 [14:17<16:16,  7.24it/s]

{'loss': 0.08, 'learning_rate': 1.0420044215180547e-05, 'epoch': 2.39}


 49%|████▊     | 6601/13570 [14:31<16:00,  7.26it/s]

{'loss': 0.0777, 'learning_rate': 1.0272660280029478e-05, 'epoch': 2.43}


 49%|████▉     | 6701/13570 [14:44<15:07,  7.57it/s]

{'loss': 0.0797, 'learning_rate': 1.012527634487841e-05, 'epoch': 2.47}


 50%|█████     | 6801/13570 [14:57<15:01,  7.51it/s]

{'loss': 0.077, 'learning_rate': 9.977892409727341e-06, 'epoch': 2.51}


 51%|█████     | 6901/13570 [15:11<14:50,  7.49it/s]

{'loss': 0.0784, 'learning_rate': 9.830508474576272e-06, 'epoch': 2.54}


 52%|█████▏    | 7001/13570 [15:24<14:12,  7.70it/s]

{'loss': 0.0762, 'learning_rate': 9.683124539425204e-06, 'epoch': 2.58}


 52%|█████▏    | 7101/13570 [15:37<13:48,  7.81it/s]

{'loss': 0.0788, 'learning_rate': 9.535740604274135e-06, 'epoch': 2.62}


 53%|█████▎    | 7201/13570 [15:50<14:10,  7.49it/s]

{'loss': 0.0774, 'learning_rate': 9.388356669123066e-06, 'epoch': 2.65}


 54%|█████▍    | 7301/13570 [16:03<13:38,  7.66it/s]

{'loss': 0.0769, 'learning_rate': 9.240972733971998e-06, 'epoch': 2.69}


 55%|█████▍    | 7401/13570 [16:16<13:25,  7.66it/s]

{'loss': 0.0764, 'learning_rate': 9.093588798820929e-06, 'epoch': 2.73}


 55%|█████▌    | 7501/13570 [16:29<13:18,  7.60it/s]

{'loss': 0.0773, 'learning_rate': 8.94620486366986e-06, 'epoch': 2.76}


 56%|█████▌    | 7601/13570 [16:42<12:51,  7.74it/s]

{'loss': 0.0743, 'learning_rate': 8.798820928518792e-06, 'epoch': 2.8}


 57%|█████▋    | 7701/13570 [16:55<13:00,  7.52it/s]

{'loss': 0.0755, 'learning_rate': 8.651436993367723e-06, 'epoch': 2.84}


 57%|█████▋    | 7801/13570 [17:08<12:34,  7.64it/s]

{'loss': 0.0781, 'learning_rate': 8.504053058216654e-06, 'epoch': 2.87}


 58%|█████▊    | 7901/13570 [17:21<12:30,  7.55it/s]

{'loss': 0.0802, 'learning_rate': 8.356669123065586e-06, 'epoch': 2.91}


 59%|█████▉    | 8001/13570 [17:34<12:25,  7.47it/s]

{'loss': 0.0774, 'learning_rate': 8.209285187914517e-06, 'epoch': 2.95}


 60%|█████▉    | 8101/13570 [17:48<11:55,  7.65it/s]

{'loss': 0.0761, 'learning_rate': 8.06190125276345e-06, 'epoch': 2.98}


                                                    
 60%|██████    | 8142/13570 [18:05<12:05,  7.48it/s]

{'eval_loss': 0.0844298005104065, 'eval_subset_accuracy': 0.4601916697382971, 'eval_micro_f1': 0.5760800363801728, 'eval_macro_f1': 0.4411730304085348, 'eval_runtime': 11.5491, 'eval_samples_per_second': 469.819, 'eval_steps_per_second': 29.439, 'epoch': 3.0}


 60%|██████    | 8201/13570 [18:15<12:06,  7.39it/s]  

{'loss': 0.074, 'learning_rate': 7.914517317612381e-06, 'epoch': 3.02}


 61%|██████    | 8301/13570 [18:28<11:49,  7.42it/s]

{'loss': 0.0701, 'learning_rate': 7.767133382461313e-06, 'epoch': 3.06}


 62%|██████▏   | 8401/13570 [18:42<11:27,  7.52it/s]

{'loss': 0.0705, 'learning_rate': 7.619749447310244e-06, 'epoch': 3.1}


 63%|██████▎   | 8501/13570 [18:55<11:11,  7.55it/s]

{'loss': 0.0735, 'learning_rate': 7.472365512159175e-06, 'epoch': 3.13}


 63%|██████▎   | 8601/13570 [19:09<11:28,  7.21it/s]

{'loss': 0.0724, 'learning_rate': 7.324981577008107e-06, 'epoch': 3.17}


 64%|██████▍   | 8701/13570 [19:22<10:59,  7.39it/s]

{'loss': 0.0687, 'learning_rate': 7.177597641857038e-06, 'epoch': 3.21}


 65%|██████▍   | 8801/13570 [19:35<10:46,  7.37it/s]

{'loss': 0.0716, 'learning_rate': 7.030213706705969e-06, 'epoch': 3.24}


 66%|██████▌   | 8901/13570 [19:49<10:39,  7.30it/s]

{'loss': 0.0702, 'learning_rate': 6.882829771554901e-06, 'epoch': 3.28}


 66%|██████▋   | 9001/13570 [20:02<09:58,  7.63it/s]

{'loss': 0.0688, 'learning_rate': 6.735445836403832e-06, 'epoch': 3.32}


 67%|██████▋   | 9101/13570 [20:15<09:47,  7.61it/s]

{'loss': 0.0709, 'learning_rate': 6.588061901252763e-06, 'epoch': 3.35}


 68%|██████▊   | 9201/13570 [20:29<09:41,  7.51it/s]

{'loss': 0.0716, 'learning_rate': 6.440677966101695e-06, 'epoch': 3.39}


 69%|██████▊   | 9301/13570 [20:42<09:21,  7.60it/s]

{'loss': 0.071, 'learning_rate': 6.293294030950628e-06, 'epoch': 3.43}


 69%|██████▉   | 9401/13570 [20:55<09:14,  7.51it/s]

{'loss': 0.0715, 'learning_rate': 6.145910095799559e-06, 'epoch': 3.46}


 70%|███████   | 9501/13570 [21:08<08:51,  7.66it/s]

{'loss': 0.0695, 'learning_rate': 5.99852616064849e-06, 'epoch': 3.5}


 71%|███████   | 9601/13570 [21:21<08:43,  7.58it/s]

{'loss': 0.0667, 'learning_rate': 5.851142225497422e-06, 'epoch': 3.54}


 71%|███████▏  | 9701/13570 [21:34<08:40,  7.43it/s]

{'loss': 0.0666, 'learning_rate': 5.703758290346353e-06, 'epoch': 3.57}


 72%|███████▏  | 9801/13570 [21:48<08:16,  7.59it/s]

{'loss': 0.0728, 'learning_rate': 5.556374355195284e-06, 'epoch': 3.61}


 73%|███████▎  | 9901/13570 [22:01<08:00,  7.64it/s]

{'loss': 0.0709, 'learning_rate': 5.408990420044216e-06, 'epoch': 3.65}


 74%|███████▎  | 10001/13570 [22:14<07:52,  7.55it/s]

{'loss': 0.0726, 'learning_rate': 5.261606484893147e-06, 'epoch': 3.68}


 74%|███████▍  | 10101/13570 [22:27<07:38,  7.57it/s]

{'loss': 0.069, 'learning_rate': 5.114222549742078e-06, 'epoch': 3.72}


 75%|███████▌  | 10201/13570 [22:40<07:27,  7.52it/s]

{'loss': 0.068, 'learning_rate': 4.966838614591011e-06, 'epoch': 3.76}


 76%|███████▌  | 10301/13570 [22:54<07:27,  7.31it/s]

{'loss': 0.0677, 'learning_rate': 4.819454679439942e-06, 'epoch': 3.8}


 77%|███████▋  | 10401/13570 [23:07<07:00,  7.54it/s]

{'loss': 0.0673, 'learning_rate': 4.672070744288873e-06, 'epoch': 3.83}


 77%|███████▋  | 10501/13570 [23:20<07:00,  7.31it/s]

{'loss': 0.0712, 'learning_rate': 4.524686809137805e-06, 'epoch': 3.87}


 78%|███████▊  | 10601/13570 [23:33<06:33,  7.55it/s]

{'loss': 0.0715, 'learning_rate': 4.377302873986736e-06, 'epoch': 3.91}


 79%|███████▉  | 10701/13570 [23:47<06:24,  7.46it/s]

{'loss': 0.0701, 'learning_rate': 4.229918938835667e-06, 'epoch': 3.94}


 80%|███████▉  | 10801/13570 [24:00<06:04,  7.59it/s]

{'loss': 0.0716, 'learning_rate': 4.082535003684599e-06, 'epoch': 3.98}


                                                     
 80%|████████  | 10856/13570 [24:19<06:06,  7.41it/s]

{'eval_loss': 0.08593972772359848, 'eval_subset_accuracy': 0.4813859196461482, 'eval_micro_f1': 0.5938375350140056, 'eval_macro_f1': 0.4649419589856686, 'eval_runtime': 11.4394, 'eval_samples_per_second': 474.327, 'eval_steps_per_second': 29.722, 'epoch': 4.0}


 80%|████████  | 10901/13570 [24:27<06:02,  7.36it/s]  

{'loss': 0.0652, 'learning_rate': 3.93515106853353e-06, 'epoch': 4.02}


 81%|████████  | 11001/13570 [24:40<05:36,  7.64it/s]

{'loss': 0.0649, 'learning_rate': 3.7877671333824617e-06, 'epoch': 4.05}


 82%|████████▏ | 11101/13570 [24:53<05:01,  8.18it/s]

{'loss': 0.0648, 'learning_rate': 3.640383198231393e-06, 'epoch': 4.09}


 83%|████████▎ | 11201/13570 [25:05<04:53,  8.08it/s]

{'loss': 0.067, 'learning_rate': 3.4929992630803244e-06, 'epoch': 4.13}


 83%|████████▎ | 11301/13570 [25:18<04:43,  8.01it/s]

{'loss': 0.063, 'learning_rate': 3.3456153279292557e-06, 'epoch': 4.16}


 84%|████████▍ | 11401/13570 [25:30<04:28,  8.08it/s]

{'loss': 0.0665, 'learning_rate': 3.198231392778187e-06, 'epoch': 4.2}


 85%|████████▍ | 11501/13570 [25:43<04:40,  7.38it/s]

{'loss': 0.065, 'learning_rate': 3.0508474576271192e-06, 'epoch': 4.24}


 85%|████████▌ | 11601/13570 [25:56<04:21,  7.53it/s]

{'loss': 0.0646, 'learning_rate': 2.9034635224760506e-06, 'epoch': 4.27}


 86%|████████▌ | 11701/13570 [26:10<04:11,  7.44it/s]

{'loss': 0.0643, 'learning_rate': 2.756079587324982e-06, 'epoch': 4.31}


 87%|████████▋ | 11801/13570 [26:23<03:59,  7.38it/s]

{'loss': 0.0665, 'learning_rate': 2.6086956521739132e-06, 'epoch': 4.35}


 88%|████████▊ | 11901/13570 [26:36<03:45,  7.40it/s]

{'loss': 0.0621, 'learning_rate': 2.461311717022845e-06, 'epoch': 4.38}


 88%|████████▊ | 12001/13570 [26:50<03:28,  7.54it/s]

{'loss': 0.0651, 'learning_rate': 2.3139277818717763e-06, 'epoch': 4.42}


 89%|████████▉ | 12101/13570 [27:03<03:14,  7.53it/s]

{'loss': 0.0659, 'learning_rate': 2.1665438467207077e-06, 'epoch': 4.46}


 90%|████████▉ | 12201/13570 [27:16<03:02,  7.50it/s]

{'loss': 0.0604, 'learning_rate': 2.019159911569639e-06, 'epoch': 4.5}


 91%|█████████ | 12301/13570 [27:30<02:49,  7.49it/s]

{'loss': 0.0668, 'learning_rate': 1.8717759764185706e-06, 'epoch': 4.53}


 91%|█████████▏| 12401/13570 [27:43<02:35,  7.51it/s]

{'loss': 0.0616, 'learning_rate': 1.7243920412675019e-06, 'epoch': 4.57}


 92%|█████████▏| 12501/13570 [27:56<02:23,  7.45it/s]

{'loss': 0.066, 'learning_rate': 1.5770081061164336e-06, 'epoch': 4.61}


 93%|█████████▎| 12601/13570 [28:10<02:09,  7.50it/s]

{'loss': 0.065, 'learning_rate': 1.429624170965365e-06, 'epoch': 4.64}


 94%|█████████▎| 12701/13570 [28:23<01:52,  7.74it/s]

{'loss': 0.0641, 'learning_rate': 1.2822402358142963e-06, 'epoch': 4.68}


 94%|█████████▍| 12801/13570 [28:36<01:44,  7.33it/s]

{'loss': 0.064, 'learning_rate': 1.1348563006632279e-06, 'epoch': 4.72}


 95%|█████████▌| 12901/13570 [28:49<01:28,  7.53it/s]

{'loss': 0.0669, 'learning_rate': 9.874723655121592e-07, 'epoch': 4.75}


 96%|█████████▌| 13001/13570 [29:02<01:16,  7.48it/s]

{'loss': 0.0645, 'learning_rate': 8.400884303610906e-07, 'epoch': 4.79}


 97%|█████████▋| 13101/13570 [29:16<01:03,  7.39it/s]

{'loss': 0.0645, 'learning_rate': 6.927044952100222e-07, 'epoch': 4.83}


 97%|█████████▋| 13201/13570 [29:29<00:49,  7.38it/s]

{'loss': 0.0622, 'learning_rate': 5.453205600589536e-07, 'epoch': 4.86}


 98%|█████████▊| 13301/13570 [29:42<00:35,  7.56it/s]

{'loss': 0.0632, 'learning_rate': 3.97936624907885e-07, 'epoch': 4.9}


 99%|█████████▉| 13401/13570 [29:56<00:22,  7.59it/s]

{'loss': 0.063, 'learning_rate': 2.505526897568165e-07, 'epoch': 4.94}


 99%|█████████▉| 13501/13570 [30:09<00:09,  7.53it/s]

{'loss': 0.0627, 'learning_rate': 1.0316875460574797e-07, 'epoch': 4.97}


                                                     
100%|██████████| 13570/13570 [30:30<00:00,  7.55it/s]

{'eval_loss': 0.08712948858737946, 'eval_subset_accuracy': 0.4730925175082934, 'eval_micro_f1': 0.5914025184541902, 'eval_macro_f1': 0.47261614528044454, 'eval_runtime': 11.5111, 'eval_samples_per_second': 471.37, 'eval_steps_per_second': 29.537, 'epoch': 5.0}


100%|██████████| 13570/13570 [30:32<00:00,  7.55it/s]

{'train_runtime': 1832.9214, 'train_samples_per_second': 118.418, 'train_steps_per_second': 7.403, 'train_loss': 0.08442423672208736, 'epoch': 5.0}


100%|██████████| 13570/13570 [30:33<00:00,  7.40it/s]



Evaluating on validation set...


100%|██████████| 340/340 [00:11<00:00, 29.56it/s]


Validation Results:
╒════╤═════════════╤════════════════════════╤═════════════════╤═════════════════╤════════════════╤═══════════════════════════╤═════════════════════════╤═════════╕
│    │   eval_loss │   eval_subset_accuracy │   eval_micro_f1 │   eval_macro_f1 │   eval_runtime │   eval_samples_per_second │   eval_steps_per_second │   epoch │
╞════╪═════════════╪════════════════════════╪═════════════════╪═════════════════╪════════════════╪═══════════════════════════╪═════════════════════════╪═════════╡
│  0 │   0.0844298 │               0.460192 │         0.57608 │        0.441173 │        11.5908 │                   468.128 │                  29.333 │       5 │
╘════╧═════════════╧════════════════════════╧═════════════════╧═════════════════╧════════════════╧═══════════════════════════╧═════════════════════════╧═════════╛

Evaluating on test set...


100%|██████████| 340/340 [00:11<00:00, 29.77it/s]

Test Results:
╒════╤═════════════╤════════════════════════╤═════════════════╤═════════════════╤════════════════╤═══════════════════════════╤═════════════════════════╤═════════╕
│    │   eval_loss │   eval_subset_accuracy │   eval_micro_f1 │   eval_macro_f1 │   eval_runtime │   eval_samples_per_second │   eval_steps_per_second │   epoch │
╞════╪═════════════╪════════════════════════╪═════════════════╪═════════════════╪════════════════╪═══════════════════════════╪═════════════════════════╪═════════╡
│  0 │   0.0840901 │                0.46066 │        0.576743 │        0.441563 │        11.4569 │                   473.688 │                  29.676 │       5 │
╘════╧═════════════╧════════════════════════╧═════════════════╧═════════════════╧════════════════╧═══════════════════════════╧═════════════════════════╧═════════╛

Sample inference text: I am feeling really happy and excited about today!
Predicted probabilities for each emotion:
[5.21914177e-02 1.28694540e-02 8.52493919e-04 1.108821




In [5]:
# Use a specific prompt for inference
prompt_text = "I want a burger"
probs = predict_emotions(prompt_text)

# Get the label names from the dataset (GoEmotions)
label_names = dataset["train"].features["labels"].feature.names

threshold = 0.5 

# Convert probabilities to binary predictions
predicted_binary = (probs > threshold).astype(int)

# For multi-label classification, there may be multiple labels predicted.
predicted_labels = [label_names[i] for i, pred in enumerate(predicted_binary) if pred == 1]

print("\nPredicted labels (threshold={}):".format(threshold))
print(predicted_labels)



Predicted labels (threshold=0.5):
['desire']


In [3]:
import transformers
print(transformers.__version__)

4.31.0


In [7]:
trainer.save_model("./my_saved_model")
tokenizer.save_pretrained("./my_saved_model")

('./my_saved_model\\tokenizer_config.json',
 './my_saved_model\\special_tokens_map.json',
 './my_saved_model\\spm.model',
 './my_saved_model\\added_tokens.json')

In [8]:
import numpy as np
import torch
from datasets import load_dataset, Value, Sequence
from transformers import (
    DebertaV2Tokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator
)
import pandas as pd
from tabulate import tabulate
from sklearn.metrics import f1_score

# ------------------------------
# CUDA Initialization Check
# ------------------------------
if torch.cuda.is_available():
    try:
        torch.cuda.init()  # Force CUDA initialization (optional)
    except Exception as e:
        print("Error during CUDA initialization:", e)
    print("CUDA is available. Running on GPU.")
    try:
        print("GPU Name:", torch.cuda.get_device_name(0))
    except Exception as e:
        print("Could not get GPU Name:", e)
    print("GPU Count:", torch.cuda.device_count())
else:
    print("CUDA is not available. Running on CPU.")

# ------------------------------
# 1. Load the GoEmotions Dataset
# ------------------------------
dataset = load_dataset("go_emotions")
print("\nSample from training set:")
print(dataset["train"][0])  # Show a sample to inspect

# Number of labels
num_labels = dataset["train"].features["labels"].feature.num_classes
print("Number of labels:", num_labels)

# ------------------------------
# 2. Data Preprocessing
# ------------------------------
def convert_labels(label_list):
    """Convert list of label indices to a multi-hot vector of floats."""
    return [1.0 if i in label_list else 0.0 for i in range(num_labels)]

model_checkpoint = "microsoft/deberta-v3-large"  # If memory is an issue, try "microsoft/deberta-v3-base"
tokenizer = DebertaV2Tokenizer.from_pretrained(model_checkpoint)
max_length = 128

def preprocess_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    tokenized["labels"] = [convert_labels(lbls) for lbls in examples["labels"]]
    return tokenized

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.cast_column("labels", Sequence(Value("float32")))
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# ------------------------------
# 3. Model Setup
# ------------------------------
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
)
model.config.problem_type = "multi_label_classification"

if torch.cuda.is_available():
    model = model.to("cuda")

# ------------------------------
# 4. Define Compute Metrics Function
# ------------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Binarize predictions using threshold 0.5 by default
    predictions = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    # Compute subset accuracy (exact match)
    subset_accuracy = np.mean(np.all(predictions == labels, axis=1))
    # Compute micro and macro F1
    micro_f1 = f1_score(labels, predictions, average="micro", zero_division=0)
    macro_f1 = f1_score(labels, predictions, average="macro", zero_division=0)
    return {
        "subset_accuracy": subset_accuracy,
        "micro_f1": micro_f1,
        "macro_f1": macro_f1
    }

# ------------------------------
# 5. Training Arguments & Trainer Setup
# ------------------------------
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,   # If you have enough GPU RAM, you can increase to 16
    per_device_eval_batch_size=8,
    num_train_epochs=8,             # Increase epochs for better performance
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=1,
    load_best_model_at_end=True,
)

def custom_data_collator(features):
    batch = DefaultDataCollator()(features)
    batch["labels"] = batch["labels"].float()
    return batch

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=custom_data_collator,
    compute_metrics=compute_metrics,
)

print("\nStarting training...")
trainer.train()

# ------------------------------
# 6. Basic Evaluation with Default Threshold (0.5)
# ------------------------------
print("\nEvaluating on validation set with threshold=0.5...")
val_results = trainer.evaluate()
df_val_results = pd.DataFrame([val_results])
print("Validation Results:")
print(tabulate(df_val_results, headers="keys", tablefmt="fancy_grid"))

print("\nEvaluating on test set with threshold=0.5...")
test_results = trainer.evaluate(eval_dataset=encoded_dataset["test"])
df_test_results = pd.DataFrame([test_results])
print("Test Results:")
print(tabulate(df_test_results, headers="keys", tablefmt="fancy_grid"))

# ------------------------------
# 7. Tune Threshold on Validation Set
# ------------------------------
print("\nFinding best threshold on validation set...")

def find_best_threshold(logits, labels, step=0.05):
    thresholds = np.arange(0.0, 1.0+step, step)
    best_thr = 0.5
    best_f1 = 0.0
    sig_logits = torch.sigmoid(torch.tensor(logits)).numpy()
    for thr in thresholds:
        preds = (sig_logits > thr).astype(int)
        f1 = f1_score(labels, preds, average="micro", zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_thr = thr
    return best_thr, best_f1

val_preds = trainer.predict(encoded_dataset["validation"])
val_logits, val_labels = val_preds.predictions, val_preds.label_ids
best_threshold, best_val_f1 = find_best_threshold(val_logits, val_labels, step=0.01)

print(f"Best threshold on validation set: {best_threshold}")
print(f"Best micro-F1 on validation set with that threshold: {best_val_f1:.4f}")

# ------------------------------
# 8. Evaluate on Test Set Using Best Threshold
# ------------------------------
test_preds = trainer.predict(encoded_dataset["test"])
test_logits, test_labels = test_preds.predictions, test_preds.label_ids
test_probs = torch.sigmoid(torch.tensor(test_logits)).numpy()

test_predictions = (test_probs > best_threshold).astype(int)

# Compute metrics manually
test_subset_accuracy = np.mean(np.all(test_predictions == test_labels, axis=1))
test_micro_f1 = f1_score(test_labels, test_predictions, average="micro", zero_division=0)
test_macro_f1 = f1_score(test_labels, test_predictions, average="macro", zero_division=0)

print("\nTest Results with tuned threshold:")
print(f"  Subset Accuracy: {test_subset_accuracy:.4f}")
print(f"  Micro-F1:        {test_micro_f1:.4f}")
print(f"  Macro-F1:        {test_macro_f1:.4f}")

# ------------------------------
# 9. Save the Model and Tokenizer
# ------------------------------
print("\nSaving the model and tokenizer...")
trainer.save_model("./my_saved_model")  # Saves the best model
tokenizer.save_pretrained("./my_saved_model")
print("Model and tokenizer saved to ./my_saved_model")

# ------------------------------
# 10. Inference Example (Using Best Threshold)
# ------------------------------
def predict_emotions(text, threshold=best_threshold):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]
    probs = torch.sigmoid(logits).cpu().numpy()
    # Use best threshold
    return (probs > threshold).astype(int), probs

sample_text = "I am feeling really happy and excited about today!"
labels_pred, probs_pred = predict_emotions(sample_text)
print("\nSample inference text:", sample_text)
print("Predicted probabilities for each emotion:", probs_pred)
print("Predicted multi-hot labels:", labels_pred)


CUDA is available. Running on GPU.
GPU Name: NVIDIA GeForce RTX 4070 Ti SUPER
GPU Count: 1

Sample from training set:
{'text': "My favourite food is anything I didn't have to cook myself.", 'labels': [27], 'id': 'eebbqej'}
Number of labels: 28


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map: 100%|██████████| 43410/43410 [00:03<00:00, 12778.70 examples/s]
Map: 100%|██████████| 5426/5426 [00:00<00:00, 13092.01 examples/s]
Map: 100%|██████████| 5427/5427 [00:00<00:00, 11171.88 examples/s]
Casting the dataset: 100%|██████████| 43410/43410 [00:00<00:00, 800138.59 examples/s]
Casting the dataset: 100%|██████████| 5426/5426 [00:00<00:00, 307785.74 examples/s]
Casting the dataset: 100%|██████████| 5427/5427 [00:00<00:00, 310971.44 examples/s]
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['pooler.


Starting training...


  0%|          | 100/43416 [00:20<2:25:45,  4.95it/s]

{'loss': 0.2916, 'learning_rate': 1.9953934033536025e-05, 'epoch': 0.02}


  0%|          | 201/43416 [00:40<2:25:38,  4.95it/s]

{'loss': 0.1535, 'learning_rate': 1.990786806707205e-05, 'epoch': 0.04}


  1%|          | 301/43416 [01:00<2:22:08,  5.06it/s]

{'loss': 0.1404, 'learning_rate': 1.9861802100608072e-05, 'epoch': 0.06}


  1%|          | 401/43416 [01:20<2:25:28,  4.93it/s]

{'loss': 0.1368, 'learning_rate': 1.9815736134144095e-05, 'epoch': 0.07}


  1%|          | 501/43416 [01:40<2:22:19,  5.03it/s]

{'loss': 0.1275, 'learning_rate': 1.976967016768012e-05, 'epoch': 0.09}


  1%|▏         | 601/43416 [02:01<2:22:01,  5.02it/s]

{'loss': 0.1219, 'learning_rate': 1.9723604201216142e-05, 'epoch': 0.11}


  2%|▏         | 700/43416 [02:21<2:31:08,  4.71it/s]

{'loss': 0.1179, 'learning_rate': 1.9677538234752166e-05, 'epoch': 0.13}


  2%|▏         | 800/43416 [02:42<2:31:09,  4.70it/s]

{'loss': 0.1156, 'learning_rate': 1.963147226828819e-05, 'epoch': 0.15}


  2%|▏         | 900/43416 [03:03<2:25:05,  4.88it/s]

{'loss': 0.1086, 'learning_rate': 1.9585406301824216e-05, 'epoch': 0.17}


  2%|▏         | 1001/43416 [03:23<2:21:21,  5.00it/s]

{'loss': 0.1065, 'learning_rate': 1.9539340335360236e-05, 'epoch': 0.18}


  3%|▎         | 1100/43416 [03:43<2:21:11,  5.00it/s]

{'loss': 0.1079, 'learning_rate': 1.949327436889626e-05, 'epoch': 0.2}


  3%|▎         | 1201/43416 [04:04<2:21:51,  4.96it/s]

{'loss': 0.1066, 'learning_rate': 1.9447208402432286e-05, 'epoch': 0.22}


  3%|▎         | 1301/43416 [04:23<2:18:49,  5.06it/s]

{'loss': 0.1012, 'learning_rate': 1.9401142435968306e-05, 'epoch': 0.24}


  3%|▎         | 1400/43416 [04:43<2:19:35,  5.02it/s]

{'loss': 0.1044, 'learning_rate': 1.9355076469504333e-05, 'epoch': 0.26}


  3%|▎         | 1501/43416 [05:03<2:18:26,  5.05it/s]

{'loss': 0.1004, 'learning_rate': 1.9309010503040357e-05, 'epoch': 0.28}


  4%|▎         | 1601/43416 [05:23<2:19:24,  5.00it/s]

{'loss': 0.1005, 'learning_rate': 1.9262944536576377e-05, 'epoch': 0.29}


  4%|▍         | 1701/43416 [05:43<2:19:04,  5.00it/s]

{'loss': 0.0964, 'learning_rate': 1.9216878570112404e-05, 'epoch': 0.31}


  4%|▍         | 1801/43416 [06:03<2:17:29,  5.04it/s]

{'loss': 0.0996, 'learning_rate': 1.9170812603648427e-05, 'epoch': 0.33}


  4%|▍         | 1901/43416 [06:23<2:18:24,  5.00it/s]

{'loss': 0.0964, 'learning_rate': 1.912474663718445e-05, 'epoch': 0.35}


  5%|▍         | 2001/43416 [06:43<2:15:53,  5.08it/s]

{'loss': 0.0943, 'learning_rate': 1.9078680670720474e-05, 'epoch': 0.37}


  5%|▍         | 2101/43416 [07:03<2:16:40,  5.04it/s]

{'loss': 0.0945, 'learning_rate': 1.9032614704256497e-05, 'epoch': 0.39}


  5%|▌         | 2200/43416 [07:22<2:16:32,  5.03it/s]

{'loss': 0.0959, 'learning_rate': 1.898654873779252e-05, 'epoch': 0.41}


  5%|▌         | 2301/43416 [07:42<2:16:40,  5.01it/s]

{'loss': 0.0937, 'learning_rate': 1.8940482771328544e-05, 'epoch': 0.42}


  6%|▌         | 2401/43416 [08:02<2:15:29,  5.05it/s]

{'loss': 0.0957, 'learning_rate': 1.8894416804864568e-05, 'epoch': 0.44}


  6%|▌         | 2500/43416 [08:22<2:20:46,  4.84it/s]

{'loss': 0.0959, 'learning_rate': 1.884835083840059e-05, 'epoch': 0.46}


  6%|▌         | 2601/43416 [08:42<2:15:17,  5.03it/s]

{'loss': 0.0968, 'learning_rate': 1.8802284871936615e-05, 'epoch': 0.48}


  6%|▌         | 2701/43416 [09:02<2:13:56,  5.07it/s]

{'loss': 0.0916, 'learning_rate': 1.8756218905472638e-05, 'epoch': 0.5}


  6%|▋         | 2801/43416 [09:22<2:15:07,  5.01it/s]

{'loss': 0.0884, 'learning_rate': 1.871015293900866e-05, 'epoch': 0.52}


  7%|▋         | 2901/43416 [09:42<2:14:49,  5.01it/s]

{'loss': 0.0949, 'learning_rate': 1.8664086972544685e-05, 'epoch': 0.53}


  7%|▋         | 3001/43416 [10:02<2:14:04,  5.02it/s]

{'loss': 0.0951, 'learning_rate': 1.861802100608071e-05, 'epoch': 0.55}


  7%|▋         | 3101/43416 [10:21<2:12:31,  5.07it/s]

{'loss': 0.0922, 'learning_rate': 1.8571955039616735e-05, 'epoch': 0.57}


  7%|▋         | 3200/43416 [10:41<2:12:49,  5.05it/s]

{'loss': 0.09, 'learning_rate': 1.8525889073152755e-05, 'epoch': 0.59}


  8%|▊         | 3300/43416 [11:01<2:12:04,  5.06it/s]

{'loss': 0.0864, 'learning_rate': 1.847982310668878e-05, 'epoch': 0.61}


  8%|▊         | 3401/43416 [11:21<2:12:35,  5.03it/s]

{'loss': 0.087, 'learning_rate': 1.8433757140224805e-05, 'epoch': 0.63}


  8%|▊         | 3501/43416 [11:41<2:13:09,  5.00it/s]

{'loss': 0.0934, 'learning_rate': 1.8387691173760826e-05, 'epoch': 0.64}


  8%|▊         | 3601/43416 [12:01<2:11:40,  5.04it/s]

{'loss': 0.0898, 'learning_rate': 1.8341625207296852e-05, 'epoch': 0.66}


  9%|▊         | 3701/43416 [12:21<2:11:29,  5.03it/s]

{'loss': 0.0907, 'learning_rate': 1.8295559240832876e-05, 'epoch': 0.68}


  9%|▉         | 3801/43416 [12:41<2:11:17,  5.03it/s]

{'loss': 0.0885, 'learning_rate': 1.8249493274368896e-05, 'epoch': 0.7}


  9%|▉         | 3901/43416 [13:00<2:10:28,  5.05it/s]

{'loss': 0.0871, 'learning_rate': 1.8203427307904923e-05, 'epoch': 0.72}


  9%|▉         | 4001/43416 [13:21<2:10:42,  5.03it/s]

{'loss': 0.0911, 'learning_rate': 1.8157361341440943e-05, 'epoch': 0.74}


  9%|▉         | 4101/43416 [13:41<2:13:30,  4.91it/s]

{'loss': 0.0884, 'learning_rate': 1.811129537497697e-05, 'epoch': 0.76}


 10%|▉         | 4201/43416 [14:01<2:09:01,  5.07it/s]

{'loss': 0.0937, 'learning_rate': 1.8065229408512993e-05, 'epoch': 0.77}


 10%|▉         | 4301/43416 [14:21<2:08:47,  5.06it/s]

{'loss': 0.0859, 'learning_rate': 1.8019163442049013e-05, 'epoch': 0.79}


 10%|█         | 4400/43416 [14:41<2:09:57,  5.00it/s]

{'loss': 0.0891, 'learning_rate': 1.797309747558504e-05, 'epoch': 0.81}


 10%|█         | 4501/43416 [15:01<2:08:54,  5.03it/s]

{'loss': 0.091, 'learning_rate': 1.7927031509121063e-05, 'epoch': 0.83}


 11%|█         | 4601/43416 [15:20<2:08:01,  5.05it/s]

{'loss': 0.0853, 'learning_rate': 1.7880965542657087e-05, 'epoch': 0.85}


 11%|█         | 4701/43416 [15:40<2:08:07,  5.04it/s]

{'loss': 0.0884, 'learning_rate': 1.783489957619311e-05, 'epoch': 0.87}


 11%|█         | 4801/43416 [16:00<2:06:58,  5.07it/s]

{'loss': 0.0883, 'learning_rate': 1.7788833609729134e-05, 'epoch': 0.88}


 11%|█▏        | 4900/43416 [16:20<2:07:04,  5.05it/s]

{'loss': 0.0856, 'learning_rate': 1.7742767643265157e-05, 'epoch': 0.9}


 12%|█▏        | 5001/43416 [16:40<2:10:06,  4.92it/s]

{'loss': 0.086, 'learning_rate': 1.769670167680118e-05, 'epoch': 0.92}


 12%|█▏        | 5101/43416 [17:00<2:06:10,  5.06it/s]

{'loss': 0.0859, 'learning_rate': 1.7650635710337204e-05, 'epoch': 0.94}


 12%|█▏        | 5201/43416 [17:20<2:06:20,  5.04it/s]

{'loss': 0.0849, 'learning_rate': 1.7604569743873227e-05, 'epoch': 0.96}


 12%|█▏        | 5301/43416 [17:40<2:07:24,  4.99it/s]

{'loss': 0.0863, 'learning_rate': 1.755850377740925e-05, 'epoch': 0.98}


 12%|█▏        | 5401/43416 [18:00<2:05:12,  5.06it/s]

{'loss': 0.0878, 'learning_rate': 1.7512437810945274e-05, 'epoch': 1.0}


 12%|█▎        | 5427/43416 [18:05<2:03:06,  5.14it/s]
 12%|█▎        | 5427/43416 [18:39<2:03:06,  5.14it/s]

{'eval_loss': 0.085018090903759, 'eval_subset_accuracy': 0.4423147806855879, 'eval_micro_f1': 0.56874715261959, 'eval_macro_f1': 0.4271840326806257, 'eval_runtime': 33.8942, 'eval_samples_per_second': 160.086, 'eval_steps_per_second': 20.033, 'epoch': 1.0}


 13%|█▎        | 5501/43416 [18:58<2:04:42,  5.07it/s]  

{'loss': 0.083, 'learning_rate': 1.7466371844481298e-05, 'epoch': 1.01}


 13%|█▎        | 5601/43416 [19:18<2:04:53,  5.05it/s]

{'loss': 0.0784, 'learning_rate': 1.742030587801732e-05, 'epoch': 1.03}


 13%|█▎        | 5701/43416 [19:38<2:05:00,  5.03it/s]

{'loss': 0.0811, 'learning_rate': 1.7374239911553345e-05, 'epoch': 1.05}


 13%|█▎        | 5801/43416 [19:58<2:04:24,  5.04it/s]

{'loss': 0.0757, 'learning_rate': 1.732817394508937e-05, 'epoch': 1.07}


 14%|█▎        | 5901/43416 [20:17<2:03:52,  5.05it/s]

{'loss': 0.0742, 'learning_rate': 1.728210797862539e-05, 'epoch': 1.09}


 14%|█▍        | 6000/43416 [20:37<2:05:07,  4.98it/s]

{'loss': 0.0834, 'learning_rate': 1.7236042012161415e-05, 'epoch': 1.11}


 14%|█▍        | 6101/43416 [20:57<2:03:04,  5.05it/s]

{'loss': 0.0783, 'learning_rate': 1.7189976045697442e-05, 'epoch': 1.12}


 14%|█▍        | 6201/43416 [21:17<2:03:42,  5.01it/s]

{'loss': 0.0763, 'learning_rate': 1.7143910079233462e-05, 'epoch': 1.14}


 15%|█▍        | 6301/43416 [21:37<2:03:24,  5.01it/s]

{'loss': 0.0737, 'learning_rate': 1.709784411276949e-05, 'epoch': 1.16}


 15%|█▍        | 6401/43416 [21:57<2:03:20,  5.00it/s]

{'loss': 0.0784, 'learning_rate': 1.7051778146305512e-05, 'epoch': 1.18}


 15%|█▍        | 6501/43416 [22:17<2:02:39,  5.02it/s]

{'loss': 0.0756, 'learning_rate': 1.7005712179841532e-05, 'epoch': 1.2}


 15%|█▌        | 6601/43416 [22:37<2:01:25,  5.05it/s]

{'loss': 0.076, 'learning_rate': 1.695964621337756e-05, 'epoch': 1.22}


 15%|█▌        | 6701/43416 [22:57<2:00:29,  5.08it/s]

{'loss': 0.0771, 'learning_rate': 1.6913580246913582e-05, 'epoch': 1.23}


 16%|█▌        | 6801/43416 [23:16<2:00:38,  5.06it/s]

{'loss': 0.0754, 'learning_rate': 1.6867514280449606e-05, 'epoch': 1.25}


 16%|█▌        | 6901/43416 [23:36<2:01:17,  5.02it/s]

{'loss': 0.0754, 'learning_rate': 1.682144831398563e-05, 'epoch': 1.27}


 16%|█▌        | 7000/43416 [23:56<2:00:48,  5.02it/s]

{'loss': 0.0757, 'learning_rate': 1.6775382347521653e-05, 'epoch': 1.29}


 16%|█▋        | 7101/43416 [24:16<2:00:02,  5.04it/s]

{'loss': 0.0779, 'learning_rate': 1.6729316381057676e-05, 'epoch': 1.31}


 17%|█▋        | 7201/43416 [24:36<2:00:49,  5.00it/s]

{'loss': 0.0762, 'learning_rate': 1.66832504145937e-05, 'epoch': 1.33}


 17%|█▋        | 7300/43416 [24:56<1:58:30,  5.08it/s]

{'loss': 0.0808, 'learning_rate': 1.6637184448129723e-05, 'epoch': 1.35}


 17%|█▋        | 7401/43416 [25:16<1:58:33,  5.06it/s]

{'loss': 0.0737, 'learning_rate': 1.6591118481665747e-05, 'epoch': 1.36}


 17%|█▋        | 7501/43416 [25:36<1:58:50,  5.04it/s]

{'loss': 0.0802, 'learning_rate': 1.654505251520177e-05, 'epoch': 1.38}


 18%|█▊        | 7601/43416 [25:55<1:58:16,  5.05it/s]

{'loss': 0.0812, 'learning_rate': 1.6498986548737793e-05, 'epoch': 1.4}


 18%|█▊        | 7700/43416 [26:15<1:57:00,  5.09it/s]

{'loss': 0.0768, 'learning_rate': 1.6452920582273817e-05, 'epoch': 1.42}


 18%|█▊        | 7801/43416 [26:35<1:56:54,  5.08it/s]

{'loss': 0.074, 'learning_rate': 1.640685461580984e-05, 'epoch': 1.44}


 18%|█▊        | 7901/43416 [26:55<1:56:58,  5.06it/s]

{'loss': 0.0785, 'learning_rate': 1.6360788649345864e-05, 'epoch': 1.46}


 18%|█▊        | 8000/43416 [27:15<1:57:02,  5.04it/s]

{'loss': 0.0759, 'learning_rate': 1.6314722682881887e-05, 'epoch': 1.47}


 19%|█▊        | 8101/43416 [27:35<1:57:53,  4.99it/s]

{'loss': 0.0809, 'learning_rate': 1.626865671641791e-05, 'epoch': 1.49}


 19%|█▉        | 8201/43416 [27:55<1:56:14,  5.05it/s]

{'loss': 0.0782, 'learning_rate': 1.6222590749953934e-05, 'epoch': 1.51}


 19%|█▉        | 8301/43416 [28:15<1:55:36,  5.06it/s]

{'loss': 0.074, 'learning_rate': 1.617652478348996e-05, 'epoch': 1.53}


 19%|█▉        | 8401/43416 [28:34<1:57:06,  4.98it/s]

{'loss': 0.0745, 'learning_rate': 1.613045881702598e-05, 'epoch': 1.55}


 20%|█▉        | 8501/43416 [28:54<1:55:12,  5.05it/s]

{'loss': 0.0767, 'learning_rate': 1.6084392850562004e-05, 'epoch': 1.57}


 20%|█▉        | 8601/43416 [29:14<1:55:27,  5.03it/s]

{'loss': 0.0793, 'learning_rate': 1.603832688409803e-05, 'epoch': 1.58}


 20%|██        | 8700/43416 [29:34<1:54:39,  5.05it/s]

{'loss': 0.0791, 'learning_rate': 1.599226091763405e-05, 'epoch': 1.6}


 20%|██        | 8801/43416 [29:54<1:54:20,  5.05it/s]

{'loss': 0.076, 'learning_rate': 1.5946194951170078e-05, 'epoch': 1.62}


 21%|██        | 8901/43416 [30:14<1:55:02,  5.00it/s]

{'loss': 0.0768, 'learning_rate': 1.59001289847061e-05, 'epoch': 1.64}


 21%|██        | 9001/43416 [30:34<1:53:43,  5.04it/s]

{'loss': 0.0801, 'learning_rate': 1.5854063018242125e-05, 'epoch': 1.66}


 21%|██        | 9101/43416 [30:54<1:53:56,  5.02it/s]

{'loss': 0.0771, 'learning_rate': 1.580799705177815e-05, 'epoch': 1.68}


 21%|██        | 9201/43416 [31:14<1:52:53,  5.05it/s]

{'loss': 0.075, 'learning_rate': 1.5761931085314172e-05, 'epoch': 1.7}


 21%|██▏       | 9301/43416 [31:33<1:53:57,  4.99it/s]

{'loss': 0.0719, 'learning_rate': 1.5715865118850195e-05, 'epoch': 1.71}


 22%|██▏       | 9401/43416 [31:53<1:52:29,  5.04it/s]

{'loss': 0.0791, 'learning_rate': 1.566979915238622e-05, 'epoch': 1.73}


 22%|██▏       | 9500/43416 [32:13<1:54:40,  4.93it/s]

{'loss': 0.0715, 'learning_rate': 1.5623733185922242e-05, 'epoch': 1.75}


 22%|██▏       | 9601/43416 [32:33<1:52:21,  5.02it/s]

{'loss': 0.0802, 'learning_rate': 1.5577667219458266e-05, 'epoch': 1.77}


 22%|██▏       | 9701/43416 [32:53<1:53:54,  4.93it/s]

{'loss': 0.0721, 'learning_rate': 1.553160125299429e-05, 'epoch': 1.79}


 23%|██▎       | 9801/43416 [33:13<1:51:52,  5.01it/s]

{'loss': 0.0784, 'learning_rate': 1.5485535286530313e-05, 'epoch': 1.81}


 23%|██▎       | 9901/43416 [33:33<1:51:34,  5.01it/s]

{'loss': 0.0766, 'learning_rate': 1.5439469320066336e-05, 'epoch': 1.82}


 23%|██▎       | 10001/43416 [33:53<1:51:03,  5.01it/s]

{'loss': 0.0763, 'learning_rate': 1.539340335360236e-05, 'epoch': 1.84}


 23%|██▎       | 10101/43416 [34:13<1:51:21,  4.99it/s]

{'loss': 0.0754, 'learning_rate': 1.5347337387138383e-05, 'epoch': 1.86}


 23%|██▎       | 10201/43416 [34:33<1:50:41,  5.00it/s]

{'loss': 0.0767, 'learning_rate': 1.5301271420674406e-05, 'epoch': 1.88}


 24%|██▎       | 10301/43416 [34:53<1:49:02,  5.06it/s]

{'loss': 0.0764, 'learning_rate': 1.5255205454210432e-05, 'epoch': 1.9}


 24%|██▍       | 10401/43416 [35:13<1:48:46,  5.06it/s]

{'loss': 0.0796, 'learning_rate': 1.5209139487746453e-05, 'epoch': 1.92}


 24%|██▍       | 10501/43416 [35:32<1:49:00,  5.03it/s]

{'loss': 0.0783, 'learning_rate': 1.5163073521282478e-05, 'epoch': 1.93}


 24%|██▍       | 10601/43416 [35:52<1:49:10,  5.01it/s]

{'loss': 0.0786, 'learning_rate': 1.5117007554818502e-05, 'epoch': 1.95}


 25%|██▍       | 10701/43416 [36:12<1:47:42,  5.06it/s]

{'loss': 0.0756, 'learning_rate': 1.5070941588354524e-05, 'epoch': 1.97}


 25%|██▍       | 10800/43416 [36:32<1:48:24,  5.01it/s]

{'loss': 0.075, 'learning_rate': 1.5024875621890549e-05, 'epoch': 1.99}


 25%|██▌       | 10854/43416 [36:43<1:40:06,  5.42it/s]
 25%|██▌       | 10854/43416 [37:17<1:40:06,  5.42it/s]

{'eval_loss': 0.08198396116495132, 'eval_subset_accuracy': 0.49023221525985994, 'eval_micro_f1': 0.605637401004854, 'eval_macro_f1': 0.45295987968417867, 'eval_runtime': 33.8111, 'eval_samples_per_second': 160.48, 'eval_steps_per_second': 20.082, 'epoch': 2.0}


 25%|██▌       | 10901/43416 [37:31<1:47:59,  5.02it/s]  

{'loss': 0.0705, 'learning_rate': 1.4978809655426572e-05, 'epoch': 2.01}


 25%|██▌       | 11001/43416 [37:51<1:47:39,  5.02it/s]

{'loss': 0.0656, 'learning_rate': 1.4932743688962597e-05, 'epoch': 2.03}


 26%|██▌       | 11101/43416 [38:11<1:47:33,  5.01it/s]

{'loss': 0.0608, 'learning_rate': 1.4886677722498619e-05, 'epoch': 2.05}


 26%|██▌       | 11200/43416 [38:30<1:45:57,  5.07it/s]

{'loss': 0.0606, 'learning_rate': 1.4840611756034643e-05, 'epoch': 2.06}


 26%|██▌       | 11301/43416 [38:50<1:46:37,  5.02it/s]

{'loss': 0.0607, 'learning_rate': 1.4794545789570666e-05, 'epoch': 2.08}


 26%|██▋       | 11401/43416 [39:10<1:45:35,  5.05it/s]

{'loss': 0.0611, 'learning_rate': 1.474847982310669e-05, 'epoch': 2.1}


 26%|██▋       | 11501/43416 [39:30<1:46:59,  4.97it/s]

{'loss': 0.0557, 'learning_rate': 1.4702413856642715e-05, 'epoch': 2.12}


 27%|██▋       | 11601/43416 [39:50<1:45:39,  5.02it/s]

{'loss': 0.0592, 'learning_rate': 1.4656347890178736e-05, 'epoch': 2.14}


 27%|██▋       | 11700/43416 [40:10<1:44:19,  5.07it/s]

{'loss': 0.0608, 'learning_rate': 1.461028192371476e-05, 'epoch': 2.16}


 27%|██▋       | 11801/43416 [40:30<1:44:58,  5.02it/s]

{'loss': 0.0588, 'learning_rate': 1.4564215957250785e-05, 'epoch': 2.17}


 27%|██▋       | 11901/43416 [40:50<1:45:12,  4.99it/s]

{'loss': 0.0623, 'learning_rate': 1.4518149990786807e-05, 'epoch': 2.19}


 28%|██▊       | 12001/43416 [41:10<1:44:07,  5.03it/s]

{'loss': 0.0605, 'learning_rate': 1.4472084024322832e-05, 'epoch': 2.21}


 28%|██▊       | 12101/43416 [41:30<1:44:05,  5.01it/s]

{'loss': 0.0576, 'learning_rate': 1.4426018057858855e-05, 'epoch': 2.23}


 28%|██▊       | 12201/43416 [41:50<1:44:59,  4.96it/s]

{'loss': 0.0624, 'learning_rate': 1.4379952091394877e-05, 'epoch': 2.25}


 28%|██▊       | 12300/43416 [42:09<1:42:35,  5.05it/s]

{'loss': 0.0635, 'learning_rate': 1.4333886124930902e-05, 'epoch': 2.27}


 29%|██▊       | 12400/43416 [42:29<1:42:04,  5.06it/s]

{'loss': 0.059, 'learning_rate': 1.4287820158466926e-05, 'epoch': 2.28}


 29%|██▉       | 12500/43416 [42:49<1:44:22,  4.94it/s]

{'loss': 0.0618, 'learning_rate': 1.424175419200295e-05, 'epoch': 2.3}


 29%|██▉       | 12600/43416 [43:09<1:41:20,  5.07it/s]

{'loss': 0.0606, 'learning_rate': 1.4195688225538972e-05, 'epoch': 2.32}


 29%|██▉       | 12701/43416 [43:29<1:42:05,  5.01it/s]

{'loss': 0.062, 'learning_rate': 1.4149622259074998e-05, 'epoch': 2.34}


 29%|██▉       | 12800/43416 [43:49<1:43:16,  4.94it/s]

{'loss': 0.0634, 'learning_rate': 1.4103556292611021e-05, 'epoch': 2.36}


 30%|██▉       | 12901/43416 [44:09<1:40:23,  5.07it/s]

{'loss': 0.0596, 'learning_rate': 1.4057490326147043e-05, 'epoch': 2.38}


 30%|██▉       | 13001/43416 [44:29<1:41:05,  5.01it/s]

{'loss': 0.0666, 'learning_rate': 1.4011424359683068e-05, 'epoch': 2.4}


 30%|███       | 13101/43416 [44:49<1:40:11,  5.04it/s]

{'loss': 0.0647, 'learning_rate': 1.3965358393219091e-05, 'epoch': 2.41}


 30%|███       | 13201/43416 [45:09<1:40:04,  5.03it/s]

{'loss': 0.0593, 'learning_rate': 1.3919292426755115e-05, 'epoch': 2.43}


 31%|███       | 13301/43416 [45:28<1:40:34,  4.99it/s]

{'loss': 0.0687, 'learning_rate': 1.3873226460291138e-05, 'epoch': 2.45}


 31%|███       | 13401/43416 [45:48<1:40:21,  4.98it/s]

{'loss': 0.0577, 'learning_rate': 1.3827160493827162e-05, 'epoch': 2.47}


 31%|███       | 13501/43416 [46:08<1:39:14,  5.02it/s]

{'loss': 0.0619, 'learning_rate': 1.3781094527363185e-05, 'epoch': 2.49}


 31%|███▏      | 13601/43416 [46:28<1:38:22,  5.05it/s]

{'loss': 0.0606, 'learning_rate': 1.3735028560899209e-05, 'epoch': 2.51}


 32%|███▏      | 13701/43416 [46:48<1:38:33,  5.03it/s]

{'loss': 0.0643, 'learning_rate': 1.3688962594435234e-05, 'epoch': 2.52}


 32%|███▏      | 13801/43416 [47:08<1:37:47,  5.05it/s]

{'loss': 0.0621, 'learning_rate': 1.3642896627971255e-05, 'epoch': 2.54}


 32%|███▏      | 13901/43416 [47:28<1:37:17,  5.06it/s]

{'loss': 0.0613, 'learning_rate': 1.3596830661507279e-05, 'epoch': 2.56}


 32%|███▏      | 14001/43416 [47:48<1:37:54,  5.01it/s]

{'loss': 0.0609, 'learning_rate': 1.3550764695043304e-05, 'epoch': 2.58}


 32%|███▏      | 14101/43416 [48:08<1:36:53,  5.04it/s]

{'loss': 0.0603, 'learning_rate': 1.3504698728579326e-05, 'epoch': 2.6}


 33%|███▎      | 14200/43416 [48:27<1:37:39,  4.99it/s]

{'loss': 0.0651, 'learning_rate': 1.3458632762115351e-05, 'epoch': 2.62}


 33%|███▎      | 14301/43416 [48:48<1:36:17,  5.04it/s]

{'loss': 0.0662, 'learning_rate': 1.3412566795651374e-05, 'epoch': 2.63}


 33%|███▎      | 14401/43416 [49:07<1:35:48,  5.05it/s]

{'loss': 0.0605, 'learning_rate': 1.3366500829187396e-05, 'epoch': 2.65}


 33%|███▎      | 14501/43416 [49:27<1:35:44,  5.03it/s]

{'loss': 0.0603, 'learning_rate': 1.3320434862723421e-05, 'epoch': 2.67}


 34%|███▎      | 14601/43416 [49:47<1:35:38,  5.02it/s]

{'loss': 0.0642, 'learning_rate': 1.3274368896259445e-05, 'epoch': 2.69}


 34%|███▍      | 14701/43416 [50:07<1:35:02,  5.04it/s]

{'loss': 0.062, 'learning_rate': 1.322830292979547e-05, 'epoch': 2.71}


 34%|███▍      | 14801/43416 [50:27<1:36:21,  4.95it/s]

{'loss': 0.0597, 'learning_rate': 1.3182236963331492e-05, 'epoch': 2.73}


 34%|███▍      | 14900/43416 [50:47<1:35:14,  4.99it/s]

{'loss': 0.0614, 'learning_rate': 1.3136170996867515e-05, 'epoch': 2.75}


 35%|███▍      | 15001/43416 [51:07<1:33:46,  5.05it/s]

{'loss': 0.0654, 'learning_rate': 1.309010503040354e-05, 'epoch': 2.76}


 35%|███▍      | 15100/43416 [51:26<1:34:01,  5.02it/s]

{'loss': 0.06, 'learning_rate': 1.3044039063939562e-05, 'epoch': 2.78}


 35%|███▌      | 15201/43416 [51:47<1:34:37,  4.97it/s]

{'loss': 0.0579, 'learning_rate': 1.2997973097475587e-05, 'epoch': 2.8}


 35%|███▌      | 15301/43416 [52:07<1:32:38,  5.06it/s]

{'loss': 0.0598, 'learning_rate': 1.2951907131011609e-05, 'epoch': 2.82}


 35%|███▌      | 15401/43416 [52:26<1:33:17,  5.00it/s]

{'loss': 0.0577, 'learning_rate': 1.2905841164547632e-05, 'epoch': 2.84}


 36%|███▌      | 15500/43416 [52:46<1:32:18,  5.04it/s]

{'loss': 0.0681, 'learning_rate': 1.2859775198083657e-05, 'epoch': 2.86}


 36%|███▌      | 15601/43416 [53:06<1:31:41,  5.06it/s]

{'loss': 0.0587, 'learning_rate': 1.2813709231619679e-05, 'epoch': 2.87}


 36%|███▌      | 15701/43416 [53:26<1:32:02,  5.02it/s]

{'loss': 0.062, 'learning_rate': 1.2767643265155704e-05, 'epoch': 2.89}


 36%|███▋      | 15801/43416 [53:46<1:32:25,  4.98it/s]

{'loss': 0.0651, 'learning_rate': 1.2721577298691728e-05, 'epoch': 2.91}


 37%|███▋      | 15901/43416 [54:06<1:31:36,  5.01it/s]

{'loss': 0.0621, 'learning_rate': 1.267551133222775e-05, 'epoch': 2.93}


 37%|███▋      | 16001/43416 [54:26<1:31:58,  4.97it/s]

{'loss': 0.0606, 'learning_rate': 1.2629445365763775e-05, 'epoch': 2.95}


 37%|███▋      | 16101/43416 [54:46<1:31:11,  4.99it/s]

{'loss': 0.0583, 'learning_rate': 1.2583379399299798e-05, 'epoch': 2.97}


 37%|███▋      | 16201/43416 [55:06<1:31:08,  4.98it/s]

{'loss': 0.0637, 'learning_rate': 1.2537313432835823e-05, 'epoch': 2.99}


 38%|███▊      | 16281/43416 [55:22<1:20:55,  5.59it/s]
 38%|███▊      | 16281/43416 [55:56<1:20:55,  5.59it/s]

{'eval_loss': 0.08621615171432495, 'eval_subset_accuracy': 0.47862145226686326, 'eval_micro_f1': 0.5887528031740555, 'eval_macro_f1': 0.4933984011684623, 'eval_runtime': 34.0851, 'eval_samples_per_second': 159.19, 'eval_steps_per_second': 19.921, 'epoch': 3.0}


 38%|███▊      | 16301/43416 [56:05<1:35:59,  4.71it/s] 

{'loss': 0.0585, 'learning_rate': 1.2491247466371845e-05, 'epoch': 3.0}


 38%|███▊      | 16401/43416 [56:25<1:29:44,  5.02it/s]

{'loss': 0.049, 'learning_rate': 1.244518149990787e-05, 'epoch': 3.02}


 38%|███▊      | 16501/43416 [56:45<1:29:35,  5.01it/s]

{'loss': 0.0429, 'learning_rate': 1.2399115533443893e-05, 'epoch': 3.04}


 38%|███▊      | 16601/43416 [57:05<1:28:57,  5.02it/s]

{'loss': 0.0446, 'learning_rate': 1.2353049566979915e-05, 'epoch': 3.06}


 38%|███▊      | 16701/43416 [57:25<1:29:32,  4.97it/s]

{'loss': 0.0416, 'learning_rate': 1.230698360051594e-05, 'epoch': 3.08}


 39%|███▊      | 16801/43416 [57:45<1:28:38,  5.00it/s]

{'loss': 0.0448, 'learning_rate': 1.2260917634051964e-05, 'epoch': 3.1}


 39%|███▉      | 16901/43416 [58:04<1:27:41,  5.04it/s]

{'loss': 0.0467, 'learning_rate': 1.2214851667587987e-05, 'epoch': 3.11}


 39%|███▉      | 17001/43416 [58:24<1:27:49,  5.01it/s]

{'loss': 0.0496, 'learning_rate': 1.216878570112401e-05, 'epoch': 3.13}


 39%|███▉      | 17100/43416 [58:44<1:28:57,  4.93it/s]

{'loss': 0.0439, 'learning_rate': 1.2122719734660034e-05, 'epoch': 3.15}


 40%|███▉      | 17201/43416 [59:04<1:27:03,  5.02it/s]

{'loss': 0.0475, 'learning_rate': 1.2076653768196058e-05, 'epoch': 3.17}


 40%|███▉      | 17301/43416 [59:24<1:27:24,  4.98it/s]

{'loss': 0.0425, 'learning_rate': 1.2030587801732081e-05, 'epoch': 3.19}


 40%|████      | 17401/43416 [59:44<1:26:24,  5.02it/s]

{'loss': 0.0468, 'learning_rate': 1.1984521835268106e-05, 'epoch': 3.21}


 40%|████      | 17501/43416 [1:00:04<1:25:16,  5.07it/s]

{'loss': 0.0427, 'learning_rate': 1.1938455868804128e-05, 'epoch': 3.22}


 41%|████      | 17601/43416 [1:00:24<1:25:45,  5.02it/s]

{'loss': 0.0474, 'learning_rate': 1.1892389902340151e-05, 'epoch': 3.24}


 41%|████      | 17701/43416 [1:00:44<1:25:59,  4.98it/s]

{'loss': 0.0455, 'learning_rate': 1.1846323935876176e-05, 'epoch': 3.26}


 41%|████      | 17801/43416 [1:01:04<1:23:42,  5.10it/s]

{'loss': 0.0423, 'learning_rate': 1.1800257969412198e-05, 'epoch': 3.28}


 41%|████      | 17901/43416 [1:01:23<1:24:01,  5.06it/s]

{'loss': 0.0419, 'learning_rate': 1.1754192002948223e-05, 'epoch': 3.3}


 41%|████▏     | 18000/43416 [1:01:43<1:24:30,  5.01it/s]

{'loss': 0.0456, 'learning_rate': 1.1708126036484247e-05, 'epoch': 3.32}


 42%|████▏     | 18101/43416 [1:02:03<1:23:48,  5.03it/s]

{'loss': 0.0453, 'learning_rate': 1.1662060070020269e-05, 'epoch': 3.34}


 42%|████▏     | 18201/43416 [1:02:23<1:24:09,  4.99it/s]

{'loss': 0.0471, 'learning_rate': 1.1615994103556294e-05, 'epoch': 3.35}


 42%|████▏     | 18301/43416 [1:02:43<1:23:59,  4.98it/s]

{'loss': 0.0464, 'learning_rate': 1.1569928137092317e-05, 'epoch': 3.37}


 42%|████▏     | 18401/43416 [1:03:03<1:22:11,  5.07it/s]

{'loss': 0.045, 'learning_rate': 1.1523862170628342e-05, 'epoch': 3.39}


 43%|████▎     | 18501/43416 [1:03:23<1:22:58,  5.00it/s]

{'loss': 0.0447, 'learning_rate': 1.1477796204164364e-05, 'epoch': 3.41}


 43%|████▎     | 18601/43416 [1:03:43<1:23:54,  4.93it/s]

{'loss': 0.0454, 'learning_rate': 1.1431730237700387e-05, 'epoch': 3.43}


 43%|████▎     | 18701/43416 [1:04:02<1:21:24,  5.06it/s]

{'loss': 0.0436, 'learning_rate': 1.1385664271236413e-05, 'epoch': 3.45}


 43%|████▎     | 18801/43416 [1:04:22<1:22:07,  5.00it/s]

{'loss': 0.0454, 'learning_rate': 1.1339598304772434e-05, 'epoch': 3.46}


 44%|████▎     | 18901/43416 [1:04:42<1:21:29,  5.01it/s]

{'loss': 0.0439, 'learning_rate': 1.129353233830846e-05, 'epoch': 3.48}


 44%|████▍     | 19001/43416 [1:05:02<1:20:23,  5.06it/s]

{'loss': 0.0432, 'learning_rate': 1.1247466371844483e-05, 'epoch': 3.5}


 44%|████▍     | 19101/43416 [1:05:22<1:20:51,  5.01it/s]

{'loss': 0.0411, 'learning_rate': 1.1201400405380505e-05, 'epoch': 3.52}


 44%|████▍     | 19201/43416 [1:05:42<1:20:23,  5.02it/s]

{'loss': 0.0407, 'learning_rate': 1.115533443891653e-05, 'epoch': 3.54}


 44%|████▍     | 19301/43416 [1:06:02<1:20:44,  4.98it/s]

{'loss': 0.0469, 'learning_rate': 1.1109268472452552e-05, 'epoch': 3.56}


 45%|████▍     | 19401/43416 [1:06:22<1:20:36,  4.97it/s]

{'loss': 0.0416, 'learning_rate': 1.1063202505988577e-05, 'epoch': 3.57}


 45%|████▍     | 19500/43416 [1:06:41<1:18:57,  5.05it/s]

{'loss': 0.0428, 'learning_rate': 1.10171365395246e-05, 'epoch': 3.59}


 45%|████▌     | 19601/43416 [1:07:01<1:18:45,  5.04it/s]

{'loss': 0.0444, 'learning_rate': 1.0971070573060622e-05, 'epoch': 3.61}


 45%|████▌     | 19701/43416 [1:07:21<1:17:46,  5.08it/s]

{'loss': 0.0464, 'learning_rate': 1.0925004606596647e-05, 'epoch': 3.63}


 46%|████▌     | 19801/43416 [1:07:41<1:19:41,  4.94it/s]

{'loss': 0.0489, 'learning_rate': 1.087893864013267e-05, 'epoch': 3.65}


 46%|████▌     | 19901/43416 [1:08:01<1:17:26,  5.06it/s]

{'loss': 0.0481, 'learning_rate': 1.0832872673668696e-05, 'epoch': 3.67}


 46%|████▌     | 20001/43416 [1:08:21<1:17:53,  5.01it/s]

{'loss': 0.0468, 'learning_rate': 1.0786806707204717e-05, 'epoch': 3.69}


 46%|████▋     | 20101/43416 [1:08:41<1:18:34,  4.94it/s]

{'loss': 0.0436, 'learning_rate': 1.0740740740740742e-05, 'epoch': 3.7}


 47%|████▋     | 20201/43416 [1:09:01<1:16:41,  5.05it/s]

{'loss': 0.043, 'learning_rate': 1.0694674774276766e-05, 'epoch': 3.72}


 47%|████▋     | 20301/43416 [1:09:21<1:16:23,  5.04it/s]

{'loss': 0.0437, 'learning_rate': 1.0648608807812788e-05, 'epoch': 3.74}


 47%|████▋     | 20401/43416 [1:09:41<1:16:52,  4.99it/s]

{'loss': 0.0416, 'learning_rate': 1.0602542841348813e-05, 'epoch': 3.76}


 47%|████▋     | 20501/43416 [1:10:01<1:16:25,  5.00it/s]

{'loss': 0.0418, 'learning_rate': 1.0556476874884836e-05, 'epoch': 3.78}


 47%|████▋     | 20601/43416 [1:10:21<1:15:08,  5.06it/s]

{'loss': 0.0423, 'learning_rate': 1.051041090842086e-05, 'epoch': 3.8}


 48%|████▊     | 20701/43416 [1:10:41<1:15:05,  5.04it/s]

{'loss': 0.0431, 'learning_rate': 1.0464344941956883e-05, 'epoch': 3.81}


 48%|████▊     | 20801/43416 [1:11:00<1:14:02,  5.09it/s]

{'loss': 0.0394, 'learning_rate': 1.0418278975492907e-05, 'epoch': 3.83}


 48%|████▊     | 20901/43416 [1:11:20<1:14:21,  5.05it/s]

{'loss': 0.045, 'learning_rate': 1.037221300902893e-05, 'epoch': 3.85}


 48%|████▊     | 21001/43416 [1:11:40<1:14:57,  4.98it/s]

{'loss': 0.0469, 'learning_rate': 1.0326147042564953e-05, 'epoch': 3.87}


 49%|████▊     | 21101/43416 [1:12:00<1:14:15,  5.01it/s]

{'loss': 0.0437, 'learning_rate': 1.0280081076100979e-05, 'epoch': 3.89}


 49%|████▉     | 21201/43416 [1:12:20<1:13:05,  5.07it/s]

{'loss': 0.0448, 'learning_rate': 1.0234015109637e-05, 'epoch': 3.91}


 49%|████▉     | 21301/43416 [1:12:40<1:13:42,  5.00it/s]

{'loss': 0.0434, 'learning_rate': 1.0187949143173024e-05, 'epoch': 3.92}


 49%|████▉     | 21401/43416 [1:13:00<1:12:46,  5.04it/s]

{'loss': 0.046, 'learning_rate': 1.0141883176709049e-05, 'epoch': 3.94}


 50%|████▉     | 21500/43416 [1:13:19<1:12:20,  5.05it/s]

{'loss': 0.0464, 'learning_rate': 1.009581721024507e-05, 'epoch': 3.96}


 50%|████▉     | 21600/43416 [1:13:39<1:12:15,  5.03it/s]

{'loss': 0.0455, 'learning_rate': 1.0049751243781096e-05, 'epoch': 3.98}


 50%|████▉     | 21700/43416 [1:13:59<1:11:17,  5.08it/s]

{'loss': 0.042, 'learning_rate': 1.000368527731712e-05, 'epoch': 4.0}


 50%|█████     | 21708/43416 [1:14:01<1:06:12,  5.46it/s]
 50%|█████     | 21708/43416 [1:14:35<1:06:12,  5.46it/s]

{'eval_loss': 0.10178525745868683, 'eval_subset_accuracy': 0.4683007740508662, 'eval_micro_f1': 0.5786456169505608, 'eval_macro_f1': 0.4946216894515417, 'eval_runtime': 34.0585, 'eval_samples_per_second': 159.314, 'eval_steps_per_second': 19.936, 'epoch': 4.0}


 50%|█████     | 21801/43416 [1:14:58<1:11:32,  5.04it/s] 

{'loss': 0.0321, 'learning_rate': 9.957619310853143e-06, 'epoch': 4.02}


 50%|█████     | 21901/43416 [1:15:18<1:11:07,  5.04it/s]

{'loss': 0.0323, 'learning_rate': 9.911553344389166e-06, 'epoch': 4.04}


 51%|█████     | 22001/43416 [1:15:38<1:11:12,  5.01it/s]

{'loss': 0.0299, 'learning_rate': 9.86548737792519e-06, 'epoch': 4.05}


 51%|█████     | 22101/43416 [1:15:58<1:10:35,  5.03it/s]

{'loss': 0.0294, 'learning_rate': 9.819421411461213e-06, 'epoch': 4.07}


 51%|█████     | 22201/43416 [1:16:18<1:10:19,  5.03it/s]

{'loss': 0.0301, 'learning_rate': 9.773355444997236e-06, 'epoch': 4.09}


 51%|█████▏    | 22301/43416 [1:16:38<1:10:18,  5.00it/s]

{'loss': 0.0315, 'learning_rate': 9.72728947853326e-06, 'epoch': 4.11}


 52%|█████▏    | 22401/43416 [1:16:58<1:09:17,  5.05it/s]

{'loss': 0.0303, 'learning_rate': 9.681223512069285e-06, 'epoch': 4.13}


 52%|█████▏    | 22501/43416 [1:17:18<1:09:06,  5.04it/s]

{'loss': 0.0267, 'learning_rate': 9.635157545605309e-06, 'epoch': 4.15}


 52%|█████▏    | 22601/43416 [1:17:37<1:10:14,  4.94it/s]

{'loss': 0.0294, 'learning_rate': 9.58909157914133e-06, 'epoch': 4.16}


 52%|█████▏    | 22701/43416 [1:17:57<1:08:35,  5.03it/s]

{'loss': 0.0308, 'learning_rate': 9.543025612677355e-06, 'epoch': 4.18}


 53%|█████▎    | 22801/43416 [1:18:17<1:08:42,  5.00it/s]

{'loss': 0.0294, 'learning_rate': 9.496959646213379e-06, 'epoch': 4.2}


 53%|█████▎    | 22901/43416 [1:18:37<1:08:50,  4.97it/s]

{'loss': 0.0306, 'learning_rate': 9.450893679749402e-06, 'epoch': 4.22}


 53%|█████▎    | 23001/43416 [1:18:57<1:07:34,  5.03it/s]

{'loss': 0.0289, 'learning_rate': 9.404827713285426e-06, 'epoch': 4.24}


 53%|█████▎    | 23101/43416 [1:19:17<1:07:12,  5.04it/s]

{'loss': 0.0318, 'learning_rate': 9.35876174682145e-06, 'epoch': 4.26}


 53%|█████▎    | 23201/43416 [1:19:37<1:07:01,  5.03it/s]

{'loss': 0.0255, 'learning_rate': 9.312695780357473e-06, 'epoch': 4.27}


 54%|█████▎    | 23301/43416 [1:19:57<1:06:31,  5.04it/s]

{'loss': 0.0301, 'learning_rate': 9.266629813893496e-06, 'epoch': 4.29}


 54%|█████▍    | 23401/43416 [1:20:17<1:06:32,  5.01it/s]

{'loss': 0.03, 'learning_rate': 9.22056384742952e-06, 'epoch': 4.31}


 54%|█████▍    | 23501/43416 [1:20:36<1:06:37,  4.98it/s]

{'loss': 0.0354, 'learning_rate': 9.174497880965545e-06, 'epoch': 4.33}


 54%|█████▍    | 23601/43416 [1:20:56<1:06:04,  5.00it/s]

{'loss': 0.0288, 'learning_rate': 9.128431914501566e-06, 'epoch': 4.35}


 55%|█████▍    | 23701/43416 [1:21:16<1:05:02,  5.05it/s]

{'loss': 0.0285, 'learning_rate': 9.08236594803759e-06, 'epoch': 4.37}


 55%|█████▍    | 23800/43416 [1:21:36<1:04:40,  5.06it/s]

{'loss': 0.0308, 'learning_rate': 9.036299981573613e-06, 'epoch': 4.39}


 55%|█████▌    | 23901/43416 [1:21:56<1:04:10,  5.07it/s]

{'loss': 0.0297, 'learning_rate': 8.990234015109638e-06, 'epoch': 4.4}


 55%|█████▌    | 24001/43416 [1:22:16<1:04:36,  5.01it/s]

{'loss': 0.0301, 'learning_rate': 8.944168048645662e-06, 'epoch': 4.42}


 56%|█████▌    | 24101/43416 [1:22:36<1:04:48,  4.97it/s]

{'loss': 0.0284, 'learning_rate': 8.898102082181684e-06, 'epoch': 4.44}


 56%|█████▌    | 24201/43416 [1:22:56<1:03:32,  5.04it/s]

{'loss': 0.0317, 'learning_rate': 8.852036115717709e-06, 'epoch': 4.46}


 56%|█████▌    | 24301/43416 [1:23:16<1:02:37,  5.09it/s]

{'loss': 0.0274, 'learning_rate': 8.805970149253732e-06, 'epoch': 4.48}


 56%|█████▌    | 24401/43416 [1:23:35<1:04:06,  4.94it/s]

{'loss': 0.0299, 'learning_rate': 8.759904182789756e-06, 'epoch': 4.5}


 56%|█████▋    | 24501/43416 [1:23:55<1:03:05,  5.00it/s]

{'loss': 0.0312, 'learning_rate': 8.713838216325779e-06, 'epoch': 4.51}


 57%|█████▋    | 24601/43416 [1:24:16<1:03:58,  4.90it/s]

{'loss': 0.0327, 'learning_rate': 8.667772249861803e-06, 'epoch': 4.53}


 57%|█████▋    | 24701/43416 [1:24:36<1:02:47,  4.97it/s]

{'loss': 0.0288, 'learning_rate': 8.621706283397826e-06, 'epoch': 4.55}


 57%|█████▋    | 24801/43416 [1:24:55<1:01:42,  5.03it/s]

{'loss': 0.0257, 'learning_rate': 8.57564031693385e-06, 'epoch': 4.57}


 57%|█████▋    | 24901/43416 [1:25:15<1:01:11,  5.04it/s]

{'loss': 0.0339, 'learning_rate': 8.529574350469873e-06, 'epoch': 4.59}


 58%|█████▊    | 25000/43416 [1:25:35<1:00:29,  5.07it/s]

{'loss': 0.0309, 'learning_rate': 8.483508384005898e-06, 'epoch': 4.61}


 58%|█████▊    | 25101/43416 [1:25:55<1:00:53,  5.01it/s]

{'loss': 0.0338, 'learning_rate': 8.437442417541921e-06, 'epoch': 4.63}


 58%|█████▊    | 25201/43416 [1:26:15<1:00:21,  5.03it/s]

{'loss': 0.029, 'learning_rate': 8.391376451077943e-06, 'epoch': 4.64}


 58%|█████▊    | 25301/43416 [1:26:35<1:00:12,  5.01it/s]

{'loss': 0.0314, 'learning_rate': 8.345310484613968e-06, 'epoch': 4.66}


 59%|█████▊    | 25400/43416 [1:26:54<59:18,  5.06it/s]  

{'loss': 0.0303, 'learning_rate': 8.299244518149992e-06, 'epoch': 4.68}


 59%|█████▊    | 25501/43416 [1:27:15<59:02,  5.06it/s]  

{'loss': 0.0326, 'learning_rate': 8.253178551686015e-06, 'epoch': 4.7}


 59%|█████▉    | 25601/43416 [1:27:34<59:44,  4.97it/s]  

{'loss': 0.0285, 'learning_rate': 8.207112585222039e-06, 'epoch': 4.72}


 59%|█████▉    | 25700/43416 [1:27:54<58:43,  5.03it/s]  

{'loss': 0.0285, 'learning_rate': 8.161046618758062e-06, 'epoch': 4.74}


 59%|█████▉    | 25801/43416 [1:28:14<58:58,  4.98it/s]  

{'loss': 0.0339, 'learning_rate': 8.114980652294086e-06, 'epoch': 4.75}


 60%|█████▉    | 25901/43416 [1:28:34<57:53,  5.04it/s]

{'loss': 0.0317, 'learning_rate': 8.068914685830109e-06, 'epoch': 4.77}


 60%|█████▉    | 26000/43416 [1:28:54<57:09,  5.08it/s]

{'loss': 0.0301, 'learning_rate': 8.022848719366132e-06, 'epoch': 4.79}


 60%|██████    | 26101/43416 [1:29:14<57:07,  5.05it/s]

{'loss': 0.0305, 'learning_rate': 7.976782752902158e-06, 'epoch': 4.81}


 60%|██████    | 26201/43416 [1:29:34<57:20,  5.00it/s]  

{'loss': 0.0321, 'learning_rate': 7.930716786438181e-06, 'epoch': 4.83}


 61%|██████    | 26301/43416 [1:29:54<56:48,  5.02it/s]

{'loss': 0.0311, 'learning_rate': 7.884650819974203e-06, 'epoch': 4.85}


 61%|██████    | 26401/43416 [1:30:14<57:21,  4.94it/s]

{'loss': 0.028, 'learning_rate': 7.838584853510228e-06, 'epoch': 4.86}


 61%|██████    | 26501/43416 [1:30:34<56:18,  5.01it/s]

{'loss': 0.0322, 'learning_rate': 7.792518887046251e-06, 'epoch': 4.88}


 61%|██████▏   | 26601/43416 [1:30:54<55:43,  5.03it/s]

{'loss': 0.0284, 'learning_rate': 7.746452920582275e-06, 'epoch': 4.9}


 62%|██████▏   | 26701/43416 [1:31:14<55:08,  5.05it/s]

{'loss': 0.0278, 'learning_rate': 7.700386954118298e-06, 'epoch': 4.92}


 62%|██████▏   | 26801/43416 [1:31:34<55:10,  5.02it/s]

{'loss': 0.0277, 'learning_rate': 7.654320987654322e-06, 'epoch': 4.94}


 62%|██████▏   | 26901/43416 [1:31:53<54:38,  5.04it/s]

{'loss': 0.0314, 'learning_rate': 7.608255021190345e-06, 'epoch': 4.96}


 62%|██████▏   | 27001/43416 [1:32:13<54:03,  5.06it/s]

{'loss': 0.0259, 'learning_rate': 7.5621890547263685e-06, 'epoch': 4.98}


 62%|██████▏   | 27101/43416 [1:32:33<54:20,  5.00it/s]

{'loss': 0.0326, 'learning_rate': 7.516123088262393e-06, 'epoch': 4.99}


 62%|██████▎   | 27135/43416 [1:32:40<49:58,  5.43it/s]
 62%|██████▎   | 27135/43416 [1:33:14<49:58,  5.43it/s]

{'eval_loss': 0.11806705594062805, 'eval_subset_accuracy': 0.4489495023958717, 'eval_micro_f1': 0.5688353413654619, 'eval_macro_f1': 0.5041805800073187, 'eval_runtime': 33.8216, 'eval_samples_per_second': 160.43, 'eval_steps_per_second': 20.076, 'epoch': 5.0}


 63%|██████▎   | 27201/43416 [1:33:32<54:28,  4.96it/s]   

{'loss': 0.0235, 'learning_rate': 7.470057121798416e-06, 'epoch': 5.01}


 63%|██████▎   | 27301/43416 [1:33:52<53:25,  5.03it/s]

{'loss': 0.0215, 'learning_rate': 7.423991155334439e-06, 'epoch': 5.03}


 63%|██████▎   | 27401/43416 [1:34:12<52:58,  5.04it/s]

{'loss': 0.0189, 'learning_rate': 7.377925188870463e-06, 'epoch': 5.05}


 63%|██████▎   | 27501/43416 [1:34:31<52:37,  5.04it/s]

{'loss': 0.0198, 'learning_rate': 7.331859222406487e-06, 'epoch': 5.07}


 64%|██████▎   | 27601/43416 [1:34:51<52:30,  5.02it/s]

{'loss': 0.0221, 'learning_rate': 7.28579325594251e-06, 'epoch': 5.09}


 64%|██████▍   | 27700/43416 [1:35:11<52:04,  5.03it/s]

{'loss': 0.0213, 'learning_rate': 7.239727289478534e-06, 'epoch': 5.1}


 64%|██████▍   | 27801/43416 [1:35:31<52:06,  4.99it/s]

{'loss': 0.0217, 'learning_rate': 7.193661323014557e-06, 'epoch': 5.12}


 64%|██████▍   | 27900/43416 [1:35:51<51:29,  5.02it/s]

{'loss': 0.0182, 'learning_rate': 7.14759535655058e-06, 'epoch': 5.14}


 64%|██████▍   | 28001/43416 [1:36:11<51:18,  5.01it/s]

{'loss': 0.0231, 'learning_rate': 7.101529390086605e-06, 'epoch': 5.16}


 65%|██████▍   | 28101/43416 [1:36:31<51:37,  4.94it/s]

{'loss': 0.0224, 'learning_rate': 7.055463423622628e-06, 'epoch': 5.18}


 65%|██████▍   | 28201/43416 [1:36:51<51:14,  4.95it/s]

{'loss': 0.0206, 'learning_rate': 7.009397457158652e-06, 'epoch': 5.2}


 65%|██████▌   | 28300/43416 [1:37:10<49:44,  5.07it/s]

{'loss': 0.0192, 'learning_rate': 6.963331490694676e-06, 'epoch': 5.21}


 65%|██████▌   | 28401/43416 [1:37:30<49:34,  5.05it/s]

{'loss': 0.022, 'learning_rate': 6.9172655242306984e-06, 'epoch': 5.23}


 66%|██████▌   | 28500/43416 [1:37:50<50:50,  4.89it/s]

{'loss': 0.0199, 'learning_rate': 6.871199557766723e-06, 'epoch': 5.25}


 66%|██████▌   | 28601/43416 [1:38:10<48:39,  5.07it/s]

{'loss': 0.0196, 'learning_rate': 6.825133591302746e-06, 'epoch': 5.27}


 66%|██████▌   | 28700/43416 [1:38:30<48:34,  5.05it/s]

{'loss': 0.0208, 'learning_rate': 6.77906762483877e-06, 'epoch': 5.29}


 66%|██████▋   | 28801/43416 [1:38:50<48:23,  5.03it/s]

{'loss': 0.0212, 'learning_rate': 6.733001658374794e-06, 'epoch': 5.31}


 67%|██████▋   | 28901/43416 [1:39:10<47:40,  5.07it/s]

{'loss': 0.0203, 'learning_rate': 6.6869356919108165e-06, 'epoch': 5.33}


 67%|██████▋   | 29001/43416 [1:39:30<47:54,  5.01it/s]

{'loss': 0.0196, 'learning_rate': 6.64086972544684e-06, 'epoch': 5.34}


 67%|██████▋   | 29101/43416 [1:39:50<48:27,  4.92it/s]

{'loss': 0.0207, 'learning_rate': 6.594803758982864e-06, 'epoch': 5.36}


 67%|██████▋   | 29201/43416 [1:40:10<47:14,  5.01it/s]

{'loss': 0.0181, 'learning_rate': 6.548737792518888e-06, 'epoch': 5.38}


 67%|██████▋   | 29301/43416 [1:40:30<47:00,  5.00it/s]

{'loss': 0.0214, 'learning_rate': 6.502671826054911e-06, 'epoch': 5.4}


 68%|██████▊   | 29401/43416 [1:40:50<46:34,  5.02it/s]

{'loss': 0.0198, 'learning_rate': 6.4566058595909346e-06, 'epoch': 5.42}


 68%|██████▊   | 29501/43416 [1:41:09<45:40,  5.08it/s]

{'loss': 0.0207, 'learning_rate': 6.410539893126958e-06, 'epoch': 5.44}


 68%|██████▊   | 29601/43416 [1:41:29<46:31,  4.95it/s]

{'loss': 0.0197, 'learning_rate': 6.3644739266629815e-06, 'epoch': 5.45}


 68%|██████▊   | 29700/43416 [1:41:49<45:29,  5.03it/s]

{'loss': 0.0204, 'learning_rate': 6.318407960199006e-06, 'epoch': 5.47}


 69%|██████▊   | 29801/43416 [1:42:09<45:12,  5.02it/s]

{'loss': 0.0219, 'learning_rate': 6.272341993735029e-06, 'epoch': 5.49}


 69%|██████▉   | 29901/43416 [1:42:29<44:55,  5.01it/s]

{'loss': 0.0202, 'learning_rate': 6.2262760272710535e-06, 'epoch': 5.51}


 69%|██████▉   | 30001/43416 [1:42:49<45:05,  4.96it/s]

{'loss': 0.02, 'learning_rate': 6.180210060807076e-06, 'epoch': 5.53}


 69%|██████▉   | 30101/43416 [1:43:09<43:40,  5.08it/s]

{'loss': 0.0187, 'learning_rate': 6.1341440943430995e-06, 'epoch': 5.55}


 70%|██████▉   | 30201/43416 [1:43:29<44:24,  4.96it/s]

{'loss': 0.022, 'learning_rate': 6.088078127879124e-06, 'epoch': 5.56}


 70%|██████▉   | 30301/43416 [1:43:49<44:44,  4.88it/s]

{'loss': 0.0197, 'learning_rate': 6.042012161415147e-06, 'epoch': 5.58}


 70%|███████   | 30401/43416 [1:44:09<43:05,  5.03it/s]

{'loss': 0.0165, 'learning_rate': 5.995946194951171e-06, 'epoch': 5.6}


 70%|███████   | 30501/43416 [1:44:29<44:17,  4.86it/s]

{'loss': 0.019, 'learning_rate': 5.949880228487194e-06, 'epoch': 5.62}


 70%|███████   | 30601/43416 [1:44:49<42:53,  4.98it/s]

{'loss': 0.0195, 'learning_rate': 5.9038142620232176e-06, 'epoch': 5.64}


 71%|███████   | 30701/43416 [1:45:09<42:03,  5.04it/s]

{'loss': 0.0178, 'learning_rate': 5.857748295559241e-06, 'epoch': 5.66}


 71%|███████   | 30801/43416 [1:45:29<42:04,  5.00it/s]

{'loss': 0.0199, 'learning_rate': 5.811682329095265e-06, 'epoch': 5.68}


 71%|███████   | 30901/43416 [1:45:49<41:24,  5.04it/s]

{'loss': 0.0184, 'learning_rate': 5.765616362631289e-06, 'epoch': 5.69}


 71%|███████▏  | 31001/43416 [1:46:09<41:01,  5.04it/s]

{'loss': 0.0198, 'learning_rate': 5.719550396167311e-06, 'epoch': 5.71}


 72%|███████▏  | 31101/43416 [1:46:29<40:52,  5.02it/s]

{'loss': 0.0181, 'learning_rate': 5.673484429703336e-06, 'epoch': 5.73}


 72%|███████▏  | 31201/43416 [1:46:49<40:45,  5.00it/s]

{'loss': 0.0174, 'learning_rate': 5.627418463239359e-06, 'epoch': 5.75}


 72%|███████▏  | 31301/43416 [1:47:09<40:20,  5.01it/s]

{'loss': 0.0223, 'learning_rate': 5.5813524967753825e-06, 'epoch': 5.77}


 72%|███████▏  | 31401/43416 [1:47:28<39:54,  5.02it/s]

{'loss': 0.0198, 'learning_rate': 5.535286530311407e-06, 'epoch': 5.79}


 73%|███████▎  | 31501/43416 [1:47:48<39:30,  5.03it/s]

{'loss': 0.0204, 'learning_rate': 5.48922056384743e-06, 'epoch': 5.8}


 73%|███████▎  | 31600/43416 [1:48:08<39:23,  5.00it/s]

{'loss': 0.0189, 'learning_rate': 5.443154597383453e-06, 'epoch': 5.82}


 73%|███████▎  | 31701/43416 [1:48:28<39:04,  5.00it/s]

{'loss': 0.0206, 'learning_rate': 5.397088630919477e-06, 'epoch': 5.84}


 73%|███████▎  | 31801/43416 [1:48:48<38:34,  5.02it/s]

{'loss': 0.0214, 'learning_rate': 5.351022664455501e-06, 'epoch': 5.86}


 73%|███████▎  | 31900/43416 [1:49:08<37:45,  5.08it/s]

{'loss': 0.0167, 'learning_rate': 5.304956697991525e-06, 'epoch': 5.88}


 74%|███████▎  | 32001/43416 [1:49:28<38:06,  4.99it/s]

{'loss': 0.022, 'learning_rate': 5.258890731527548e-06, 'epoch': 5.9}


 74%|███████▍  | 32100/43416 [1:49:48<37:45,  5.00it/s]

{'loss': 0.0214, 'learning_rate': 5.212824765063571e-06, 'epoch': 5.91}


 74%|███████▍  | 32201/43416 [1:50:08<37:09,  5.03it/s]

{'loss': 0.0197, 'learning_rate': 5.166758798599595e-06, 'epoch': 5.93}


 74%|███████▍  | 32301/43416 [1:50:28<36:59,  5.01it/s]

{'loss': 0.0194, 'learning_rate': 5.120692832135619e-06, 'epoch': 5.95}


 75%|███████▍  | 32400/43416 [1:50:47<36:37,  5.01it/s]

{'loss': 0.0201, 'learning_rate': 5.074626865671642e-06, 'epoch': 5.97}


 75%|███████▍  | 32500/43416 [1:51:07<35:44,  5.09it/s]

{'loss': 0.019, 'learning_rate': 5.028560899207666e-06, 'epoch': 5.99}


 75%|███████▌  | 32562/43416 [1:51:19<32:21,  5.59it/s]
 75%|███████▌  | 32562/43416 [1:51:53<32:21,  5.59it/s]

{'eval_loss': 0.13296400010585785, 'eval_subset_accuracy': 0.4375230372281607, 'eval_micro_f1': 0.5709591613206071, 'eval_macro_f1': 0.5136715490700043, 'eval_runtime': 33.9806, 'eval_samples_per_second': 159.679, 'eval_steps_per_second': 19.982, 'epoch': 6.0}


 75%|███████▌  | 32600/43416 [1:52:06<35:44,  5.04it/s]   

{'loss': 0.0165, 'learning_rate': 4.98249493274369e-06, 'epoch': 6.01}


 75%|███████▌  | 32700/43416 [1:52:26<35:27,  5.04it/s]

{'loss': 0.0149, 'learning_rate': 4.936428966279712e-06, 'epoch': 6.03}


 76%|███████▌  | 32801/43416 [1:52:46<35:14,  5.02it/s]

{'loss': 0.0133, 'learning_rate': 4.890362999815737e-06, 'epoch': 6.04}


 76%|███████▌  | 32901/43416 [1:53:06<34:43,  5.05it/s]

{'loss': 0.0113, 'learning_rate': 4.84429703335176e-06, 'epoch': 6.06}


 76%|███████▌  | 33001/43416 [1:53:26<35:26,  4.90it/s]

{'loss': 0.0138, 'learning_rate': 4.798231066887784e-06, 'epoch': 6.08}


 76%|███████▌  | 33100/43416 [1:53:45<34:13,  5.02it/s]

{'loss': 0.012, 'learning_rate': 4.752165100423807e-06, 'epoch': 6.1}


 76%|███████▋  | 33201/43416 [1:54:06<34:05,  4.99it/s]

{'loss': 0.0124, 'learning_rate': 4.706099133959831e-06, 'epoch': 6.12}


 77%|███████▋  | 33301/43416 [1:54:26<33:44,  5.00it/s]

{'loss': 0.0133, 'learning_rate': 4.660033167495854e-06, 'epoch': 6.14}


 77%|███████▋  | 33401/43416 [1:54:46<33:19,  5.01it/s]

{'loss': 0.0161, 'learning_rate': 4.613967201031878e-06, 'epoch': 6.15}


 77%|███████▋  | 33501/43416 [1:55:05<32:35,  5.07it/s]

{'loss': 0.0145, 'learning_rate': 4.567901234567902e-06, 'epoch': 6.17}


 77%|███████▋  | 33601/43416 [1:55:25<32:55,  4.97it/s]

{'loss': 0.0127, 'learning_rate': 4.521835268103925e-06, 'epoch': 6.19}


 78%|███████▊  | 33701/43416 [1:55:45<32:17,  5.01it/s]

{'loss': 0.0122, 'learning_rate': 4.4757693016399485e-06, 'epoch': 6.21}


 78%|███████▊  | 33801/43416 [1:56:05<31:47,  5.04it/s]

{'loss': 0.0115, 'learning_rate': 4.429703335175972e-06, 'epoch': 6.23}


 78%|███████▊  | 33901/43416 [1:56:25<31:22,  5.06it/s]

{'loss': 0.0135, 'learning_rate': 4.383637368711996e-06, 'epoch': 6.25}


 78%|███████▊  | 34001/43416 [1:56:45<31:36,  4.97it/s]

{'loss': 0.0125, 'learning_rate': 4.33757140224802e-06, 'epoch': 6.26}


 79%|███████▊  | 34101/43416 [1:57:05<30:48,  5.04it/s]

{'loss': 0.0124, 'learning_rate': 4.291505435784043e-06, 'epoch': 6.28}


 79%|███████▉  | 34201/43416 [1:57:25<31:39,  4.85it/s]

{'loss': 0.0114, 'learning_rate': 4.245439469320067e-06, 'epoch': 6.3}


 79%|███████▉  | 34300/43416 [1:57:44<30:21,  5.01it/s]

{'loss': 0.0148, 'learning_rate': 4.19937350285609e-06, 'epoch': 6.32}


 79%|███████▉  | 34401/43416 [1:58:05<30:04,  5.00it/s]

{'loss': 0.0145, 'learning_rate': 4.1533075363921135e-06, 'epoch': 6.34}


 79%|███████▉  | 34501/43416 [1:58:24<29:34,  5.02it/s]

{'loss': 0.0139, 'learning_rate': 4.107241569928138e-06, 'epoch': 6.36}


 80%|███████▉  | 34600/43416 [1:58:44<29:17,  5.02it/s]

{'loss': 0.0147, 'learning_rate': 4.061175603464161e-06, 'epoch': 6.38}


 80%|███████▉  | 34701/43416 [1:59:04<28:43,  5.06it/s]

{'loss': 0.0122, 'learning_rate': 4.015109637000185e-06, 'epoch': 6.39}


 80%|████████  | 34801/43416 [1:59:24<29:20,  4.89it/s]

{'loss': 0.0156, 'learning_rate': 3.969043670536208e-06, 'epoch': 6.41}


 80%|████████  | 34901/43416 [1:59:44<28:13,  5.03it/s]

{'loss': 0.0133, 'learning_rate': 3.9229777040722316e-06, 'epoch': 6.43}


 81%|████████  | 35001/43416 [2:00:04<28:05,  4.99it/s]

{'loss': 0.0128, 'learning_rate': 3.876911737608256e-06, 'epoch': 6.45}


 81%|████████  | 35101/43416 [2:00:24<27:44,  5.00it/s]

{'loss': 0.0128, 'learning_rate': 3.8308457711442784e-06, 'epoch': 6.47}


 81%|████████  | 35200/43416 [2:00:44<27:34,  4.96it/s]

{'loss': 0.0135, 'learning_rate': 3.7847798046803023e-06, 'epoch': 6.49}


 81%|████████▏ | 35300/43416 [2:01:04<26:40,  5.07it/s]

{'loss': 0.0142, 'learning_rate': 3.738713838216326e-06, 'epoch': 6.5}


 82%|████████▏ | 35401/43416 [2:01:24<26:30,  5.04it/s]

{'loss': 0.0116, 'learning_rate': 3.6926478717523496e-06, 'epoch': 6.52}


 82%|████████▏ | 35500/43416 [2:01:43<26:03,  5.06it/s]

{'loss': 0.0156, 'learning_rate': 3.6465819052883735e-06, 'epoch': 6.54}


 82%|████████▏ | 35601/43416 [2:02:04<25:44,  5.06it/s]

{'loss': 0.0124, 'learning_rate': 3.600515938824397e-06, 'epoch': 6.56}


 82%|████████▏ | 35700/43416 [2:02:23<25:32,  5.03it/s]

{'loss': 0.013, 'learning_rate': 3.5544499723604204e-06, 'epoch': 6.58}


 82%|████████▏ | 35801/43416 [2:02:43<25:21,  5.01it/s]

{'loss': 0.0132, 'learning_rate': 3.5083840058964442e-06, 'epoch': 6.6}


 83%|████████▎ | 35900/43416 [2:03:03<24:40,  5.08it/s]

{'loss': 0.0135, 'learning_rate': 3.4623180394324673e-06, 'epoch': 6.62}


 83%|████████▎ | 36000/43416 [2:03:23<25:24,  4.87it/s]

{'loss': 0.0152, 'learning_rate': 3.416252072968491e-06, 'epoch': 6.63}


 83%|████████▎ | 36101/43416 [2:03:43<24:53,  4.90it/s]

{'loss': 0.0131, 'learning_rate': 3.370186106504515e-06, 'epoch': 6.65}


 83%|████████▎ | 36201/43416 [2:04:03<23:57,  5.02it/s]

{'loss': 0.0156, 'learning_rate': 3.324120140040538e-06, 'epoch': 6.67}


 84%|████████▎ | 36301/43416 [2:04:23<23:48,  4.98it/s]

{'loss': 0.0137, 'learning_rate': 3.278054173576562e-06, 'epoch': 6.69}


 84%|████████▍ | 36401/43416 [2:04:43<23:23,  5.00it/s]

{'loss': 0.0123, 'learning_rate': 3.2319882071125857e-06, 'epoch': 6.71}


 84%|████████▍ | 36501/43416 [2:05:03<22:57,  5.02it/s]

{'loss': 0.0122, 'learning_rate': 3.185922240648609e-06, 'epoch': 6.73}


 84%|████████▍ | 36601/43416 [2:05:23<22:25,  5.07it/s]

{'loss': 0.0148, 'learning_rate': 3.1398562741846326e-06, 'epoch': 6.74}


 85%|████████▍ | 36701/43416 [2:05:43<22:16,  5.02it/s]

{'loss': 0.0129, 'learning_rate': 3.093790307720656e-06, 'epoch': 6.76}


 85%|████████▍ | 36801/43416 [2:06:03<22:06,  4.99it/s]

{'loss': 0.0135, 'learning_rate': 3.04772434125668e-06, 'epoch': 6.78}


 85%|████████▍ | 36901/43416 [2:06:23<21:43,  5.00it/s]

{'loss': 0.0121, 'learning_rate': 3.001658374792704e-06, 'epoch': 6.8}


 85%|████████▌ | 37000/43416 [2:06:42<21:27,  4.98it/s]

{'loss': 0.013, 'learning_rate': 2.955592408328727e-06, 'epoch': 6.82}


 85%|████████▌ | 37101/43416 [2:07:02<20:52,  5.04it/s]

{'loss': 0.0128, 'learning_rate': 2.9095264418647507e-06, 'epoch': 6.84}


 86%|████████▌ | 37201/43416 [2:07:22<20:35,  5.03it/s]

{'loss': 0.012, 'learning_rate': 2.8634604754007737e-06, 'epoch': 6.85}


 86%|████████▌ | 37301/43416 [2:07:42<20:21,  5.00it/s]

{'loss': 0.0117, 'learning_rate': 2.8173945089367976e-06, 'epoch': 6.87}


 86%|████████▌ | 37401/43416 [2:08:02<20:01,  5.01it/s]

{'loss': 0.0149, 'learning_rate': 2.7713285424728214e-06, 'epoch': 6.89}


 86%|████████▋ | 37501/43416 [2:08:22<19:31,  5.05it/s]

{'loss': 0.0129, 'learning_rate': 2.725262576008845e-06, 'epoch': 6.91}


 87%|████████▋ | 37601/43416 [2:08:42<19:23,  5.00it/s]

{'loss': 0.0128, 'learning_rate': 2.6791966095448683e-06, 'epoch': 6.93}


 87%|████████▋ | 37701/43416 [2:09:02<18:55,  5.03it/s]

{'loss': 0.0116, 'learning_rate': 2.633130643080892e-06, 'epoch': 6.95}


 87%|████████▋ | 37801/43416 [2:09:21<18:28,  5.07it/s]

{'loss': 0.0127, 'learning_rate': 2.5870646766169156e-06, 'epoch': 6.97}


 87%|████████▋ | 37901/43416 [2:09:41<18:27,  4.98it/s]

{'loss': 0.0109, 'learning_rate': 2.5409987101529395e-06, 'epoch': 6.98}


 88%|████████▊ | 37989/43416 [2:09:59<16:39,  5.43it/s]
 88%|████████▊ | 37989/43416 [2:10:33<16:39,  5.43it/s]

{'eval_loss': 0.14610277116298676, 'eval_subset_accuracy': 0.44452635458901585, 'eval_micro_f1': 0.5804109377413873, 'eval_macro_f1': 0.5121589174387003, 'eval_runtime': 33.8573, 'eval_samples_per_second': 160.261, 'eval_steps_per_second': 20.055, 'epoch': 7.0}


 88%|████████▊ | 38001/43416 [2:10:40<38:57,  2.32it/s]   

{'loss': 0.0103, 'learning_rate': 2.494932743688963e-06, 'epoch': 7.0}


 88%|████████▊ | 38101/43416 [2:11:00<17:35,  5.03it/s]

{'loss': 0.0077, 'learning_rate': 2.4488667772249864e-06, 'epoch': 7.02}


 88%|████████▊ | 38201/43416 [2:11:20<18:08,  4.79it/s]

{'loss': 0.0106, 'learning_rate': 2.40280081076101e-06, 'epoch': 7.04}


 88%|████████▊ | 38301/43416 [2:11:41<17:20,  4.92it/s]

{'loss': 0.0103, 'learning_rate': 2.3567348442970337e-06, 'epoch': 7.06}


 88%|████████▊ | 38401/43416 [2:12:01<16:44,  4.99it/s]

{'loss': 0.0087, 'learning_rate': 2.310668877833057e-06, 'epoch': 7.08}


 89%|████████▊ | 38501/43416 [2:12:21<16:16,  5.03it/s]

{'loss': 0.0085, 'learning_rate': 2.2646029113690806e-06, 'epoch': 7.09}


 89%|████████▉ | 38601/43416 [2:12:42<16:21,  4.91it/s]

{'loss': 0.009, 'learning_rate': 2.218536944905104e-06, 'epoch': 7.11}


 89%|████████▉ | 38701/43416 [2:13:02<16:31,  4.76it/s]

{'loss': 0.0084, 'learning_rate': 2.172470978441128e-06, 'epoch': 7.13}


 89%|████████▉ | 38800/43416 [2:13:22<16:16,  4.73it/s]

{'loss': 0.0092, 'learning_rate': 2.1264050119771513e-06, 'epoch': 7.15}


 90%|████████▉ | 38900/43416 [2:13:43<15:08,  4.97it/s]

{'loss': 0.0091, 'learning_rate': 2.080339045513175e-06, 'epoch': 7.17}


 90%|████████▉ | 39001/43416 [2:14:03<14:46,  4.98it/s]

{'loss': 0.0084, 'learning_rate': 2.0342730790491987e-06, 'epoch': 7.19}


 90%|█████████ | 39100/43416 [2:14:24<15:28,  4.65it/s]

{'loss': 0.0091, 'learning_rate': 1.9882071125852225e-06, 'epoch': 7.2}


 90%|█████████ | 39201/43416 [2:14:44<14:40,  4.79it/s]

{'loss': 0.0092, 'learning_rate': 1.942141146121246e-06, 'epoch': 7.22}


 91%|█████████ | 39300/43416 [2:15:05<13:53,  4.94it/s]

{'loss': 0.0072, 'learning_rate': 1.8960751796572694e-06, 'epoch': 7.24}


 91%|█████████ | 39401/43416 [2:15:25<13:37,  4.91it/s]

{'loss': 0.0091, 'learning_rate': 1.8500092131932929e-06, 'epoch': 7.26}


 91%|█████████ | 39501/43416 [2:15:46<13:00,  5.01it/s]

{'loss': 0.0082, 'learning_rate': 1.8039432467293167e-06, 'epoch': 7.28}


 91%|█████████ | 39601/43416 [2:16:05<12:36,  5.05it/s]

{'loss': 0.0085, 'learning_rate': 1.7578772802653402e-06, 'epoch': 7.3}


 91%|█████████▏| 39700/43416 [2:16:26<12:25,  4.99it/s]

{'loss': 0.0091, 'learning_rate': 1.7118113138013636e-06, 'epoch': 7.32}


 92%|█████████▏| 39801/43416 [2:16:46<12:01,  5.01it/s]

{'loss': 0.0099, 'learning_rate': 1.6657453473373873e-06, 'epoch': 7.33}


 92%|█████████▏| 39901/43416 [2:17:06<11:38,  5.03it/s]

{'loss': 0.0099, 'learning_rate': 1.619679380873411e-06, 'epoch': 7.35}


 92%|█████████▏| 40001/43416 [2:17:26<11:26,  4.98it/s]

{'loss': 0.0084, 'learning_rate': 1.5736134144094346e-06, 'epoch': 7.37}


 92%|█████████▏| 40101/43416 [2:17:46<11:02,  5.01it/s]

{'loss': 0.0095, 'learning_rate': 1.527547447945458e-06, 'epoch': 7.39}


 93%|█████████▎| 40201/43416 [2:18:06<10:39,  5.03it/s]

{'loss': 0.0081, 'learning_rate': 1.4814814814814815e-06, 'epoch': 7.41}


 93%|█████████▎| 40301/43416 [2:18:26<10:19,  5.03it/s]

{'loss': 0.0088, 'learning_rate': 1.4354155150175053e-06, 'epoch': 7.43}


 93%|█████████▎| 40401/43416 [2:18:46<09:58,  5.03it/s]

{'loss': 0.0077, 'learning_rate': 1.3893495485535288e-06, 'epoch': 7.44}


 93%|█████████▎| 40501/43416 [2:19:06<09:36,  5.05it/s]

{'loss': 0.0092, 'learning_rate': 1.3432835820895524e-06, 'epoch': 7.46}


 94%|█████████▎| 40601/43416 [2:19:25<09:22,  5.01it/s]

{'loss': 0.0083, 'learning_rate': 1.2972176156255759e-06, 'epoch': 7.48}


 94%|█████████▎| 40701/43416 [2:19:45<09:04,  4.99it/s]

{'loss': 0.0096, 'learning_rate': 1.2511516491615997e-06, 'epoch': 7.5}


 94%|█████████▍| 40801/43416 [2:20:05<08:40,  5.03it/s]

{'loss': 0.0078, 'learning_rate': 1.2050856826976232e-06, 'epoch': 7.52}


 94%|█████████▍| 40901/43416 [2:20:25<08:20,  5.02it/s]

{'loss': 0.0092, 'learning_rate': 1.1590197162336466e-06, 'epoch': 7.54}


 94%|█████████▍| 41001/43416 [2:20:45<07:59,  5.04it/s]

{'loss': 0.0083, 'learning_rate': 1.1129537497696703e-06, 'epoch': 7.55}


 95%|█████████▍| 41101/43416 [2:21:05<07:38,  5.04it/s]

{'loss': 0.0089, 'learning_rate': 1.066887783305694e-06, 'epoch': 7.57}


 95%|█████████▍| 41201/43416 [2:21:25<07:21,  5.02it/s]

{'loss': 0.0089, 'learning_rate': 1.0208218168417174e-06, 'epoch': 7.59}


 95%|█████████▌| 41301/43416 [2:21:45<07:09,  4.93it/s]

{'loss': 0.0079, 'learning_rate': 9.74755850377741e-07, 'epoch': 7.61}


 95%|█████████▌| 41401/43416 [2:22:05<06:42,  5.00it/s]

{'loss': 0.0089, 'learning_rate': 9.286898839137646e-07, 'epoch': 7.63}


 96%|█████████▌| 41501/43416 [2:22:25<06:22,  5.01it/s]

{'loss': 0.0085, 'learning_rate': 8.826239174497882e-07, 'epoch': 7.65}


 96%|█████████▌| 41601/43416 [2:22:44<06:00,  5.03it/s]

{'loss': 0.0079, 'learning_rate': 8.365579509858117e-07, 'epoch': 7.67}


 96%|█████████▌| 41701/43416 [2:23:04<05:43,  4.99it/s]

{'loss': 0.01, 'learning_rate': 7.904919845218353e-07, 'epoch': 7.68}


 96%|█████████▋| 41801/43416 [2:23:24<05:27,  4.93it/s]

{'loss': 0.0085, 'learning_rate': 7.444260180578589e-07, 'epoch': 7.7}


 97%|█████████▋| 41901/43416 [2:23:44<05:03,  5.00it/s]

{'loss': 0.0087, 'learning_rate': 6.983600515938825e-07, 'epoch': 7.72}


 97%|█████████▋| 42001/43416 [2:24:04<04:42,  5.01it/s]

{'loss': 0.0087, 'learning_rate': 6.522940851299061e-07, 'epoch': 7.74}


 97%|█████████▋| 42100/43416 [2:24:24<04:18,  5.09it/s]

{'loss': 0.0088, 'learning_rate': 6.062281186659296e-07, 'epoch': 7.76}


 97%|█████████▋| 42201/43416 [2:24:44<04:04,  4.97it/s]

{'loss': 0.007, 'learning_rate': 5.601621522019533e-07, 'epoch': 7.78}


 97%|█████████▋| 42301/43416 [2:25:04<03:41,  5.03it/s]

{'loss': 0.0097, 'learning_rate': 5.140961857379768e-07, 'epoch': 7.79}


 98%|█████████▊| 42401/43416 [2:25:24<03:22,  5.01it/s]

{'loss': 0.0076, 'learning_rate': 4.6803021927400043e-07, 'epoch': 7.81}


 98%|█████████▊| 42501/43416 [2:25:44<03:01,  5.04it/s]

{'loss': 0.0097, 'learning_rate': 4.2196425281002403e-07, 'epoch': 7.83}


 98%|█████████▊| 42601/43416 [2:26:03<02:40,  5.06it/s]

{'loss': 0.0075, 'learning_rate': 3.758982863460476e-07, 'epoch': 7.85}


 98%|█████████▊| 42701/43416 [2:26:23<02:23,  4.98it/s]

{'loss': 0.01, 'learning_rate': 3.298323198820712e-07, 'epoch': 7.87}


 99%|█████████▊| 42801/43416 [2:26:43<02:02,  5.02it/s]

{'loss': 0.0072, 'learning_rate': 2.8376635341809473e-07, 'epoch': 7.89}


 99%|█████████▉| 42901/43416 [2:27:03<01:41,  5.06it/s]

{'loss': 0.0087, 'learning_rate': 2.377003869541183e-07, 'epoch': 7.9}


 99%|█████████▉| 43001/43416 [2:27:23<01:22,  5.01it/s]

{'loss': 0.0087, 'learning_rate': 1.9163442049014188e-07, 'epoch': 7.92}


 99%|█████████▉| 43101/43416 [2:27:43<01:03,  4.98it/s]

{'loss': 0.0096, 'learning_rate': 1.4556845402616549e-07, 'epoch': 7.94}


100%|█████████▉| 43201/43416 [2:28:03<00:42,  5.01it/s]

{'loss': 0.0092, 'learning_rate': 9.950248756218906e-08, 'epoch': 7.96}


100%|█████████▉| 43300/43416 [2:28:23<00:23,  4.98it/s]

{'loss': 0.0093, 'learning_rate': 5.343652109821264e-08, 'epoch': 7.98}


100%|█████████▉| 43400/43416 [2:28:43<00:03,  5.01it/s]

{'loss': 0.0106, 'learning_rate': 7.370554634236227e-09, 'epoch': 8.0}


100%|██████████| 43416/43416 [2:28:46<00:00,  5.59it/s]
100%|██████████| 43416/43416 [2:29:20<00:00,  5.59it/s]

{'eval_loss': 0.15499846637248993, 'eval_subset_accuracy': 0.43715444157758937, 'eval_micro_f1': 0.5758086769891154, 'eval_macro_f1': 0.5058401627485241, 'eval_runtime': 33.7248, 'eval_samples_per_second': 160.891, 'eval_steps_per_second': 20.134, 'epoch': 8.0}


100%|██████████| 43416/43416 [2:29:25<00:00,  5.59it/s]

{'train_runtime': 8965.5418, 'train_samples_per_second': 38.735, 'train_steps_per_second': 4.843, 'train_loss': 0.044623604394915584, 'epoch': 8.0}


100%|██████████| 43416/43416 [2:29:26<00:00,  4.84it/s]



Evaluating on validation set with threshold=0.5...


100%|██████████| 679/679 [00:33<00:00, 20.18it/s]


Validation Results:
╒════╤═════════════╤════════════════════════╤═════════════════╤═════════════════╤════════════════╤═══════════════════════════╤═════════════════════════╤═════════╕
│    │   eval_loss │   eval_subset_accuracy │   eval_micro_f1 │   eval_macro_f1 │   eval_runtime │   eval_samples_per_second │   eval_steps_per_second │   epoch │
╞════╪═════════════╪════════════════════════╪═════════════════╪═════════════════╪════════════════╪═══════════════════════════╪═════════════════════════╪═════════╡
│  0 │    0.081984 │               0.490232 │        0.605637 │         0.45296 │        33.8725 │                   160.189 │                  20.046 │       8 │
╘════╧═════════════╧════════════════════════╧═════════════════╧═════════════════╧════════════════╧═══════════════════════════╧═════════════════════════╧═════════╛

Evaluating on test set with threshold=0.5...


100%|██████████| 679/679 [00:33<00:00, 20.17it/s]


Test Results:
╒════╤═════════════╤════════════════════════╤═════════════════╤═════════════════╤════════════════╤═══════════════════════════╤═════════════════════════╤═════════╕
│    │   eval_loss │   eval_subset_accuracy │   eval_micro_f1 │   eval_macro_f1 │   eval_runtime │   eval_samples_per_second │   eval_steps_per_second │   epoch │
╞════╪═════════════╪════════════════════════╪═════════════════╪═════════════════╪════════════════╪═══════════════════════════╪═════════════════════════╪═════════╡
│  0 │   0.0816054 │               0.488852 │        0.606616 │         0.45304 │         33.718 │                   160.952 │                  20.138 │       8 │
╘════╧═════════════╧════════════════════════╧═════════════════╧═════════════════╧════════════════╧═══════════════════════════╧═════════════════════════╧═════════╛

Finding best threshold on validation set...


100%|██████████| 679/679 [00:33<00:00, 20.19it/s]


Best threshold on validation set: 0.35000000000000003
Best micro-F1 on validation set with that threshold: 0.6128


100%|██████████| 679/679 [00:33<00:00, 20.11it/s]



Test Results with tuned threshold:
  Subset Accuracy: 0.4514
  Micro-F1:        0.6163
  Macro-F1:        0.4861

Saving the model and tokenizer...
Model and tokenizer saved to ./my_saved_model

Sample inference text: I am feeling really happy and excited about today!
Predicted probabilities for each emotion: [2.0597113e-02 5.7688584e-03 4.0924043e-04 1.1593949e-03 2.3545619e-02
 1.1877004e-02 1.5522424e-03 7.0234756e-03 2.6009441e-03 4.5316090e-04
 1.2884886e-03 2.7227736e-04 2.0298503e-04 7.5846952e-01 3.2438326e-04
 1.5194606e-02 5.8916811e-04 7.1431482e-01 3.5642970e-03 1.1694117e-03
 3.7335148e-03 1.5889428e-03 6.4925775e-03 1.4910299e-02 1.3529226e-04
 6.3893624e-04 5.6049395e-03 7.8487642e-02]
Predicted multi-hot labels: [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0]


In [9]:
from tabulate import tabulate

# Compute metrics manually
test_subset_accuracy = np.mean(np.all(test_predictions == test_labels, axis=1))
test_micro_f1 = f1_score(test_labels, test_predictions, average="micro", zero_division=0)
test_macro_f1 = f1_score(test_labels, test_predictions, average="macro", zero_division=0)

results_dict = {
    "subset_accuracy": [test_subset_accuracy],
    "micro_f1": [test_micro_f1],
    "macro_f1": [test_macro_f1]
}

print("\nTest Results with tuned threshold:")
df_final_results = pd.DataFrame(results_dict)
print(tabulate(df_final_results, headers="keys", tablefmt="fancy_grid"))



Test Results with tuned threshold:
╒════╤═══════════════════╤════════════╤════════════╕
│    │   subset_accuracy │   micro_f1 │   macro_f1 │
╞════╪═══════════════════╪════════════╪════════════╡
│  0 │          0.451446 │   0.616346 │   0.486148 │
╘════╧═══════════════════╧════════════╧════════════╛


In [1]:
# 1. Get the label names from the GoEmotions dataset
label_names = dataset["train"].features["labels"].feature.names

print(label_names)

# 2. A helper function that uses our trained model and tokenizer
#    to predict emotions for a single text prompt
def classify_prompt(text, threshold=best_threshold):
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Forward pass (inference)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]
    
    # Convert logits to probabilities
    probs = torch.sigmoid(logits).cpu().numpy()
    
    # Apply the best threshold to get binary predictions
    preds = (probs > threshold).astype(int)
    
    # Gather the names of the predicted labels
    predicted_labels = [label_names[i] for i, p in enumerate(preds) if p == 1]
    return predicted_labels, probs

# 3. Test the model on a few custom prompts
test_prompts = [
    "I am feeling really happy and excited about today!",
    "I am disappointed and angry with the results.",
    "This is just okay, nothing special.",
    "I want to order a cheeseburger."
]

for prompt in test_prompts:
    labels, probabilities = classify_prompt(prompt)
    print(f"\nPrompt: {prompt}")
    if labels:
        for lbl in labels:
            idx = label_names.index(lbl)
            print(f"  -> {lbl} (prob={probabilities[idx]:.3f})")
    else:
        print("  -> No emotion predicted above threshold.")


NameError: name 'dataset' is not defined