In [5]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DefaultDataCollator
)
from datasets import Dataset, DatasetDict, Value, Sequence

# 1. Read CSV files
df1 = pd.read_csv("goemotions/data/full_dataset/goemotions_1.csv")
df2 = pd.read_csv("goemotions/data/full_dataset/goemotions_2.csv")
df3 = pd.read_csv("goemotions/data/full_dataset/goemotions_3.csv")

# Combine them if needed (this step depends on how your CSVs are structured)
df = pd.concat([df1, df2, df3], ignore_index=True)

# Inspect
print("Combined dataset size:", len(df))
print(df.head())


Combined dataset size: 211225
                                                text       id  \
0                                    That game hurt.  eew5j0j   
1   >sexuality shouldn’t be a grouping category I...  eemcysk   
2     You do right, if you don't care then fuck 'em!  ed2mah1   
3                                 Man I love reddit.  eeibobj   
4  [NAME] was nowhere near them, he was by the Fa...  eda6yn6   

                author            subreddit    link_id   parent_id  \
0                Brdd9                  nrl  t3_ajis4z  t1_eew18eq   
1          TheGreen888     unpopularopinion  t3_ai4q37   t3_ai4q37   
2             Labalool          confessions  t3_abru74  t1_ed2m7g7   
3        MrsRobertshaw             facepalm  t3_ahulml   t3_ahulml   
4  American_Fascist713  starwarsspeculation  t3_ackt2f  t1_eda65q2   

    created_utc  rater_id  example_very_unclear  admiration  ...  love  \
0  1.548381e+09         1                 False           0  ...     0   
1  1.54808

In [6]:
# Suppose your CSV has a column "labels" that is a list of emotion strings:
# e.g. row["labels"] might be ["joy", "amusement"]

label_groups = {
    "anger": "anger",
    "annoyance": "anger",
    "disgust": "anger",
    "joy": "joy",
    "amusement": "joy",
    "excitement": "joy",
    "sadness": "sadness",
    "grief": "sadness",
    "disappointment": "sadness",
    "love": "love",
    "caring": "love",
    "fear": "fear",
    "nervousness": "fear",
    "surprise": "surprise",
    # ... etc ...
    "admiration": "admiration",
    "approval": "approval",
    "confusion": "confusion",
    "curiosity": "curiosity",
    "desire": "desire",
    "disapproval": "disapproval",
    "embarrassment": "embarrassment",
    "gratitude": "gratitude",
    "optimism": "optimism",
    "pride": "pride",
    "realization": "realization",
    "relief": "relief",
    "remorse": "remorse",
    "neutral": "neutral"
}

# Build a sorted list of unique new labels:
unique_new_labels = sorted(list(set(label_groups.values())))
print("New label set:", unique_new_labels)
# e.g. ['admiration', 'anger', 'approval', 'confusion', 'curiosity', 'desire', 
#       'disapproval', 'embarrassment', 'fear', 'gratitude', 'joy', 'love', 
#       'neutral', 'optimism', 'pride', 'realization', 'relief', 'remorse', 
#       'sadness', 'surprise']

# Create a mapping from label -> index
new_label2id = {lbl: i for i, lbl in enumerate(unique_new_labels)}

def map_labels_to_new(labels_list):
    """
    labels_list is the list of original label strings for a single example.
    We'll map each to the new label, then build a multi-hot vector.
    """
    # Start all zeros
    multi_hot = [0] * len(unique_new_labels)
    for old_lbl in labels_list:
        if old_lbl in label_groups:
            new_lbl = label_groups[old_lbl]
            idx = new_label2id[new_lbl]
            multi_hot[idx] = 1
        # If there's a label not in label_groups, decide how to handle it (ignore or add it)
    return multi_hot


