In [6]:
import numpy as np
import torch
from datasets import load_dataset, Value, Sequence
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
    DebertaV2Tokenizer
)
import pandas as pd
from tabulate import tabulate
from sklearn.metrics import f1_score

# ------------------------------
# CUDA Initialization Check
# ------------------------------
if torch.cuda.is_available():
    try:
        torch.cuda.init()  # Force CUDA initialization (optional)
    except Exception as e:
        print("Error during CUDA initialization:", e)
    print("CUDA is available. Running on GPU.")
    try:
        print("GPU Name:", torch.cuda.get_device_name(0))
    except Exception as e:
        print("Could not get GPU Name:", e)
    print("GPU Count:", torch.cuda.device_count())
else:
    print("CUDA is not available. Running on CPU.")

# ------------------------------
# 1. Load the GoEmotions Dataset
# ------------------------------
dataset = load_dataset("go_emotions")
print("\nSample from training set:")
print(dataset["train"][0])  # Show a sample to inspect

# Number of labels
num_labels = dataset["train"].features["labels"].feature.num_classes
print("Number of labels:", num_labels)

# ------------------------------
# 2. Data Preprocessing
# ------------------------------
def convert_labels(label_list):
    """Convert list of label indices to a multi-hot vector of floats."""
    return [1.0 if i in label_list else 0.0 for i in range(num_labels)]

model_checkpoint = "microsoft/deberta-v3-base"
# Use a stronger checkpoint such as microsoft/deberta-v3-base
tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-base")
max_length = 128  # You can experiment with longer sequences

def preprocess_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    tokenized["labels"] = [convert_labels(lbls) for lbls in examples["labels"]]
    return tokenized

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.cast_column("labels", Sequence(Value("float32")))
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# ------------------------------
# 3. Model Setup
# ------------------------------
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
)
model.config.problem_type = "multi_label_classification"

if torch.cuda.is_available():
    model = model.to("cuda")

# ------------------------------
# 4. Define Compute Metrics Function
# ------------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Binarize predictions using threshold 0.5
    predictions = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    # Compute subset accuracy (exact match)
    subset_accuracy = np.mean(np.all(predictions == labels, axis=1))
    # Compute micro and macro F1
    micro_f1 = f1_score(labels, predictions, average="micro", zero_division=0)
    macro_f1 = f1_score(labels, predictions, average="macro", zero_division=0)
    return {
        "subset_accuracy": subset_accuracy,
        "micro_f1": micro_f1,
        "macro_f1": macro_f1
    }

# ------------------------------
# 5. Training Arguments & Trainer Setup
# ------------------------------
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,  # Try 5-10 epochs for better performance
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=1,
    load_best_model_at_end=True,
)

def custom_data_collator(features):
    batch = DefaultDataCollator()(features)
    batch["labels"] = batch["labels"].float()
    return batch

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=custom_data_collator,
    compute_metrics=compute_metrics,
)

print("\nStarting training...")
trainer.train()

# ------------------------------
# 6. Evaluation
# ------------------------------
print("\nEvaluating on validation set...")
val_results = trainer.evaluate()
df_val_results = pd.DataFrame([val_results])
print("Validation Results:")
print(tabulate(df_val_results, headers="keys", tablefmt="fancy_grid"))

print("\nEvaluating on test set...")
test_results = trainer.evaluate(eval_dataset=encoded_dataset["test"])
df_test_results = pd.DataFrame([test_results])
print("Test Results:")
print(tabulate(df_test_results, headers="keys", tablefmt="fancy_grid"))

# ------------------------------
# 7. Inference Example
# ------------------------------
def predict_emotions(text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]
    probs = torch.sigmoid(logits).cpu().numpy()
    return probs

sample_text = "I am feeling really happy and excited about today!"
probs = predict_emotions(sample_text)
print("\nSample inference text:", sample_text)
print("Predicted probabilities for each emotion:")
print(probs)


CUDA is available. Running on GPU.
GPU Name: NVIDIA GeForce RTX 4070 Ti SUPER
GPU Count: 1

Sample from training set:
{'text': "My favourite food is anything I didn't have to cook myself.", 'labels': [27], 'id': 'eebbqej'}
Number of labels: 28


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Starting training...


  1%|          | 101/13570 [00:13<27:54,  8.04it/s]

{'loss': 0.3424, 'learning_rate': 1.9852616064848935e-05, 'epoch': 0.04}


  1%|▏         | 201/13570 [00:26<27:33,  8.09it/s]

