In [1]:
import pandas as pd

finetuned_dirname = "40-epoch-roberta-finetuned-phemernr2-tf"
transformer_name = "roberta-base"

data = pd.read_csv("../../data/processed/phemernr2-tf_dataset.csv", sep=",")
data = data[['tweet_text', 'tvt2', 'label']]
data['tweet_text'] = data['tweet_text'].str.lower()
print(data.shape)
data.head()

(1705, 3)


Unnamed: 0,tweet_text,tvt2,label
0,breaking - a germanwings airbus a320 plane rep...,training,True
1,reports that two of the dead in the #charliehe...,training,True
2,'no survivors' in #germanwings crash says fren...,training,False
3,tragedy mounts as soldier shot this am dies of...,training,True
4,watch the moment gunfire and explosions were h...,training,True


In [2]:
combined_data = data

In [3]:
import torch

class CustomTextDataset(torch.utils.data.dataset.Dataset):

    def __init__(self, texts, labels):
        self.labels = labels
        self.texts = texts
        self.attention_mask = None
        self.input_ids = None
        self.token_type_ids = None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sample = {
            "text": self.texts[idx],
            "label": self.labels[idx],
            "attention_mask": self.attention_mask[idx] if self.attention_mask else None,
            "input_ids": self.input_ids[idx] if self.input_ids else None,
#             "token_type_ids": self.token_type_ids[idx] if self.token_type_ids else None
        }
        return sample
    
    def tokenize(self, tokenizer):
        self.attention_mask = []
        self.input_ids = []
        self.token_type_ids = []

        for text in self.texts:
            token = tokenizer(text, padding="max_length", truncation=True)
            
            self.attention_mask.append(token['attention_mask'])
            self.input_ids.append(token['input_ids'])
#             self.token_type_ids.append(token['token_type_ids'])

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
labels = []

labels_str = combined_data['label'].unique().tolist()
for i, d in combined_data.iterrows():
    lab = labels_str.index(d['label'])
    labels.append(lab)
    
print(len(labels))
labels[:10]

1705


[0, 0, 1, 0, 0, 0, 0, 1, 0, 0]

In [5]:
train_dataset = CustomTextDataset(
    [d['tweet_text'] for i, d in combined_data.iterrows() if d['tvt2'] == 'training'],
    [labels[i] for i, d in combined_data.iterrows() if d['tvt2'] == 'training'])
test_dataset = CustomTextDataset(
    [d['tweet_text'] for i, d in combined_data.iterrows() if d['tvt2'] == 'validation'],
    [labels[i] for i, d in combined_data.iterrows() if d['tvt2'] == 'validation'])
train_dataset[0]

{'text': 'breaking - a germanwings airbus a320 plane reportedly crashed in the region of digne (french alps) #flightradar24 - french tv #itele',
 'label': 0,
 'attention_mask': None,
 'input_ids': None}

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(transformer_name)

In [7]:
# inputs = tokenizer(["you're stuck in a timewrap from 2004 though", "summa lumma dumma lumma"], padding="max_length", truncation=True)
# for k,v in inputs.items():
#     print(k)

In [8]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

train_dataset.tokenize(tokenizer)
test_dataset.tokenize(tokenizer)

In [9]:
print(len(train_dataset))
print(len(test_dataset))

1176
371


### Fine Tuning

In [10]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(transformer_name,
                                                           output_hidden_states=False,
                                                           num_labels=2)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
from transformers import TrainingArguments

epochs = 40
batch_size = 8
save_steps = (round((len(train_dataset)/batch_size) + 0.49)) * epochs
# save_steps = 1_000_000

training_args = TrainingArguments(
    output_dir=f"../../data/models/{finetuned_dirname}",
    num_train_epochs=epochs,
    save_steps=save_steps,
    logging_steps=300,
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy="epoch",
    logging_strategy="epoch"
)

print(f"Save Steps : {save_steps}")

Save Steps : 5880


In [12]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

  metric = load_metric("accuracy")


In [13]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

In [14]:
import time

start = time.time()

trainer.train()

print(f"Execution Time : {round(time.time() - start)} seconds")

  2%|▎         | 147/5880 [00:46<30:23,  3.14it/s]

