In [63]:
import pandas as pd
import numpy as np

import torch

from transformers import RobertaTokenizer, RobertaForSequenceClassification, AutoTokenizer

In [64]:
df = pd.read_csv('../data/esnli_train_1.csv') #rmb to train on whole dataset
df.head()

Unnamed: 0,pairID,gold_label,Sentence1,Sentence2,Explanation_1,WorkerId,Sentence1_marked_1,Sentence2_marked_1,Sentence1_Highlighted_1,Sentence2_Highlighted_1
0,3416050480.jpg#4r1n,neutral,A person on a horse jumps over a broken down a...,A person is training his horse for a competition.,the person is not necessarily training his horse,AF0PI3RISB5Q7,A person on a horse jumps over a broken down a...,A person is *training* *his* *horse* for a co...,{},345
1,3416050480.jpg#4r1c,contradiction,A person on a horse jumps over a broken down a...,"A person is at a diner, ordering an omelette.",One cannot be on a jumping horse cannot be a d...,A36ZT2WFIA2HMF,A person *on* *a* *horse* *jumps* over a brok...,"A person *is* *at* *a* *diner,* *ordering* an...",4235,25436
2,3416050480.jpg#4r1e,entailment,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.",a broken down airplane is outdoors,A2GK75ZQTX2RDZ,A person on a horse jumps over *a* *broken* *...,"A person is *outdoors,* on a horse.",89107,3
3,2267923837.jpg#2r1n,neutral,Children smiling and waving at camera,They are smiling at their parents,Just because they are smiling and waving at a ...,A18TOIDG32QICP,Children smiling and waving at camera,They are smiling *at* *their* *parents*,{},534
4,2267923837.jpg#2r1e,entailment,Children smiling and waving at camera,There are children present,The children must be present to see them smili...,AEX0YE6TUZRHT,*Children* *smiling* *and* *waving* at camera,There are children *present*,0132,3


In [65]:
def renameColumnsTrain(df):
    return df.rename(columns={'Sentence1': 'premise', 'Sentence2': 'hypothesis', 'Explanation_1': 'explanation'}).drop(["WorkerId", "Sentence1_Highlighted_1", "Sentence2_Highlighted_1"], axis=1)

df_cleaned = renameColumnsTrain(df)

In [66]:
label_to_id = {"entailment": 0, "neutral": 1, "contradiction": 2}
id_to_label = {v: k for k, v in label_to_id.items()}

## Test

In [67]:
# cell for testing the model's output for a single example
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')

# Tokenize input
premise = df_cleaned['premise'][0]
hypothesis = df_cleaned['hypothesis'][0]
explanation = df_cleaned['explanation'][0]
actual_label = df_cleaned['gold_label'][0]
encoded_input = tokenizer.encode_plus(premise, hypothesis, explanation, padding=True, truncation=True, return_tensors='pt')

labels = torch.tensor(df_cleaned['gold_label'].replace(label_to_id).tolist())[0]
print(encoded_input)
output = model(**encoded_input)

predicted_class = torch.argmax(output.logits, dim=1)

print(f"Premise: {premise}\nHypothesis: {hypothesis}\nExplanation: {explanation}\n")
print(f"True class: {actual_label}")
print(f"Predicted class: {id_to_label[predicted_class.item()]}")

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.out_proj.bias']
You should pr

{'input_ids': tensor([[    0,   250,   621,    15,    10,  5253, 13855,    81,    10,  3187,
           159, 16847,     4,     2,     2,   250,   621,    16,  1058,    39,
          5253,    13,    10,  1465,     4,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]])}
Premise: A person on a horse jumps over a broken down airplane.
Hypothesis: A person is training his horse for a competition.
Explanation: the person is not necessarily training his horse

True class: neutral
Predicted class: neutral


Concatenate everything together

# RoBERTa classifier

---

In [272]:
def filterNan(df):
    return df.dropna()

# def tokenize(df):
#     return df.apply(lambda x: tokenizer.encode_plus(x['premise'], x['hypothesis'], x['explanation'], padding='max_length', return_tensors='pt'), axis=1)

def convert_to_tensors(df):
    return torch.tensor(df.values)

def encode_labels(df):
    return df.apply(lambda x: label_to_id[x])

template = "Given that {}, it is hypothesized that {}. {}."

def tokenize(df):
    tokenized_batch = []
    for _, row in df.iterrows():
        premise = row['premise'].lower()
        if premise[-1] in ['.', '!', '?']:
            premise = premise[:-1]
        hypothesis = row['hypothesis'].lower()
        if hypothesis[-1] in ['.', '!', '?']:
            hypothesis = hypothesis[:-1]
        explanation = row['explanation'].lower()
        if explanation[-1] in ['.', '!', '?']:
            explanation = explanation[:-1]


        encoded_dict = tokenizer.encode_plus(
            text = template.format(premise, hypothesis, explanation),
            # row['premise'], # two ways to encode
            # row['hypothesis'], 
            # row['explanation'],
            padding=True,
            return_tensors='pt',
            # truncation=True
        )
        tokenized_batch.append(encoded_dict)
    return tokenized_batch