New label set: ['admiration', 'anger', 'approval', 'confusion', 'curiosity', 'desire', 'disapproval', 'embarrassment', 'fear', 'gratitude', 'joy', 'love', 'neutral', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise']


In [9]:
label_groups = {
    "anger": "anger",
    "disgust": "anger",   # Merge disgust into "anger"

    "joy": "joy",
    "amusement": "joy",
    "excitement": "joy",

    "sadness": "sadness",
    "grief": "sadness",
    "disappointment": "sadness",

    "love": "love",
    "caring": "love",

    "fear": "fear",
    "nervousness": "fear",

    # Keep some as-is:
    "admiration": "admiration",
    "approval":   "approval",
    "confusion":  "confusion",
    "curiosity":  "curiosity",
    "desire":     "desire",
    "disapproval":"disapproval",
    "embarrassment":"embarrassment",
    "gratitude":  "gratitude",
    "optimism":   "optimism",
    "pride":      "pride",
    "realization":"realization",
    "relief":     "relief",
    "remorse":    "remorse",
    "sadness":    "sadness",
    "surprise":   "surprise",
    "neutral":    "neutral"
}


In [14]:
unique_new_labels = sorted(set(label_groups.values()))
print("New label set:", unique_new_labels)
new_label2id = {lbl: i for i, lbl in enumerate(unique_new_labels)}


New label set: ['admiration', 'anger', 'approval', 'confusion', 'curiosity', 'desire', 'disapproval', 'embarrassment', 'fear', 'gratitude', 'joy', 'love', 'neutral', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise']


In [22]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DefaultDataCollator
)
from datasets import Dataset, DatasetDict, Sequence, Value
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

# 1) Read your local CSV
# If you have multiple CSVs (goemotions_1.csv, goemotions_2.csv, etc.), concatenate them
df = pd.read_csv("goemotions/data/full_dataset/goemotions_1.csv")
# df2 = pd.read_csv("goemotions/data/full_dataset/goemotions_2.csv")
# df3 = pd.read_csv("goemotions/data/full_dataset/goemotions_3.csv")
# df = pd.concat([df, df2, df3], ignore_index=True)

print("Columns in df:", df.columns)
print("Dataset size:", len(df))

# 2) List of emotion columns (one column per emotion, with 0/1 values).
#    Make sure these match EXACTLY the columns in your DataFrame.
emotions = [
    "admiration", "amusement", "anger", "approval", "caring",
    "confusion", "curiosity", "desire", "disappointment", "disapproval",
    "disgust", "embarrassment", "excitement", "fear", "gratitude",
    "grief", "joy", "love", "nervousness", "optimism", "pride",
    "realization", "relief", "remorse", "sadness", "surprise", "neutral"
]

# 3) Define how to merge overlapping emotions:
#    Key = original emotion column, Value = your new "merged" label name.
label_groups = {
    # Merge 'disgust' into 'anger'
    "anger": "anger",
    "disgust": "anger",

    # Merge 'amusement' & 'excitement' into 'joy'
    "joy": "joy",
    "amusement": "joy",
    "excitement": "joy",

    # Merge 'grief' & 'disappointment' into 'sadness'
    "sadness": "sadness",
    "grief": "sadness",
    "disappointment": "sadness",

    # Merge 'caring' into 'love'
    "love": "love",
    "caring": "love",

    # Merge 'nervousness' into 'fear'
    "fear": "fear",
    "nervousness": "fear",

    # Keep everything else as is
    "admiration": "admiration",
    "approval": "approval",
    "confusion": "confusion",
    "curiosity": "curiosity",
    "desire": "desire",
    "disapproval": "disapproval",
    "embarrassment": "embarrassment",
    "gratitude": "gratitude",
    "optimism": "optimism",
    "pride": "pride",
    "realization": "realization",
    "relief": "relief",
    "remorse": "remorse",
    "neutral": "neutral",
    "surprise": "surprise"
}

# Build a sorted list of your new merged labels
unique_new_labels = sorted(set(label_groups.values()))
print("New label set:", unique_new_labels)

# Map each merged label to an index
new_label2id = {lbl: i for i, lbl in enumerate(unique_new_labels)}


Columns in df: Index(['text', 'id', 'author', 'subreddit', 'link_id', 'parent_id',
       'created_utc', 'rater_id', 'example_very_unclear', 'admiration',
       'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion',
       'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust',
       'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy',
       'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief',
       'remorse', 'sadness', 'surprise', 'neutral'],
      dtype='object')
Dataset size: 70000
New label set: ['admiration', 'anger', 'approval', 'confusion', 'curiosity', 'desire', 'disapproval', 'embarrassment', 'fear', 'gratitude', 'joy', 'love', 'neutral', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise']


In [23]:
def row_to_new_multi_hot(row):
    # multi_hot has length = number of merged labels
    multi_hot = [0]*len(unique_new_labels)
    # For each original emotion column
    for emo in emotions:
        # If this row has emo=1, we set the merged label index to 1
        if row[emo] == 1:
            merged_label = label_groups[emo]  # e.g. "anger" if emo is "disgust"
            idx = new_label2id[merged_label]
            multi_hot[idx] = 1
    return multi_hot

# Apply the function to each row
df["new_multi_hot"] = df.apply(row_to_new_multi_hot, axis=1)


  7%|▋         | 7186/105615 [03:30<43:04, 38.09it/s]

In [24]:
# 80% train, 10% val, 10% test
df_train, df_temp = train_test_split(df, test_size=0.2, random_state=42)
df_val, df_test = train_test_split(df_temp, test_size=0.5, random_state=42)

print("Train size:", len(df_train))
print("Val size:", len(df_val))
print("Test size:", len(df_test))


Train size: 56000
Val size: 7000
Test size: 7000


In [25]:
ds_train = Dataset.from_pandas(df_train)
ds_val   = Dataset.from_pandas(df_val)
ds_test  = Dataset.from_pandas(df_test)

dataset = DatasetDict({
    "train": ds_train,
    "validation": ds_val,
    "test": ds_test
})


In [26]:
model_checkpoint = "distilbert-base-uncased"  # a smaller model
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def preprocess_function(examples):
    tok = tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True, 
        max_length=128
    )
    # Attach the new_multi_hot vector as labels
    tok["labels"] = examples["new_multi_hot"]
    return tok

encoded_dataset = dataset.map(preprocess_function, batched=True)

# Cast the labels to float for multi-label classification
encoded_dataset = encoded_dataset.cast_column("labels", Sequence(Value("float32")))
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


Map: 100%|██████████| 56000/56000 [00:02<00:00, 26149.90 examples/s]
Map: 100%|██████████| 7000/7000 [00:00<00:00, 27635.04 examples/s]
Map: 100%|██████████| 7000/7000 [00:00<00:00, 28176.57 examples/s]
Casting the dataset: 100%|██████████| 56000/56000 [00:00<00:00, 501712.08 examples/s]
Casting the dataset: 100%|██████████| 7000/7000 [00:00<00:00, 428346.12 examples/s]
Casting the dataset: 100%|██████████| 7000/7000 [00:00<00:00, 422204.89 examples/s]


In [27]:
num_labels = len(unique_new_labels)
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels
)
model.config.problem_type = "multi_label_classification"