{'loss': 0.5302, 'learning_rate': 9.75e-06, 'epoch': 1.0}


                                                  
  2%|▎         | 147/5880 [00:51<30:23,  3.14it/s]

{'eval_loss': 0.4067111313343048, 'eval_accuracy': 0.8355795148247979, 'eval_runtime': 4.5236, 'eval_samples_per_second': 82.013, 'eval_steps_per_second': 10.39, 'epoch': 1.0}


  5%|▌         | 294/5880 [01:38<29:54,  3.11it/s]  

{'loss': 0.3187, 'learning_rate': 9.5e-06, 'epoch': 2.0}


                                                  
  5%|▌         | 294/5880 [01:43<29:54,  3.11it/s]

{'eval_loss': 0.4847545325756073, 'eval_accuracy': 0.8194070080862533, 'eval_runtime': 4.5395, 'eval_samples_per_second': 81.727, 'eval_steps_per_second': 10.354, 'epoch': 2.0}


  8%|▊         | 441/5880 [02:30<29:21,  3.09it/s]  

{'loss': 0.2562, 'learning_rate': 9.250000000000001e-06, 'epoch': 3.0}


                                                  
  8%|▊         | 441/5880 [02:35<29:21,  3.09it/s]

{'eval_loss': 0.6185199022293091, 'eval_accuracy': 0.8301886792452831, 'eval_runtime': 4.5509, 'eval_samples_per_second': 81.522, 'eval_steps_per_second': 10.328, 'epoch': 3.0}


 10%|█         | 588/5880 [03:22<28:24,  3.10it/s]  

{'loss': 0.1996, 'learning_rate': 9e-06, 'epoch': 4.0}


                                                  
 10%|█         | 588/5880 [03:27<28:24,  3.10it/s]

{'eval_loss': 0.6875048279762268, 'eval_accuracy': 0.8733153638814016, 'eval_runtime': 4.5409, 'eval_samples_per_second': 81.702, 'eval_steps_per_second': 10.35, 'epoch': 4.0}


 12%|█▎        | 735/5880 [04:14<27:39,  3.10it/s]  

{'loss': 0.124, 'learning_rate': 8.750000000000001e-06, 'epoch': 5.0}


                                                  
 12%|█▎        | 735/5880 [04:19<27:39,  3.10it/s]

{'eval_loss': 0.7360397577285767, 'eval_accuracy': 0.8598382749326146, 'eval_runtime': 4.5433, 'eval_samples_per_second': 81.659, 'eval_steps_per_second': 10.345, 'epoch': 5.0}


 15%|█▌        | 882/5880 [05:06<26:39,  3.12it/s]  

{'loss': 0.074, 'learning_rate': 8.5e-06, 'epoch': 6.0}


                                                  
 15%|█▌        | 882/5880 [05:10<26:39,  3.12it/s]

{'eval_loss': 0.7825492024421692, 'eval_accuracy': 0.8787061994609164, 'eval_runtime': 4.5107, 'eval_samples_per_second': 82.249, 'eval_steps_per_second': 10.42, 'epoch': 6.0}


 18%|█▊        | 1029/5880 [05:57<25:50,  3.13it/s] 

{'loss': 0.0447, 'learning_rate': 8.25e-06, 'epoch': 7.0}


                                                   
 18%|█▊        | 1029/5880 [06:02<25:50,  3.13it/s]

{'eval_loss': 0.7288329005241394, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.5071, 'eval_samples_per_second': 82.315, 'eval_steps_per_second': 10.428, 'epoch': 7.0}


 20%|██        | 1176/5880 [06:49<24:59,  3.14it/s]  

{'loss': 0.0287, 'learning_rate': 8.000000000000001e-06, 'epoch': 8.0}


                                                   
 20%|██        | 1176/5880 [06:53<24:59,  3.14it/s]

{'eval_loss': 1.1793564558029175, 'eval_accuracy': 0.8598382749326146, 'eval_runtime': 4.5107, 'eval_samples_per_second': 82.25, 'eval_steps_per_second': 10.42, 'epoch': 8.0}


 22%|██▎       | 1323/5880 [07:40<24:32,  3.09it/s]  

{'loss': 0.0485, 'learning_rate': 7.75e-06, 'epoch': 9.0}


                                                   
 22%|██▎       | 1323/5880 [07:45<24:32,  3.09it/s]