[RoBERTA huggingface](https://huggingface.co/FacebookAI/roberta-base#:~:text=RoBERTa%20is%20a%20transformers%20model%20pretrained%20on%20a,to%20generate%20inputs%20and%20labels%20from%20those%20texts)

In [273]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.out_proj.bias']
You should pr

In [274]:
# for mac
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# for nvidia GPUs
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [275]:
df_cleaned = renameColumnsTrain(df)
df_cleaned = df_cleaned[:5000] # change accordingly to the size of dataset you want to train on
df_cleaned = filterNan(df_cleaned)
df_cleaned['gold_label'] = encode_labels(df_cleaned['gold_label'])
# split the data into training and validation sets before processing
train_size = int(0.8 * len(df_cleaned))
val_size = len(df_cleaned) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(df_cleaned, [train_size, val_size])
train_dataset = train_dataset.dataset
val_dataset = val_dataset.dataset

tokenized_input_train = tokenize(train_dataset)
tokenized_input_val = tokenize(val_dataset)
train_labels = convert_to_tensors(train_dataset['gold_label'])
val_labels = convert_to_tensors(val_dataset['gold_label'])


# actual_labels = convert_to_tensors(df_cleaned['gold_label'])

In [276]:
tokenized_input[0]

{'input_ids': tensor([[    0, 18377,    14,    10,   621,    15,    10,  5253, 13855,    81,
            10,  3187,   159, 16847,     6,    24,    16, 45936,    14,    10,
           621,    16,  1058,    39,  5253,    13,    10,  1465,     6,   142,
             5,   621,    16,    45,  4784,  1058,    39,  5253,     4,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

## Sample

Template goes like this:

Given that [PREMISE] it is hypothesized that [HYPOTHESIS]. This is <mask> because [EXPLANATION]

if model predicts contradict then,

Given that a person on a horse jumps over a broken down airplane, it is hypothesized that a person is training his horse for a competition. This is <span style="color:red">contradiction</span> because the person is not necessarily training his horse.

In [282]:
mask_token = tokenizer.mask_token
template = "Given that {}, it is hypothesized that {}. This is {} because {}."

In [279]:
sample_premise = df_cleaned['premise'][0].lower()
sample_premise = sample_premise[:-1] if sample_premise[-1] in ['.', '!', '?'] else sample_premise
sample_hypothesis = df_cleaned['hypothesis'][0].lower()
sample_hypothesis = sample_hypothesis[:-1] if sample_hypothesis[-1] in ['.', '!', '?'] else sample_hypothesis
sample_explanation = df_cleaned['explanation'][0].lower()
sample_explanation = sample_explanation[:-1] if sample_explanation[-1] in ['.', '!', '?'] else sample_explanation

In [286]:
encoded_input = tokenizer(template.format(sample_premise, sample_hypothesis, mask_token, sample_explanation),
                          padding="max_length",
                          truncation=True,
                          return_tensors='pt')

In [287]:
tokenizer.batch_decode(encoded_input.input_ids)

['<s>Given that a person on a horse jumps over a broken down airplane, it is hypothesized that a person is training his horse for a competition. This is<mask> because the person is not necessarily training his horse.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [288]:
encoded_input.attention_mask # 1's are for the actual input tokens, 0's are for the padding

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0

RoBERTa doesn't have token_type_ids, separate with </s> or tokenizer.sep_token

In [289]:
tokenizer.sep_token

'</s>'

In [290]:
input_ids, attention_mask = encoded_input.values()
output = model(input_ids=input_ids, attention_mask=attention_mask)

In [296]:
pred = output.logits.argmax()
pred, id_to_label[pred.item()], id_to_label[df_cleaned['gold_label'][0]]

(tensor(1), 'neutral', 'neutral')

## Predict without training

### updated

In [297]:
premise_template = 'Given that {}, it is hypothesized that {}.'
explanation_template = ' because {}.'

class eSNLIDataset(Dataset):
    def __init__(self, df, train=True):
        self.df = df
        self.train = train

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        example = self.df.iloc[idx,:]
        premise = example["premise"]
        hypothesis = example["hypothesis"]
        explanation = example["explanation"]

        premise = premise_template.format(premise, hypothesis)
        explanation = explanation_template.format(explanation)

        if self.train:
            label = example["gold_label"]
            return premise, explanation, label
        
        return premise, explanation

labels = train_labels

dataset = eSNLIDataset(train_dataset, train=False) # just for testing
dataloader = DataLoader(dataset, batch_size=16)

In [298]:
predictions = []

model.to(device)
model.eval()
with torch.no_grad():
    for batch in dataloader:
        premise, explanation = batch
        encoded_input = tokenizer(premise, explanation, padding=True, truncation=True, return_tensors='pt').to(device)
        outputs = model(**encoded_input)
        logits = outputs.logits

        predictions.extend(logits.argmax(dim=-1).cpu().tolist())



In [300]:
print(calc_f1_score(predictions, labels))

(0.49804746170021025, 0.3316, 0.16601582056673678)


In [301]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# print(tokenized_input)
input_ids = [x['input_ids'].squeeze(0) for x in tokenized_input_train]
input_ids = pad_sequence(input_ids, batch_first=True)
attention_masks = [x['attention_mask'].squeeze(0) for x in tokenized_input_train]
attention_masks = pad_sequence(attention_masks, batch_first=True)

labels = train_labels

dataset = TensorDataset(input_ids, attention_masks, labels)
loader = DataLoader(dataset, batch_size=16)

model.to(device)
model.eval()
predictions = []

with torch.no_grad():
    for batch in loader:
        batch_input_ids, batch_attention_mask, batch_labels = batch

        batch_input_ids = batch_input_ids.to(device)
        batch_attention_mask = batch_attention_mask.to(device)
        
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        logits = outputs.logits.cpu()

        predicted_classes = torch.argmax(logits, dim=1)
        predictions.extend(predicted_classes)

predictions = torch.stack(predictions)

In [302]:
from sklearn.metrics import f1_score

def calc_f1_score(predicted_classes, actual_labels):
    return f1_score(predicted_classes, actual_labels, average='weighted'), f1_score(predicted_classes, actual_labels, average='micro'), f1_score(predicted_classes, actual_labels, average='macro')


In [304]:
print(calc_f1_score(predictions, labels))

(0.49804746170021025, 0.3316, 0.16601582056673678)


Training Loop using Trainer

In [305]:
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset, TensorDataset
import torch

# freeze all the parameters in the base model
for param in model.roberta.parameters():
    param.requires_grad = False

# only train the classification head
for param in model.classifier.parameters():
    param.requires_grad = True

# need to wrap in a dictionary to use the Trainer class
class DictDataset(Dataset):
    def __init__(self, tensor_dataset):
        self.tensor_dataset = tensor_dataset

    def __len__(self):
        return len(self.tensor_dataset)

    def __getitem__(self, idx):
        input_ids, attention_mask, labels = self.tensor_dataset[idx]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

input_ids_train = [x['input_ids'].squeeze(0) for x in tokenized_input_train]
input_ids_train = torch.nn.utils.rnn.pad_sequence(input_ids_train, batch_first=True)
attention_masks_train = [x['attention_mask'].squeeze(0) for x in tokenized_input_train]
attention_masks_train = torch.nn.utils.rnn.pad_sequence(attention_masks_train, batch_first=True)
labels_one_hot_train = torch.nn.functional.one_hot(train_labels, num_classes=3).float()

input_ids_val = [x['input_ids'].squeeze(0) for x in tokenized_input_val]
input_ids_val = torch.nn.utils.rnn.pad_sequence(input_ids_val, batch_first=True)
attention_masks_val = [x['attention_mask'].squeeze(0) for x in tokenized_input_val]
attention_masks_val = torch.nn.utils.rnn.pad_sequence(attention_masks_val, batch_first=True)
labels_one_hot_val = torch.nn.functional.one_hot(val_labels, num_classes=3).float()

tensor_dataset = TensorDataset(input_ids_train, attention_masks_train, labels_one_hot_train)

dataset = DictDataset(tensor_dataset)

validation_tensor_dataset = TensorDataset(input_ids_val, attention_masks_val, labels_one_hot_val)
validation_dataset = DictDataset(validation_tensor_dataset)

training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=10,             
    per_device_train_batch_size=16,  
    per_device_eval_batch_size=64,   
    warmup_steps=500,                
    weight_decay=0.01,               
    logging_dir='./logs',            
    logging_steps=10,                
    evaluation_strategy='steps',     
    eval_steps=50,                   
    save_strategy='epoch',           
    save_steps=100,
)

# Initialize the Trainer with the wrapped dataset
trainer = Trainer(
    model=model,                   
    args=training_args,            
    train_dataset=dataset,
    eval_dataset=validation_dataset
)

# Train
trainer.train()




  0%|          | 0/3130 [00:00<?, ?it/s]

{'loss': 0.7144, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.03}
{'loss': 0.716, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.06}
{'loss': 0.7095, 'learning_rate': 3e-06, 'epoch': 0.1}
{'loss': 0.7117, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.13}
{'loss': 0.7076, 'learning_rate': 5e-06, 'epoch': 0.16}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.7053645253181458, 'eval_runtime': 196.2196, 'eval_samples_per_second': 25.482, 'eval_steps_per_second': 0.403, 'epoch': 0.16}
{'loss': 0.7033, 'learning_rate': 6e-06, 'epoch': 0.19}
{'loss': 0.705, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.22}
{'loss': 0.696, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.26}
{'loss': 0.6886, 'learning_rate': 9e-06, 'epoch': 0.29}
{'loss': 0.6944, 'learning_rate': 1e-05, 'epoch': 0.32}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6850722432136536, 'eval_runtime': 202.136, 'eval_samples_per_second': 24.736, 'eval_steps_per_second': 0.391, 'epoch': 0.32}
{'loss': 0.6781, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.35}
{'loss': 0.6848, 'learning_rate': 1.2e-05, 'epoch': 0.38}
{'loss': 0.6723, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.42}
{'loss': 0.6642, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.45}
{'loss': 0.663, 'learning_rate': 1.5e-05, 'epoch': 0.48}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6625556349754333, 'eval_runtime': 198.0308, 'eval_samples_per_second': 25.249, 'eval_steps_per_second': 0.399, 'epoch': 0.48}
{'loss': 0.6562, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.51}
{'loss': 0.6558, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.54}
{'loss': 0.6527, 'learning_rate': 1.8e-05, 'epoch': 0.58}
{'loss': 0.6491, 'learning_rate': 1.9e-05, 'epoch': 0.61}
{'loss': 0.6523, 'learning_rate': 2e-05, 'epoch': 0.64}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6488460898399353, 'eval_runtime': 194.2217, 'eval_samples_per_second': 25.744, 'eval_steps_per_second': 0.407, 'epoch': 0.64}
{'loss': 0.6511, 'learning_rate': 2.1e-05, 'epoch': 0.67}
{'loss': 0.6468, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.7}
{'loss': 0.6461, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.73}
{'loss': 0.6488, 'learning_rate': 2.4e-05, 'epoch': 0.77}
{'loss': 0.6462, 'learning_rate': 2.5e-05, 'epoch': 0.8}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6406314969062805, 'eval_runtime': 192.0931, 'eval_samples_per_second': 26.029, 'eval_steps_per_second': 0.411, 'epoch': 0.8}
{'loss': 0.6427, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.83}
{'loss': 0.6398, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.86}
{'loss': 0.6417, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.89}
{'loss': 0.6425, 'learning_rate': 2.9e-05, 'epoch': 0.93}
{'loss': 0.6439, 'learning_rate': 3e-05, 'epoch': 0.96}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.637536883354187, 'eval_runtime': 197.4539, 'eval_samples_per_second': 25.322, 'eval_steps_per_second': 0.4, 'epoch': 0.96}
{'loss': 0.6416, 'learning_rate': 3.1e-05, 'epoch': 0.99}
{'loss': 0.6424, 'learning_rate': 3.2000000000000005e-05, 'epoch': 1.02}
{'loss': 0.6402, 'learning_rate': 3.3e-05, 'epoch': 1.05}
{'loss': 0.6416, 'learning_rate': 3.4000000000000007e-05, 'epoch': 1.09}
{'loss': 0.6405, 'learning_rate': 3.5e-05, 'epoch': 1.12}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6355772614479065, 'eval_runtime': 197.744, 'eval_samples_per_second': 25.285, 'eval_steps_per_second': 0.4, 'epoch': 1.12}
{'loss': 0.6378, 'learning_rate': 3.6e-05, 'epoch': 1.15}
{'loss': 0.6363, 'learning_rate': 3.7e-05, 'epoch': 1.18}
{'loss': 0.6377, 'learning_rate': 3.8e-05, 'epoch': 1.21}
{'loss': 0.6426, 'learning_rate': 3.9000000000000006e-05, 'epoch': 1.25}
{'loss': 0.641, 'learning_rate': 4e-05, 'epoch': 1.28}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6345017552375793, 'eval_runtime': 197.1772, 'eval_samples_per_second': 25.358, 'eval_steps_per_second': 0.401, 'epoch': 1.28}
{'loss': 0.637, 'learning_rate': 4.1e-05, 'epoch': 1.31}
{'loss': 0.6383, 'learning_rate': 4.2e-05, 'epoch': 1.34}
{'loss': 0.6384, 'learning_rate': 4.3e-05, 'epoch': 1.37}
{'loss': 0.6354, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.41}
{'loss': 0.6414, 'learning_rate': 4.5e-05, 'epoch': 1.44}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6342917680740356, 'eval_runtime': 193.368, 'eval_samples_per_second': 25.857, 'eval_steps_per_second': 0.409, 'epoch': 1.44}
{'loss': 0.6417, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.47}
{'loss': 0.6354, 'learning_rate': 4.7e-05, 'epoch': 1.5}
{'loss': 0.635, 'learning_rate': 4.8e-05, 'epoch': 1.53}
{'loss': 0.6392, 'learning_rate': 4.9e-05, 'epoch': 1.57}
{'loss': 0.6356, 'learning_rate': 5e-05, 'epoch': 1.6}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.633437991142273, 'eval_runtime': 214.072, 'eval_samples_per_second': 23.357, 'eval_steps_per_second': 0.369, 'epoch': 1.6}
{'loss': 0.6397, 'learning_rate': 4.980988593155894e-05, 'epoch': 1.63}
{'loss': 0.6387, 'learning_rate': 4.9619771863117875e-05, 'epoch': 1.66}
{'loss': 0.6383, 'learning_rate': 4.942965779467681e-05, 'epoch': 1.69}
{'loss': 0.6365, 'learning_rate': 4.923954372623574e-05, 'epoch': 1.73}
{'loss': 0.6326, 'learning_rate': 4.904942965779468e-05, 'epoch': 1.76}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6319358348846436, 'eval_runtime': 209.352, 'eval_samples_per_second': 23.883, 'eval_steps_per_second': 0.377, 'epoch': 1.76}
{'loss': 0.6374, 'learning_rate': 4.8859315589353615e-05, 'epoch': 1.79}
{'loss': 0.6349, 'learning_rate': 4.866920152091255e-05, 'epoch': 1.82}
{'loss': 0.6342, 'learning_rate': 4.847908745247148e-05, 'epoch': 1.85}
{'loss': 0.6384, 'learning_rate': 4.8288973384030424e-05, 'epoch': 1.88}
{'loss': 0.6341, 'learning_rate': 4.8098859315589354e-05, 'epoch': 1.92}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6312624216079712, 'eval_runtime': 205.5102, 'eval_samples_per_second': 24.33, 'eval_steps_per_second': 0.384, 'epoch': 1.92}
{'loss': 0.6295, 'learning_rate': 4.790874524714829e-05, 'epoch': 1.95}
{'loss': 0.6405, 'learning_rate': 4.771863117870723e-05, 'epoch': 1.98}
{'loss': 0.635, 'learning_rate': 4.7528517110266163e-05, 'epoch': 2.01}
{'loss': 0.6362, 'learning_rate': 4.73384030418251e-05, 'epoch': 2.04}
{'loss': 0.6288, 'learning_rate': 4.714828897338403e-05, 'epoch': 2.08}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6302877068519592, 'eval_runtime': 196.4629, 'eval_samples_per_second': 25.45, 'eval_steps_per_second': 0.402, 'epoch': 2.08}
{'loss': 0.6336, 'learning_rate': 4.695817490494297e-05, 'epoch': 2.11}
{'loss': 0.6347, 'learning_rate': 4.67680608365019e-05, 'epoch': 2.14}
{'loss': 0.6344, 'learning_rate': 4.657794676806084e-05, 'epoch': 2.17}
{'loss': 0.6274, 'learning_rate': 4.6387832699619776e-05, 'epoch': 2.2}
{'loss': 0.63, 'learning_rate': 4.619771863117871e-05, 'epoch': 2.24}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6298921704292297, 'eval_runtime': 197.3809, 'eval_samples_per_second': 25.332, 'eval_steps_per_second': 0.4, 'epoch': 2.24}
{'loss': 0.6298, 'learning_rate': 4.600760456273764e-05, 'epoch': 2.27}
{'loss': 0.6372, 'learning_rate': 4.581749049429658e-05, 'epoch': 2.3}
{'loss': 0.6287, 'learning_rate': 4.5627376425855515e-05, 'epoch': 2.33}
{'loss': 0.6313, 'learning_rate': 4.543726235741445e-05, 'epoch': 2.36}
{'loss': 0.627, 'learning_rate': 4.524714828897338e-05, 'epoch': 2.4}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.628957211971283, 'eval_runtime': 205.3187, 'eval_samples_per_second': 24.352, 'eval_steps_per_second': 0.385, 'epoch': 2.4}
{'loss': 0.6303, 'learning_rate': 4.5057034220532325e-05, 'epoch': 2.43}
{'loss': 0.6282, 'learning_rate': 4.4866920152091254e-05, 'epoch': 2.46}
{'loss': 0.6327, 'learning_rate': 4.467680608365019e-05, 'epoch': 2.49}
{'loss': 0.6304, 'learning_rate': 4.448669201520913e-05, 'epoch': 2.52}
{'loss': 0.6319, 'learning_rate': 4.4296577946768064e-05, 'epoch': 2.56}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6277164816856384, 'eval_runtime': 204.3202, 'eval_samples_per_second': 24.471, 'eval_steps_per_second': 0.387, 'epoch': 2.56}
{'loss': 0.6286, 'learning_rate': 4.4106463878327e-05, 'epoch': 2.59}
{'loss': 0.6282, 'learning_rate': 4.391634980988593e-05, 'epoch': 2.62}
{'loss': 0.6317, 'learning_rate': 4.3726235741444873e-05, 'epoch': 2.65}
{'loss': 0.6309, 'learning_rate': 4.35361216730038e-05, 'epoch': 2.68}
{'loss': 0.6303, 'learning_rate': 4.334600760456274e-05, 'epoch': 2.72}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6268128156661987, 'eval_runtime': 199.3827, 'eval_samples_per_second': 25.077, 'eval_steps_per_second': 0.396, 'epoch': 2.72}
{'loss': 0.6308, 'learning_rate': 4.3155893536121676e-05, 'epoch': 2.75}
{'loss': 0.6254, 'learning_rate': 4.296577946768061e-05, 'epoch': 2.78}
{'loss': 0.6297, 'learning_rate': 4.277566539923954e-05, 'epoch': 2.81}
{'loss': 0.6295, 'learning_rate': 4.258555133079848e-05, 'epoch': 2.84}
{'loss': 0.6262, 'learning_rate': 4.2395437262357415e-05, 'epoch': 2.88}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6266138553619385, 'eval_runtime': 204.2653, 'eval_samples_per_second': 24.478, 'eval_steps_per_second': 0.387, 'epoch': 2.88}
{'loss': 0.6278, 'learning_rate': 4.220532319391635e-05, 'epoch': 2.91}
{'loss': 0.6259, 'learning_rate': 4.201520912547529e-05, 'epoch': 2.94}
{'loss': 0.626, 'learning_rate': 4.1825095057034225e-05, 'epoch': 2.97}
{'loss': 0.6271, 'learning_rate': 4.163498098859316e-05, 'epoch': 3.0}
{'loss': 0.6251, 'learning_rate': 4.144486692015209e-05, 'epoch': 3.04}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6256536841392517, 'eval_runtime': 204.9242, 'eval_samples_per_second': 24.399, 'eval_steps_per_second': 0.386, 'epoch': 3.04}
{'loss': 0.6256, 'learning_rate': 4.125475285171103e-05, 'epoch': 3.07}
{'loss': 0.6251, 'learning_rate': 4.1064638783269964e-05, 'epoch': 3.1}
{'loss': 0.6278, 'learning_rate': 4.08745247148289e-05, 'epoch': 3.13}
{'loss': 0.6297, 'learning_rate': 4.068441064638783e-05, 'epoch': 3.16}
{'loss': 0.6211, 'learning_rate': 4.0494296577946774e-05, 'epoch': 3.19}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6245977878570557, 'eval_runtime': 205.189, 'eval_samples_per_second': 24.368, 'eval_steps_per_second': 0.385, 'epoch': 3.19}
{'loss': 0.6227, 'learning_rate': 4.0304182509505703e-05, 'epoch': 3.23}
{'loss': 0.6292, 'learning_rate': 4.011406844106464e-05, 'epoch': 3.26}
{'loss': 0.6262, 'learning_rate': 3.9923954372623577e-05, 'epoch': 3.29}
{'loss': 0.6221, 'learning_rate': 3.973384030418251e-05, 'epoch': 3.32}
{'loss': 0.6249, 'learning_rate': 3.954372623574145e-05, 'epoch': 3.35}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6237746477127075, 'eval_runtime': 202.3721, 'eval_samples_per_second': 24.707, 'eval_steps_per_second': 0.39, 'epoch': 3.35}
{'loss': 0.6253, 'learning_rate': 3.935361216730038e-05, 'epoch': 3.39}
{'loss': 0.6242, 'learning_rate': 3.916349809885932e-05, 'epoch': 3.42}
{'loss': 0.6234, 'learning_rate': 3.897338403041825e-05, 'epoch': 3.45}
{'loss': 0.6214, 'learning_rate': 3.878326996197719e-05, 'epoch': 3.48}
{'loss': 0.6247, 'learning_rate': 3.8593155893536125e-05, 'epoch': 3.51}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6228663921356201, 'eval_runtime': 204.8808, 'eval_samples_per_second': 24.404, 'eval_steps_per_second': 0.386, 'epoch': 3.51}
{'loss': 0.6273, 'learning_rate': 3.840304182509506e-05, 'epoch': 3.55}
{'loss': 0.6303, 'learning_rate': 3.821292775665399e-05, 'epoch': 3.58}
{'loss': 0.624, 'learning_rate': 3.802281368821293e-05, 'epoch': 3.61}
{'loss': 0.6194, 'learning_rate': 3.7832699619771865e-05, 'epoch': 3.64}
{'loss': 0.6197, 'learning_rate': 3.76425855513308e-05, 'epoch': 3.67}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6223964095115662, 'eval_runtime': 204.8504, 'eval_samples_per_second': 24.408, 'eval_steps_per_second': 0.386, 'epoch': 3.67}
{'loss': 0.626, 'learning_rate': 3.745247148288973e-05, 'epoch': 3.71}
{'loss': 0.6189, 'learning_rate': 3.7262357414448674e-05, 'epoch': 3.74}
{'loss': 0.6278, 'learning_rate': 3.7072243346007604e-05, 'epoch': 3.77}
{'loss': 0.6225, 'learning_rate': 3.688212927756654e-05, 'epoch': 3.8}
{'loss': 0.6232, 'learning_rate': 3.669201520912548e-05, 'epoch': 3.83}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.621483564376831, 'eval_runtime': 194.585, 'eval_samples_per_second': 25.696, 'eval_steps_per_second': 0.406, 'epoch': 3.83}
{'loss': 0.6258, 'learning_rate': 3.6501901140684413e-05, 'epoch': 3.87}
{'loss': 0.6254, 'learning_rate': 3.631178707224335e-05, 'epoch': 3.9}
{'loss': 0.6265, 'learning_rate': 3.612167300380228e-05, 'epoch': 3.93}
{'loss': 0.6225, 'learning_rate': 3.593155893536122e-05, 'epoch': 3.96}
{'loss': 0.6212, 'learning_rate': 3.574144486692015e-05, 'epoch': 3.99}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.620911717414856, 'eval_runtime': 202.3112, 'eval_samples_per_second': 24.714, 'eval_steps_per_second': 0.39, 'epoch': 3.99}
{'loss': 0.6201, 'learning_rate': 3.555133079847909e-05, 'epoch': 4.03}
{'loss': 0.6317, 'learning_rate': 3.5361216730038026e-05, 'epoch': 4.06}
{'loss': 0.6178, 'learning_rate': 3.517110266159696e-05, 'epoch': 4.09}
{'loss': 0.6173, 'learning_rate': 3.498098859315589e-05, 'epoch': 4.12}
{'loss': 0.6229, 'learning_rate': 3.479087452471483e-05, 'epoch': 4.15}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6204317808151245, 'eval_runtime': 198.509, 'eval_samples_per_second': 25.188, 'eval_steps_per_second': 0.398, 'epoch': 4.15}
{'loss': 0.621, 'learning_rate': 3.4600760456273765e-05, 'epoch': 4.19}
{'loss': 0.6229, 'learning_rate': 3.44106463878327e-05, 'epoch': 4.22}
{'loss': 0.6197, 'learning_rate': 3.422053231939164e-05, 'epoch': 4.25}
{'loss': 0.6212, 'learning_rate': 3.4030418250950574e-05, 'epoch': 4.28}
{'loss': 0.6191, 'learning_rate': 3.384030418250951e-05, 'epoch': 4.31}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.619526743888855, 'eval_runtime': 197.178, 'eval_samples_per_second': 25.358, 'eval_steps_per_second': 0.401, 'epoch': 4.31}
{'loss': 0.6202, 'learning_rate': 3.365019011406844e-05, 'epoch': 4.35}
{'loss': 0.6198, 'learning_rate': 3.346007604562738e-05, 'epoch': 4.38}
{'loss': 0.6292, 'learning_rate': 3.3269961977186314e-05, 'epoch': 4.41}
{'loss': 0.6173, 'learning_rate': 3.307984790874525e-05, 'epoch': 4.44}
{'loss': 0.6145, 'learning_rate': 3.288973384030418e-05, 'epoch': 4.47}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6185753345489502, 'eval_runtime': 189.707, 'eval_samples_per_second': 26.356, 'eval_steps_per_second': 0.416, 'epoch': 4.47}
{'loss': 0.6198, 'learning_rate': 3.269961977186312e-05, 'epoch': 4.5}
{'loss': 0.616, 'learning_rate': 3.250950570342205e-05, 'epoch': 4.54}
{'loss': 0.6161, 'learning_rate': 3.231939163498099e-05, 'epoch': 4.57}
{'loss': 0.6159, 'learning_rate': 3.2129277566539926e-05, 'epoch': 4.6}
{'loss': 0.6175, 'learning_rate': 3.193916349809886e-05, 'epoch': 4.63}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6179611682891846, 'eval_runtime': 193.3709, 'eval_samples_per_second': 25.857, 'eval_steps_per_second': 0.409, 'epoch': 4.63}
{'loss': 0.6197, 'learning_rate': 3.174904942965779e-05, 'epoch': 4.66}
{'loss': 0.6212, 'learning_rate': 3.155893536121673e-05, 'epoch': 4.7}
{'loss': 0.6129, 'learning_rate': 3.1368821292775665e-05, 'epoch': 4.73}
{'loss': 0.6246, 'learning_rate': 3.11787072243346e-05, 'epoch': 4.76}
{'loss': 0.6182, 'learning_rate': 3.098859315589354e-05, 'epoch': 4.79}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6171596646308899, 'eval_runtime': 190.3297, 'eval_samples_per_second': 26.27, 'eval_steps_per_second': 0.415, 'epoch': 4.79}
{'loss': 0.6218, 'learning_rate': 3.0798479087452475e-05, 'epoch': 4.82}
{'loss': 0.6175, 'learning_rate': 3.060836501901141e-05, 'epoch': 4.86}
{'loss': 0.6171, 'learning_rate': 3.041825095057034e-05, 'epoch': 4.89}
{'loss': 0.6165, 'learning_rate': 3.0228136882129278e-05, 'epoch': 4.92}
{'loss': 0.6176, 'learning_rate': 3.0038022813688214e-05, 'epoch': 4.95}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6168698072433472, 'eval_runtime': 204.4656, 'eval_samples_per_second': 24.454, 'eval_steps_per_second': 0.386, 'epoch': 4.95}
{'loss': 0.6174, 'learning_rate': 2.984790874524715e-05, 'epoch': 4.98}
{'loss': 0.6175, 'learning_rate': 2.9657794676806084e-05, 'epoch': 5.02}
{'loss': 0.6149, 'learning_rate': 2.9467680608365024e-05, 'epoch': 5.05}
{'loss': 0.6175, 'learning_rate': 2.9277566539923957e-05, 'epoch': 5.08}
{'loss': 0.6237, 'learning_rate': 2.908745247148289e-05, 'epoch': 5.11}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6160021424293518, 'eval_runtime': 201.1153, 'eval_samples_per_second': 24.861, 'eval_steps_per_second': 0.393, 'epoch': 5.11}
{'loss': 0.6172, 'learning_rate': 2.8897338403041823e-05, 'epoch': 5.14}
{'loss': 0.6126, 'learning_rate': 2.8707224334600763e-05, 'epoch': 5.18}
{'loss': 0.6183, 'learning_rate': 2.8517110266159696e-05, 'epoch': 5.21}
{'loss': 0.6116, 'learning_rate': 2.832699619771863e-05, 'epoch': 5.24}
{'loss': 0.6216, 'learning_rate': 2.813688212927757e-05, 'epoch': 5.27}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6154322028160095, 'eval_runtime': 203.9195, 'eval_samples_per_second': 24.519, 'eval_steps_per_second': 0.387, 'epoch': 5.27}
{'loss': 0.6176, 'learning_rate': 2.7946768060836502e-05, 'epoch': 5.3}
{'loss': 0.6194, 'learning_rate': 2.775665399239544e-05, 'epoch': 5.34}
{'loss': 0.6165, 'learning_rate': 2.7566539923954375e-05, 'epoch': 5.37}
{'loss': 0.6162, 'learning_rate': 2.7376425855513312e-05, 'epoch': 5.4}
{'loss': 0.6125, 'learning_rate': 2.7186311787072245e-05, 'epoch': 5.43}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6144956350326538, 'eval_runtime': 199.6586, 'eval_samples_per_second': 25.043, 'eval_steps_per_second': 0.396, 'epoch': 5.43}
{'loss': 0.6149, 'learning_rate': 2.6996197718631178e-05, 'epoch': 5.46}
{'loss': 0.6139, 'learning_rate': 2.6806083650190118e-05, 'epoch': 5.5}
{'loss': 0.6143, 'learning_rate': 2.661596958174905e-05, 'epoch': 5.53}
{'loss': 0.6113, 'learning_rate': 2.6425855513307984e-05, 'epoch': 5.56}
{'loss': 0.6176, 'learning_rate': 2.6235741444866924e-05, 'epoch': 5.59}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.613808274269104, 'eval_runtime': 211.7733, 'eval_samples_per_second': 23.61, 'eval_steps_per_second': 0.373, 'epoch': 5.59}
{'loss': 0.6171, 'learning_rate': 2.6045627376425857e-05, 'epoch': 5.62}
{'loss': 0.6178, 'learning_rate': 2.585551330798479e-05, 'epoch': 5.65}
{'loss': 0.6135, 'learning_rate': 2.5665399239543723e-05, 'epoch': 5.69}
{'loss': 0.6153, 'learning_rate': 2.5475285171102663e-05, 'epoch': 5.72}
{'loss': 0.6155, 'learning_rate': 2.5285171102661596e-05, 'epoch': 5.75}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.613182008266449, 'eval_runtime': 207.3071, 'eval_samples_per_second': 24.119, 'eval_steps_per_second': 0.381, 'epoch': 5.75}
{'loss': 0.6175, 'learning_rate': 2.5095057034220533e-05, 'epoch': 5.78}
{'loss': 0.6134, 'learning_rate': 2.490494296577947e-05, 'epoch': 5.81}
{'loss': 0.6134, 'learning_rate': 2.4714828897338406e-05, 'epoch': 5.85}
{'loss': 0.6184, 'learning_rate': 2.452471482889734e-05, 'epoch': 5.88}
{'loss': 0.6087, 'learning_rate': 2.4334600760456276e-05, 'epoch': 5.91}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6128717064857483, 'eval_runtime': 201.09, 'eval_samples_per_second': 24.864, 'eval_steps_per_second': 0.393, 'epoch': 5.91}
{'loss': 0.6124, 'learning_rate': 2.4144486692015212e-05, 'epoch': 5.94}
{'loss': 0.6112, 'learning_rate': 2.3954372623574145e-05, 'epoch': 5.97}
{'loss': 0.6085, 'learning_rate': 2.3764258555133082e-05, 'epoch': 6.01}
{'loss': 0.6125, 'learning_rate': 2.3574144486692015e-05, 'epoch': 6.04}
{'loss': 0.6186, 'learning_rate': 2.338403041825095e-05, 'epoch': 6.07}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.612210750579834, 'eval_runtime': 194.0257, 'eval_samples_per_second': 25.77, 'eval_steps_per_second': 0.407, 'epoch': 6.07}
{'loss': 0.6139, 'learning_rate': 2.3193916349809888e-05, 'epoch': 6.1}
{'loss': 0.6128, 'learning_rate': 2.300380228136882e-05, 'epoch': 6.13}
{'loss': 0.6203, 'learning_rate': 2.2813688212927758e-05, 'epoch': 6.17}
{'loss': 0.6057, 'learning_rate': 2.262357414448669e-05, 'epoch': 6.2}
{'loss': 0.6144, 'learning_rate': 2.2433460076045627e-05, 'epoch': 6.23}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6115713119506836, 'eval_runtime': 196.2003, 'eval_samples_per_second': 25.484, 'eval_steps_per_second': 0.403, 'epoch': 6.23}
{'loss': 0.6173, 'learning_rate': 2.2243346007604564e-05, 'epoch': 6.26}
{'loss': 0.6071, 'learning_rate': 2.20532319391635e-05, 'epoch': 6.29}
{'loss': 0.6148, 'learning_rate': 2.1863117870722437e-05, 'epoch': 6.33}
{'loss': 0.6022, 'learning_rate': 2.167300380228137e-05, 'epoch': 6.36}
{'loss': 0.6154, 'learning_rate': 2.1482889733840306e-05, 'epoch': 6.39}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6108203530311584, 'eval_runtime': 191.4336, 'eval_samples_per_second': 26.119, 'eval_steps_per_second': 0.413, 'epoch': 6.39}
{'loss': 0.6209, 'learning_rate': 2.129277566539924e-05, 'epoch': 6.42}
{'loss': 0.609, 'learning_rate': 2.1102661596958176e-05, 'epoch': 6.45}
{'loss': 0.6125, 'learning_rate': 2.0912547528517112e-05, 'epoch': 6.49}
{'loss': 0.6066, 'learning_rate': 2.0722433460076046e-05, 'epoch': 6.52}
{'loss': 0.6063, 'learning_rate': 2.0532319391634982e-05, 'epoch': 6.55}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6103072762489319, 'eval_runtime': 210.3317, 'eval_samples_per_second': 23.772, 'eval_steps_per_second': 0.376, 'epoch': 6.55}
{'loss': 0.6055, 'learning_rate': 2.0342205323193915e-05, 'epoch': 6.58}
{'loss': 0.6104, 'learning_rate': 2.0152091254752852e-05, 'epoch': 6.61}
{'loss': 0.6193, 'learning_rate': 1.9961977186311788e-05, 'epoch': 6.65}
{'loss': 0.6113, 'learning_rate': 1.9771863117870725e-05, 'epoch': 6.68}
{'loss': 0.6036, 'learning_rate': 1.958174904942966e-05, 'epoch': 6.71}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6098010540008545, 'eval_runtime': 205.1426, 'eval_samples_per_second': 24.373, 'eval_steps_per_second': 0.385, 'epoch': 6.71}
{'loss': 0.6088, 'learning_rate': 1.9391634980988594e-05, 'epoch': 6.74}
{'loss': 0.6156, 'learning_rate': 1.920152091254753e-05, 'epoch': 6.77}
{'loss': 0.6191, 'learning_rate': 1.9011406844106464e-05, 'epoch': 6.81}
{'loss': 0.6121, 'learning_rate': 1.88212927756654e-05, 'epoch': 6.84}
{'loss': 0.6088, 'learning_rate': 1.8631178707224337e-05, 'epoch': 6.87}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6092498302459717, 'eval_runtime': 208.1274, 'eval_samples_per_second': 24.024, 'eval_steps_per_second': 0.38, 'epoch': 6.87}
{'loss': 0.6068, 'learning_rate': 1.844106463878327e-05, 'epoch': 6.9}
{'loss': 0.6115, 'learning_rate': 1.8250950570342207e-05, 'epoch': 6.93}
{'loss': 0.6076, 'learning_rate': 1.806083650190114e-05, 'epoch': 6.96}
{'loss': 0.614, 'learning_rate': 1.7870722433460076e-05, 'epoch': 7.0}
{'loss': 0.6113, 'learning_rate': 1.7680608365019013e-05, 'epoch': 7.03}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6087948083877563, 'eval_runtime': 217.1364, 'eval_samples_per_second': 23.027, 'eval_steps_per_second': 0.364, 'epoch': 7.03}
{'loss': 0.608, 'learning_rate': 1.7490494296577946e-05, 'epoch': 7.06}
{'loss': 0.6001, 'learning_rate': 1.7300380228136882e-05, 'epoch': 7.09}
{'loss': 0.6077, 'learning_rate': 1.711026615969582e-05, 'epoch': 7.12}
{'loss': 0.6036, 'learning_rate': 1.6920152091254756e-05, 'epoch': 7.16}
{'loss': 0.6086, 'learning_rate': 1.673003802281369e-05, 'epoch': 7.19}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6083433628082275, 'eval_runtime': 207.6299, 'eval_samples_per_second': 24.081, 'eval_steps_per_second': 0.38, 'epoch': 7.19}
{'loss': 0.6105, 'learning_rate': 1.6539923954372625e-05, 'epoch': 7.22}
{'loss': 0.605, 'learning_rate': 1.634980988593156e-05, 'epoch': 7.25}
{'loss': 0.61, 'learning_rate': 1.6159695817490495e-05, 'epoch': 7.28}
{'loss': 0.6085, 'learning_rate': 1.596958174904943e-05, 'epoch': 7.32}
{'loss': 0.605, 'learning_rate': 1.5779467680608364e-05, 'epoch': 7.35}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6078518629074097, 'eval_runtime': 218.3705, 'eval_samples_per_second': 22.897, 'eval_steps_per_second': 0.362, 'epoch': 7.35}
{'loss': 0.6056, 'learning_rate': 1.55893536121673e-05, 'epoch': 7.38}
{'loss': 0.6078, 'learning_rate': 1.5399239543726237e-05, 'epoch': 7.41}
{'loss': 0.6085, 'learning_rate': 1.520912547528517e-05, 'epoch': 7.44}
{'loss': 0.6052, 'learning_rate': 1.5019011406844107e-05, 'epoch': 7.48}
{'loss': 0.611, 'learning_rate': 1.4828897338403042e-05, 'epoch': 7.51}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6075143218040466, 'eval_runtime': 191.8319, 'eval_samples_per_second': 26.064, 'eval_steps_per_second': 0.412, 'epoch': 7.51}
{'loss': 0.618, 'learning_rate': 1.4638783269961978e-05, 'epoch': 7.54}
{'loss': 0.6024, 'learning_rate': 1.4448669201520912e-05, 'epoch': 7.57}
{'loss': 0.6055, 'learning_rate': 1.4258555133079848e-05, 'epoch': 7.6}
{'loss': 0.6083, 'learning_rate': 1.4068441064638785e-05, 'epoch': 7.64}
{'loss': 0.5968, 'learning_rate': 1.387832699619772e-05, 'epoch': 7.67}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.60704505443573, 'eval_runtime': 191.3073, 'eval_samples_per_second': 26.136, 'eval_steps_per_second': 0.413, 'epoch': 7.67}
{'loss': 0.6094, 'learning_rate': 1.3688212927756656e-05, 'epoch': 7.7}
{'loss': 0.6067, 'learning_rate': 1.3498098859315589e-05, 'epoch': 7.73}
{'loss': 0.6088, 'learning_rate': 1.3307984790874526e-05, 'epoch': 7.76}
{'loss': 0.6039, 'learning_rate': 1.3117870722433462e-05, 'epoch': 7.8}
{'loss': 0.6153, 'learning_rate': 1.2927756653992395e-05, 'epoch': 7.83}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6067071557044983, 'eval_runtime': 191.2575, 'eval_samples_per_second': 26.143, 'eval_steps_per_second': 0.413, 'epoch': 7.83}
{'loss': 0.6105, 'learning_rate': 1.2737642585551332e-05, 'epoch': 7.86}
{'loss': 0.6061, 'learning_rate': 1.2547528517110266e-05, 'epoch': 7.89}
{'loss': 0.6059, 'learning_rate': 1.2357414448669203e-05, 'epoch': 7.92}
{'loss': 0.607, 'learning_rate': 1.2167300380228138e-05, 'epoch': 7.96}
{'loss': 0.6031, 'learning_rate': 1.1977186311787073e-05, 'epoch': 7.99}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6063896417617798, 'eval_runtime': 187.5426, 'eval_samples_per_second': 26.661, 'eval_steps_per_second': 0.421, 'epoch': 7.99}
{'loss': 0.6159, 'learning_rate': 1.1787072243346007e-05, 'epoch': 8.02}
{'loss': 0.6078, 'learning_rate': 1.1596958174904944e-05, 'epoch': 8.05}
{'loss': 0.6092, 'learning_rate': 1.1406844106463879e-05, 'epoch': 8.08}
{'loss': 0.6022, 'learning_rate': 1.1216730038022814e-05, 'epoch': 8.12}
{'loss': 0.6099, 'learning_rate': 1.102661596958175e-05, 'epoch': 8.15}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6061394810676575, 'eval_runtime': 190.3995, 'eval_samples_per_second': 26.261, 'eval_steps_per_second': 0.415, 'epoch': 8.15}
{'loss': 0.603, 'learning_rate': 1.0836501901140685e-05, 'epoch': 8.18}
{'loss': 0.6033, 'learning_rate': 1.064638783269962e-05, 'epoch': 8.21}
{'loss': 0.6078, 'learning_rate': 1.0456273764258556e-05, 'epoch': 8.24}
{'loss': 0.614, 'learning_rate': 1.0266159695817491e-05, 'epoch': 8.27}
{'loss': 0.6157, 'learning_rate': 1.0076045627376426e-05, 'epoch': 8.31}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.605705976486206, 'eval_runtime': 189.6944, 'eval_samples_per_second': 26.358, 'eval_steps_per_second': 0.416, 'epoch': 8.31}
{'loss': 0.6168, 'learning_rate': 9.885931558935362e-06, 'epoch': 8.34}
{'loss': 0.6005, 'learning_rate': 9.695817490494297e-06, 'epoch': 8.37}
{'loss': 0.601, 'learning_rate': 9.505703422053232e-06, 'epoch': 8.4}
{'loss': 0.6074, 'learning_rate': 9.315589353612169e-06, 'epoch': 8.43}
{'loss': 0.6052, 'learning_rate': 9.125475285171103e-06, 'epoch': 8.47}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6055253148078918, 'eval_runtime': 188.8155, 'eval_samples_per_second': 26.481, 'eval_steps_per_second': 0.418, 'epoch': 8.47}
{'loss': 0.6153, 'learning_rate': 8.935361216730038e-06, 'epoch': 8.5}
{'loss': 0.6068, 'learning_rate': 8.745247148288973e-06, 'epoch': 8.53}
{'loss': 0.605, 'learning_rate': 8.55513307984791e-06, 'epoch': 8.56}
{'loss': 0.6019, 'learning_rate': 8.365019011406844e-06, 'epoch': 8.59}
{'loss': 0.5993, 'learning_rate': 8.17490494296578e-06, 'epoch': 8.63}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6052126884460449, 'eval_runtime': 189.9948, 'eval_samples_per_second': 26.317, 'eval_steps_per_second': 0.416, 'epoch': 8.63}
{'loss': 0.6016, 'learning_rate': 7.984790874524716e-06, 'epoch': 8.66}
{'loss': 0.5989, 'learning_rate': 7.79467680608365e-06, 'epoch': 8.69}
{'loss': 0.6026, 'learning_rate': 7.604562737642585e-06, 'epoch': 8.72}
{'loss': 0.603, 'learning_rate': 7.414448669201521e-06, 'epoch': 8.75}
{'loss': 0.601, 'learning_rate': 7.224334600760456e-06, 'epoch': 8.79}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6049203872680664, 'eval_runtime': 191.754, 'eval_samples_per_second': 26.075, 'eval_steps_per_second': 0.412, 'epoch': 8.79}
{'loss': 0.6028, 'learning_rate': 7.034220532319392e-06, 'epoch': 8.82}
{'loss': 0.6052, 'learning_rate': 6.844106463878328e-06, 'epoch': 8.85}
{'loss': 0.6043, 'learning_rate': 6.653992395437263e-06, 'epoch': 8.88}
{'loss': 0.5996, 'learning_rate': 6.4638783269961976e-06, 'epoch': 8.91}
{'loss': 0.6081, 'learning_rate': 6.273764258555133e-06, 'epoch': 8.95}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6047185659408569, 'eval_runtime': 193.3977, 'eval_samples_per_second': 25.853, 'eval_steps_per_second': 0.408, 'epoch': 8.95}
{'loss': 0.6041, 'learning_rate': 6.083650190114069e-06, 'epoch': 8.98}
{'loss': 0.6034, 'learning_rate': 5.893536121673004e-06, 'epoch': 9.01}
{'loss': 0.604, 'learning_rate': 5.703422053231939e-06, 'epoch': 9.04}
{'loss': 0.6036, 'learning_rate': 5.513307984790875e-06, 'epoch': 9.07}
{'loss': 0.6067, 'learning_rate': 5.32319391634981e-06, 'epoch': 9.11}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6045900583267212, 'eval_runtime': 193.8316, 'eval_samples_per_second': 25.796, 'eval_steps_per_second': 0.408, 'epoch': 9.11}
{'loss': 0.6061, 'learning_rate': 5.1330798479087455e-06, 'epoch': 9.14}
{'loss': 0.6008, 'learning_rate': 4.942965779467681e-06, 'epoch': 9.17}
{'loss': 0.6051, 'learning_rate': 4.752851711026616e-06, 'epoch': 9.2}
{'loss': 0.6051, 'learning_rate': 4.562737642585552e-06, 'epoch': 9.23}
{'loss': 0.6046, 'learning_rate': 4.3726235741444865e-06, 'epoch': 9.27}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6044347286224365, 'eval_runtime': 192.085, 'eval_samples_per_second': 26.03, 'eval_steps_per_second': 0.411, 'epoch': 9.27}
{'loss': 0.603, 'learning_rate': 4.182509505703422e-06, 'epoch': 9.3}
{'loss': 0.6066, 'learning_rate': 3.992395437262358e-06, 'epoch': 9.33}
{'loss': 0.6069, 'learning_rate': 3.8022813688212926e-06, 'epoch': 9.36}
{'loss': 0.608, 'learning_rate': 3.612167300380228e-06, 'epoch': 9.39}
{'loss': 0.5935, 'learning_rate': 3.422053231939164e-06, 'epoch': 9.42}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6042946577072144, 'eval_runtime': 191.4129, 'eval_samples_per_second': 26.122, 'eval_steps_per_second': 0.413, 'epoch': 9.42}
{'loss': 0.6026, 'learning_rate': 3.2319391634980988e-06, 'epoch': 9.46}
{'loss': 0.6082, 'learning_rate': 3.0418250950570345e-06, 'epoch': 9.49}
{'loss': 0.602, 'learning_rate': 2.8517110266159697e-06, 'epoch': 9.52}
{'loss': 0.6096, 'learning_rate': 2.661596958174905e-06, 'epoch': 9.55}
{'loss': 0.6037, 'learning_rate': 2.4714828897338406e-06, 'epoch': 9.58}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6042035818099976, 'eval_runtime': 191.6875, 'eval_samples_per_second': 26.084, 'eval_steps_per_second': 0.412, 'epoch': 9.58}
{'loss': 0.6043, 'learning_rate': 2.281368821292776e-06, 'epoch': 9.62}
{'loss': 0.6014, 'learning_rate': 2.091254752851711e-06, 'epoch': 9.65}
{'loss': 0.598, 'learning_rate': 1.9011406844106463e-06, 'epoch': 9.68}
{'loss': 0.6029, 'learning_rate': 1.711026615969582e-06, 'epoch': 9.71}
{'loss': 0.6037, 'learning_rate': 1.5209125475285172e-06, 'epoch': 9.74}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6041309833526611, 'eval_runtime': 193.5476, 'eval_samples_per_second': 25.833, 'eval_steps_per_second': 0.408, 'epoch': 9.74}
{'loss': 0.6026, 'learning_rate': 1.3307984790874525e-06, 'epoch': 9.78}
{'loss': 0.5979, 'learning_rate': 1.140684410646388e-06, 'epoch': 9.81}
{'loss': 0.6043, 'learning_rate': 9.505703422053232e-07, 'epoch': 9.84}
{'loss': 0.6009, 'learning_rate': 7.604562737642586e-07, 'epoch': 9.87}
{'loss': 0.5985, 'learning_rate': 5.70342205323194e-07, 'epoch': 9.9}


  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.6041001081466675, 'eval_runtime': 191.7254, 'eval_samples_per_second': 26.079, 'eval_steps_per_second': 0.412, 'epoch': 9.9}
{'loss': 0.6022, 'learning_rate': 3.802281368821293e-07, 'epoch': 9.94}
{'loss': 0.6121, 'learning_rate': 1.9011406844106465e-07, 'epoch': 9.97}
{'loss': 0.6032, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 16153.6749, 'train_samples_per_second': 3.095, 'train_steps_per_second': 0.194, 'train_loss': 0.6225034703081027, 'epoch': 10.0}


TrainOutput(global_step=3130, training_loss=0.6225034703081027, metrics={'train_runtime': 16153.6749, 'train_samples_per_second': 3.095, 'train_steps_per_second': 0.194, 'train_loss': 0.6225034703081027, 'epoch': 10.0})

In [306]:
# predict after training

model.to(device)
model.eval()
predictions = []

with torch.no_grad():
    for batch in validation_dataset:
        batch_input_ids = batch['input_ids'].unsqueeze(0).to(device)
        batch_attention_mask = batch['attention_mask'].unsqueeze(0).to(device)
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        logits = outputs.logits.cpu()

        predicted_classes = torch.argmax(logits, dim=1)
        predictions.extend(predicted_classes)

predictions = torch.stack(predictions)


print(calc_f1_score(predictions, val_labels))

(0.7491518376579546, 0.7461999999999999, 0.7432060968848617)