{'loss': 0.1575, 'learning_rate': 1.9705232129697864e-05, 'epoch': 0.07}


  2%|▏         | 301/13570 [00:38<29:07,  7.59it/s]

{'loss': 0.1489, 'learning_rate': 1.9557848194546797e-05, 'epoch': 0.11}


  3%|▎         | 401/13570 [00:51<27:12,  8.07it/s]

{'loss': 0.1452, 'learning_rate': 1.9410464259395727e-05, 'epoch': 0.15}


  4%|▎         | 501/13570 [01:03<26:55,  8.09it/s]

{'loss': 0.1385, 'learning_rate': 1.926308032424466e-05, 'epoch': 0.18}


  4%|▍         | 601/13570 [01:16<26:35,  8.13it/s]

{'loss': 0.1353, 'learning_rate': 1.911569638909359e-05, 'epoch': 0.22}


  5%|▌         | 701/13570 [01:28<26:33,  8.08it/s]

{'loss': 0.1295, 'learning_rate': 1.8968312453942523e-05, 'epoch': 0.26}


  6%|▌         | 801/13570 [01:40<26:39,  7.98it/s]

{'loss': 0.1268, 'learning_rate': 1.8820928518791452e-05, 'epoch': 0.29}


  7%|▋         | 901/13570 [01:53<26:12,  8.06it/s]

{'loss': 0.1194, 'learning_rate': 1.8673544583640385e-05, 'epoch': 0.33}


  7%|▋         | 1001/13570 [02:05<27:22,  7.65it/s]

{'loss': 0.1162, 'learning_rate': 1.8526160648489315e-05, 'epoch': 0.37}


  8%|▊         | 1101/13570 [02:18<26:09,  7.94it/s]

{'loss': 0.1158, 'learning_rate': 1.8378776713338248e-05, 'epoch': 0.41}


  9%|▉         | 1201/13570 [02:31<26:42,  7.72it/s]

{'loss': 0.1114, 'learning_rate': 1.8231392778187178e-05, 'epoch': 0.44}


 10%|▉         | 1301/13570 [02:43<25:31,  8.01it/s]

{'loss': 0.1118, 'learning_rate': 1.808400884303611e-05, 'epoch': 0.48}


 10%|█         | 1401/13570 [02:56<25:15,  8.03it/s]

{'loss': 0.1043, 'learning_rate': 1.7936624907885044e-05, 'epoch': 0.52}


 11%|█         | 1501/13570 [03:08<25:01,  8.04it/s]

{'loss': 0.1081, 'learning_rate': 1.7789240972733973e-05, 'epoch': 0.55}


 12%|█▏        | 1601/13570 [03:21<24:56,  8.00it/s]

{'loss': 0.1054, 'learning_rate': 1.7641857037582906e-05, 'epoch': 0.59}


 13%|█▎        | 1701/13570 [03:33<24:31,  8.06it/s]

{'loss': 0.0995, 'learning_rate': 1.7494473102431836e-05, 'epoch': 0.63}


 13%|█▎        | 1801/13570 [03:45<24:15,  8.09it/s]

{'loss': 0.1042, 'learning_rate': 1.734708916728077e-05, 'epoch': 0.66}


 14%|█▍        | 1901/13570 [03:58<24:22,  7.98it/s]

{'loss': 0.1026, 'learning_rate': 1.71997052321297e-05, 'epoch': 0.7}


 15%|█▍        | 2001/13570 [04:10<24:22,  7.91it/s]

{'loss': 0.1, 'learning_rate': 1.7052321296978632e-05, 'epoch': 0.74}


 15%|█▌        | 2101/13570 [04:23<23:58,  7.97it/s]

{'loss': 0.1028, 'learning_rate': 1.690493736182756e-05, 'epoch': 0.77}


 16%|█▌        | 2201/13570 [04:35<23:36,  8.02it/s]

{'loss': 0.0994, 'learning_rate': 1.6757553426676494e-05, 'epoch': 0.81}


 17%|█▋        | 2301/13570 [04:48<23:07,  8.12it/s]

{'loss': 0.0992, 'learning_rate': 1.6610169491525424e-05, 'epoch': 0.85}


 18%|█▊        | 2401/13570 [05:00<23:12,  8.02it/s]

{'loss': 0.0978, 'learning_rate': 1.6462785556374357e-05, 'epoch': 0.88}


 18%|█▊        | 2501/13570 [05:13<23:09,  7.97it/s]