# Compute pos_weight if desired
def compute_pos_weight(dataset_split, n_labels):
    counts = np.zeros(n_labels)
    for example in dataset_split:
        counts += np.array(example["labels"])
    total_samples = len(dataset_split)
    # pos_weight = (N - count) / count
    pos_weight = (total_samples - counts) / (counts + 1e-5)
    return torch.tensor(pos_weight, dtype=torch.float32)

pos_weight = compute_pos_weight(encoded_dataset["train"], num_labels)

# Custom Trainer to use BCEWithLogitsLoss + pos_weight
class WeightedBCELossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs["labels"]
        outputs = model(**inputs)
        logits = outputs.logits
        bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(logits.device))
        loss = bce(logits, labels)
        return (loss, outputs) if return_outputs else loss

# Define a metric function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    subset_accuracy = np.mean(np.all(preds == labels, axis=1))
    micro_f1 = f1_score(labels, preds, average="micro", zero_division=0)
    macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
    return {
        "subset_accuracy": subset_accuracy,
        "micro_f1": micro_f1,
        "macro_f1": macro_f1
    }


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
training_args = TrainingArguments(
    output_dir="./distilbert_merged_emotions",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=100,
    report_to="none",
    load_best_model_at_end=True,
    save_total_limit=1,
    fp16=True  # if your GPU supports it
)

trainer = WeightedBCELossTrainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=DefaultDataCollator(),
    compute_metrics=compute_metrics,
)

trainer.train()


  0%|          | 106/105615 [00:05<45:59, 38.24it/s] 

{'loss': 1.3005, 'learning_rate': 9.46790380609733e-08, 'epoch': 0.0}


  0%|          | 207/105615 [00:08<47:13, 37.20it/s]

{'loss': 1.3277, 'learning_rate': 1.893580761219466e-07, 'epoch': 0.01}


  0%|          | 303/105615 [00:10<47:27, 36.98it/s]

{'loss': 1.3556, 'learning_rate': 2.830903238023102e-07, 'epoch': 0.01}


  0%|          | 404/105615 [00:13<46:26, 37.76it/s]

