<a href="https://colab.research.google.com/github/TurkuNLP/intro-to-nlp/blob/master/course_project_2023_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to HLT Project (Template)

- Student(s) Name(s): Omar Nasri, Konsta Nyman  
- Date: 22.04.2023
- Chosen Corpus: IMDB
- Contributions (if group project): 

Omar: Pre-prosessing, modeling, evaluation

Konsta: Hyperparameter tuning, results and summary

### Corpus information

- Description of the chosen corpus: Large Movie Review Dataset. This is a dataset for binary sentiment classification containing substantially more data than previous benchmark datasets. We provide a set of 25,000 highly polar movie reviews for training, and 25,000 for testing. There is additional unlabeled data for use as well.
- Paper(s) and other published materials related to the corpus: 

State-of-the-art leaderboard: https://paperswithcode.com/sota/sentiment-analysis-on-imdb

Related sentiment analysis paper: Sentiment Analysis for Movies Reviews Dataset Using Deep Learning Models Nehal Mohamed Ali, Marwa Mostafa Abd El Hamid and Aliaa Youssif

- State-of-the-art performance (best published results) on this corpus:

1
RoBERTa-large with LlamBERT
96.68
LlamBERT: Large-scale low-cost data annotation in NLP
2024

2
RoBERTa-large
96.54
LlamBERT: Large-scale low-cost data annotation in NLP
2024

3
XLNet
96.21
XLNet: Generalized Autoregressive Pretraining for Language Understanding
2019
Transformer

---

## 1. Setup

In [81]:
# Your code to install and import libraries etc. here
import datasets 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
import torch 
import transformers
import evaluate

---

## 2. Data download and preprocessing

### 2.1. Download the corpus