{'loss': 0.0951, 'learning_rate': 1.6315401621223287e-05, 'epoch': 0.92}


 19%|█▉        | 2601/13570 [05:25<22:36,  8.08it/s]

{'loss': 0.0942, 'learning_rate': 1.616801768607222e-05, 'epoch': 0.96}


 20%|█▉        | 2701/13570 [05:38<23:31,  7.70it/s]

{'loss': 0.097, 'learning_rate': 1.602063375092115e-05, 'epoch': 0.99}


                                                    
 20%|██        | 2714/13570 [05:51<22:51,  7.92it/s]

{'eval_loss': 0.09200803190469742, 'eval_subset_accuracy': 0.40600810910431256, 'eval_micro_f1': 0.528171106050104, 'eval_macro_f1': 0.29070297317790333, 'eval_runtime': 11.3419, 'eval_samples_per_second': 478.402, 'eval_steps_per_second': 29.977, 'epoch': 1.0}


 21%|██        | 2801/13570 [06:04<22:30,  7.98it/s]  

{'loss': 0.0945, 'learning_rate': 1.5873249815770082e-05, 'epoch': 1.03}


 21%|██▏       | 2901/13570 [06:16<22:20,  7.96it/s]

{'loss': 0.0908, 'learning_rate': 1.5725865880619012e-05, 'epoch': 1.07}


 22%|██▏       | 3001/13570 [06:29<22:03,  7.98it/s]

{'loss': 0.0935, 'learning_rate': 1.5578481945467945e-05, 'epoch': 1.11}


 23%|██▎       | 3101/13570 [06:41<22:03,  7.91it/s]

{'loss': 0.0882, 'learning_rate': 1.5431098010316875e-05, 'epoch': 1.14}


 24%|██▎       | 3201/13570 [06:54<21:23,  8.08it/s]

{'loss': 0.0894, 'learning_rate': 1.5283714075165808e-05, 'epoch': 1.18}


 24%|██▍       | 3301/13570 [07:06<21:18,  8.03it/s]

{'loss': 0.0853, 'learning_rate': 1.5136330140014739e-05, 'epoch': 1.22}


 25%|██▌       | 3401/13570 [07:19<20:59,  8.07it/s]

{'loss': 0.0878, 'learning_rate': 1.498894620486367e-05, 'epoch': 1.25}


 26%|██▌       | 3501/13570 [07:31<21:10,  7.92it/s]

{'loss': 0.0872, 'learning_rate': 1.4841562269712602e-05, 'epoch': 1.29}


 27%|██▋       | 3601/13570 [07:44<21:38,  7.68it/s]

{'loss': 0.0883, 'learning_rate': 1.4694178334561533e-05, 'epoch': 1.33}


 27%|██▋       | 3701/13570 [07:57<20:48,  7.90it/s]

{'loss': 0.0887, 'learning_rate': 1.4546794399410464e-05, 'epoch': 1.36}


 28%|██▊       | 3801/13570 [08:09<20:07,  8.09it/s]

{'loss': 0.0892, 'learning_rate': 1.4399410464259397e-05, 'epoch': 1.4}


 29%|██▊       | 3901/13570 [08:22<20:11,  7.98it/s]

{'loss': 0.0854, 'learning_rate': 1.4252026529108329e-05, 'epoch': 1.44}


 29%|██▉       | 4001/13570 [08:34<20:03,  7.95it/s]

{'loss': 0.0887, 'learning_rate': 1.410464259395726e-05, 'epoch': 1.47}


 30%|███       | 4101/13570 [08:46<19:31,  8.08it/s]

{'loss': 0.0901, 'learning_rate': 1.3957258658806191e-05, 'epoch': 1.51}


 31%|███       | 4201/13570 [08:59<19:40,  7.93it/s]

{'loss': 0.0845, 'learning_rate': 1.3809874723655123e-05, 'epoch': 1.55}


 32%|███▏      | 4301/13570 [09:11<19:05,  8.09it/s]

{'loss': 0.0902, 'learning_rate': 1.3662490788504054e-05, 'epoch': 1.58}


 32%|███▏      | 4401/13570 [09:24<19:10,  7.97it/s]

{'loss': 0.0886, 'learning_rate': 1.3515106853352985e-05, 'epoch': 1.62}


 33%|███▎      | 4501/13570 [09:37<20:15,  7.46it/s]

{'loss': 0.0884, 'learning_rate': 1.3367722918201917e-05, 'epoch': 1.66}


 34%|███▍      | 4601/13570 [09:50<19:57,  7.49it/s]