{'loss': 1.335, 'learning_rate': 3.7776936186328347e-07, 'epoch': 0.02}


  0%|          | 506/105615 [00:16<45:56, 38.13it/s]

{'loss': 1.3529, 'learning_rate': 4.724483999242568e-07, 'epoch': 0.02}


  1%|          | 607/105615 [00:18<44:47, 39.07it/s]

{'loss': 1.3172, 'learning_rate': 5.671274379852301e-07, 'epoch': 0.03}


  1%|          | 705/105615 [00:21<45:40, 38.28it/s]

{'loss': 1.3012, 'learning_rate': 6.618064760462034e-07, 'epoch': 0.03}


  1%|          | 807/105615 [00:23<44:46, 39.02it/s]

{'loss': 1.3514, 'learning_rate': 7.564855141071767e-07, 'epoch': 0.04}


  1%|          | 905/105615 [00:26<45:50, 38.07it/s]

{'loss': 1.2958, 'learning_rate': 8.511645521681501e-07, 'epoch': 0.04}


  1%|          | 1003/105615 [00:29<48:11, 36.18it/s]

{'loss': 1.336, 'learning_rate': 9.458435902291234e-07, 'epoch': 0.05}


  1%|          | 1104/105615 [00:31<46:15, 37.66it/s]

{'loss': 1.302, 'learning_rate': 1.0405226282900967e-06, 'epoch': 0.05}


  1%|          | 1203/105615 [00:34<45:21, 38.36it/s]

{'loss': 1.3126, 'learning_rate': 1.13520166635107e-06, 'epoch': 0.06}


  1%|          | 1303/105615 [00:36<48:54, 35.54it/s]

{'loss': 1.3054, 'learning_rate': 1.2298807044120432e-06, 'epoch': 0.06}


  1%|▏         | 1407/105615 [00:39<45:08, 38.47it/s]

{'loss': 1.3622, 'learning_rate': 1.3245597424730167e-06, 'epoch': 0.07}


  1%|▏         | 1504/105615 [00:42<47:06, 36.84it/s]

{'loss': 1.3093, 'learning_rate': 1.41829199015338e-06, 'epoch': 0.07}


  2%|▏         | 1605/105615 [00:45<48:18, 35.88it/s]

{'loss': 1.3078, 'learning_rate': 1.5129710282143534e-06, 'epoch': 0.08}


  2%|▏         | 1705/105615 [00:47<50:08, 34.53it/s]

{'loss': 1.3523, 'learning_rate': 1.6076500662753266e-06, 'epoch': 0.08}


  2%|▏         | 1805/105615 [00:50<50:18, 34.39it/s]

{'loss': 1.2882, 'learning_rate': 1.7013823139556904e-06, 'epoch': 0.09}


  2%|▏         | 1905/105615 [00:53<49:02, 35.24it/s]

{'loss': 1.2358, 'learning_rate': 1.7960613520166637e-06, 'epoch': 0.09}


  2%|▏         | 2005/105615 [00:56<48:17, 35.75it/s]

{'loss': 1.3135, 'learning_rate': 1.8907403900776369e-06, 'epoch': 0.09}


  2%|▏         | 2105/105615 [00:59<46:10, 37.37it/s]

{'loss': 1.255, 'learning_rate': 1.9854194281386104e-06, 'epoch': 0.1}


  2%|▏         | 2206/105615 [01:01<47:23, 36.36it/s]

{'loss': 1.2553, 'learning_rate': 2.0800984661995833e-06, 'epoch': 0.1}


  2%|▏         | 2305/105615 [01:04<46:18, 37.18it/s]

{'loss': 1.2514, 'learning_rate': 2.1747775042605567e-06, 'epoch': 0.11}


  2%|▏         | 2405/105615 [01:07<46:47, 36.77it/s]

{'loss': 1.226, 'learning_rate': 2.26945654232153e-06, 'epoch': 0.11}


  2%|▏         | 2505/105615 [01:10<47:56, 35.85it/s]

{'loss': 1.2778, 'learning_rate': 2.364135580382504e-06, 'epoch': 0.12}


  2%|▏         | 2605/105615 [01:12<50:54, 33.72it/s]

{'loss': 1.2624, 'learning_rate': 2.4588146184434768e-06, 'epoch': 0.12}


  3%|▎         | 2705/105615 [01:15<43:17, 39.62it/s]