In [82]:
# Your code to download the corpus here
dset = datasets.load_dataset('imdb')
#load the dataset from cache
display(dset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

### 2.2. Preprocessing

In [83]:
# Your code for any necessary preprocessing here
#Shuffle the dataset
dset = dset.shuffle(seed=42)
#Remove the unsupervised data part of the dataset as we dont need it for this task
del dset['unsupervised']

In [84]:
vectorizer = CountVectorizer(binary=True, max_features=25000)
text_list = [i['text'] for i in dset['train']]
vectorizer.fit(text_list)

def vectorize_example(examples, vectorizer): 
    vectorized = vectorizer.transform([examples["text"]])
    non_zero = vectorized.nonzero()[1]
    non_zero += 1
    return {'input_ids': non_zero}

# Conversion vocabulary 
idx2word = {v: k for k, v in vectorizer.vocabulary_.items()}

tokenized_data = dset.map(vectorize_example, num_proc=4, fn_kwargs={'vectorizer': vectorizer})

In [85]:
test_row = tokenized_data['train'][0]['input_ids']
convered_text = [idx2word[i] for i in test_row]
print(convered_text)

['above', 'action', 'actress', 'allah', 'americana', 'anders', 'area', 'argument', 'atari', 'beulah', 'bother', 'butch', 'bye', 'characterisation', 'clairvoyant', 'classical', 'compared', 'complicating', 'criminal', 'englishman', 'enjoyment', 'evaluated', 'factions', 'faraway', 'fur', 'goodbye', 'handbook', 'haven', 'howard', 'ifc', 'isaac', 'italian', 'judged', 'justice', 'languages', 'likeable', 'looming', 'maine', 'mayberry', 'moreau', 'noah', 'notable', 'onassis', 'oral', 'others', 'peoples', 'plotline', 'plotted', 'policeman', 'preferable', 'primed', 'quits', 'realm', 'relations', 'serio', 'similarity', 'simpler', 'spirited', 'spotlight', 'superficiality', 'suspected', 'thank', 'thatch', 'theater', 'thereafter', 'thick', 'things', 'thinker', 'tho', 'toad', 'took', 'violently', 'wayans', 'weak', 'weaken', 'weirdos', 'writings']


In [86]:
def collator(examples):
    batch = {"labels":torch.tensor(list(example["label"] for example in examples))}
    tensors = []
    max_len = max(len(example["input_ids"]) for example in examples)
    for example in examples:
        ids = torch.tensor(example["input_ids"])
        padded = torch.nn.functional.pad(ids, (0, max_len - ids.shape[0]))
        tensors.append(padded)
    batch["input_ids"] = torch.vstack(tensors)
    return batch

---

## 3. Machine learning model

### 3.1. Model training

In [87]:
# Your code to train the machine learning model on the training set and evaluate the performance on the validation set here

class MLPConfig(transformers.PretrainedConfig):
    pass
class MLP(transformers.PreTrainedModel):
    config_class=MLPConfig
    def __init__(self,config):
        super().__init__(config)
        self.vocab_size=config.vocab_size #embedding matrix row count
        self.embedding=torch.nn.Embedding(num_embeddings=self.vocab_size+1,embedding_dim=config.hidden_size,padding_idx=0)
        torch.nn.init.uniform_(self.embedding.weight.data,-0.001,0.001) 
        self.output=torch.nn.Linear(in_features=config.hidden_size,out_features=config.nlabels)

    def forward(self,input_ids,labels=None):
        embedded=self.embedding(input_ids)
        embedded_summed=torch.sum(embedded,dim=1)
        projected=torch.tanh(embedded_summed) 
        logits=self.output(projected)
        if labels is not None:
            loss=torch.nn.CrossEntropyLoss()
            return (loss(logits,labels),logits)
        else:
            return (logits,)

### 3.2 Hyperparameter optimization

In [88]:
# Your code for hyperparameter optimization here
learning_rates = [1e-5, 1e-4, 1e-3]
batch_sizes = [32, 64, 128]

def compute_accuracy(outputs_and_labels):
    outputs, labels = outputs_and_labels
    predictions = np.argmax(outputs, axis=-1) #pick the index of the "winning" label
    return accuracy.compute(predictions=predictions, references=labels)

accuracy = evaluate.load("accuracy")

mlp_config=MLPConfig(vocab_size=len(vectorizer.vocabulary_),hidden_size=20,nlabels=2)

best_learning_rate = None
best_batch_size = None
best_accuracy = 0
for lr in learning_rates:
    for batch_size in batch_sizes:
        print(f"Training with lr={lr} and batch_size={batch_size}")
        mlp=MLP(mlp_config)
        trainer_args = transformers.TrainingArguments(
            "mlp_checkpoints", #save checkpoints here
            evaluation_strategy="steps",
            logging_strategy="steps",
            eval_steps=500,
            logging_steps=500,
            learning_rate=lr, #learning rate of the gradient descent
            max_steps=20000,
            load_best_model_at_end=True,
            per_device_train_batch_size=batch_size,
        )
        early_stopping = transformers.EarlyStoppingCallback(5)
        trainer = transformers.Trainer(
            model=mlp,
            args=trainer_args,
            train_dataset=tokenized_data["train"].select(range(10000)),
            eval_dataset=tokenized_data["test"].select(range(1000)), #make a smaller subset to evaluate on
            compute_metrics=compute_accuracy,
            data_collator=collator,
            callbacks=[early_stopping]
        )
        trainer.train()
        eval_result = trainer.evaluate(tokenized_data["test"])
        if eval_result["eval_accuracy"] > best_accuracy:
            best_accuracy = eval_result["eval_accuracy"]
            best_learning_rate = lr
            best_batch_size = batch_size
            

Training with lr=1e-05 and batch_size=32


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
 51%|█████     | 10177/20000 [02:49<02:43, 59.90it/s] 
  2%|▎         | 500/20000 [00:03<01:49, 177.75it/s]

{'loss': 0.6772, 'grad_norm': 0.8402527570724487, 'learning_rate': 9.75e-06, 'epoch': 1.6}


                                                    
  2%|▎         | 500/20000 [00:03<01:49, 177.75it/s]

{'eval_loss': 0.6655793190002441, 'eval_accuracy': 0.792, 'eval_runtime': 0.1628, 'eval_samples_per_second': 6142.872, 'eval_steps_per_second': 767.859, 'epoch': 1.6}


  5%|▌         | 1000/20000 [00:06<01:57, 161.86it/s]

{'loss': 0.6441, 'grad_norm': 0.8588707447052002, 'learning_rate': 9.5e-06, 'epoch': 3.19}


                                                     
  5%|▌         | 1023/20000 [00:06<02:55, 107.88it/s]

{'eval_loss': 0.6400033831596375, 'eval_accuracy': 0.81, 'eval_runtime': 0.1879, 'eval_samples_per_second': 5321.093, 'eval_steps_per_second': 665.137, 'epoch': 3.19}


  8%|▊         | 1500/20000 [00:09<01:49, 169.43it/s]

{'loss': 0.6108, 'grad_norm': 1.1797585487365723, 'learning_rate': 9.250000000000001e-06, 'epoch': 4.79}


                                                     
  8%|▊         | 1500/20000 [00:09<01:49, 169.43it/s]

{'eval_loss': 0.614205539226532, 'eval_accuracy': 0.82, 'eval_runtime': 0.174, 'eval_samples_per_second': 5748.063, 'eval_steps_per_second': 718.508, 'epoch': 4.79}


 10%|█         | 2000/20000 [00:12<01:43, 174.43it/s]

{'loss': 0.5776, 'grad_norm': 0.8083836436271667, 'learning_rate': 9e-06, 'epoch': 6.39}


                                                     
 10%|█         | 2027/20000 [00:12<02:27, 121.53it/s]

{'eval_loss': 0.5903905034065247, 'eval_accuracy': 0.811, 'eval_runtime': 0.1823, 'eval_samples_per_second': 5484.141, 'eval_steps_per_second': 685.518, 'epoch': 6.39}


 12%|█▎        | 2500/20000 [00:15<01:41, 171.87it/s]

{'loss': 0.5482, 'grad_norm': 0.8313344717025757, 'learning_rate': 8.750000000000001e-06, 'epoch': 7.99}


                                                     
 12%|█▎        | 2500/20000 [00:15<01:41, 171.87it/s]


{'eval_loss': 0.5674571394920349, 'eval_accuracy': 0.823, 'eval_runtime': 0.1699, 'eval_samples_per_second': 5886.982, 'eval_steps_per_second': 735.873, 'epoch': 7.99}


 15%|█▌        | 3000/20000 [00:18<01:46, 159.62it/s]

{'loss': 0.5184, 'grad_norm': 1.0029003620147705, 'learning_rate': 8.5e-06, 'epoch': 9.58}


                                                     
 15%|█▌        | 3023/20000 [00:19<02:37, 108.07it/s]

{'eval_loss': 0.5469377636909485, 'eval_accuracy': 0.826, 'eval_runtime': 0.1934, 'eval_samples_per_second': 5170.659, 'eval_steps_per_second': 646.332, 'epoch': 9.58}


 18%|█▊        | 3500/20000 [00:22<01:32, 177.45it/s]

{'loss': 0.4946, 'grad_norm': 1.018371820449829, 'learning_rate': 8.25e-06, 'epoch': 11.18}


                                                     
 18%|█▊        | 3500/20000 [00:22<01:32, 177.45it/s]

{'eval_loss': 0.52842116355896, 'eval_accuracy': 0.829, 'eval_runtime': 0.1692, 'eval_samples_per_second': 5909.693, 'eval_steps_per_second': 738.712, 'epoch': 11.18}


 20%|██        | 4000/20000 [00:25<01:34, 169.20it/s]

{'loss': 0.4705, 'grad_norm': 0.8455029726028442, 'learning_rate': 8.000000000000001e-06, 'epoch': 12.78}


                                                     
 20%|██        | 4000/20000 [00:25<01:34, 169.20it/s]

{'eval_loss': 0.5113728046417236, 'eval_accuracy': 0.835, 'eval_runtime': 0.1748, 'eval_samples_per_second': 5721.786, 'eval_steps_per_second': 715.223, 'epoch': 12.78}


 22%|██▎       | 4500/20000 [00:28<01:33, 166.51it/s]

{'loss': 0.4481, 'grad_norm': 0.6645598411560059, 'learning_rate': 7.75e-06, 'epoch': 14.38}


                                                     
 23%|██▎       | 4530/20000 [00:28<02:08, 120.75it/s]

{'eval_loss': 0.49594610929489136, 'eval_accuracy': 0.836, 'eval_runtime': 0.1789, 'eval_samples_per_second': 5588.419, 'eval_steps_per_second': 698.552, 'epoch': 14.38}


 25%|██▌       | 5000/20000 [00:31<01:26, 172.48it/s]

{'loss': 0.4305, 'grad_norm': 0.7982301712036133, 'learning_rate': 7.500000000000001e-06, 'epoch': 15.97}


                                                     
 25%|██▌       | 5026/20000 [00:32<02:07, 117.63it/s]

{'eval_loss': 0.482132226228714, 'eval_accuracy': 0.841, 'eval_runtime': 0.1814, 'eval_samples_per_second': 5513.779, 'eval_steps_per_second': 689.222, 'epoch': 15.97}


 28%|██▊       | 5500/20000 [00:34<01:28, 163.55it/s]

{'loss': 0.41, 'grad_norm': 0.804196834564209, 'learning_rate': 7.25e-06, 'epoch': 17.57}


                                                     
 28%|██▊       | 5518/20000 [00:35<02:07, 113.95it/s]

{'eval_loss': 0.4695770740509033, 'eval_accuracy': 0.845, 'eval_runtime': 0.1835, 'eval_samples_per_second': 5448.57, 'eval_steps_per_second': 681.071, 'epoch': 17.57}


 30%|███       | 6000/20000 [00:38<01:26, 161.81it/s]

{'loss': 0.3985, 'grad_norm': 0.8963615298271179, 'learning_rate': 7e-06, 'epoch': 19.17}


                                                     
 30%|███       | 6003/20000 [00:38<02:10, 107.33it/s]

{'eval_loss': 0.45832404494285583, 'eval_accuracy': 0.847, 'eval_runtime': 0.1605, 'eval_samples_per_second': 6230.805, 'eval_steps_per_second': 778.851, 'epoch': 19.17}


 32%|███▎      | 6500/20000 [00:41<01:24, 159.32it/s]

{'loss': 0.3819, 'grad_norm': 1.1032607555389404, 'learning_rate': 6.750000000000001e-06, 'epoch': 20.77}


                                                     
 32%|███▎      | 6500/20000 [00:41<01:24, 159.32it/s]
100%|██████████| 125/125 [00:00<00:00, 767.95it/s]

{'eval_loss': 0.4480094015598297, 'eval_accuracy': 0.849, 'eval_runtime': 0.1791, 'eval_samples_per_second': 5584.602, 'eval_steps_per_second': 698.075, 'epoch': 20.77}


 35%|███▌      | 7000/20000 [00:44<01:20, 160.50it/s]

{'loss': 0.3669, 'grad_norm': 0.6428804993629456, 'learning_rate': 6.5000000000000004e-06, 'epoch': 22.36}


                                                     
 35%|███▌      | 7000/20000 [00:44<01:20, 160.50it/s]

{'eval_loss': 0.4386056959629059, 'eval_accuracy': 0.844, 'eval_runtime': 0.1783, 'eval_samples_per_second': 5608.078, 'eval_steps_per_second': 701.01, 'epoch': 22.36}


 38%|███▊      | 7500/20000 [00:47<01:18, 158.46it/s]

{'loss': 0.3565, 'grad_norm': 0.8660621643066406, 'learning_rate': 6.25e-06, 'epoch': 23.96}


                                                     
 38%|███▊      | 7518/20000 [00:48<01:56, 107.05it/s]

{'eval_loss': 0.43016552925109863, 'eval_accuracy': 0.846, 'eval_runtime': 0.2107, 'eval_samples_per_second': 4747.002, 'eval_steps_per_second': 593.375, 'epoch': 23.96}


 40%|████      | 8000/20000 [00:51<01:08, 175.76it/s]

{'loss': 0.3453, 'grad_norm': 0.7485974431037903, 'learning_rate': 6e-06, 'epoch': 25.56}


                                                     
 40%|████      | 8025/20000 [00:51<01:34, 126.46it/s]

{'eval_loss': 0.42230066657066345, 'eval_accuracy': 0.847, 'eval_runtime': 0.1719, 'eval_samples_per_second': 5817.673, 'eval_steps_per_second': 727.209, 'epoch': 25.56}


 42%|████▎     | 8500/20000 [00:54<01:09, 164.94it/s]

{'loss': 0.3331, 'grad_norm': 0.7147629857063293, 'learning_rate': 5.75e-06, 'epoch': 27.16}


                                                     
 43%|████▎     | 8529/20000 [00:54<01:39, 115.86it/s]

{'eval_loss': 0.4153141379356384, 'eval_accuracy': 0.845, 'eval_runtime': 0.1835, 'eval_samples_per_second': 5448.422, 'eval_steps_per_second': 681.053, 'epoch': 27.16}


 45%|████▌     | 9000/20000 [00:57<01:04, 169.85it/s]

{'loss': 0.3254, 'grad_norm': 0.8416149020195007, 'learning_rate': 5.500000000000001e-06, 'epoch': 28.75}


                                                     
 45%|████▌     | 9024/20000 [00:57<01:37, 113.07it/s]

{'eval_loss': 0.4087965786457062, 'eval_accuracy': 0.848, 'eval_runtime': 0.1974, 'eval_samples_per_second': 5065.058, 'eval_steps_per_second': 633.132, 'epoch': 28.75}


 48%|████▊     | 9500/20000 [01:00<01:01, 169.69it/s]

{'loss': 0.317, 'grad_norm': 0.6615365147590637, 'learning_rate': 5.2500000000000006e-06, 'epoch': 30.35}


                                                     
 48%|████▊     | 9524/20000 [01:01<01:30, 116.28it/s]

{'eval_loss': 0.40304088592529297, 'eval_accuracy': 0.846, 'eval_runtime': 0.1803, 'eval_samples_per_second': 5546.789, 'eval_steps_per_second': 693.349, 'epoch': 30.35}


 50%|█████     | 10000/20000 [01:04<01:00, 165.73it/s]

{'loss': 0.3068, 'grad_norm': 1.0841376781463623, 'learning_rate': 5e-06, 'epoch': 31.95}


                                                      
 50%|█████     | 10028/20000 [01:04<01:26, 114.86it/s]

{'eval_loss': 0.39770689606666565, 'eval_accuracy': 0.848, 'eval_runtime': 0.1844, 'eval_samples_per_second': 5424.464, 'eval_steps_per_second': 678.058, 'epoch': 31.95}


 52%|█████▎    | 10500/20000 [01:07<00:54, 173.17it/s]

{'loss': 0.3, 'grad_norm': 0.9763253331184387, 'learning_rate': 4.75e-06, 'epoch': 33.55}


                                                      
 52%|█████▎    | 10500/20000 [01:07<00:54, 173.17it/s]

{'eval_loss': 0.3928989768028259, 'eval_accuracy': 0.849, 'eval_runtime': 0.1726, 'eval_samples_per_second': 5793.454, 'eval_steps_per_second': 724.182, 'epoch': 33.55}


 55%|█████▌    | 11000/20000 [01:10<00:53, 169.49it/s]

{'loss': 0.2945, 'grad_norm': 0.8224273920059204, 'learning_rate': 4.5e-06, 'epoch': 35.14}


                                                      
 55%|█████▌    | 11029/20000 [01:10<01:14, 119.85it/s]

{'eval_loss': 0.38846978545188904, 'eval_accuracy': 0.852, 'eval_runtime': 0.1895, 'eval_samples_per_second': 5275.98, 'eval_steps_per_second': 659.498, 'epoch': 35.14}


 57%|█████▊    | 11500/20000 [01:13<00:47, 179.76it/s]

{'loss': 0.2883, 'grad_norm': 0.7045588493347168, 'learning_rate': 4.25e-06, 'epoch': 36.74}


                                                      
 57%|█████▊    | 11500/20000 [01:13<00:47, 179.76it/s]

{'eval_loss': 0.38450437784194946, 'eval_accuracy': 0.852, 'eval_runtime': 0.1624, 'eval_samples_per_second': 6155.755, 'eval_steps_per_second': 769.469, 'epoch': 36.74}


 60%|██████    | 12000/20000 [01:16<00:45, 175.80it/s]

{'loss': 0.2806, 'grad_norm': 0.7337397933006287, 'learning_rate': 4.000000000000001e-06, 'epoch': 38.34}


                                                      
 60%|██████    | 12000/20000 [01:16<00:45, 175.80it/s]

{'eval_loss': 0.38087666034698486, 'eval_accuracy': 0.854, 'eval_runtime': 0.1677, 'eval_samples_per_second': 5961.493, 'eval_steps_per_second': 745.187, 'epoch': 38.34}


 62%|██████▎   | 12500/20000 [01:19<00:40, 186.91it/s]

{'loss': 0.2773, 'grad_norm': 0.7294104695320129, 'learning_rate': 3.7500000000000005e-06, 'epoch': 39.94}


                                                      
 63%|██████▎   | 12502/20000 [01:19<01:02, 119.85it/s]

{'eval_loss': 0.3776748776435852, 'eval_accuracy': 0.856, 'eval_runtime': 0.1639, 'eval_samples_per_second': 6102.022, 'eval_steps_per_second': 762.753, 'epoch': 39.94}


 65%|██████▌   | 13000/20000 [01:22<00:38, 180.06it/s]

{'loss': 0.2702, 'grad_norm': 0.6221808791160583, 'learning_rate': 3.5e-06, 'epoch': 41.53}


                                                      
 65%|██████▌   | 13025/20000 [01:22<00:56, 123.42it/s]

{'eval_loss': 0.37480053305625916, 'eval_accuracy': 0.857, 'eval_runtime': 0.1852, 'eval_samples_per_second': 5399.569, 'eval_steps_per_second': 674.946, 'epoch': 41.53}


 68%|██████▊   | 13500/20000 [01:25<00:34, 186.55it/s]

{'loss': 0.2691, 'grad_norm': 0.6917203664779663, 'learning_rate': 3.2500000000000002e-06, 'epoch': 43.13}


                                                      
 68%|██████▊   | 13500/20000 [01:25<00:34, 186.55it/s]

{'eval_loss': 0.3719993233680725, 'eval_accuracy': 0.855, 'eval_runtime': 0.1717, 'eval_samples_per_second': 5824.363, 'eval_steps_per_second': 728.045, 'epoch': 43.13}


 70%|███████   | 14000/20000 [01:28<00:35, 170.19it/s]

{'loss': 0.2623, 'grad_norm': 0.7605948448181152, 'learning_rate': 3e-06, 'epoch': 44.73}


                                                      
 70%|███████   | 14024/20000 [01:28<00:50, 119.29it/s]

{'eval_loss': 0.36950305104255676, 'eval_accuracy': 0.856, 'eval_runtime': 0.1909, 'eval_samples_per_second': 5238.84, 'eval_steps_per_second': 654.855, 'epoch': 44.73}


 72%|███████▎  | 14500/20000 [01:31<00:30, 181.64it/s]

{'loss': 0.2607, 'grad_norm': 0.8834921717643738, 'learning_rate': 2.7500000000000004e-06, 'epoch': 46.33}


                                                      
 72%|███████▎  | 14500/20000 [01:31<00:30, 181.64it/s]

{'eval_loss': 0.3674236834049225, 'eval_accuracy': 0.858, 'eval_runtime': 0.1631, 'eval_samples_per_second': 6129.452, 'eval_steps_per_second': 766.181, 'epoch': 46.33}


 75%|███████▌  | 15000/20000 [01:34<00:28, 173.02it/s]

{'loss': 0.2558, 'grad_norm': 0.8777567148208618, 'learning_rate': 2.5e-06, 'epoch': 47.92}


                                                      
 75%|███████▌  | 15001/20000 [01:34<00:44, 113.07it/s]

{'eval_loss': 0.36560338735580444, 'eval_accuracy': 0.856, 'eval_runtime': 0.1751, 'eval_samples_per_second': 5710.452, 'eval_steps_per_second': 713.806, 'epoch': 47.92}


 78%|███████▊  | 15500/20000 [01:37<00:23, 189.34it/s]

{'loss': 0.2532, 'grad_norm': 0.8382019400596619, 'learning_rate': 2.25e-06, 'epoch': 49.52}


                                                      
 78%|███████▊  | 15500/20000 [01:37<00:23, 189.34it/s]

{'eval_loss': 0.363948792219162, 'eval_accuracy': 0.856, 'eval_runtime': 0.1664, 'eval_samples_per_second': 6010.892, 'eval_steps_per_second': 751.361, 'epoch': 49.52}


 80%|████████  | 16000/20000 [01:40<00:23, 169.37it/s]

{'loss': 0.2512, 'grad_norm': 1.0121062994003296, 'learning_rate': 2.0000000000000003e-06, 'epoch': 51.12}


                                                      
 80%|████████  | 16021/20000 [01:40<00:33, 119.39it/s]

{'eval_loss': 0.3623458743095398, 'eval_accuracy': 0.857, 'eval_runtime': 0.1853, 'eval_samples_per_second': 5395.727, 'eval_steps_per_second': 674.466, 'epoch': 51.12}


 82%|████████▎ | 16500/20000 [01:43<00:18, 186.84it/s]

{'loss': 0.2481, 'grad_norm': 0.9526575207710266, 'learning_rate': 1.75e-06, 'epoch': 52.72}


                                                      
 83%|████████▎ | 16525/20000 [01:43<00:25, 133.71it/s]

{'eval_loss': 0.36109086871147156, 'eval_accuracy': 0.857, 'eval_runtime': 0.1734, 'eval_samples_per_second': 5768.007, 'eval_steps_per_second': 721.001, 'epoch': 52.72}


 85%|████████▌ | 17000/20000 [01:46<00:18, 164.43it/s]

{'loss': 0.2463, 'grad_norm': 0.5677160024642944, 'learning_rate': 1.5e-06, 'epoch': 54.31}


                                                      
 85%|████████▌ | 17016/20000 [01:46<00:27, 109.14it/s]

{'eval_loss': 0.3599223792552948, 'eval_accuracy': 0.857, 'eval_runtime': 0.2025, 'eval_samples_per_second': 4937.333, 'eval_steps_per_second': 617.167, 'epoch': 54.31}


 88%|████████▊ | 17500/20000 [01:49<00:15, 166.55it/s]

{'loss': 0.2445, 'grad_norm': 0.674629807472229, 'learning_rate': 1.25e-06, 'epoch': 55.91}


                                                      
 88%|████████▊ | 17500/20000 [01:49<00:15, 166.55it/s]

{'eval_loss': 0.35902172327041626, 'eval_accuracy': 0.857, 'eval_runtime': 0.1705, 'eval_samples_per_second': 5863.388, 'eval_steps_per_second': 732.923, 'epoch': 55.91}


 90%|█████████ | 18000/20000 [01:52<00:12, 162.01it/s]

{'loss': 0.2439, 'grad_norm': 0.9155569076538086, 'learning_rate': 1.0000000000000002e-06, 'epoch': 57.51}


                                                      
 90%|█████████ | 18029/20000 [01:53<00:17, 109.91it/s]

{'eval_loss': 0.35829392075538635, 'eval_accuracy': 0.857, 'eval_runtime': 0.1899, 'eval_samples_per_second': 5265.297, 'eval_steps_per_second': 658.162, 'epoch': 57.51}


 92%|█████████▎| 18500/20000 [01:56<00:09, 155.17it/s]

{'loss': 0.2414, 'grad_norm': 0.6704147458076477, 'learning_rate': 7.5e-07, 'epoch': 59.11}


                                                      
 93%|█████████▎| 18513/20000 [01:56<00:16, 90.36it/s] 

{'eval_loss': 0.35772013664245605, 'eval_accuracy': 0.857, 'eval_runtime': 0.2244, 'eval_samples_per_second': 4455.555, 'eval_steps_per_second': 556.944, 'epoch': 59.11}


 95%|█████████▌| 19000/20000 [01:59<00:06, 159.55it/s]

{'loss': 0.2413, 'grad_norm': 0.6864736080169678, 'learning_rate': 5.000000000000001e-07, 'epoch': 60.7}


                                                      
 95%|█████████▌| 19000/20000 [01:59<00:06, 159.55it/s]

{'eval_loss': 0.3573075234889984, 'eval_accuracy': 0.858, 'eval_runtime': 0.1794, 'eval_samples_per_second': 5575.3, 'eval_steps_per_second': 696.913, 'epoch': 60.7}


 98%|█████████▊| 19500/20000 [02:02<00:03, 158.14it/s]

{'loss': 0.2408, 'grad_norm': 0.6366955637931824, 'learning_rate': 2.5000000000000004e-07, 'epoch': 62.3}


                                                      
 98%|█████████▊| 19527/20000 [02:02<00:04, 110.95it/s]

{'eval_loss': 0.3570707142353058, 'eval_accuracy': 0.858, 'eval_runtime': 0.1899, 'eval_samples_per_second': 5266.348, 'eval_steps_per_second': 658.294, 'epoch': 62.3}


100%|██████████| 20000/20000 [02:06<00:00, 158.52it/s]

{'loss': 0.242, 'grad_norm': 0.7342261672019958, 'learning_rate': 0.0, 'epoch': 63.9}


                                                      
100%|██████████| 20000/20000 [02:06<00:00, 158.52it/s]

{'eval_loss': 0.3569808900356293, 'eval_accuracy': 0.859, 'eval_runtime': 0.1681, 'eval_samples_per_second': 5948.87, 'eval_steps_per_second': 743.609, 'epoch': 63.9}


100%|██████████| 20000/20000 [02:06<00:00, 158.42it/s]


{'train_runtime': 126.2434, 'train_samples_per_second': 5069.573, 'train_steps_per_second': 158.424, 'train_loss': 0.3543166919708252, 'epoch': 63.9}


100%|██████████| 3125/3125 [00:04<00:00, 739.46it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=1e-05 and batch_size=64


  2%|▎         | 500/20000 [00:04<03:12, 101.07it/s]

{'loss': 0.6902, 'grad_norm': 0.46055394411087036, 'learning_rate': 9.75e-06, 'epoch': 3.18}


                                                    
  3%|▎         | 515/20000 [00:05<04:36, 70.56it/s] 

{'eval_loss': 0.6694180965423584, 'eval_accuracy': 0.603, 'eval_runtime': 0.1848, 'eval_samples_per_second': 5410.679, 'eval_steps_per_second': 676.335, 'epoch': 3.18}


  5%|▌         | 1000/20000 [00:10<03:03, 103.48it/s]

{'loss': 0.6528, 'grad_norm': 0.6369129419326782, 'learning_rate': 9.5e-06, 'epoch': 6.37}


                                                     
  5%|▌         | 1017/20000 [00:10<04:12, 75.12it/s] 

{'eval_loss': 0.6438913345336914, 'eval_accuracy': 0.726, 'eval_runtime': 0.1791, 'eval_samples_per_second': 5583.598, 'eval_steps_per_second': 697.95, 'epoch': 6.37}


  8%|▊         | 1500/20000 [00:15<02:53, 106.44it/s]

{'loss': 0.6184, 'grad_norm': 0.511425793170929, 'learning_rate': 9.250000000000001e-06, 'epoch': 9.55}


                                                     
  8%|▊         | 1500/20000 [00:15<02:53, 106.44it/s]

{'eval_loss': 0.6177575588226318, 'eval_accuracy': 0.77, 'eval_runtime': 0.1773, 'eval_samples_per_second': 5641.577, 'eval_steps_per_second': 705.197, 'epoch': 9.55}


 10%|█         | 2000/20000 [00:19<02:49, 105.94it/s]

{'loss': 0.5839, 'grad_norm': 0.5307046175003052, 'learning_rate': 9e-06, 'epoch': 12.74}


                                                     
 10%|█         | 2000/20000 [00:19<02:49, 105.94it/s]

{'eval_loss': 0.5923187136650085, 'eval_accuracy': 0.793, 'eval_runtime': 0.1791, 'eval_samples_per_second': 5583.189, 'eval_steps_per_second': 697.899, 'epoch': 12.74}


 12%|█▎        | 2500/20000 [00:24<02:36, 111.82it/s]

{'loss': 0.5497, 'grad_norm': 0.4976402223110199, 'learning_rate': 8.750000000000001e-06, 'epoch': 15.92}


                                                     
 13%|█▎        | 2517/20000 [00:24<03:52, 75.18it/s] 

{'eval_loss': 0.5684905052185059, 'eval_accuracy': 0.808, 'eval_runtime': 0.2101, 'eval_samples_per_second': 4758.554, 'eval_steps_per_second': 594.819, 'epoch': 15.92}


 15%|█▌        | 3000/20000 [00:29<02:22, 119.53it/s]

{'loss': 0.5188, 'grad_norm': 0.5135688185691833, 'learning_rate': 8.5e-06, 'epoch': 19.11}


                                                     
 15%|█▌        | 3000/20000 [00:29<02:22, 119.53it/s]
                                                  

{'eval_loss': 0.5467765927314758, 'eval_accuracy': 0.821, 'eval_runtime': 0.1699, 'eval_samples_per_second': 5886.767, 'eval_steps_per_second': 735.846, 'epoch': 19.11}


 18%|█▊        | 3500/20000 [00:33<02:30, 109.77it/s]

{'loss': 0.4904, 'grad_norm': 0.5523794889450073, 'learning_rate': 8.25e-06, 'epoch': 22.29}


                                                     
 18%|█▊        | 3500/20000 [00:34<02:30, 109.77it/s]

{'eval_loss': 0.5271146297454834, 'eval_accuracy': 0.818, 'eval_runtime': 0.1752, 'eval_samples_per_second': 5708.501, 'eval_steps_per_second': 713.563, 'epoch': 22.29}


 20%|██        | 4000/20000 [00:38<02:25, 109.70it/s]

{'loss': 0.4653, 'grad_norm': 0.5472829937934875, 'learning_rate': 8.000000000000001e-06, 'epoch': 25.48}


                                                     
 20%|██        | 4009/20000 [00:38<03:50, 69.47it/s] 

{'eval_loss': 0.5093602538108826, 'eval_accuracy': 0.823, 'eval_runtime': 0.1821, 'eval_samples_per_second': 5492.083, 'eval_steps_per_second': 686.51, 'epoch': 25.48}


 22%|██▎       | 4500/20000 [00:43<02:14, 115.03it/s]

{'loss': 0.4421, 'grad_norm': 0.5369781255722046, 'learning_rate': 7.75e-06, 'epoch': 28.66}


                                                     
 22%|██▎       | 4500/20000 [00:43<02:14, 115.03it/s]

{'eval_loss': 0.493335485458374, 'eval_accuracy': 0.827, 'eval_runtime': 0.1677, 'eval_samples_per_second': 5962.027, 'eval_steps_per_second': 745.253, 'epoch': 28.66}


 25%|██▌       | 5000/20000 [00:47<02:10, 114.73it/s]

{'loss': 0.4208, 'grad_norm': 0.5787872672080994, 'learning_rate': 7.500000000000001e-06, 'epoch': 31.85}


                                                     
 25%|██▌       | 5001/20000 [00:48<03:18, 75.75it/s] 

{'eval_loss': 0.47863343358039856, 'eval_accuracy': 0.835, 'eval_runtime': 0.1583, 'eval_samples_per_second': 6315.135, 'eval_steps_per_second': 789.392, 'epoch': 31.85}


 28%|██▊       | 5500/20000 [00:52<02:16, 106.41it/s]

{'loss': 0.4024, 'grad_norm': 0.5034195780754089, 'learning_rate': 7.25e-06, 'epoch': 35.03}


                                                     
 28%|██▊       | 5512/20000 [00:52<03:04, 78.32it/s] 

{'eval_loss': 0.4655495584011078, 'eval_accuracy': 0.836, 'eval_runtime': 0.1739, 'eval_samples_per_second': 5750.222, 'eval_steps_per_second': 718.778, 'epoch': 35.03}


 30%|███       | 6000/20000 [00:57<01:58, 117.77it/s]

{'loss': 0.385, 'grad_norm': 0.5052285194396973, 'learning_rate': 7e-06, 'epoch': 38.22}


                                                     
 30%|███       | 6000/20000 [00:57<01:58, 117.77it/s]

{'eval_loss': 0.45360448956489563, 'eval_accuracy': 0.838, 'eval_runtime': 0.1685, 'eval_samples_per_second': 5935.955, 'eval_steps_per_second': 741.994, 'epoch': 38.22}


 32%|███▎      | 6500/20000 [01:01<01:55, 116.96it/s]

{'loss': 0.369, 'grad_norm': 0.5393340587615967, 'learning_rate': 6.750000000000001e-06, 'epoch': 41.4}


                                                     
 32%|███▎      | 6500/20000 [01:01<01:55, 116.96it/s]

{'eval_loss': 0.44264301657676697, 'eval_accuracy': 0.841, 'eval_runtime': 0.1687, 'eval_samples_per_second': 5929.309, 'eval_steps_per_second': 741.164, 'epoch': 41.4}


 35%|███▌      | 7000/20000 [01:06<01:53, 114.37it/s]

{'loss': 0.3542, 'grad_norm': 0.5248522162437439, 'learning_rate': 6.5000000000000004e-06, 'epoch': 44.59}


                                                     
 35%|███▌      | 7000/20000 [01:06<01:53, 114.37it/s]

{'eval_loss': 0.43305888772010803, 'eval_accuracy': 0.842, 'eval_runtime': 0.1667, 'eval_samples_per_second': 5999.946, 'eval_steps_per_second': 749.993, 'epoch': 44.59}


 38%|███▊      | 7500/20000 [01:11<01:59, 105.04it/s]

{'loss': 0.342, 'grad_norm': 0.50478196144104, 'learning_rate': 6.25e-06, 'epoch': 47.77}


                                                     
 38%|███▊      | 7518/20000 [01:11<02:46, 74.81it/s] 

{'eval_loss': 0.4238169491291046, 'eval_accuracy': 0.841, 'eval_runtime': 0.186, 'eval_samples_per_second': 5377.34, 'eval_steps_per_second': 672.168, 'epoch': 47.77}


 40%|████      | 8000/20000 [01:15<01:46, 112.89it/s]

{'loss': 0.329, 'grad_norm': 0.579477846622467, 'learning_rate': 6e-06, 'epoch': 50.96}


                                                     
 40%|████      | 8000/20000 [01:15<01:46, 112.89it/s]

{'eval_loss': 0.4159235954284668, 'eval_accuracy': 0.843, 'eval_runtime': 0.1728, 'eval_samples_per_second': 5785.821, 'eval_steps_per_second': 723.228, 'epoch': 50.96}


 42%|████▎     | 8500/20000 [01:20<01:37, 117.66it/s]

{'loss': 0.3185, 'grad_norm': 0.4636477530002594, 'learning_rate': 5.75e-06, 'epoch': 54.14}


                                                     
 42%|████▎     | 8500/20000 [01:20<01:37, 117.66it/s]

{'eval_loss': 0.40855371952056885, 'eval_accuracy': 0.847, 'eval_runtime': 0.1765, 'eval_samples_per_second': 5665.597, 'eval_steps_per_second': 708.2, 'epoch': 54.14}


 45%|████▌     | 9000/20000 [01:25<01:44, 104.95it/s]

{'loss': 0.307, 'grad_norm': 0.4952618479728699, 'learning_rate': 5.500000000000001e-06, 'epoch': 57.32}


                                                     
 45%|████▌     | 9013/20000 [01:25<02:28, 74.09it/s] 

{'eval_loss': 0.4019678235054016, 'eval_accuracy': 0.846, 'eval_runtime': 0.1826, 'eval_samples_per_second': 5475.485, 'eval_steps_per_second': 684.436, 'epoch': 57.32}


 48%|████▊     | 9500/20000 [01:30<01:37, 107.36it/s]

{'loss': 0.2989, 'grad_norm': 0.4600812792778015, 'learning_rate': 5.2500000000000006e-06, 'epoch': 60.51}


                                                     
 48%|████▊     | 9500/20000 [01:30<01:37, 107.36it/s]


{'eval_loss': 0.39568695425987244, 'eval_accuracy': 0.851, 'eval_runtime': 0.1818, 'eval_samples_per_second': 5499.277, 'eval_steps_per_second': 687.41, 'epoch': 60.51}


 50%|█████     | 10000/20000 [01:35<01:44, 95.58it/s]

{'loss': 0.2898, 'grad_norm': 0.4312400221824646, 'learning_rate': 5e-06, 'epoch': 63.69}


                                                     
 50%|█████     | 10012/20000 [01:35<02:29, 66.82it/s]

{'eval_loss': 0.39034923911094666, 'eval_accuracy': 0.85, 'eval_runtime': 0.1966, 'eval_samples_per_second': 5087.508, 'eval_steps_per_second': 635.938, 'epoch': 63.69}


 52%|█████▎    | 10500/20000 [01:40<01:30, 104.95it/s]

{'loss': 0.2823, 'grad_norm': 0.5052204728126526, 'learning_rate': 4.75e-06, 'epoch': 66.88}


                                                      
 53%|█████▎    | 10514/20000 [01:40<02:07, 74.25it/s] 

{'eval_loss': 0.3853646218776703, 'eval_accuracy': 0.851, 'eval_runtime': 0.1796, 'eval_samples_per_second': 5568.557, 'eval_steps_per_second': 696.07, 'epoch': 66.88}


 55%|█████▌    | 11000/20000 [01:45<01:29, 100.71it/s]

{'loss': 0.2754, 'grad_norm': 0.46315997838974, 'learning_rate': 4.5e-06, 'epoch': 70.06}


                                                      
 55%|█████▌    | 11017/20000 [01:46<02:09, 69.30it/s] 

{'eval_loss': 0.38074737787246704, 'eval_accuracy': 0.851, 'eval_runtime': 0.1869, 'eval_samples_per_second': 5350.424, 'eval_steps_per_second': 668.803, 'epoch': 70.06}


 57%|█████▊    | 11500/20000 [01:51<01:20, 105.43it/s]

{'loss': 0.2679, 'grad_norm': 0.41044729948043823, 'learning_rate': 4.25e-06, 'epoch': 73.25}


                                                      
 58%|█████▊    | 11513/20000 [01:51<01:53, 75.10it/s] 

{'eval_loss': 0.3766644299030304, 'eval_accuracy': 0.852, 'eval_runtime': 0.1773, 'eval_samples_per_second': 5640.819, 'eval_steps_per_second': 705.102, 'epoch': 73.25}


 60%|██████    | 12000/20000 [01:56<01:23, 95.87it/s] 

{'loss': 0.2629, 'grad_norm': 0.45812100172042847, 'learning_rate': 4.000000000000001e-06, 'epoch': 76.43}


                                                     
 60%|██████    | 12015/20000 [01:56<01:58, 67.44it/s]

{'eval_loss': 0.3731173872947693, 'eval_accuracy': 0.853, 'eval_runtime': 0.1834, 'eval_samples_per_second': 5451.134, 'eval_steps_per_second': 681.392, 'epoch': 76.43}


 62%|██████▎   | 12500/20000 [02:01<01:12, 103.57it/s]

{'loss': 0.256, 'grad_norm': 0.43337810039520264, 'learning_rate': 3.7500000000000005e-06, 'epoch': 79.62}


                                                      
 63%|██████▎   | 12510/20000 [02:01<01:56, 64.47it/s] 

{'eval_loss': 0.36966195702552795, 'eval_accuracy': 0.852, 'eval_runtime': 0.1834, 'eval_samples_per_second': 5451.729, 'eval_steps_per_second': 681.466, 'epoch': 79.62}


 65%|██████▌   | 13000/20000 [02:06<01:06, 105.73it/s]

{'loss': 0.2525, 'grad_norm': 0.5018870830535889, 'learning_rate': 3.5e-06, 'epoch': 82.8}


                                                      
 65%|██████▌   | 13007/20000 [02:06<01:44, 66.92it/s] 

{'eval_loss': 0.3665390610694885, 'eval_accuracy': 0.85, 'eval_runtime': 0.1849, 'eval_samples_per_second': 5407.219, 'eval_steps_per_second': 675.902, 'epoch': 82.8}


 68%|██████▊   | 13500/20000 [02:11<01:03, 101.92it/s]

{'loss': 0.2462, 'grad_norm': 0.5071868896484375, 'learning_rate': 3.2500000000000002e-06, 'epoch': 85.99}


                                                      
 68%|██████▊   | 13517/20000 [02:11<01:27, 74.26it/s] 

{'eval_loss': 0.3639027774333954, 'eval_accuracy': 0.851, 'eval_runtime': 0.1845, 'eval_samples_per_second': 5420.279, 'eval_steps_per_second': 677.535, 'epoch': 85.99}


 70%|███████   | 14000/20000 [02:16<01:00, 99.31it/s] 

{'loss': 0.2429, 'grad_norm': 0.4621407985687256, 'learning_rate': 3e-06, 'epoch': 89.17}


                                                     
 70%|███████   | 14009/20000 [02:16<01:34, 63.67it/s]

{'eval_loss': 0.36146146059036255, 'eval_accuracy': 0.852, 'eval_runtime': 0.1756, 'eval_samples_per_second': 5693.963, 'eval_steps_per_second': 711.745, 'epoch': 89.17}


 72%|███████▎  | 14500/20000 [02:21<00:53, 103.34it/s]

{'loss': 0.239, 'grad_norm': 0.41504883766174316, 'learning_rate': 2.7500000000000004e-06, 'epoch': 92.36}


                                                      
 73%|███████▎  | 14516/20000 [02:21<01:15, 72.97it/s] 

{'eval_loss': 0.3592887222766876, 'eval_accuracy': 0.851, 'eval_runtime': 0.1963, 'eval_samples_per_second': 5094.206, 'eval_steps_per_second': 636.776, 'epoch': 92.36}


 75%|███████▌  | 15000/20000 [02:26<00:47, 105.48it/s]

{'loss': 0.235, 'grad_norm': 0.4037564992904663, 'learning_rate': 2.5e-06, 'epoch': 95.54}


                                                      
 75%|███████▌  | 15011/20000 [02:26<01:18, 63.48it/s] 

{'eval_loss': 0.3572770953178406, 'eval_accuracy': 0.851, 'eval_runtime': 0.1928, 'eval_samples_per_second': 5188.049, 'eval_steps_per_second': 648.506, 'epoch': 95.54}


 78%|███████▊  | 15500/20000 [02:31<00:43, 103.02it/s]

{'loss': 0.232, 'grad_norm': 0.397609144449234, 'learning_rate': 2.25e-06, 'epoch': 98.73}


                                                      
 78%|███████▊  | 15516/20000 [02:32<01:01, 72.84it/s] 

{'eval_loss': 0.35572487115859985, 'eval_accuracy': 0.853, 'eval_runtime': 0.2021, 'eval_samples_per_second': 4947.362, 'eval_steps_per_second': 618.42, 'epoch': 98.73}


 80%|████████  | 16000/20000 [02:36<00:38, 103.36it/s]

{'loss': 0.231, 'grad_norm': 0.4496039152145386, 'learning_rate': 2.0000000000000003e-06, 'epoch': 101.91}


                                                      
 80%|████████  | 16015/20000 [02:37<00:54, 72.92it/s] 

{'eval_loss': 0.35417595505714417, 'eval_accuracy': 0.852, 'eval_runtime': 0.1936, 'eval_samples_per_second': 5164.553, 'eval_steps_per_second': 645.569, 'epoch': 101.91}


 82%|████████▎ | 16500/20000 [02:42<00:33, 103.63it/s]

{'loss': 0.2264, 'grad_norm': 0.40670260787010193, 'learning_rate': 1.75e-06, 'epoch': 105.1}


                                                      
 83%|████████▎ | 16517/20000 [02:42<00:49, 70.85it/s] 

{'eval_loss': 0.3529439866542816, 'eval_accuracy': 0.853, 'eval_runtime': 0.2165, 'eval_samples_per_second': 4618.977, 'eval_steps_per_second': 577.372, 'epoch': 105.1}


 85%|████████▌ | 17000/20000 [02:47<00:29, 103.04it/s]

{'loss': 0.2252, 'grad_norm': 0.3957420587539673, 'learning_rate': 1.5e-06, 'epoch': 108.28}


                                                      
 85%|████████▌ | 17019/20000 [02:47<00:41, 72.56it/s] 

{'eval_loss': 0.3517990708351135, 'eval_accuracy': 0.852, 'eval_runtime': 0.1843, 'eval_samples_per_second': 5425.123, 'eval_steps_per_second': 678.14, 'epoch': 108.28}


 88%|████████▊ | 17500/20000 [02:52<00:24, 103.23it/s]

{'loss': 0.2231, 'grad_norm': 0.4096619486808777, 'learning_rate': 1.25e-06, 'epoch': 111.46}


                                                      
 88%|████████▊ | 17508/20000 [02:52<00:38, 64.15it/s] 

{'eval_loss': 0.3509713411331177, 'eval_accuracy': 0.853, 'eval_runtime': 0.185, 'eval_samples_per_second': 5404.982, 'eval_steps_per_second': 675.623, 'epoch': 111.46}


 90%|█████████ | 18000/20000 [02:57<00:19, 102.80it/s]

{'loss': 0.2224, 'grad_norm': 0.4432070553302765, 'learning_rate': 1.0000000000000002e-06, 'epoch': 114.65}


                                                      
 90%|█████████ | 18014/20000 [02:57<00:26, 74.13it/s] 

{'eval_loss': 0.3501666784286499, 'eval_accuracy': 0.852, 'eval_runtime': 0.1839, 'eval_samples_per_second': 5438.863, 'eval_steps_per_second': 679.858, 'epoch': 114.65}


 92%|█████████▎| 18500/20000 [03:02<00:14, 103.10it/s]

{'loss': 0.2199, 'grad_norm': 0.4901573359966278, 'learning_rate': 7.5e-07, 'epoch': 117.83}


                                                      
 93%|█████████▎| 18512/20000 [03:02<00:20, 71.01it/s] 

{'eval_loss': 0.3496090769767761, 'eval_accuracy': 0.852, 'eval_runtime': 0.1815, 'eval_samples_per_second': 5508.97, 'eval_steps_per_second': 688.621, 'epoch': 117.83}


 95%|█████████▌| 19000/20000 [03:07<00:10, 98.01it/s] 

{'loss': 0.2195, 'grad_norm': 0.4393983483314514, 'learning_rate': 5.000000000000001e-07, 'epoch': 121.02}


                                                     
 95%|█████████▌| 19008/20000 [03:08<00:16, 61.06it/s]

{'eval_loss': 0.3492220938205719, 'eval_accuracy': 0.852, 'eval_runtime': 0.1882, 'eval_samples_per_second': 5313.591, 'eval_steps_per_second': 664.199, 'epoch': 121.02}


 98%|█████████▊| 19500/20000 [03:12<00:04, 100.38it/s]

{'loss': 0.2191, 'grad_norm': 0.4250403344631195, 'learning_rate': 2.5000000000000004e-07, 'epoch': 124.2}


                                                      
 98%|█████████▊| 19511/20000 [03:13<00:06, 72.16it/s] 

{'eval_loss': 0.3489702641963959, 'eval_accuracy': 0.852, 'eval_runtime': 0.1799, 'eval_samples_per_second': 5558.918, 'eval_steps_per_second': 694.865, 'epoch': 124.2}


100%|██████████| 20000/20000 [03:17<00:00, 104.05it/s]

{'loss': 0.2182, 'grad_norm': 0.4141634702682495, 'learning_rate': 0.0, 'epoch': 127.39}


                                                      
100%|██████████| 20000/20000 [03:18<00:00, 100.95it/s]


{'eval_loss': 0.34888797998428345, 'eval_accuracy': 0.852, 'eval_runtime': 0.1903, 'eval_samples_per_second': 5256.158, 'eval_steps_per_second': 657.02, 'epoch': 127.39}
{'train_runtime': 198.1104, 'train_samples_per_second': 6461.045, 'train_steps_per_second': 100.954, 'train_loss': 0.34062267417907716, 'epoch': 127.39}


100%|██████████| 3125/3125 [00:04<00:00, 736.58it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=1e-05 and batch_size=128


  2%|▎         | 500/20000 [00:08<05:41, 57.06it/s]

{'loss': 0.6684, 'grad_norm': 0.41283607482910156, 'learning_rate': 9.75e-06, 'epoch': 6.33}


                                                   
  3%|▎         | 507/20000 [00:08<07:56, 40.95it/s]

{'eval_loss': 0.6551390886306763, 'eval_accuracy': 0.748, 'eval_runtime': 0.1834, 'eval_samples_per_second': 5452.572, 'eval_steps_per_second': 681.572, 'epoch': 6.33}


  5%|▌         | 1000/20000 [00:16<05:07, 61.74it/s]

{'loss': 0.6136, 'grad_norm': 0.4432348906993866, 'learning_rate': 9.5e-06, 'epoch': 12.66}


                                                    
  5%|▌         | 1008/20000 [00:17<07:01, 45.04it/s]

{'eval_loss': 0.6146953701972961, 'eval_accuracy': 0.796, 'eval_runtime': 0.1753, 'eval_samples_per_second': 5703.936, 'eval_steps_per_second': 712.992, 'epoch': 12.66}


  8%|▊         | 1500/20000 [00:25<05:07, 60.16it/s]

{'loss': 0.563, 'grad_norm': 0.4556746482849121, 'learning_rate': 9.250000000000001e-06, 'epoch': 18.99}


                                                    
  8%|▊         | 1507/20000 [00:25<06:53, 44.76it/s]

{'eval_loss': 0.5781656503677368, 'eval_accuracy': 0.811, 'eval_runtime': 0.1781, 'eval_samples_per_second': 5614.827, 'eval_steps_per_second': 701.853, 'epoch': 18.99}


 10%|█         | 2000/20000 [00:33<04:49, 62.12it/s]

{'loss': 0.5178, 'grad_norm': 0.5066905617713928, 'learning_rate': 9e-06, 'epoch': 25.32}


                                                    
 10%|█         | 2000/20000 [00:33<04:49, 62.12it/s]


{'eval_loss': 0.546440601348877, 'eval_accuracy': 0.817, 'eval_runtime': 0.1668, 'eval_samples_per_second': 5994.912, 'eval_steps_per_second': 749.364, 'epoch': 25.32}


 12%|█▎        | 2500/20000 [00:41<04:36, 63.36it/s]A

{'loss': 0.479, 'grad_norm': 0.4432109594345093, 'learning_rate': 8.750000000000001e-06, 'epoch': 31.65}


                                                    
 12%|█▎        | 2500/20000 [00:41<04:36, 63.36it/s]

{'eval_loss': 0.5195026993751526, 'eval_accuracy': 0.823, 'eval_runtime': 0.1696, 'eval_samples_per_second': 5894.925, 'eval_steps_per_second': 736.866, 'epoch': 31.65}


 15%|█▌        | 3000/20000 [00:49<04:35, 61.65it/s]

{'loss': 0.4448, 'grad_norm': 0.4528982937335968, 'learning_rate': 8.5e-06, 'epoch': 37.97}


                                                    
 15%|█▌        | 3000/20000 [00:49<04:35, 61.65it/s]

{'eval_loss': 0.4962066113948822, 'eval_accuracy': 0.831, 'eval_runtime': 0.1764, 'eval_samples_per_second': 5668.928, 'eval_steps_per_second': 708.616, 'epoch': 37.97}


 18%|█▊        | 3500/20000 [00:58<04:55, 55.75it/s]

{'loss': 0.4148, 'grad_norm': 0.44970962405204773, 'learning_rate': 8.25e-06, 'epoch': 44.3}


                                                    
 18%|█▊        | 3506/20000 [00:58<07:18, 37.58it/s]

{'eval_loss': 0.4758424162864685, 'eval_accuracy': 0.833, 'eval_runtime': 0.1953, 'eval_samples_per_second': 5120.144, 'eval_steps_per_second': 640.018, 'epoch': 44.3}


 20%|██        | 4000/20000 [01:07<04:40, 57.09it/s]

{'loss': 0.3893, 'grad_norm': 0.3910762071609497, 'learning_rate': 8.000000000000001e-06, 'epoch': 50.63}


                                                    
 20%|██        | 4009/20000 [01:07<07:06, 37.52it/s]

{'eval_loss': 0.4582245647907257, 'eval_accuracy': 0.835, 'eval_runtime': 0.1988, 'eval_samples_per_second': 5030.945, 'eval_steps_per_second': 628.868, 'epoch': 50.63}


 22%|██▎       | 4500/20000 [01:16<04:27, 57.91it/s]

{'loss': 0.3652, 'grad_norm': 0.4252554178237915, 'learning_rate': 7.75e-06, 'epoch': 56.96}


                                                    
 23%|██▎       | 4505/20000 [01:16<06:46, 38.07it/s]

{'eval_loss': 0.4429992735385895, 'eval_accuracy': 0.836, 'eval_runtime': 0.1738, 'eval_samples_per_second': 5754.04, 'eval_steps_per_second': 719.255, 'epoch': 56.96}


 25%|██▌       | 5000/20000 [01:25<04:18, 58.05it/s]

{'loss': 0.3453, 'grad_norm': 0.41647258400917053, 'learning_rate': 7.500000000000001e-06, 'epoch': 63.29}


                                                    
 25%|██▌       | 5001/20000 [01:25<06:23, 39.08it/s]

{'eval_loss': 0.429566353559494, 'eval_accuracy': 0.838, 'eval_runtime': 0.1646, 'eval_samples_per_second': 6076.729, 'eval_steps_per_second': 759.591, 'epoch': 63.29}


 28%|██▊       | 5500/20000 [01:33<04:12, 57.39it/s]

{'loss': 0.3276, 'grad_norm': 0.4196538031101227, 'learning_rate': 7.25e-06, 'epoch': 69.62}


                                                    
 28%|██▊       | 5505/20000 [01:34<06:32, 36.90it/s]

{'eval_loss': 0.41765162348747253, 'eval_accuracy': 0.846, 'eval_runtime': 0.1827, 'eval_samples_per_second': 5473.435, 'eval_steps_per_second': 684.179, 'epoch': 69.62}


 30%|███       | 6000/20000 [01:42<03:53, 59.94it/s]

{'loss': 0.3104, 'grad_norm': 0.3958381116390228, 'learning_rate': 7e-06, 'epoch': 75.95}


                                                    
 30%|███       | 6010/20000 [01:43<05:19, 43.72it/s]

{'eval_loss': 0.4073443114757538, 'eval_accuracy': 0.847, 'eval_runtime': 0.1851, 'eval_samples_per_second': 5401.321, 'eval_steps_per_second': 675.165, 'epoch': 75.95}


 32%|███▎      | 6500/20000 [01:51<03:45, 59.98it/s]

{'loss': 0.2965, 'grad_norm': 0.38736504316329956, 'learning_rate': 6.750000000000001e-06, 'epoch': 82.28}


                                                    
 33%|███▎      | 6509/20000 [01:51<05:15, 42.81it/s]

{'eval_loss': 0.39810484647750854, 'eval_accuracy': 0.849, 'eval_runtime': 0.1785, 'eval_samples_per_second': 5602.969, 'eval_steps_per_second': 700.371, 'epoch': 82.28}


 35%|███▌      | 7000/20000 [02:00<03:34, 60.70it/s]

{'loss': 0.2832, 'grad_norm': 0.3488764762878418, 'learning_rate': 6.5000000000000004e-06, 'epoch': 88.61}


                                                    
 35%|███▌      | 7009/20000 [02:00<04:51, 44.56it/s]

{'eval_loss': 0.38996776938438416, 'eval_accuracy': 0.852, 'eval_runtime': 0.1825, 'eval_samples_per_second': 5480.408, 'eval_steps_per_second': 685.051, 'epoch': 88.61}


 38%|███▊      | 7500/20000 [02:08<03:32, 58.94it/s]

{'loss': 0.2708, 'grad_norm': 0.36442849040031433, 'learning_rate': 6.25e-06, 'epoch': 94.94}


                                                    
 38%|███▊      | 7500/20000 [02:09<03:32, 58.94it/s]

{'eval_loss': 0.3828020691871643, 'eval_accuracy': 0.853, 'eval_runtime': 0.1766, 'eval_samples_per_second': 5661.368, 'eval_steps_per_second': 707.671, 'epoch': 94.94}


 40%|████      | 8000/20000 [02:17<03:23, 59.09it/s]

{'loss': 0.2604, 'grad_norm': 0.3488769233226776, 'learning_rate': 6e-06, 'epoch': 101.27}


                                                    
 40%|████      | 8007/20000 [02:18<04:51, 41.13it/s]

{'eval_loss': 0.37641894817352295, 'eval_accuracy': 0.855, 'eval_runtime': 0.1829, 'eval_samples_per_second': 5468.746, 'eval_steps_per_second': 683.593, 'epoch': 101.27}


 42%|████▎     | 8500/20000 [02:26<03:17, 58.15it/s]

{'loss': 0.2499, 'grad_norm': 0.34266310930252075, 'learning_rate': 5.75e-06, 'epoch': 107.59}


                                                    
 43%|████▎     | 8510/20000 [02:26<04:31, 42.34it/s]

{'eval_loss': 0.37092891335487366, 'eval_accuracy': 0.854, 'eval_runtime': 0.1887, 'eval_samples_per_second': 5299.746, 'eval_steps_per_second': 662.468, 'epoch': 107.59}


 45%|████▌     | 9000/20000 [02:35<03:04, 59.58it/s]

{'loss': 0.2407, 'grad_norm': 0.3476540148258209, 'learning_rate': 5.500000000000001e-06, 'epoch': 113.92}


                                                    
 45%|████▌     | 9005/20000 [02:35<04:46, 38.38it/s]

{'eval_loss': 0.36566248536109924, 'eval_accuracy': 0.855, 'eval_runtime': 0.1848, 'eval_samples_per_second': 5410.106, 'eval_steps_per_second': 676.263, 'epoch': 113.92}


 48%|████▊     | 9500/20000 [02:43<02:55, 59.71it/s]

{'loss': 0.2322, 'grad_norm': 0.31838712096214294, 'learning_rate': 5.2500000000000006e-06, 'epoch': 120.25}


                                                    
 48%|████▊     | 9509/20000 [02:44<04:08, 42.26it/s]

{'eval_loss': 0.3610936999320984, 'eval_accuracy': 0.857, 'eval_runtime': 0.1764, 'eval_samples_per_second': 5668.967, 'eval_steps_per_second': 708.621, 'epoch': 120.25}


 50%|█████     | 10000/20000 [02:52<02:50, 58.64it/s]

{'loss': 0.2245, 'grad_norm': 0.34020116925239563, 'learning_rate': 5e-06, 'epoch': 126.58}


                                                     
 50%|█████     | 10007/20000 [02:52<04:03, 41.00it/s]

{'eval_loss': 0.3570573627948761, 'eval_accuracy': 0.857, 'eval_runtime': 0.2072, 'eval_samples_per_second': 4826.265, 'eval_steps_per_second': 603.283, 'epoch': 126.58}


 52%|█████▎    | 10500/20000 [03:01<02:46, 57.15it/s]

{'loss': 0.2175, 'grad_norm': 0.34983599185943604, 'learning_rate': 4.75e-06, 'epoch': 132.91}


                                                     
 53%|█████▎    | 10508/20000 [03:01<03:42, 42.66it/s]

{'eval_loss': 0.3533427119255066, 'eval_accuracy': 0.858, 'eval_runtime': 0.1797, 'eval_samples_per_second': 5565.114, 'eval_steps_per_second': 695.639, 'epoch': 132.91}


 55%|█████▌    | 11000/20000 [03:10<02:28, 60.65it/s]

{'loss': 0.2116, 'grad_norm': 0.348513126373291, 'learning_rate': 4.5e-06, 'epoch': 139.24}


                                                     
 55%|█████▌    | 11007/20000 [03:10<03:33, 42.18it/s]

{'eval_loss': 0.3502821624279022, 'eval_accuracy': 0.859, 'eval_runtime': 0.1832, 'eval_samples_per_second': 5458.384, 'eval_steps_per_second': 682.298, 'epoch': 139.24}


 57%|█████▊    | 11500/20000 [03:18<02:21, 60.26it/s]

{'loss': 0.2054, 'grad_norm': 0.3385497033596039, 'learning_rate': 4.25e-06, 'epoch': 145.57}


                                                     
 58%|█████▊    | 11506/20000 [03:19<03:38, 38.87it/s]

{'eval_loss': 0.34739091992378235, 'eval_accuracy': 0.858, 'eval_runtime': 0.1871, 'eval_samples_per_second': 5343.825, 'eval_steps_per_second': 667.978, 'epoch': 145.57}


 60%|██████    | 12000/20000 [03:27<02:18, 57.70it/s]

{'loss': 0.1997, 'grad_norm': 0.4802778363227844, 'learning_rate': 4.000000000000001e-06, 'epoch': 151.9}


                                                     
 60%|██████    | 12006/20000 [03:27<03:35, 37.04it/s]

{'eval_loss': 0.34482911229133606, 'eval_accuracy': 0.858, 'eval_runtime': 0.18, 'eval_samples_per_second': 5556.79, 'eval_steps_per_second': 694.599, 'epoch': 151.9}


 62%|██████▎   | 12500/20000 [03:36<02:01, 61.71it/s]

{'loss': 0.1946, 'grad_norm': 0.3216278553009033, 'learning_rate': 3.7500000000000005e-06, 'epoch': 158.23}


                                                     
 63%|██████▎   | 12510/20000 [03:36<03:01, 41.20it/s]

{'eval_loss': 0.3425210118293762, 'eval_accuracy': 0.858, 'eval_runtime': 0.1971, 'eval_samples_per_second': 5074.029, 'eval_steps_per_second': 634.254, 'epoch': 158.23}


 65%|██████▌   | 13000/20000 [03:44<01:56, 59.99it/s]

{'loss': 0.1909, 'grad_norm': 0.2988818287849426, 'learning_rate': 3.5e-06, 'epoch': 164.56}


                                                     
 65%|██████▌   | 13006/20000 [03:45<03:01, 38.52it/s]

{'eval_loss': 0.3404403626918793, 'eval_accuracy': 0.859, 'eval_runtime': 0.1848, 'eval_samples_per_second': 5411.879, 'eval_steps_per_second': 676.485, 'epoch': 164.56}


 68%|██████▊   | 13500/20000 [03:53<01:47, 60.32it/s]

{'loss': 0.1861, 'grad_norm': 0.30129945278167725, 'learning_rate': 3.2500000000000002e-06, 'epoch': 170.89}


                                                     
 68%|██████▊   | 13506/20000 [03:53<02:44, 39.43it/s]

{'eval_loss': 0.33855682611465454, 'eval_accuracy': 0.858, 'eval_runtime': 0.1759, 'eval_samples_per_second': 5686.567, 'eval_steps_per_second': 710.821, 'epoch': 170.89}


 70%|███████   | 14000/20000 [04:02<01:39, 60.09it/s]

{'loss': 0.1826, 'grad_norm': 0.2889377176761627, 'learning_rate': 3e-06, 'epoch': 177.22}


                                                     
 70%|███████   | 14005/20000 [04:02<02:39, 37.60it/s]

{'eval_loss': 0.33708348870277405, 'eval_accuracy': 0.859, 'eval_runtime': 0.1928, 'eval_samples_per_second': 5187.715, 'eval_steps_per_second': 648.464, 'epoch': 177.22}


 72%|███████▎  | 14500/20000 [04:10<01:33, 58.68it/s]

{'loss': 0.1795, 'grad_norm': 0.4545452892780304, 'learning_rate': 2.7500000000000004e-06, 'epoch': 183.54}


                                                     
 72%|███████▎  | 14500/20000 [04:11<01:33, 58.68it/s]

{'eval_loss': 0.33570167422294617, 'eval_accuracy': 0.859, 'eval_runtime': 0.1754, 'eval_samples_per_second': 5701.796, 'eval_steps_per_second': 712.725, 'epoch': 183.54}


 75%|███████▌  | 15000/20000 [04:19<01:25, 58.61it/s]

{'loss': 0.1763, 'grad_norm': 0.2664376497268677, 'learning_rate': 2.5e-06, 'epoch': 189.87}


                                                     
 75%|███████▌  | 15012/20000 [04:20<01:51, 44.75it/s]

{'eval_loss': 0.3345523476600647, 'eval_accuracy': 0.859, 'eval_runtime': 0.1774, 'eval_samples_per_second': 5637.937, 'eval_steps_per_second': 704.742, 'epoch': 189.87}


 78%|███████▊  | 15500/20000 [04:28<01:13, 60.96it/s]

{'loss': 0.1746, 'grad_norm': 0.35709065198898315, 'learning_rate': 2.25e-06, 'epoch': 196.2}


                                                     
 78%|███████▊  | 15509/20000 [04:28<01:44, 42.89it/s]

{'eval_loss': 0.3334924280643463, 'eval_accuracy': 0.859, 'eval_runtime': 0.1833, 'eval_samples_per_second': 5455.473, 'eval_steps_per_second': 681.934, 'epoch': 196.2}


 80%|████████  | 16000/20000 [04:37<01:07, 59.56it/s]

{'loss': 0.1711, 'grad_norm': 0.35248440504074097, 'learning_rate': 2.0000000000000003e-06, 'epoch': 202.53}


                                                     
 80%|████████  | 16009/20000 [04:37<01:33, 42.61it/s]

{'eval_loss': 0.3325043022632599, 'eval_accuracy': 0.858, 'eval_runtime': 0.1784, 'eval_samples_per_second': 5603.852, 'eval_steps_per_second': 700.482, 'epoch': 202.53}


 82%|████████▎ | 16500/20000 [04:45<00:58, 59.72it/s]

{'loss': 0.1689, 'grad_norm': 0.3302529752254486, 'learning_rate': 1.75e-06, 'epoch': 208.86}


                                                     
 83%|████████▎ | 16512/20000 [04:46<01:19, 43.75it/s]

{'eval_loss': 0.33159294724464417, 'eval_accuracy': 0.858, 'eval_runtime': 0.1875, 'eval_samples_per_second': 5332.757, 'eval_steps_per_second': 666.595, 'epoch': 208.86}


 85%|████████▌ | 17000/20000 [04:54<00:49, 61.05it/s]

{'loss': 0.1669, 'grad_norm': 0.4367532432079315, 'learning_rate': 1.5e-06, 'epoch': 215.19}


                                                     
 85%|████████▌ | 17005/20000 [04:54<01:16, 39.34it/s]

{'eval_loss': 0.33103135228157043, 'eval_accuracy': 0.859, 'eval_runtime': 0.1749, 'eval_samples_per_second': 5718.284, 'eval_steps_per_second': 714.785, 'epoch': 215.19}


 88%|████████▊ | 17500/20000 [05:03<00:42, 58.21it/s]

{'loss': 0.1665, 'grad_norm': 0.282416969537735, 'learning_rate': 1.25e-06, 'epoch': 221.52}


                                                     
 88%|████████▊ | 17510/20000 [05:03<00:58, 42.50it/s]

{'eval_loss': 0.3305148780345917, 'eval_accuracy': 0.859, 'eval_runtime': 0.18, 'eval_samples_per_second': 5554.803, 'eval_steps_per_second': 694.35, 'epoch': 221.52}


 90%|█████████ | 18000/20000 [05:11<00:32, 60.61it/s]

{'loss': 0.1644, 'grad_norm': 0.32570672035217285, 'learning_rate': 1.0000000000000002e-06, 'epoch': 227.85}


                                                     
 90%|█████████ | 18006/20000 [05:12<00:52, 37.69it/s]

{'eval_loss': 0.3300652503967285, 'eval_accuracy': 0.859, 'eval_runtime': 0.1973, 'eval_samples_per_second': 5069.129, 'eval_steps_per_second': 633.641, 'epoch': 227.85}


 92%|█████████▎| 18500/20000 [05:20<00:24, 60.18it/s]

{'loss': 0.1635, 'grad_norm': 0.3053121268749237, 'learning_rate': 7.5e-07, 'epoch': 234.18}


                                                     
 93%|█████████▎| 18508/20000 [05:20<00:35, 42.28it/s]

{'eval_loss': 0.3297293186187744, 'eval_accuracy': 0.859, 'eval_runtime': 0.1876, 'eval_samples_per_second': 5330.148, 'eval_steps_per_second': 666.268, 'epoch': 234.18}


 95%|█████████▌| 19000/20000 [05:29<00:16, 59.98it/s]

{'loss': 0.1625, 'grad_norm': 0.3006148636341095, 'learning_rate': 5.000000000000001e-07, 'epoch': 240.51}


                                                     
 95%|█████████▌| 19010/20000 [05:29<00:23, 41.66it/s]

{'eval_loss': 0.3294815719127655, 'eval_accuracy': 0.859, 'eval_runtime': 0.1851, 'eval_samples_per_second': 5401.391, 'eval_steps_per_second': 675.174, 'epoch': 240.51}


 98%|█████████▊| 19500/20000 [05:38<00:08, 58.13it/s]

{'loss': 0.1627, 'grad_norm': 0.2814271152019501, 'learning_rate': 2.5000000000000004e-07, 'epoch': 246.84}


                                                     
 98%|█████████▊| 19509/20000 [05:38<00:11, 42.96it/s]

{'eval_loss': 0.3293549120426178, 'eval_accuracy': 0.859, 'eval_runtime': 0.1756, 'eval_samples_per_second': 5695.958, 'eval_steps_per_second': 711.995, 'epoch': 246.84}


100%|██████████| 20000/20000 [05:47<00:00, 58.48it/s]

{'loss': 0.1621, 'grad_norm': 0.286687970161438, 'learning_rate': 0.0, 'epoch': 253.16}


                                                     
100%|██████████| 20000/20000 [05:47<00:00, 57.50it/s]


{'eval_loss': 0.3293077051639557, 'eval_accuracy': 0.859, 'eval_runtime': 0.1771, 'eval_samples_per_second': 5647.845, 'eval_steps_per_second': 705.981, 'epoch': 253.16}
{'train_runtime': 347.8199, 'train_samples_per_second': 7360.13, 'train_steps_per_second': 57.501, 'train_loss': 0.27762306060791014, 'epoch': 253.16}


100%|██████████| 3125/3125 [00:04<00:00, 700.42it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=0.0001 and batch_size=32


  2%|▎         | 500/20000 [00:03<02:03, 157.85it/s]

{'loss': 0.5408, 'grad_norm': 0.7861588597297668, 'learning_rate': 9.75e-05, 'epoch': 1.6}


                                                    
  3%|▎         | 519/20000 [00:03<02:53, 112.60it/s]

{'eval_loss': 0.45812302827835083, 'eval_accuracy': 0.833, 'eval_runtime': 0.1905, 'eval_samples_per_second': 5248.22, 'eval_steps_per_second': 656.028, 'epoch': 1.6}


  5%|▌         | 1000/20000 [00:06<01:55, 164.65it/s]

{'loss': 0.3236, 'grad_norm': 0.7274240255355835, 'learning_rate': 9.5e-05, 'epoch': 3.19}


                                                     
  5%|▌         | 1022/20000 [00:06<02:42, 117.14it/s]

{'eval_loss': 0.35962817072868347, 'eval_accuracy': 0.857, 'eval_runtime': 0.181, 'eval_samples_per_second': 5524.338, 'eval_steps_per_second': 690.542, 'epoch': 3.19}


  8%|▊         | 1500/20000 [00:10<02:15, 137.03it/s]

{'loss': 0.2178, 'grad_norm': 1.1432832479476929, 'learning_rate': 9.250000000000001e-05, 'epoch': 4.79}


                                                     
  8%|▊         | 1516/20000 [00:10<03:00, 102.21it/s]

{'eval_loss': 0.32266512513160706, 'eval_accuracy': 0.869, 'eval_runtime': 0.1905, 'eval_samples_per_second': 5249.823, 'eval_steps_per_second': 656.228, 'epoch': 4.79}


 10%|█         | 2000/20000 [00:13<01:41, 177.64it/s]

{'loss': 0.1575, 'grad_norm': 0.7394194006919861, 'learning_rate': 9e-05, 'epoch': 6.39}


                                                     
 10%|█         | 2000/20000 [00:13<01:41, 177.64it/s]

{'eval_loss': 0.31127482652664185, 'eval_accuracy': 0.863, 'eval_runtime': 0.1725, 'eval_samples_per_second': 5798.075, 'eval_steps_per_second': 724.759, 'epoch': 6.39}


 12%|█▎        | 2500/20000 [00:16<01:46, 163.88it/s]

{'loss': 0.1199, 'grad_norm': 0.617310643196106, 'learning_rate': 8.75e-05, 'epoch': 7.99}


                                                     
 13%|█▎        | 2530/20000 [00:16<02:31, 115.14it/s]

{'eval_loss': 0.3094245195388794, 'eval_accuracy': 0.862, 'eval_runtime': 0.1937, 'eval_samples_per_second': 5162.093, 'eval_steps_per_second': 645.262, 'epoch': 7.99}


 15%|█▌        | 3000/20000 [00:19<01:38, 173.32it/s]

{'loss': 0.0893, 'grad_norm': 0.57343590259552, 'learning_rate': 8.5e-05, 'epoch': 9.58}


                                                     
 15%|█▌        | 3000/20000 [00:19<01:38, 173.32it/s]

{'eval_loss': 0.3120001256465912, 'eval_accuracy': 0.861, 'eval_runtime': 0.1689, 'eval_samples_per_second': 5921.381, 'eval_steps_per_second': 740.173, 'epoch': 9.58}


 18%|█▊        | 3500/20000 [00:22<01:41, 162.52it/s]

{'loss': 0.0705, 'grad_norm': 0.46658071875572205, 'learning_rate': 8.25e-05, 'epoch': 11.18}


                                                     
 18%|█▊        | 3521/20000 [00:23<02:26, 112.51it/s]

{'eval_loss': 0.31940189003944397, 'eval_accuracy': 0.86, 'eval_runtime': 0.1779, 'eval_samples_per_second': 5621.11, 'eval_steps_per_second': 702.639, 'epoch': 11.18}


 20%|██        | 4000/20000 [00:25<01:22, 193.20it/s]

{'loss': 0.054, 'grad_norm': 0.44713687896728516, 'learning_rate': 8e-05, 'epoch': 12.78}


                                                     
 20%|██        | 4000/20000 [00:25<01:22, 193.20it/s]

{'eval_loss': 0.32911986112594604, 'eval_accuracy': 0.861, 'eval_runtime': 0.1705, 'eval_samples_per_second': 5864.486, 'eval_steps_per_second': 733.061, 'epoch': 12.78}


 22%|██▎       | 4500/20000 [00:28<01:30, 172.08it/s]

{'loss': 0.0421, 'grad_norm': 0.35886329412460327, 'learning_rate': 7.75e-05, 'epoch': 14.38}


                                                     
 23%|██▎       | 4514/20000 [00:28<02:23, 108.05it/s]

{'eval_loss': 0.33854779601097107, 'eval_accuracy': 0.854, 'eval_runtime': 0.1818, 'eval_samples_per_second': 5501.6, 'eval_steps_per_second': 687.7, 'epoch': 14.38}


 25%|██▌       | 5000/20000 [00:31<01:39, 150.39it/s]

{'loss': 0.0337, 'grad_norm': 0.42198750376701355, 'learning_rate': 7.500000000000001e-05, 'epoch': 15.97}


                                                     
 25%|██▌       | 5000/20000 [00:31<01:35, 156.68it/s]


{'eval_loss': 0.35233137011528015, 'eval_accuracy': 0.856, 'eval_runtime': 0.1932, 'eval_samples_per_second': 5174.857, 'eval_steps_per_second': 646.857, 'epoch': 15.97}
{'train_runtime': 31.9104, 'train_samples_per_second': 20056.183, 'train_steps_per_second': 626.756, 'train_loss': 0.1649141445159912, 'epoch': 15.97}


100%|██████████| 3125/3125 [00:04<00:00, 694.07it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=0.0001 and batch_size=64


  2%|▎         | 500/20000 [00:05<03:24, 95.21it/s] 

{'loss': 0.5212, 'grad_norm': 0.4938310980796814, 'learning_rate': 9.75e-05, 'epoch': 3.18}


                                                   
  3%|▎         | 515/20000 [00:05<04:41, 69.19it/s]

{'eval_loss': 0.4411846399307251, 'eval_accuracy': 0.841, 'eval_runtime': 0.1899, 'eval_samples_per_second': 5265.429, 'eval_steps_per_second': 658.179, 'epoch': 3.18}


  5%|▌         | 1000/20000 [00:10<03:05, 102.51it/s]

{'loss': 0.284, 'grad_norm': 0.4643971920013428, 'learning_rate': 9.5e-05, 'epoch': 6.37}


                                                     
  5%|▌         | 1010/20000 [00:10<05:11, 61.04it/s] 

{'eval_loss': 0.3471643924713135, 'eval_accuracy': 0.856, 'eval_runtime': 0.1977, 'eval_samples_per_second': 5058.547, 'eval_steps_per_second': 632.318, 'epoch': 6.37}


  8%|▊         | 1500/20000 [00:15<03:08, 98.10it/s] 

{'loss': 0.1818, 'grad_norm': 0.5724911093711853, 'learning_rate': 9.250000000000001e-05, 'epoch': 9.55}


                                                    
  8%|▊         | 1517/20000 [00:16<04:19, 71.35it/s]

{'eval_loss': 0.3187406361103058, 'eval_accuracy': 0.864, 'eval_runtime': 0.1826, 'eval_samples_per_second': 5476.408, 'eval_steps_per_second': 684.551, 'epoch': 9.55}


 10%|█         | 2000/20000 [00:20<02:46, 107.85it/s]

{'loss': 0.1256, 'grad_norm': 0.23596730828285217, 'learning_rate': 9e-05, 'epoch': 12.74}


                                                     
 10%|█         | 2000/20000 [00:21<02:46, 107.85it/s]

{'eval_loss': 0.3144744336605072, 'eval_accuracy': 0.851, 'eval_runtime': 0.1635, 'eval_samples_per_second': 6114.476, 'eval_steps_per_second': 764.309, 'epoch': 12.74}


 12%|█▎        | 2500/20000 [00:25<03:03, 95.20it/s] 

{'loss': 0.0904, 'grad_norm': 0.2999766767024994, 'learning_rate': 8.75e-05, 'epoch': 15.92}


                                                    
 13%|█▎        | 2517/20000 [00:26<04:18, 67.67it/s]

{'eval_loss': 0.3161222040653229, 'eval_accuracy': 0.855, 'eval_runtime': 0.1908, 'eval_samples_per_second': 5240.96, 'eval_steps_per_second': 655.12, 'epoch': 15.92}


 15%|█▌        | 3000/20000 [00:31<02:50, 99.78it/s] 

{'loss': 0.0656, 'grad_norm': 0.3559325635433197, 'learning_rate': 8.5e-05, 'epoch': 19.11}


                                                    
 15%|█▌        | 3014/20000 [00:31<03:49, 73.90it/s]

{'eval_loss': 0.3231002986431122, 'eval_accuracy': 0.855, 'eval_runtime': 0.176, 'eval_samples_per_second': 5680.514, 'eval_steps_per_second': 710.064, 'epoch': 19.11}


 18%|█▊        | 3500/20000 [00:36<02:28, 111.34it/s]

{'loss': 0.0484, 'grad_norm': 0.20271587371826172, 'learning_rate': 8.25e-05, 'epoch': 22.29}


                                                     
 18%|█▊        | 3500/20000 [00:36<02:28, 111.34it/s]

{'eval_loss': 0.3359217643737793, 'eval_accuracy': 0.858, 'eval_runtime': 0.1698, 'eval_samples_per_second': 5888.205, 'eval_steps_per_second': 736.026, 'epoch': 22.29}


 20%|██        | 4000/20000 [00:41<02:32, 104.68it/s]

{'loss': 0.0361, 'grad_norm': 0.09327216446399689, 'learning_rate': 8e-05, 'epoch': 25.48}


                                                     
 20%|██        | 4009/20000 [00:41<03:56, 67.56it/s] 

{'eval_loss': 0.34771981835365295, 'eval_accuracy': 0.859, 'eval_runtime': 0.1712, 'eval_samples_per_second': 5840.689, 'eval_steps_per_second': 730.086, 'epoch': 25.48}


 22%|██▎       | 4500/20000 [00:46<02:40, 96.38it/s] 

{'loss': 0.0273, 'grad_norm': 0.13128013908863068, 'learning_rate': 7.75e-05, 'epoch': 28.66}


                                                    
 22%|██▎       | 4500/20000 [00:46<02:40, 96.70it/s]


{'eval_loss': 0.3644729256629944, 'eval_accuracy': 0.853, 'eval_runtime': 0.1959, 'eval_samples_per_second': 5104.671, 'eval_steps_per_second': 638.084, 'epoch': 28.66}
{'train_runtime': 46.539, 'train_samples_per_second': 27503.806, 'train_steps_per_second': 429.747, 'train_loss': 0.15337608761257596, 'epoch': 28.66}


100%|██████████| 3125/3125 [00:04<00:00, 712.53it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=0.0001 and batch_size=128


  2%|▎         | 500/20000 [00:08<05:25, 59.97it/s]

{'loss': 0.4613, 'grad_norm': 0.40016743540763855, 'learning_rate': 9.75e-05, 'epoch': 6.33}


                                                   
  3%|▎         | 507/20000 [00:08<08:31, 38.12it/s]

{'eval_loss': 0.4021058976650238, 'eval_accuracy': 0.851, 'eval_runtime': 0.1989, 'eval_samples_per_second': 5028.418, 'eval_steps_per_second': 628.552, 'epoch': 6.33}


  5%|▌         | 1000/20000 [00:16<05:01, 62.94it/s]

{'loss': 0.2277, 'grad_norm': 0.3008047342300415, 'learning_rate': 9.5e-05, 'epoch': 12.66}


                                                    
  5%|▌         | 1000/20000 [00:17<05:01, 62.94it/s]

{'eval_loss': 0.3320286273956299, 'eval_accuracy': 0.861, 'eval_runtime': 0.1712, 'eval_samples_per_second': 5841.885, 'eval_steps_per_second': 730.236, 'epoch': 12.66}


  8%|▊         | 1500/20000 [00:25<04:53, 63.00it/s]

{'loss': 0.1375, 'grad_norm': 0.21533240377902985, 'learning_rate': 9.250000000000001e-05, 'epoch': 18.99}


                                                    
  8%|▊         | 1500/20000 [00:25<04:53, 63.00it/s]

{'eval_loss': 0.31638434529304504, 'eval_accuracy': 0.858, 'eval_runtime': 0.1773, 'eval_samples_per_second': 5640.978, 'eval_steps_per_second': 705.122, 'epoch': 18.99}


 10%|█         | 2000/20000 [00:33<04:34, 65.49it/s]

{'loss': 0.0893, 'grad_norm': 0.2286618947982788, 'learning_rate': 9e-05, 'epoch': 25.32}


                                                    
 10%|█         | 2005/20000 [00:33<07:07, 42.06it/s]

{'eval_loss': 0.31852489709854126, 'eval_accuracy': 0.86, 'eval_runtime': 0.1767, 'eval_samples_per_second': 5660.429, 'eval_steps_per_second': 707.554, 'epoch': 25.32}


 12%|█▎        | 2500/20000 [00:41<04:35, 63.45it/s]

{'loss': 0.0611, 'grad_norm': 0.20441140234470367, 'learning_rate': 8.75e-05, 'epoch': 31.65}


                                                    
 12%|█▎        | 2500/20000 [00:41<04:35, 63.45it/s]

{'eval_loss': 0.3283809423446655, 'eval_accuracy': 0.855, 'eval_runtime': 0.1575, 'eval_samples_per_second': 6349.849, 'eval_steps_per_second': 793.731, 'epoch': 31.65}


 15%|█▌        | 3000/20000 [00:49<04:31, 62.57it/s]

{'loss': 0.0421, 'grad_norm': 0.15556249022483826, 'learning_rate': 8.5e-05, 'epoch': 37.97}


                                                    
 15%|█▌        | 3000/20000 [00:49<04:31, 62.57it/s]

{'eval_loss': 0.3419542610645294, 'eval_accuracy': 0.847, 'eval_runtime': 0.1652, 'eval_samples_per_second': 6053.575, 'eval_steps_per_second': 756.697, 'epoch': 37.97}


 18%|█▊        | 3500/20000 [00:58<04:20, 63.43it/s]

{'loss': 0.0296, 'grad_norm': 0.1339619755744934, 'learning_rate': 8.25e-05, 'epoch': 44.3}


                                                    
 18%|█▊        | 3500/20000 [00:58<04:20, 63.43it/s]

{'eval_loss': 0.35831522941589355, 'eval_accuracy': 0.849, 'eval_runtime': 0.1696, 'eval_samples_per_second': 5895.348, 'eval_steps_per_second': 736.918, 'epoch': 44.3}


 20%|██        | 4000/20000 [01:06<04:14, 62.87it/s]

{'loss': 0.0213, 'grad_norm': 0.12691959738731384, 'learning_rate': 8e-05, 'epoch': 50.63}


                                                    
 20%|██        | 4000/20000 [01:06<04:14, 62.87it/s]

{'eval_loss': 0.37621381878852844, 'eval_accuracy': 0.85, 'eval_runtime': 0.1746, 'eval_samples_per_second': 5728.577, 'eval_steps_per_second': 716.072, 'epoch': 50.63}


 20%|██        | 4000/20000 [01:06<04:26, 60.03it/s]


{'train_runtime': 66.6312, 'train_samples_per_second': 38420.412, 'train_steps_per_second': 300.159, 'train_loss': 0.13373650646209717, 'epoch': 50.63}


100%|██████████| 3125/3125 [00:04<00:00, 779.27it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=0.001 and batch_size=32


  2%|▎         | 500/20000 [00:02<01:46, 182.51it/s]

{'loss': 0.2672, 'grad_norm': 0.2694037854671478, 'learning_rate': 0.000975, 'epoch': 1.6}


                                                    
  3%|▎         | 532/20000 [00:03<02:28, 130.92it/s]

{'eval_loss': 0.35813531279563904, 'eval_accuracy': 0.845, 'eval_runtime': 0.1796, 'eval_samples_per_second': 5568.934, 'eval_steps_per_second': 696.117, 'epoch': 1.6}


  5%|▌         | 1000/20000 [00:05<01:44, 181.54it/s]

{'loss': 0.0595, 'grad_norm': 0.07922583073377609, 'learning_rate': 0.00095, 'epoch': 3.19}


                                                     
  5%|▌         | 1000/20000 [00:05<01:44, 181.54it/s]

{'eval_loss': 0.446895033121109, 'eval_accuracy': 0.844, 'eval_runtime': 0.1672, 'eval_samples_per_second': 5981.907, 'eval_steps_per_second': 747.738, 'epoch': 3.19}


  8%|▊         | 1500/20000 [00:08<01:38, 188.58it/s]

{'loss': 0.0085, 'grad_norm': 0.10368716716766357, 'learning_rate': 0.000925, 'epoch': 4.79}


                                                     
  8%|▊         | 1500/20000 [00:08<01:38, 188.58it/s]

{'eval_loss': 0.5269079804420471, 'eval_accuracy': 0.839, 'eval_runtime': 0.1725, 'eval_samples_per_second': 5798.075, 'eval_steps_per_second': 724.759, 'epoch': 4.79}


 10%|█         | 2000/20000 [00:11<01:39, 181.01it/s]

{'loss': 0.0024, 'grad_norm': 0.019374113529920578, 'learning_rate': 0.0009000000000000001, 'epoch': 6.39}


                                                     
 10%|█         | 2000/20000 [00:11<01:39, 181.01it/s]

{'eval_loss': 0.5750563740730286, 'eval_accuracy': 0.84, 'eval_runtime': 0.1765, 'eval_samples_per_second': 5665.766, 'eval_steps_per_second': 708.221, 'epoch': 6.39}


 12%|█▎        | 2500/20000 [00:14<01:34, 185.59it/s]

{'loss': 0.0012, 'grad_norm': 0.00679285591468215, 'learning_rate': 0.000875, 'epoch': 7.99}


                                                     
 12%|█▎        | 2500/20000 [00:14<01:34, 185.59it/s]


{'eval_loss': 0.6164447069168091, 'eval_accuracy': 0.843, 'eval_runtime': 0.1772, 'eval_samples_per_second': 5644.622, 'eval_steps_per_second': 705.578, 'epoch': 7.99}


 15%|█▌        | 3000/20000 [00:17<01:39, 170.53it/s]

{'loss': 0.0007, 'grad_norm': 0.004224023316055536, 'learning_rate': 0.00085, 'epoch': 9.58}


                                                     
 15%|█▌        | 3000/20000 [00:17<01:40, 168.85it/s]


{'eval_loss': 0.649470329284668, 'eval_accuracy': 0.842, 'eval_runtime': 0.1689, 'eval_samples_per_second': 5922.151, 'eval_steps_per_second': 740.269, 'epoch': 9.58}
{'train_runtime': 17.7681, 'train_samples_per_second': 36019.686, 'train_steps_per_second': 1125.615, 'train_loss': 0.056590605795383456, 'epoch': 9.58}


100%|██████████| 3125/3125 [00:03<00:00, 793.09it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=0.001 and batch_size=64


  2%|▎         | 500/20000 [00:04<02:45, 117.83it/s]

{'loss': 0.1867, 'grad_norm': 0.09874967485666275, 'learning_rate': 0.000975, 'epoch': 3.18}


                                                    
  2%|▎         | 500/20000 [00:04<02:45, 117.83it/s]

{'eval_loss': 0.37103894352912903, 'eval_accuracy': 0.852, 'eval_runtime': 0.1716, 'eval_samples_per_second': 5828.458, 'eval_steps_per_second': 728.557, 'epoch': 3.18}


  5%|▌         | 1000/20000 [00:09<02:42, 116.93it/s]

{'loss': 0.0128, 'grad_norm': 0.04546279087662697, 'learning_rate': 0.00095, 'epoch': 6.37}


                                                     
  5%|▌         | 1000/20000 [00:09<02:42, 116.93it/s]

{'eval_loss': 0.4758375585079193, 'eval_accuracy': 0.842, 'eval_runtime': 0.1663, 'eval_samples_per_second': 6012.744, 'eval_steps_per_second': 751.593, 'epoch': 6.37}


  8%|▊         | 1500/20000 [00:13<02:50, 108.47it/s]

{'loss': 0.0028, 'grad_norm': 0.01981068029999733, 'learning_rate': 0.000925, 'epoch': 9.55}


                                                     
  8%|▊         | 1518/20000 [00:14<04:05, 75.14it/s] 

{'eval_loss': 0.5451767444610596, 'eval_accuracy': 0.843, 'eval_runtime': 0.2042, 'eval_samples_per_second': 4896.187, 'eval_steps_per_second': 612.023, 'epoch': 9.55}


 10%|█         | 2000/20000 [00:18<02:29, 120.37it/s]

{'loss': 0.0012, 'grad_norm': 0.006133056711405516, 'learning_rate': 0.0009000000000000001, 'epoch': 12.74}


                                                     
 10%|█         | 2002/20000 [00:18<03:45, 79.85it/s] 

{'eval_loss': 0.5910115838050842, 'eval_accuracy': 0.842, 'eval_runtime': 0.1679, 'eval_samples_per_second': 5955.272, 'eval_steps_per_second': 744.409, 'epoch': 12.74}


 12%|█▎        | 2500/20000 [00:22<02:32, 115.05it/s]

{'loss': 0.0007, 'grad_norm': 0.006309365853667259, 'learning_rate': 0.000875, 'epoch': 15.92}


                                                     
 13%|█▎        | 2522/20000 [00:23<03:26, 84.73it/s] 

{'eval_loss': 0.631594181060791, 'eval_accuracy': 0.841, 'eval_runtime': 0.1697, 'eval_samples_per_second': 5893.144, 'eval_steps_per_second': 736.643, 'epoch': 15.92}


 15%|█▌        | 3000/20000 [00:27<02:28, 114.50it/s]

{'loss': 0.0004, 'grad_norm': 0.004911419935524464, 'learning_rate': 0.00085, 'epoch': 19.11}


                                                     
 15%|█▌        | 3000/20000 [00:27<02:28, 114.50it/s]

{'eval_loss': 0.6680358648300171, 'eval_accuracy': 0.838, 'eval_runtime': 0.1681, 'eval_samples_per_second': 5949.08, 'eval_steps_per_second': 743.635, 'epoch': 19.11}


 15%|█▌        | 3000/20000 [00:27<02:37, 108.07it/s]


{'train_runtime': 27.7591, 'train_samples_per_second': 46111.009, 'train_steps_per_second': 720.485, 'train_loss': 0.034116939589381216, 'epoch': 19.11}


100%|██████████| 3125/3125 [00:04<00:00, 777.59it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Training with lr=0.001 and batch_size=128


  2%|▎         | 500/20000 [00:08<05:09, 62.91it/s]

{'loss': 0.1398, 'grad_norm': 0.07805134356021881, 'learning_rate': 0.000975, 'epoch': 6.33}


                                                   
  3%|▎         | 511/20000 [00:08<07:14, 44.81it/s]

{'eval_loss': 0.39274823665618896, 'eval_accuracy': 0.841, 'eval_runtime': 0.1754, 'eval_samples_per_second': 5701.153, 'eval_steps_per_second': 712.644, 'epoch': 6.33}


  5%|▌         | 1000/20000 [00:16<05:14, 60.36it/s]

{'loss': 0.0067, 'grad_norm': 0.011366828344762325, 'learning_rate': 0.00095, 'epoch': 12.66}


                                                    
  5%|▌         | 1000/20000 [00:16<05:14, 60.36it/s]

{'eval_loss': 0.49432098865509033, 'eval_accuracy': 0.847, 'eval_runtime': 0.167, 'eval_samples_per_second': 5987.56, 'eval_steps_per_second': 748.445, 'epoch': 12.66}


  8%|▊         | 1500/20000 [00:24<05:00, 61.48it/s]

{'loss': 0.0019, 'grad_norm': 0.006835592444986105, 'learning_rate': 0.000925, 'epoch': 18.99}


                                                    
  8%|▊         | 1500/20000 [00:24<05:00, 61.48it/s]

{'eval_loss': 0.5626725554466248, 'eval_accuracy': 0.843, 'eval_runtime': 0.1623, 'eval_samples_per_second': 6161.878, 'eval_steps_per_second': 770.235, 'epoch': 18.99}


 10%|█         | 2000/20000 [00:32<04:51, 61.66it/s]

{'loss': 0.0009, 'grad_norm': 0.004158695228397846, 'learning_rate': 0.0009000000000000001, 'epoch': 25.32}


                                                    
 10%|█         | 2000/20000 [00:32<04:51, 61.66it/s]

{'eval_loss': 0.6110959053039551, 'eval_accuracy': 0.842, 'eval_runtime': 0.1656, 'eval_samples_per_second': 6036.881, 'eval_steps_per_second': 754.61, 'epoch': 25.32}


 12%|█▎        | 2500/20000 [00:41<04:51, 60.05it/s]

{'loss': 0.0005, 'grad_norm': 0.00270537449978292, 'learning_rate': 0.000875, 'epoch': 31.65}


                                                    
 12%|█▎        | 2500/20000 [00:41<04:51, 60.05it/s]

{'eval_loss': 0.6515752077102661, 'eval_accuracy': 0.841, 'eval_runtime': 0.1668, 'eval_samples_per_second': 5996.12, 'eval_steps_per_second': 749.515, 'epoch': 31.65}


 15%|█▌        | 3000/20000 [00:50<04:46, 59.39it/s]

{'loss': 0.0003, 'grad_norm': 0.0015523714246228337, 'learning_rate': 0.00085, 'epoch': 37.97}


                                                    
 15%|█▌        | 3000/20000 [00:50<04:46, 59.39it/s]

{'eval_loss': 0.6868419051170349, 'eval_accuracy': 0.837, 'eval_runtime': 0.1694, 'eval_samples_per_second': 5904.593, 'eval_steps_per_second': 738.074, 'epoch': 37.97}


 15%|█▌        | 3000/20000 [00:50<04:45, 59.59it/s]


{'train_runtime': 50.345, 'train_samples_per_second': 50849.101, 'train_steps_per_second': 397.259, 'train_loss': 0.025013207534948984, 'epoch': 37.97}


100%|██████████| 3125/3125 [00:04<00:00, 753.12it/s]


In [95]:
print(f"Best accuracy: {best_accuracy}")
print(f"Best learning rate: {best_learning_rate}")
print(f"Best batch size: {best_batch_size}")

Best accuracy: 0.87948
Best learning rate: 0.0001
Best batch size: 32


In [99]:
#Create a new model with the best hyperparameters

mlp_config=MLPConfig(vocab_size=len(vectorizer.vocabulary_),hidden_size=20,nlabels=2)
mlp=MLP(mlp_config)

trainer_args = transformers.TrainingArguments(
    "mlp_checkpoints", #save checkpoints here
    evaluation_strategy="steps",
    logging_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4, #learning rate of the gradient descent
    max_steps=20000,
    load_best_model_at_end=True,
    per_device_train_batch_size=32
)

accuracy = evaluate.load("accuracy")

def compute_accuracy(outputs_and_labels):
    outputs, labels = outputs_and_labels
    predictions = np.argmax(outputs, axis=-1) #pick the index of the "winning" label
    return accuracy.compute(predictions=predictions, references=labels)

early_stopping = transformers.EarlyStoppingCallback(5)

trainer = transformers.Trainer(
    model=mlp,
    args=trainer_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"].select(range(5000)), #make a smaller subset to evaluate on
    compute_metrics=compute_accuracy,
    data_collator=collator,
    callbacks=[early_stopping]
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  2%|▎         | 500/20000 [00:03<02:02, 159.70it/s]

{'loss': 0.5513, 'grad_norm': 0.8755040764808655, 'learning_rate': 9.75e-05, 'epoch': 0.64}


                                                    
  3%|▎         | 518/20000 [00:04<05:53, 55.16it/s] 

{'eval_loss': 0.4453277289867401, 'eval_accuracy': 0.8516, 'eval_runtime': 0.8962, 'eval_samples_per_second': 5578.944, 'eval_steps_per_second': 697.368, 'epoch': 0.64}


  5%|▌         | 1000/20000 [00:07<01:58, 159.85it/s]

{'loss': 0.3605, 'grad_norm': 0.8528380393981934, 'learning_rate': 9.5e-05, 'epoch': 1.28}


                                                     
  5%|▌         | 1025/20000 [00:08<05:40, 55.79it/s] 

{'eval_loss': 0.34466803073883057, 'eval_accuracy': 0.869, 'eval_runtime': 0.8864, 'eval_samples_per_second': 5640.808, 'eval_steps_per_second': 705.101, 'epoch': 1.28}


  8%|▊         | 1500/20000 [00:10<01:47, 171.65it/s]

{'loss': 0.2763, 'grad_norm': 1.2068967819213867, 'learning_rate': 9.250000000000001e-05, 'epoch': 1.92}


                                                     
  8%|▊         | 1523/20000 [00:12<05:08, 59.98it/s] 

{'eval_loss': 0.3044004738330841, 'eval_accuracy': 0.879, 'eval_runtime': 0.8797, 'eval_samples_per_second': 5683.825, 'eval_steps_per_second': 710.478, 'epoch': 1.92}


 10%|█         | 2000/20000 [00:14<01:53, 159.14it/s]

{'loss': 0.2234, 'grad_norm': 0.7322660684585571, 'learning_rate': 9e-05, 'epoch': 2.56}


                                                     
 10%|█         | 2015/20000 [00:15<06:43, 44.62it/s] 

{'eval_loss': 0.2908095419406891, 'eval_accuracy': 0.8824, 'eval_runtime': 0.8823, 'eval_samples_per_second': 5667.141, 'eval_steps_per_second': 708.393, 'epoch': 2.56}


 12%|█▎        | 2500/20000 [00:18<01:40, 174.73it/s]

{'loss': 0.1957, 'grad_norm': 0.621422290802002, 'learning_rate': 8.75e-05, 'epoch': 3.2}


                                                     
 13%|█▎        | 2520/20000 [00:19<04:29, 64.75it/s] 

{'eval_loss': 0.2801074683666229, 'eval_accuracy': 0.8874, 'eval_runtime': 0.8073, 'eval_samples_per_second': 6193.625, 'eval_steps_per_second': 774.203, 'epoch': 3.2}


 15%|█▌        | 3000/20000 [00:22<01:41, 167.09it/s]

{'loss': 0.1724, 'grad_norm': 0.8161685466766357, 'learning_rate': 8.5e-05, 'epoch': 3.84}


                                                     
 15%|█▌        | 3028/20000 [00:23<04:43, 59.91it/s] 

{'eval_loss': 0.2778237760066986, 'eval_accuracy': 0.8874, 'eval_runtime': 0.8322, 'eval_samples_per_second': 6008.126, 'eval_steps_per_second': 751.016, 'epoch': 3.84}


 18%|█▊        | 3500/20000 [00:26<01:38, 167.12it/s]

{'loss': 0.1508, 'grad_norm': 0.6816261410713196, 'learning_rate': 8.25e-05, 'epoch': 4.48}


                                                     
 18%|█▊        | 3525/20000 [00:27<04:45, 57.68it/s] 

{'eval_loss': 0.2819083034992218, 'eval_accuracy': 0.888, 'eval_runtime': 0.9141, 'eval_samples_per_second': 5470.147, 'eval_steps_per_second': 683.768, 'epoch': 4.48}


 20%|██        | 4000/20000 [00:30<01:34, 169.35it/s]

{'loss': 0.136, 'grad_norm': 0.43170928955078125, 'learning_rate': 8e-05, 'epoch': 5.12}


                                                     
 20%|██        | 4022/20000 [00:31<04:30, 59.11it/s] 

{'eval_loss': 0.28243979811668396, 'eval_accuracy': 0.8848, 'eval_runtime': 0.8799, 'eval_samples_per_second': 5682.77, 'eval_steps_per_second': 710.346, 'epoch': 5.12}


 22%|██▎       | 4500/20000 [00:34<01:37, 158.36it/s]

{'loss': 0.1208, 'grad_norm': 1.1410670280456543, 'learning_rate': 7.75e-05, 'epoch': 5.75}


                                                     
 23%|██▎       | 4526/20000 [00:35<04:31, 56.95it/s] 

{'eval_loss': 0.28922900557518005, 'eval_accuracy': 0.885, 'eval_runtime': 0.8957, 'eval_samples_per_second': 5582.09, 'eval_steps_per_second': 697.761, 'epoch': 5.75}


 25%|██▌       | 5000/20000 [00:38<01:22, 181.65it/s]

{'loss': 0.1109, 'grad_norm': 1.0026087760925293, 'learning_rate': 7.500000000000001e-05, 'epoch': 6.39}


                                                     
 25%|██▌       | 5024/20000 [00:39<04:06, 60.79it/s] 

{'eval_loss': 0.2949044406414032, 'eval_accuracy': 0.8828, 'eval_runtime': 0.8948, 'eval_samples_per_second': 5587.905, 'eval_steps_per_second': 698.488, 'epoch': 6.39}


 28%|██▊       | 5500/20000 [00:42<01:36, 149.65it/s]

{'loss': 0.0994, 'grad_norm': 0.9777780175209045, 'learning_rate': 7.25e-05, 'epoch': 7.03}


                                                     
 28%|██▊       | 5500/20000 [00:43<01:54, 127.02it/s]

{'eval_loss': 0.30186718702316284, 'eval_accuracy': 0.8826, 'eval_runtime': 0.8137, 'eval_samples_per_second': 6144.438, 'eval_steps_per_second': 768.055, 'epoch': 7.03}
{'train_runtime': 43.3017, 'train_samples_per_second': 14780.019, 'train_steps_per_second': 461.876, 'train_loss': 0.21797148271040484, 'epoch': 7.03}





TrainOutput(global_step=5500, training_loss=0.21797148271040484, metrics={'train_runtime': 43.3017, 'train_samples_per_second': 14780.019, 'train_steps_per_second': 461.876, 'train_loss': 0.21797148271040484, 'epoch': 7.03})

### 3.3. Evaluation on test set

In [100]:
# Your code to evaluate the final model on the test set here
test_results = trainer.predict(tokenized_data["test"].select(range(5000)))
test_accuracy = compute_accuracy((test_results.predictions, test_results.label_ids))
print(f"Test accuracy: {test_accuracy}")

100%|██████████| 625/625 [00:00<00:00, 766.48it/s]

Test accuracy: {'accuracy': 0.8874}





In [101]:
#Convert the 10 first predictions to labels 
print("Predictions:", "\n", test_results.predictions[:10])
print("Binary predicted labels:", np.argmax(test_results.predictions[:10], axis=-1))


#Print first 10 true labels
true_labels = test_results.label_ids[:10]
print("True labels:", true_labels)


Predictions: 
 [[-1.650777    2.4864943 ]
 [ 0.36898008 -0.0657791 ]
 [ 0.26644838  0.04376237]
 [-1.0865515   1.7044837 ]
 [ 0.6988572  -0.5024215 ]
 [-0.9906143   1.6176448 ]
 [ 0.6038404  -0.3825643 ]
 [ 1.8174835  -1.8273368 ]
 [ 0.7308029  -0.5329372 ]
 [-0.6699647   1.1603235 ]]
Binary predicted labels: [1 0 0 1 0 1 0 0 0 1]
True labels: [1 1 0 1 0 1 1 0 0 1]


---

## 4. Results and summary

### 4.1 Corpus insights

The corpus "imdb" consists of movie reviews from IMDB: 25,000 positive and 25,000 negative reviews. Each entry consists of the review and the corresponding sentiment label. Only highly polarizing reviews are considered in this dataset - no neutral reviews are included. People have written movie reviews and given the movie a score from 1 to 10. Reviews with a score of <= 4 are labeled as negative and reviews with a score of >= 7 are labeled positive. No more than 30 reviews per movie are included.

### 4.2 Results

We got an evaluation accuracy of 88.02%. We performed hyperparameter tuning on a subset of the data and found the best learning rate (1e-4) and batch size (32), which got us to an accuracy of 88.74%. Since evaluation loss starts to rise early (and the training loss keeps decreasing), the model is most likely overfitting. 

### 4.3 Relation to state of the art

The state-of-the-art results of binary classifiers of the "imdb" dataset reach an accuracy of 96.68% with a RoBERTa-large with LlamBERT model. BERT is the standard state-of-the-art model in all NLP. BERT is a language model develped by Google so our accuracy with such a simple model can be viewed as a success. Nehal et al. got an accuracy of 86.74% in their paper using an MLP model.

---

## 5. Bonus Task (optional)

### 5.1. Annotating out-of-domain documents

(Briefly describe the chosen out-of-domain documents)

(Briefly describe the process of annotation)

### 5.2 Conversion into dataset

In [92]:
# Your code to convert the annotations into a dataset here

### 5.3. Model evaluation on out-of-domain test set

In [93]:
# Your code to evaluate the model on the out-of-domain test set here

### 5.4 Bonus task results

(Present the results of the evaluation on the out-of-domain test set)

### 5.5. Annotated data

In [94]:
# Include your annotated out-of-domain data here