{'loss': 0.086, 'learning_rate': 1.3220338983050848e-05, 'epoch': 1.69}


 35%|███▍      | 4701/13570 [10:03<19:37,  7.53it/s]

{'loss': 0.0862, 'learning_rate': 1.307295504789978e-05, 'epoch': 1.73}


 35%|███▌      | 4801/13570 [10:17<20:04,  7.28it/s]

{'loss': 0.0844, 'learning_rate': 1.292557111274871e-05, 'epoch': 1.77}


 36%|███▌      | 4901/13570 [10:30<19:23,  7.45it/s]

{'loss': 0.0851, 'learning_rate': 1.2778187177597642e-05, 'epoch': 1.81}


 37%|███▋      | 5001/13570 [10:43<18:47,  7.60it/s]

{'loss': 0.085, 'learning_rate': 1.2630803242446574e-05, 'epoch': 1.84}


 38%|███▊      | 5101/13570 [10:57<18:50,  7.49it/s]

{'loss': 0.0854, 'learning_rate': 1.2483419307295505e-05, 'epoch': 1.88}


 38%|███▊      | 5201/13570 [11:10<18:27,  7.56it/s]

{'loss': 0.087, 'learning_rate': 1.2336035372144436e-05, 'epoch': 1.92}


 39%|███▉      | 5301/13570 [11:23<18:13,  7.57it/s]

{'loss': 0.0872, 'learning_rate': 1.218865143699337e-05, 'epoch': 1.95}


 40%|███▉      | 5401/13570 [11:36<18:06,  7.52it/s]

{'loss': 0.0822, 'learning_rate': 1.20412675018423e-05, 'epoch': 1.99}


                                                    
 40%|████      | 5428/13570 [11:51<17:25,  7.79it/s]

{'eval_loss': 0.08527296781539917, 'eval_subset_accuracy': 0.4566900110578695, 'eval_micro_f1': 0.5731989272172385, 'eval_macro_f1': 0.40966896291164445, 'eval_runtime': 11.5131, 'eval_samples_per_second': 471.29, 'eval_steps_per_second': 29.532, 'epoch': 2.0}


 41%|████      | 5501/13570 [12:03<18:00,  7.47it/s]  

{'loss': 0.0816, 'learning_rate': 1.1893883566691232e-05, 'epoch': 2.03}


 41%|████▏     | 5601/13570 [12:17<17:37,  7.53it/s]

{'loss': 0.0793, 'learning_rate': 1.1746499631540163e-05, 'epoch': 2.06}


 42%|████▏     | 5701/13570 [12:30<17:22,  7.55it/s]

{'loss': 0.0784, 'learning_rate': 1.1599115696389095e-05, 'epoch': 2.1}


 43%|████▎     | 5801/13570 [12:43<17:11,  7.53it/s]

{'loss': 0.0748, 'learning_rate': 1.1451731761238026e-05, 'epoch': 2.14}


 43%|████▎     | 5901/13570 [12:56<16:57,  7.53it/s]

{'loss': 0.0762, 'learning_rate': 1.1304347826086957e-05, 'epoch': 2.17}


 44%|████▍     | 6001/13570 [13:10<16:44,  7.54it/s]

{'loss': 0.0772, 'learning_rate': 1.1156963890935889e-05, 'epoch': 2.21}


 45%|████▍     | 6101/13570 [13:23<16:42,  7.45it/s]

{'loss': 0.0761, 'learning_rate': 1.100957995578482e-05, 'epoch': 2.25}


 46%|████▌     | 6201/13570 [13:36<16:00,  7.67it/s]

{'loss': 0.0785, 'learning_rate': 1.0862196020633753e-05, 'epoch': 2.28}


 46%|████▋     | 6301/13570 [13:49<16:03,  7.54it/s]

{'loss': 0.0777, 'learning_rate': 1.0714812085482684e-05, 'epoch': 2.32}


 47%|████▋     | 6401/13570 [14:03<16:44,  7.14it/s]

{'loss': 0.0777, 'learning_rate': 1.0567428150331616e-05, 'epoch': 2.36}


 48%|████▊     | 6501/13570 [14:17<16:16,  7.24it/s]

{'loss': 0.08, 'learning_rate': 1.0420044215180547e-05, 'epoch': 2.39}


 49%|████▊     | 6601/13570 [14:31<16:00,  7.26it/s]

{'loss': 0.0777, 'learning_rate': 1.0272660280029478e-05, 'epoch': 2.43}


 49%|████▉     | 6701/13570 [14:44<15:07,  7.57it/s]