{'loss': 1.2179, 'learning_rate': 2.55349365650445e-06, 'epoch': 0.13}


  3%|▎         | 2805/105615 [01:18<46:08, 37.13it/s]

{'loss': 1.2137, 'learning_rate': 2.6481726945654235e-06, 'epoch': 0.13}


  3%|▎         | 2905/105615 [01:21<49:41, 34.45it/s]

{'loss': 1.2482, 'learning_rate': 2.7428517326263964e-06, 'epoch': 0.14}


  3%|▎         | 3005/105615 [01:24<47:29, 36.00it/s]

{'loss': 1.2208, 'learning_rate': 2.83753077068737e-06, 'epoch': 0.14}


  3%|▎         | 3106/105615 [01:26<48:43, 35.06it/s]

{'loss': 1.2908, 'learning_rate': 2.9322098087483435e-06, 'epoch': 0.15}


  3%|▎         | 3205/105615 [01:29<45:38, 37.40it/s]

{'loss': 1.1706, 'learning_rate': 3.0268888468093164e-06, 'epoch': 0.15}


  3%|▎         | 3304/105615 [01:32<45:51, 37.18it/s]

{'loss': 1.1336, 'learning_rate': 3.1215678848702902e-06, 'epoch': 0.16}


  3%|▎         | 3404/105615 [01:34<47:12, 36.09it/s]

{'loss': 1.0896, 'learning_rate': 3.216246922931263e-06, 'epoch': 0.16}


  3%|▎         | 3506/105615 [01:37<45:26, 37.45it/s]

{'loss': 1.1419, 'learning_rate': 3.310925960992237e-06, 'epoch': 0.17}


  3%|▎         | 3603/105615 [01:40<45:38, 37.25it/s]

{'loss': 1.0961, 'learning_rate': 3.40560499905321e-06, 'epoch': 0.17}


  4%|▎         | 3705/105615 [01:42<46:19, 36.66it/s]

{'loss': 1.2392, 'learning_rate': 3.500284037114183e-06, 'epoch': 0.18}


  4%|▎         | 3805/105615 [01:45<43:29, 39.01it/s]

{'loss': 1.1702, 'learning_rate': 3.594963075175156e-06, 'epoch': 0.18}


  4%|▎         | 3905/105615 [01:48<44:56, 37.72it/s]

{'loss': 1.0587, 'learning_rate': 3.68964211323613e-06, 'epoch': 0.18}


  4%|▍         | 4003/105615 [01:50<47:32, 35.62it/s]

{'loss': 1.2097, 'learning_rate': 3.784321151297103e-06, 'epoch': 0.19}


  4%|▍         | 4107/105615 [01:53<45:28, 37.20it/s]

{'loss': 0.9737, 'learning_rate': 3.879000189358077e-06, 'epoch': 0.19}


  4%|▍         | 4203/105615 [01:56<46:43, 36.17it/s]

{'loss': 1.127, 'learning_rate': 3.97367922741905e-06, 'epoch': 0.2}


  4%|▍         | 4305/105615 [01:58<44:08, 38.25it/s]

{'loss': 1.1101, 'learning_rate': 4.068358265480023e-06, 'epoch': 0.2}


  4%|▍         | 4405/105615 [02:01<44:57, 37.52it/s]

{'loss': 1.1097, 'learning_rate': 4.163037303540997e-06, 'epoch': 0.21}


  4%|▍         | 4506/105615 [02:04<45:43, 36.85it/s]

{'loss': 1.0491, 'learning_rate': 4.257716341601969e-06, 'epoch': 0.21}


  4%|▍         | 4604/105615 [02:06<43:11, 38.98it/s]

{'loss': 1.1313, 'learning_rate': 4.352395379662943e-06, 'epoch': 0.22}


  4%|▍         | 4706/105615 [02:09<44:06, 38.14it/s]

{'loss': 1.1074, 'learning_rate': 4.446127627343307e-06, 'epoch': 0.22}


  5%|▍         | 4803/105615 [02:12<45:03, 37.29it/s]

{'loss': 1.1194, 'learning_rate': 4.54080666540428e-06, 'epoch': 0.23}


  5%|▍         | 4905/105615 [02:14<46:41, 35.95it/s]