{'eval_loss': 0.9237436652183533, 'eval_accuracy': 0.876010781671159, 'eval_runtime': 4.5172, 'eval_samples_per_second': 82.131, 'eval_steps_per_second': 10.405, 'epoch': 9.0}


 25%|██▌       | 1470/5880 [08:32<23:26,  3.14it/s]  

{'loss': 0.0257, 'learning_rate': 7.500000000000001e-06, 'epoch': 10.0}


                                                   
 25%|██▌       | 1470/5880 [08:36<23:26,  3.14it/s]

{'eval_loss': 0.7304961681365967, 'eval_accuracy': 0.894878706199461, 'eval_runtime': 4.5117, 'eval_samples_per_second': 82.23, 'eval_steps_per_second': 10.417, 'epoch': 10.0}


 28%|██▊       | 1617/5880 [09:23<22:46,  3.12it/s]  

{'loss': 0.0343, 'learning_rate': 7.25e-06, 'epoch': 11.0}


                                                   
 28%|██▊       | 1617/5880 [09:28<22:46,  3.12it/s]

{'eval_loss': 0.8657605051994324, 'eval_accuracy': 0.8867924528301887, 'eval_runtime': 4.5069, 'eval_samples_per_second': 82.319, 'eval_steps_per_second': 10.429, 'epoch': 11.0}


 30%|███       | 1764/5880 [10:15<22:03,  3.11it/s]  

{'loss': 0.0252, 'learning_rate': 7e-06, 'epoch': 12.0}


                                                   
 30%|███       | 1764/5880 [10:19<22:03,  3.11it/s]

{'eval_loss': 0.9976335167884827, 'eval_accuracy': 0.8787061994609164, 'eval_runtime': 4.5012, 'eval_samples_per_second': 82.422, 'eval_steps_per_second': 10.442, 'epoch': 12.0}


 32%|███▎      | 1911/5880 [11:06<20:52,  3.17it/s]  

{'loss': 0.0313, 'learning_rate': 6.750000000000001e-06, 'epoch': 13.0}


                                                   
 32%|███▎      | 1911/5880 [11:11<20:52,  3.17it/s]

{'eval_loss': 0.9641602635383606, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.5006, 'eval_samples_per_second': 82.434, 'eval_steps_per_second': 10.443, 'epoch': 13.0}


 35%|███▌      | 2058/5880 [11:58<20:16,  3.14it/s]  

{'loss': 0.0255, 'learning_rate': 6.5000000000000004e-06, 'epoch': 14.0}


                                                   
 35%|███▌      | 2058/5880 [12:02<20:16,  3.14it/s]

{'eval_loss': 0.971036434173584, 'eval_accuracy': 0.8787061994609164, 'eval_runtime': 4.502, 'eval_samples_per_second': 82.408, 'eval_steps_per_second': 10.44, 'epoch': 14.0}


 38%|███▊      | 2205/5880 [12:49<19:33,  3.13it/s]  

{'loss': 0.0451, 'learning_rate': 6.25e-06, 'epoch': 15.0}


                                                   
 38%|███▊      | 2205/5880 [12:54<19:33,  3.13it/s]

{'eval_loss': 0.8173243999481201, 'eval_accuracy': 0.9056603773584906, 'eval_runtime': 4.5, 'eval_samples_per_second': 82.445, 'eval_steps_per_second': 10.444, 'epoch': 15.0}


 40%|████      | 2352/5880 [13:41<18:56,  3.10it/s]  

{'loss': 0.0324, 'learning_rate': 6e-06, 'epoch': 16.0}


                                                   
 40%|████      | 2352/5880 [13:45<18:56,  3.10it/s]

{'eval_loss': 1.0049408674240112, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.5015, 'eval_samples_per_second': 82.417, 'eval_steps_per_second': 10.441, 'epoch': 16.0}


 42%|████▎     | 2499/5880 [14:32<17:49,  3.16it/s]  

{'loss': 0.0074, 'learning_rate': 5.75e-06, 'epoch': 17.0}


                                                   
 42%|████▎     | 2499/5880 [14:37<17:49,  3.16it/s]