{'loss': 0.0797, 'learning_rate': 1.012527634487841e-05, 'epoch': 2.47}


 50%|█████     | 6801/13570 [14:57<15:01,  7.51it/s]

{'loss': 0.077, 'learning_rate': 9.977892409727341e-06, 'epoch': 2.51}


 51%|█████     | 6901/13570 [15:11<14:50,  7.49it/s]

{'loss': 0.0784, 'learning_rate': 9.830508474576272e-06, 'epoch': 2.54}


 52%|█████▏    | 7001/13570 [15:24<14:12,  7.70it/s]

{'loss': 0.0762, 'learning_rate': 9.683124539425204e-06, 'epoch': 2.58}


 52%|█████▏    | 7101/13570 [15:37<13:48,  7.81it/s]

{'loss': 0.0788, 'learning_rate': 9.535740604274135e-06, 'epoch': 2.62}


 53%|█████▎    | 7201/13570 [15:50<14:10,  7.49it/s]

{'loss': 0.0774, 'learning_rate': 9.388356669123066e-06, 'epoch': 2.65}


 54%|█████▍    | 7301/13570 [16:03<13:38,  7.66it/s]

{'loss': 0.0769, 'learning_rate': 9.240972733971998e-06, 'epoch': 2.69}


 55%|█████▍    | 7401/13570 [16:16<13:25,  7.66it/s]

{'loss': 0.0764, 'learning_rate': 9.093588798820929e-06, 'epoch': 2.73}


 55%|█████▌    | 7501/13570 [16:29<13:18,  7.60it/s]

{'loss': 0.0773, 'learning_rate': 8.94620486366986e-06, 'epoch': 2.76}


 56%|█████▌    | 7601/13570 [16:42<12:51,  7.74it/s]

{'loss': 0.0743, 'learning_rate': 8.798820928518792e-06, 'epoch': 2.8}


 57%|█████▋    | 7701/13570 [16:55<13:00,  7.52it/s]

{'loss': 0.0755, 'learning_rate': 8.651436993367723e-06, 'epoch': 2.84}


 57%|█████▋    | 7801/13570 [17:08<12:34,  7.64it/s]

{'loss': 0.0781, 'learning_rate': 8.504053058216654e-06, 'epoch': 2.87}


 58%|█████▊    | 7901/13570 [17:21<12:30,  7.55it/s]

{'loss': 0.0802, 'learning_rate': 8.356669123065586e-06, 'epoch': 2.91}


 59%|█████▉    | 8001/13570 [17:34<12:25,  7.47it/s]

{'loss': 0.0774, 'learning_rate': 8.209285187914517e-06, 'epoch': 2.95}


 60%|█████▉    | 8101/13570 [17:48<11:55,  7.65it/s]

{'loss': 0.0761, 'learning_rate': 8.06190125276345e-06, 'epoch': 2.98}


                                                    
 60%|██████    | 8142/13570 [18:05<12:05,  7.48it/s]

{'eval_loss': 0.0844298005104065, 'eval_subset_accuracy': 0.4601916697382971, 'eval_micro_f1': 0.5760800363801728, 'eval_macro_f1': 0.4411730304085348, 'eval_runtime': 11.5491, 'eval_samples_per_second': 469.819, 'eval_steps_per_second': 29.439, 'epoch': 3.0}


 60%|██████    | 8201/13570 [18:15<12:06,  7.39it/s]  

{'loss': 0.074, 'learning_rate': 7.914517317612381e-06, 'epoch': 3.02}


 61%|██████    | 8301/13570 [18:28<11:49,  7.42it/s]

{'loss': 0.0701, 'learning_rate': 7.767133382461313e-06, 'epoch': 3.06}


 62%|██████▏   | 8401/13570 [18:42<11:27,  7.52it/s]

{'loss': 0.0705, 'learning_rate': 7.619749447310244e-06, 'epoch': 3.1}


 63%|██████▎   | 8501/13570 [18:55<11:11,  7.55it/s]

{'loss': 0.0735, 'learning_rate': 7.472365512159175e-06, 'epoch': 3.13}


 63%|██████▎   | 8601/13570 [19:09<11:28,  7.21it/s]

{'loss': 0.0724, 'learning_rate': 7.324981577008107e-06, 'epoch': 3.17}


 64%|██████▍   | 8701/13570 [19:22<10:59,  7.39it/s]