{'loss': 1.1655, 'learning_rate': 4.6354857034652534e-06, 'epoch': 0.23}


  5%|▍         | 5005/105615 [02:17<47:08, 35.57it/s]

{'loss': 1.0931, 'learning_rate': 4.730164741526227e-06, 'epoch': 0.24}


  5%|▍         | 5106/105615 [02:20<44:11, 37.91it/s]

{'loss': 1.0871, 'learning_rate': 4.8248437795872e-06, 'epoch': 0.24}


  5%|▍         | 5208/105615 [02:23<44:59, 37.19it/s]

{'loss': 1.0289, 'learning_rate': 4.919522817648173e-06, 'epoch': 0.25}


  5%|▌         | 5306/105615 [02:25<47:01, 35.56it/s]

{'loss': 1.0624, 'learning_rate': 5.014201855709147e-06, 'epoch': 0.25}


  5%|▌         | 5403/105615 [02:28<46:31, 35.90it/s]

{'loss': 0.9564, 'learning_rate': 5.108880893770119e-06, 'epoch': 0.26}


  5%|▌         | 5508/105615 [02:31<43:36, 38.26it/s]

{'loss': 1.0435, 'learning_rate': 5.203559931831093e-06, 'epoch': 0.26}


  5%|▌         | 5604/105615 [02:34<44:18, 37.62it/s]

{'loss': 1.0715, 'learning_rate': 5.298238969892067e-06, 'epoch': 0.27}


  5%|▌         | 5706/105615 [02:37<46:45, 35.61it/s]

{'loss': 1.0944, 'learning_rate': 5.392918007953039e-06, 'epoch': 0.27}


  5%|▌         | 5803/105615 [02:39<44:21, 37.51it/s]

{'loss': 0.9906, 'learning_rate': 5.487597046014013e-06, 'epoch': 0.27}


  6%|▌         | 5906/105615 [02:42<47:16, 35.16it/s]

{'loss': 1.0333, 'learning_rate': 5.582276084074987e-06, 'epoch': 0.28}


  6%|▌         | 6006/105615 [02:45<46:12, 35.93it/s]

{'loss': 1.0273, 'learning_rate': 5.6769551221359594e-06, 'epoch': 0.28}


  6%|▌         | 6106/105615 [02:48<44:43, 37.08it/s]

{'loss': 1.0863, 'learning_rate': 5.771634160196933e-06, 'epoch': 0.29}


  6%|▌         | 6206/105615 [02:50<44:20, 37.36it/s]

{'loss': 1.0505, 'learning_rate': 5.866313198257906e-06, 'epoch': 0.29}


  6%|▌         | 6306/105615 [02:53<45:41, 36.22it/s]

{'loss': 0.9142, 'learning_rate': 5.9609922363188795e-06, 'epoch': 0.3}


  6%|▌         | 6406/105615 [02:56<43:41, 37.85it/s]

{'loss': 1.0958, 'learning_rate': 6.055671274379853e-06, 'epoch': 0.3}


  6%|▌         | 6506/105615 [02:58<43:29, 37.98it/s]

{'loss': 0.935, 'learning_rate': 6.150350312440826e-06, 'epoch': 0.31}


  6%|▋         | 6606/105615 [03:01<46:12, 35.72it/s]

{'loss': 1.0416, 'learning_rate': 6.245029350501799e-06, 'epoch': 0.31}


  6%|▋         | 6706/105615 [03:04<43:28, 37.92it/s]

{'loss': 1.0358, 'learning_rate': 6.339708388562773e-06, 'epoch': 0.32}


  6%|▋         | 6806/105615 [03:06<44:26, 37.06it/s]

{'loss': 0.9607, 'learning_rate': 6.434387426623746e-06, 'epoch': 0.32}


  7%|▋         | 6906/105615 [03:09<44:51, 36.68it/s]

{'loss': 0.9704, 'learning_rate': 6.529066464684719e-06, 'epoch': 0.33}


  7%|▋         | 7003/105615 [03:12<43:57, 37.39it/s]

{'loss': 1.1562, 'learning_rate': 6.622798712365083e-06, 'epoch': 0.33}


  7%|▋         | 7103/105615 [03:14<44:07, 37.22it/s]

{'loss': 1.047, 'learning_rate': 6.7174777504260555e-06, 'epoch': 0.34}


  7%|▋         | 7183/105615 [03:17<43:04, 38.09it/s]

KeyboardInterrupt: 