{'eval_loss': 0.8809204697608948, 'eval_accuracy': 0.894878706199461, 'eval_runtime': 4.4991, 'eval_samples_per_second': 82.46, 'eval_steps_per_second': 10.446, 'epoch': 17.0}


 45%|████▌     | 2646/5880 [15:23<17:06,  3.15it/s]  

{'loss': 0.0076, 'learning_rate': 5.500000000000001e-06, 'epoch': 18.0}


                                                   
 45%|████▌     | 2646/5880 [15:28<17:06,  3.15it/s]

{'eval_loss': 1.0433967113494873, 'eval_accuracy': 0.8733153638814016, 'eval_runtime': 4.4992, 'eval_samples_per_second': 82.459, 'eval_steps_per_second': 10.446, 'epoch': 18.0}


 48%|████▊     | 2793/5880 [16:15<16:19,  3.15it/s]  

{'loss': 0.0098, 'learning_rate': 5.2500000000000006e-06, 'epoch': 19.0}


                                                   
 48%|████▊     | 2793/5880 [16:19<16:19,  3.15it/s]

{'eval_loss': 1.0361790657043457, 'eval_accuracy': 0.8814016172506739, 'eval_runtime': 4.4947, 'eval_samples_per_second': 82.541, 'eval_steps_per_second': 10.457, 'epoch': 19.0}


 50%|█████     | 2940/5880 [17:06<15:33,  3.15it/s]  

{'loss': 0.0082, 'learning_rate': 5e-06, 'epoch': 20.0}


                                                   
 50%|█████     | 2940/5880 [17:11<15:33,  3.15it/s]

{'eval_loss': 0.9530299305915833, 'eval_accuracy': 0.8921832884097035, 'eval_runtime': 4.5021, 'eval_samples_per_second': 82.405, 'eval_steps_per_second': 10.439, 'epoch': 20.0}


 52%|█████▎    | 3087/5880 [17:58<14:47,  3.15it/s]  

{'loss': 0.0002, 'learning_rate': 4.75e-06, 'epoch': 21.0}


                                                   
 52%|█████▎    | 3087/5880 [18:02<14:47,  3.15it/s]

{'eval_loss': 0.9312748908996582, 'eval_accuracy': 0.8921832884097035, 'eval_runtime': 4.5009, 'eval_samples_per_second': 82.429, 'eval_steps_per_second': 10.442, 'epoch': 21.0}


 55%|█████▌    | 3234/5880 [18:49<14:01,  3.14it/s]  

{'loss': 0.017, 'learning_rate': 4.5e-06, 'epoch': 22.0}


                                                   
 55%|█████▌    | 3234/5880 [18:53<14:01,  3.14it/s]

{'eval_loss': 1.0192753076553345, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.5111, 'eval_samples_per_second': 82.242, 'eval_steps_per_second': 10.419, 'epoch': 22.0}


 57%|█████▊    | 3381/5880 [19:40<13:20,  3.12it/s]  

{'loss': 0.0102, 'learning_rate': 4.25e-06, 'epoch': 23.0}


                                                   
 57%|█████▊    | 3381/5880 [19:45<13:20,  3.12it/s]

{'eval_loss': 0.9908373951911926, 'eval_accuracy': 0.889487870619946, 'eval_runtime': 4.5012, 'eval_samples_per_second': 82.422, 'eval_steps_per_second': 10.442, 'epoch': 23.0}


 60%|██████    | 3528/5880 [20:31<12:35,  3.11it/s]  

{'loss': 0.0073, 'learning_rate': 4.000000000000001e-06, 'epoch': 24.0}


                                                   
 60%|██████    | 3528/5880 [20:36<12:35,  3.11it/s]

{'eval_loss': 1.1276861429214478, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.4998, 'eval_samples_per_second': 82.448, 'eval_steps_per_second': 10.445, 'epoch': 24.0}


 62%|██████▎   | 3675/5880 [21:23<11:44,  3.13it/s]  

{'loss': 0.0173, 'learning_rate': 3.7500000000000005e-06, 'epoch': 25.0}


                                                   
 62%|██████▎   | 3675/5880 [21:27<11:44,  3.13it/s]

{'eval_loss': 1.1439719200134277, 'eval_accuracy': 0.8679245283018868, 'eval_runtime': 4.5039, 'eval_samples_per_second': 82.374, 'eval_steps_per_second': 10.435, 'epoch': 25.0}


 65%|██████▌   | 3822/5880 [22:14<10:55,  3.14it/s]  