{'loss': 0.0687, 'learning_rate': 7.177597641857038e-06, 'epoch': 3.21}


 65%|██████▍   | 8801/13570 [19:35<10:46,  7.37it/s]

{'loss': 0.0716, 'learning_rate': 7.030213706705969e-06, 'epoch': 3.24}


 66%|██████▌   | 8901/13570 [19:49<10:39,  7.30it/s]

{'loss': 0.0702, 'learning_rate': 6.882829771554901e-06, 'epoch': 3.28}


 66%|██████▋   | 9001/13570 [20:02<09:58,  7.63it/s]

{'loss': 0.0688, 'learning_rate': 6.735445836403832e-06, 'epoch': 3.32}


 67%|██████▋   | 9101/13570 [20:15<09:47,  7.61it/s]

{'loss': 0.0709, 'learning_rate': 6.588061901252763e-06, 'epoch': 3.35}


 68%|██████▊   | 9201/13570 [20:29<09:41,  7.51it/s]

{'loss': 0.0716, 'learning_rate': 6.440677966101695e-06, 'epoch': 3.39}


 69%|██████▊   | 9301/13570 [20:42<09:21,  7.60it/s]

{'loss': 0.071, 'learning_rate': 6.293294030950628e-06, 'epoch': 3.43}


 69%|██████▉   | 9401/13570 [20:55<09:14,  7.51it/s]

{'loss': 0.0715, 'learning_rate': 6.145910095799559e-06, 'epoch': 3.46}


 70%|███████   | 9501/13570 [21:08<08:51,  7.66it/s]

{'loss': 0.0695, 'learning_rate': 5.99852616064849e-06, 'epoch': 3.5}


 71%|███████   | 9601/13570 [21:21<08:43,  7.58it/s]

{'loss': 0.0667, 'learning_rate': 5.851142225497422e-06, 'epoch': 3.54}


 71%|███████▏  | 9701/13570 [21:34<08:40,  7.43it/s]

{'loss': 0.0666, 'learning_rate': 5.703758290346353e-06, 'epoch': 3.57}


 72%|███████▏  | 9801/13570 [21:48<08:16,  7.59it/s]

{'loss': 0.0728, 'learning_rate': 5.556374355195284e-06, 'epoch': 3.61}


 73%|███████▎  | 9901/13570 [22:01<08:00,  7.64it/s]

{'loss': 0.0709, 'learning_rate': 5.408990420044216e-06, 'epoch': 3.65}


 74%|███████▎  | 10001/13570 [22:14<07:52,  7.55it/s]

{'loss': 0.0726, 'learning_rate': 5.261606484893147e-06, 'epoch': 3.68}


 74%|███████▍  | 10101/13570 [22:27<07:38,  7.57it/s]

{'loss': 0.069, 'learning_rate': 5.114222549742078e-06, 'epoch': 3.72}


 75%|███████▌  | 10201/13570 [22:40<07:27,  7.52it/s]

{'loss': 0.068, 'learning_rate': 4.966838614591011e-06, 'epoch': 3.76}


 76%|███████▌  | 10301/13570 [22:54<07:27,  7.31it/s]

{'loss': 0.0677, 'learning_rate': 4.819454679439942e-06, 'epoch': 3.8}


 77%|███████▋  | 10401/13570 [23:07<07:00,  7.54it/s]

{'loss': 0.0673, 'learning_rate': 4.672070744288873e-06, 'epoch': 3.83}


 77%|███████▋  | 10501/13570 [23:20<07:00,  7.31it/s]

{'loss': 0.0712, 'learning_rate': 4.524686809137805e-06, 'epoch': 3.87}


 78%|███████▊  | 10601/13570 [23:33<06:33,  7.55it/s]

{'loss': 0.0715, 'learning_rate': 4.377302873986736e-06, 'epoch': 3.91}


 79%|███████▉  | 10701/13570 [23:47<06:24,  7.46it/s]

{'loss': 0.0701, 'learning_rate': 4.229918938835667e-06, 'epoch': 3.94}


 80%|███████▉  | 10801/13570 [24:00<06:04,  7.59it/s]

{'loss': 0.0716, 'learning_rate': 4.082535003684599e-06, 'epoch': 3.98}


                                                     
 80%|████████  | 10856/13570 [24:19<06:06,  7.41it/s]

{'eval_loss': 0.08593972772359848, 'eval_subset_accuracy': 0.4813859196461482, 'eval_micro_f1': 0.5938375350140056, 'eval_macro_f1': 0.4649419589856686, 'eval_runtime': 11.4394, 'eval_samples_per_second': 474.327, 'eval_steps_per_second': 29.722, 'epoch': 4.0}


 80%|████████  | 10901/13570 [24:27<06:02,  7.36it/s]  

{'loss': 0.0652, 'learning_rate': 3.93515106853353e-06, 'epoch': 4.02}


 81%|████████  | 11001/13570 [24:40<05:36,  7.64it/s]

{'loss': 0.0649, 'learning_rate': 3.7877671333824617e-06, 'epoch': 4.05}


 82%|████████▏ | 11101/13570 [24:53<05:01,  8.18it/s]

{'loss': 0.0648, 'learning_rate': 3.640383198231393e-06, 'epoch': 4.09}


 83%|████████▎ | 11201/13570 [25:05<04:53,  8.08it/s]

{'loss': 0.067, 'learning_rate': 3.4929992630803244e-06, 'epoch': 4.13}


 83%|████████▎ | 11301/13570 [25:18<04:43,  8.01it/s]

{'loss': 0.063, 'learning_rate': 3.3456153279292557e-06, 'epoch': 4.16}


 84%|████████▍ | 11401/13570 [25:30<04:28,  8.08it/s]

{'loss': 0.0665, 'learning_rate': 3.198231392778187e-06, 'epoch': 4.2}


 85%|████████▍ | 11501/13570 [25:43<04:40,  7.38it/s]

{'loss': 0.065, 'learning_rate': 3.0508474576271192e-06, 'epoch': 4.24}


 85%|████████▌ | 11601/13570 [25:56<04:21,  7.53it/s]

{'loss': 0.0646, 'learning_rate': 2.9034635224760506e-06, 'epoch': 4.27}


 86%|████████▌ | 11701/13570 [26:10<04:11,  7.44it/s]

{'loss': 0.0643, 'learning_rate': 2.756079587324982e-06, 'epoch': 4.31}


 87%|████████▋ | 11801/13570 [26:23<03:59,  7.38it/s]

{'loss': 0.0665, 'learning_rate': 2.6086956521739132e-06, 'epoch': 4.35}


 88%|████████▊ | 11901/13570 [26:36<03:45,  7.40it/s]

{'loss': 0.0621, 'learning_rate': 2.461311717022845e-06, 'epoch': 4.38}


 88%|████████▊ | 12001/13570 [26:50<03:28,  7.54it/s]

{'loss': 0.0651, 'learning_rate': 2.3139277818717763e-06, 'epoch': 4.42}


 89%|████████▉ | 12101/13570 [27:03<03:14,  7.53it/s]

{'loss': 0.0659, 'learning_rate': 2.1665438467207077e-06, 'epoch': 4.46}


 90%|████████▉ | 12201/13570 [27:16<03:02,  7.50it/s]

{'loss': 0.0604, 'learning_rate': 2.019159911569639e-06, 'epoch': 4.5}


 91%|█████████ | 12301/13570 [27:30<02:49,  7.49it/s]

{'loss': 0.0668, 'learning_rate': 1.8717759764185706e-06, 'epoch': 4.53}


 91%|█████████▏| 12401/13570 [27:43<02:35,  7.51it/s]

{'loss': 0.0616, 'learning_rate': 1.7243920412675019e-06, 'epoch': 4.57}


 92%|█████████▏| 12501/13570 [27:56<02:23,  7.45it/s]

{'loss': 0.066, 'learning_rate': 1.5770081061164336e-06, 'epoch': 4.61}


 93%|█████████▎| 12601/13570 [28:10<02:09,  7.50it/s]

{'loss': 0.065, 'learning_rate': 1.429624170965365e-06, 'epoch': 4.64}


 94%|█████████▎| 12701/13570 [28:23<01:52,  7.74it/s]

{'loss': 0.0641, 'learning_rate': 1.2822402358142963e-06, 'epoch': 4.68}


 94%|█████████▍| 12801/13570 [28:36<01:44,  7.33it/s]

{'loss': 0.064, 'learning_rate': 1.1348563006632279e-06, 'epoch': 4.72}


 95%|█████████▌| 12901/13570 [28:49<01:28,  7.53it/s]

{'loss': 0.0669, 'learning_rate': 9.874723655121592e-07, 'epoch': 4.75}


 96%|█████████▌| 13001/13570 [29:02<01:16,  7.48it/s]

{'loss': 0.0645, 'learning_rate': 8.400884303610906e-07, 'epoch': 4.79}


 97%|█████████▋| 13101/13570 [29:16<01:03,  7.39it/s]