{'loss': 0.0069, 'learning_rate': 3.5e-06, 'epoch': 26.0}


                                                   
 65%|██████▌   | 3822/5880 [22:19<10:55,  3.14it/s]

{'eval_loss': 1.0115715265274048, 'eval_accuracy': 0.889487870619946, 'eval_runtime': 4.5096, 'eval_samples_per_second': 82.269, 'eval_steps_per_second': 10.422, 'epoch': 26.0}


 68%|██████▊   | 3969/5880 [23:05<10:04,  3.16it/s]

{'loss': 0.0002, 'learning_rate': 3.2500000000000002e-06, 'epoch': 27.0}


                                                   
 68%|██████▊   | 3969/5880 [23:10<10:04,  3.16it/s]

{'eval_loss': 0.9743837714195251, 'eval_accuracy': 0.9002695417789758, 'eval_runtime': 4.5006, 'eval_samples_per_second': 82.433, 'eval_steps_per_second': 10.443, 'epoch': 27.0}


 70%|███████   | 4116/5880 [23:57<09:20,  3.14it/s]

{'loss': 0.001, 'learning_rate': 3e-06, 'epoch': 28.0}


                                                   
 70%|███████   | 4116/5880 [24:01<09:20,  3.14it/s]

{'eval_loss': 1.0454785823822021, 'eval_accuracy': 0.8921832884097035, 'eval_runtime': 4.4971, 'eval_samples_per_second': 82.498, 'eval_steps_per_second': 10.451, 'epoch': 28.0}


 72%|███████▎  | 4263/5880 [24:48<08:37,  3.13it/s]

{'loss': 0.0, 'learning_rate': 2.7500000000000004e-06, 'epoch': 29.0}


                                                   
 72%|███████▎  | 4263/5880 [24:53<08:37,  3.13it/s]

{'eval_loss': 0.9830540418624878, 'eval_accuracy': 0.9002695417789758, 'eval_runtime': 4.4987, 'eval_samples_per_second': 82.469, 'eval_steps_per_second': 10.447, 'epoch': 29.0}


 75%|███████▌  | 4410/5880 [25:40<07:47,  3.15it/s]

{'loss': 0.0056, 'learning_rate': 2.5e-06, 'epoch': 30.0}


                                                   
 75%|███████▌  | 4410/5880 [25:44<07:47,  3.15it/s]

{'eval_loss': 1.0039817094802856, 'eval_accuracy': 0.8975741239892183, 'eval_runtime': 4.507, 'eval_samples_per_second': 82.317, 'eval_steps_per_second': 10.428, 'epoch': 30.0}


 78%|███████▊  | 4557/5880 [26:31<07:01,  3.14it/s]

{'loss': 0.0002, 'learning_rate': 2.25e-06, 'epoch': 31.0}


                                                   
 78%|███████▊  | 4557/5880 [26:35<07:01,  3.14it/s]

{'eval_loss': 1.1085126399993896, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.5049, 'eval_samples_per_second': 82.355, 'eval_steps_per_second': 10.433, 'epoch': 31.0}


 80%|████████  | 4704/5880 [27:22<06:17,  3.12it/s]

{'loss': 0.0098, 'learning_rate': 2.0000000000000003e-06, 'epoch': 32.0}


                                                   
 80%|████████  | 4704/5880 [27:27<06:17,  3.12it/s]

{'eval_loss': 1.192080020904541, 'eval_accuracy': 0.8706199460916442, 'eval_runtime': 4.5011, 'eval_samples_per_second': 82.424, 'eval_steps_per_second': 10.442, 'epoch': 32.0}


 82%|████████▎ | 4851/5880 [28:14<05:31,  3.10it/s]

{'loss': 0.0, 'learning_rate': 1.75e-06, 'epoch': 33.0}


                                                   
 82%|████████▎ | 4851/5880 [28:18<05:31,  3.10it/s]

{'eval_loss': 1.1571475267410278, 'eval_accuracy': 0.8814016172506739, 'eval_runtime': 4.4913, 'eval_samples_per_second': 82.604, 'eval_steps_per_second': 10.465, 'epoch': 33.0}


 85%|████████▌ | 4998/5880 [29:05<04:38,  3.16it/s]