{'loss': 0.0645, 'learning_rate': 6.927044952100222e-07, 'epoch': 4.83}


 97%|█████████▋| 13201/13570 [29:29<00:49,  7.38it/s]

{'loss': 0.0622, 'learning_rate': 5.453205600589536e-07, 'epoch': 4.86}


 98%|█████████▊| 13301/13570 [29:42<00:35,  7.56it/s]

{'loss': 0.0632, 'learning_rate': 3.97936624907885e-07, 'epoch': 4.9}


 99%|█████████▉| 13401/13570 [29:56<00:22,  7.59it/s]

{'loss': 0.063, 'learning_rate': 2.505526897568165e-07, 'epoch': 4.94}


 99%|█████████▉| 13501/13570 [30:09<00:09,  7.53it/s]

{'loss': 0.0627, 'learning_rate': 1.0316875460574797e-07, 'epoch': 4.97}


                                                     
100%|██████████| 13570/13570 [30:30<00:00,  7.55it/s]

{'eval_loss': 0.08712948858737946, 'eval_subset_accuracy': 0.4730925175082934, 'eval_micro_f1': 0.5914025184541902, 'eval_macro_f1': 0.47261614528044454, 'eval_runtime': 11.5111, 'eval_samples_per_second': 471.37, 'eval_steps_per_second': 29.537, 'epoch': 5.0}


100%|██████████| 13570/13570 [30:32<00:00,  7.55it/s]

{'train_runtime': 1832.9214, 'train_samples_per_second': 118.418, 'train_steps_per_second': 7.403, 'train_loss': 0.08442423672208736, 'epoch': 5.0}


100%|██████████| 13570/13570 [30:33<00:00,  7.40it/s]



Evaluating on validation set...


100%|██████████| 340/340 [00:11<00:00, 29.56it/s]


Validation Results:
╒════╤═════════════╤════════════════════════╤═════════════════╤═════════════════╤════════════════╤═══════════════════════════╤═════════════════════════╤═════════╕
│    │   eval_loss │   eval_subset_accuracy │   eval_micro_f1 │   eval_macro_f1 │   eval_runtime │   eval_samples_per_second │   eval_steps_per_second │   epoch │
╞════╪═════════════╪════════════════════════╪═════════════════╪═════════════════╪════════════════╪═══════════════════════════╪═════════════════════════╪═════════╡
│  0 │   0.0844298 │               0.460192 │         0.57608 │        0.441173 │        11.5908 │                   468.128 │                  29.333 │       5 │
╘════╧═════════════╧════════════════════════╧═════════════════╧═════════════════╧════════════════╧═══════════════════════════╧═════════════════════════╧═════════╛

Evaluating on test set...


100%|██████████| 340/340 [00:11<00:00, 29.77it/s]

Test Results:
╒════╤═════════════╤════════════════════════╤═════════════════╤═════════════════╤════════════════╤═══════════════════════════╤═════════════════════════╤═════════╕
│    │   eval_loss │   eval_subset_accuracy │   eval_micro_f1 │   eval_macro_f1 │   eval_runtime │   eval_samples_per_second │   eval_steps_per_second │   epoch │
╞════╪═════════════╪════════════════════════╪═════════════════╪═════════════════╪════════════════╪═══════════════════════════╪═════════════════════════╪═════════╡
│  0 │   0.0840901 │                0.46066 │        0.576743 │        0.441563 │        11.4569 │                   473.688 │                  29.676 │       5 │
╘════╧═════════════╧════════════════════════╧═════════════════╧═════════════════╧════════════════╧═══════════════════════════╧═════════════════════════╧═════════╛

Sample inference text: I am feeling really happy and excited about today!
Predicted probabilities for each emotion:
[5.21914177e-02 1.28694540e-02 8.52493919e-04 1.108821




In [5]:
# Use a specific prompt for inference
prompt_text = "I want a burger"
probs = predict_emotions(prompt_text)

# Get the label names from the dataset (GoEmotions)
label_names = dataset["train"].features["labels"].feature.names

threshold = 0.5 

# Convert probabilities to binary predictions
predicted_binary = (probs > threshold).astype(int)

# For multi-label classification, there may be multiple labels predicted.
predicted_labels = [label_names[i] for i, pred in enumerate(predicted_binary) if pred == 1]

print("\nPredicted labels (threshold={}):".format(threshold))
print(predicted_labels)



Predicted labels (threshold=0.5):
['desire']


In [3]:
import transformers
print(transformers.__version__)

4.31.0