{'loss': 0.0, 'learning_rate': 1.5e-06, 'epoch': 34.0}


                                                   
 85%|████████▌ | 4998/5880 [29:10<04:38,  3.16it/s]

{'eval_loss': 1.1485248804092407, 'eval_accuracy': 0.8840970350404312, 'eval_runtime': 4.503, 'eval_samples_per_second': 82.389, 'eval_steps_per_second': 10.437, 'epoch': 34.0}


 88%|████████▊ | 5145/5880 [29:57<03:54,  3.14it/s]

{'loss': 0.0, 'learning_rate': 1.25e-06, 'epoch': 35.0}


                                                   
 88%|████████▊ | 5145/5880 [30:01<03:54,  3.14it/s]

{'eval_loss': 1.0984734296798706, 'eval_accuracy': 0.8921832884097035, 'eval_runtime': 4.5052, 'eval_samples_per_second': 82.349, 'eval_steps_per_second': 10.432, 'epoch': 35.0}


 90%|█████████ | 5292/5880 [30:48<03:06,  3.16it/s]

{'loss': 0.0, 'learning_rate': 1.0000000000000002e-06, 'epoch': 36.0}


                                                   
 90%|█████████ | 5292/5880 [30:52<03:06,  3.16it/s]

{'eval_loss': 1.0531655550003052, 'eval_accuracy': 0.889487870619946, 'eval_runtime': 4.4998, 'eval_samples_per_second': 82.448, 'eval_steps_per_second': 10.445, 'epoch': 36.0}


 92%|█████████▎| 5439/5880 [31:39<02:21,  3.11it/s]

{'loss': 0.0096, 'learning_rate': 7.5e-07, 'epoch': 37.0}


                                                   
 92%|█████████▎| 5439/5880 [31:44<02:21,  3.11it/s]

{'eval_loss': 1.042358160018921, 'eval_accuracy': 0.894878706199461, 'eval_runtime': 4.5079, 'eval_samples_per_second': 82.299, 'eval_steps_per_second': 10.426, 'epoch': 37.0}


 95%|█████████▌| 5586/5880 [32:30<01:32,  3.16it/s]

{'loss': 0.0, 'learning_rate': 5.000000000000001e-07, 'epoch': 38.0}


                                                   
 95%|█████████▌| 5586/5880 [32:35<01:32,  3.16it/s]

{'eval_loss': 1.0429675579071045, 'eval_accuracy': 0.894878706199461, 'eval_runtime': 4.5041, 'eval_samples_per_second': 82.37, 'eval_steps_per_second': 10.435, 'epoch': 38.0}


 98%|█████████▊| 5733/5880 [33:22<00:47,  3.12it/s]

{'loss': 0.0, 'learning_rate': 2.5000000000000004e-07, 'epoch': 39.0}


                                                   
 98%|█████████▊| 5733/5880 [33:26<00:47,  3.12it/s]

{'eval_loss': 1.0437697172164917, 'eval_accuracy': 0.894878706199461, 'eval_runtime': 4.496, 'eval_samples_per_second': 82.518, 'eval_steps_per_second': 10.454, 'epoch': 39.0}


100%|██████████| 5880/5880 [34:15<00:00,  3.16it/s]

{'loss': 0.0, 'learning_rate': 0.0, 'epoch': 40.0}


                                                   
100%|██████████| 5880/5880 [34:19<00:00,  2.85it/s]

{'eval_loss': 1.0364022254943848, 'eval_accuracy': 0.894878706199461, 'eval_runtime': 4.5053, 'eval_samples_per_second': 82.347, 'eval_steps_per_second': 10.432, 'epoch': 40.0}
{'train_runtime': 2059.758, 'train_samples_per_second': 22.838, 'train_steps_per_second': 2.855, 'train_loss': 0.04906069917396839, 'epoch': 40.0}
Execution Time : 2060 seconds





In [15]:
trainer.evaluate()

100%|██████████| 47/47 [00:04<00:00, 10.79it/s]


{'eval_loss': 1.0364022254943848,
 'eval_accuracy': 0.894878706199461,
 'eval_runtime': 4.4547,
 'eval_samples_per_second': 83.282,
 'eval_steps_per_second': 10.551,
 'epoch': 40.0}

## 