<a href="https://colab.research.google.com/github/TurkuNLP/intro-to-nlp/blob/master/course_project_2023_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to HLT Project (Template)

- Student(s) Name(s):
- Date:
- Chosen Corpus:
- Contributions (if group project):

### Corpus information

- Description of the chosen corpus:
- Paper(s) and other published materials related to the corpus:
- State-of-the-art performance (best published results) on this corpus:

---

## 1. Setup

In [22]:
# Your code to install and import libraries etc. here
import datasets 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
import torch 
import transformers
import evaluate

---

## 2. Data download and preprocessing

### 2.1. Download the corpus

In [11]:
# Your code to download the corpus here
dset = datasets.load_dataset('imdb')
display(dset)


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

### 2.2. Preprocessing

In [12]:
# Your code for any necessary preprocessing here
#Shuffle the dataset
dset = dset.shuffle(seed=42)
#Remove the unsupervised data part of the dataset as we dont need it for this task
del dset['unsupervised']

In [13]:
vectorizer = CountVectorizer(binary=True, max_features=25000)
text_list = [i['text'] for i in dset['train']]
vectorizer.fit(text_list)

def vectorize_example(examples, vectorizer): 
    vectorized = vectorizer.transform([examples["text"]])
    non_zero = vectorized.nonzero()[1]
    non_zero += 1
    return {'input_ids': non_zero}

# Conversion vocabulary 
idx2word = {v: k for k, v in vectorizer.vocabulary_.items()}

tokenized_data = dset.map(vectorize_example, num_proc=4, fn_kwargs={'vectorizer': vectorizer})

Map (num_proc=4): 100%|██████████| 25000/25000 [00:06<00:00, 4134.59 examples/s]
Map (num_proc=4): 100%|██████████| 25000/25000 [00:06<00:00, 4093.63 examples/s]


In [18]:
test_row = tokenized_data['train'][0]['input_ids']
convered_text = [idx2word[i] for i in test_row]
print(convered_text)

['above', 'action', 'actress', 'allah', 'americana', 'anders', 'area', 'argument', 'atari', 'beulah', 'bother', 'butch', 'bye', 'characterisation', 'clairvoyant', 'classical', 'compared', 'complicating', 'criminal', 'englishman', 'enjoyment', 'evaluated', 'factions', 'faraway', 'fur', 'goodbye', 'handbook', 'haven', 'howard', 'ifc', 'isaac', 'italian', 'judged', 'justice', 'languages', 'likeable', 'looming', 'maine', 'mayberry', 'moreau', 'noah', 'notable', 'onassis', 'oral', 'others', 'peoples', 'plotline', 'plotted', 'policeman', 'preferable', 'primed', 'quits', 'realm', 'relations', 'serio', 'similarity', 'simpler', 'spirited', 'spotlight', 'superficiality', 'suspected', 'thank', 'thatch', 'theater', 'thereafter', 'thick', 'things', 'thinker', 'tho', 'toad', 'took', 'violently', 'wayans', 'weak', 'weaken', 'weirdos', 'writings']


In [20]:
def collator(examples):
    batch = {"labels":torch.tensor(list(example["label"] for example in examples))}
    tensors = []
    max_len = max(len(example["input_ids"]) for example in examples)
    for example in examples:
        ids = torch.tensor(example["input_ids"])
        padded = torch.nn.functional.pad(ids, (0, max_len - ids.shape[0]))
        tensors.append(padded)
    batch["input_ids"] = torch.vstack(tensors)
    return batch

---

## 3. Machine learning model

### 3.1. Model training

In [23]:
# Your code to train the machine learning model on the training set and evaluate the performance on the validation set here

class MLPConfig(transformers.PretrainedConfig):
    pass
class MLP(transformers.PreTrainedModel):
    config_class=MLPConfig
    def __init__(self,config):
        super().__init__(config)
        self.vocab_size=config.vocab_size #embedding matrix row count
        self.embedding=torch.nn.Embedding(num_embeddings=self.vocab_size+1,embedding_dim=config.hidden_size,padding_idx=0)
        torch.nn.init.uniform_(self.embedding.weight.data,-0.001,0.001) 
        self.output=torch.nn.Linear(in_features=config.hidden_size,out_features=config.nlabels)

    def forward(self,input_ids,labels=None):
        embedded=self.embedding(input_ids)
        embedded_summed=torch.sum(embedded,dim=1)
        projected=torch.tanh(embedded_summed) 
        logits=self.output(projected)
        if labels is not None:
            loss=torch.nn.CrossEntropyLoss()
            return (loss(logits,labels),logits)
        else:
            return (logits,)

### 3.2 Hyperparameter optimization

In [24]:
# Your code for hyperparameter optimization here

mlp_config=MLPConfig(vocab_size=len(vectorizer.vocabulary_),hidden_size=20,nlabels=2)
mlp=MLP(mlp_config)

trainer_args = transformers.TrainingArguments(
    "mlp_checkpoints", #save checkpoints here
    evaluation_strategy="steps",
    logging_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-5, #learning rate of the gradient descent
    max_steps=20000,
    load_best_model_at_end=True,
    per_device_train_batch_size=128
)

accuracy = evaluate.load("accuracy")

def compute_accuracy(outputs_and_labels):
    outputs, labels = outputs_and_labels
    predictions = np.argmax(outputs, axis=-1) #pick the index of the "winning" label
    return accuracy.compute(predictions=predictions, references=labels)

mlp = MLP(mlp_config)


early_stopping = transformers.EarlyStoppingCallback(5)

trainer = transformers.Trainer(
    model=mlp,
    args=trainer_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"].select(range(5000)), #make a smaller subset to evaluate on
    compute_metrics=compute_accuracy,
    data_collator=collator,
    callbacks=[early_stopping]
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  2%|▎         | 500/20000 [00:08<05:28, 59.28it/s]

{'loss': 0.6663, 'grad_norm': 0.6451306343078613, 'learning_rate': 9.75e-06, 'epoch': 2.55}



  3%|▎         | 508/20000 [00:09<14:13, 22.83it/s]

{'eval_loss': 0.6431292295455933, 'eval_accuracy': 0.762, 'eval_runtime': 0.8556, 'eval_samples_per_second': 5843.516, 'eval_steps_per_second': 730.44, 'epoch': 2.55}


  5%|▌         | 1000/20000 [00:16<05:00, 63.21it/s]

{'loss': 0.6113, 'grad_norm': 0.49040278792381287, 'learning_rate': 9.5e-06, 'epoch': 5.1}


                                                    
  5%|▌         | 1006/20000 [00:17<16:14, 19.50it/s]

{'eval_loss': 0.5980353355407715, 'eval_accuracy': 0.8078, 'eval_runtime': 0.8072, 'eval_samples_per_second': 6194.386, 'eval_steps_per_second': 774.298, 'epoch': 5.1}


  8%|▊         | 1500/20000 [00:25<04:56, 62.30it/s]

{'loss': 0.5632, 'grad_norm': 0.5152321457862854, 'learning_rate': 9.250000000000001e-06, 'epoch': 7.65}


                                                    
  8%|▊         | 1507/20000 [00:26<15:50, 19.45it/s]

{'eval_loss': 0.5588062405586243, 'eval_accuracy': 0.8246, 'eval_runtime': 0.8021, 'eval_samples_per_second': 6233.971, 'eval_steps_per_second': 779.246, 'epoch': 7.65}


 10%|█         | 2000/20000 [00:34<04:42, 63.64it/s]

{'loss': 0.5217, 'grad_norm': 0.4457358121871948, 'learning_rate': 9e-06, 'epoch': 10.2}


                                                    
 10%|█         | 2010/20000 [00:35<12:17, 24.38it/s]

{'eval_loss': 0.5255573987960815, 'eval_accuracy': 0.835, 'eval_runtime': 0.8352, 'eval_samples_per_second': 5986.301, 'eval_steps_per_second': 748.288, 'epoch': 10.2}


 12%|█▎        | 2500/20000 [00:42<04:32, 64.23it/s]

{'loss': 0.4858, 'grad_norm': 0.5143020749092102, 'learning_rate': 8.750000000000001e-06, 'epoch': 12.76}


                                                    
 13%|█▎        | 2510/20000 [00:43<12:09, 23.98it/s]

{'eval_loss': 0.4970256984233856, 'eval_accuracy': 0.8402, 'eval_runtime': 0.8416, 'eval_samples_per_second': 5941.366, 'eval_steps_per_second': 742.671, 'epoch': 12.76}


 15%|█▌        | 3000/20000 [00:51<04:46, 59.40it/s]

{'loss': 0.4546, 'grad_norm': 0.46808943152427673, 'learning_rate': 8.5e-06, 'epoch': 15.31}


                                                    
 15%|█▌        | 3011/20000 [00:52<13:53, 20.39it/s]

{'eval_loss': 0.4724537134170532, 'eval_accuracy': 0.8472, 'eval_runtime': 0.947, 'eval_samples_per_second': 5279.778, 'eval_steps_per_second': 659.972, 'epoch': 15.31}


 18%|█▊        | 3500/20000 [01:00<04:16, 64.41it/s]

{'loss': 0.4286, 'grad_norm': 0.4440845847129822, 'learning_rate': 8.25e-06, 'epoch': 17.86}


                                                    
 18%|█▊        | 3506/20000 [01:01<14:09, 19.41it/s]

{'eval_loss': 0.45170971751213074, 'eval_accuracy': 0.8506, 'eval_runtime': 0.819, 'eval_samples_per_second': 6104.752, 'eval_steps_per_second': 763.094, 'epoch': 17.86}


 20%|██        | 4000/20000 [01:09<04:00, 66.43it/s]

{'loss': 0.4058, 'grad_norm': 0.42311277985572815, 'learning_rate': 8.000000000000001e-06, 'epoch': 20.41}


                                                    
 20%|██        | 4005/20000 [01:10<13:31, 19.71it/s]

{'eval_loss': 0.4336338937282562, 'eval_accuracy': 0.8532, 'eval_runtime': 0.8096, 'eval_samples_per_second': 6175.794, 'eval_steps_per_second': 771.974, 'epoch': 20.41}


 22%|██▎       | 4500/20000 [01:18<04:13, 61.20it/s]

{'loss': 0.3853, 'grad_norm': 0.4174938499927521, 'learning_rate': 7.75e-06, 'epoch': 22.96}


                                                    
 23%|██▎       | 4509/20000 [01:19<10:41, 24.15it/s]

{'eval_loss': 0.4179871082305908, 'eval_accuracy': 0.8568, 'eval_runtime': 0.8558, 'eval_samples_per_second': 5842.411, 'eval_steps_per_second': 730.301, 'epoch': 22.96}


 25%|██▌       | 5000/20000 [01:27<04:25, 56.47it/s]

{'loss': 0.3673, 'grad_norm': 0.40321213006973267, 'learning_rate': 7.500000000000001e-06, 'epoch': 25.51}


                                                    
 25%|██▌       | 5009/20000 [01:28<11:05, 22.52it/s]

{'eval_loss': 0.40429040789604187, 'eval_accuracy': 0.8588, 'eval_runtime': 0.8083, 'eval_samples_per_second': 6185.815, 'eval_steps_per_second': 773.227, 'epoch': 25.51}


 28%|██▊       | 5500/20000 [01:36<04:07, 58.50it/s]

{'loss': 0.3525, 'grad_norm': 0.4166144132614136, 'learning_rate': 7.25e-06, 'epoch': 28.06}


                                                    
 28%|██▊       | 5510/20000 [01:37<11:25, 21.15it/s]

{'eval_loss': 0.3923192024230957, 'eval_accuracy': 0.8608, 'eval_runtime': 0.9353, 'eval_samples_per_second': 5345.906, 'eval_steps_per_second': 668.238, 'epoch': 28.06}


 30%|███       | 6000/20000 [01:45<03:37, 64.23it/s]

{'loss': 0.3379, 'grad_norm': 0.5610472559928894, 'learning_rate': 7e-06, 'epoch': 30.61}


                                                    
 30%|███       | 6005/20000 [01:46<12:56, 18.02it/s]

{'eval_loss': 0.38174641132354736, 'eval_accuracy': 0.8636, 'eval_runtime': 0.8934, 'eval_samples_per_second': 5596.847, 'eval_steps_per_second': 699.606, 'epoch': 30.61}


 32%|███▎      | 6500/20000 [01:54<03:38, 61.86it/s]

{'loss': 0.3262, 'grad_norm': 0.4215192496776581, 'learning_rate': 6.750000000000001e-06, 'epoch': 33.16}


                                                    
 33%|███▎      | 6506/20000 [01:55<12:23, 18.15it/s]

{'eval_loss': 0.37255215644836426, 'eval_accuracy': 0.866, 'eval_runtime': 0.8666, 'eval_samples_per_second': 5769.589, 'eval_steps_per_second': 721.199, 'epoch': 33.16}


 35%|███▌      | 7000/20000 [02:03<03:34, 60.52it/s]

{'loss': 0.3145, 'grad_norm': 0.485082745552063, 'learning_rate': 6.5000000000000004e-06, 'epoch': 35.71}


                                                    
 35%|███▌      | 7011/20000 [02:04<09:30, 22.76it/s]

{'eval_loss': 0.3641456663608551, 'eval_accuracy': 0.8674, 'eval_runtime': 0.8777, 'eval_samples_per_second': 5696.583, 'eval_steps_per_second': 712.073, 'epoch': 35.71}


 38%|███▊      | 7500/20000 [02:12<03:15, 64.08it/s]

{'loss': 0.305, 'grad_norm': 0.4145914614200592, 'learning_rate': 6.25e-06, 'epoch': 38.27}


                                                    
 38%|███▊      | 7506/20000 [02:13<10:30, 19.81it/s]

{'eval_loss': 0.35685399174690247, 'eval_accuracy': 0.8694, 'eval_runtime': 0.793, 'eval_samples_per_second': 6305.44, 'eval_steps_per_second': 788.18, 'epoch': 38.27}


 40%|████      | 8000/20000 [02:20<03:01, 66.13it/s]

{'loss': 0.2954, 'grad_norm': 0.4086167514324188, 'learning_rate': 6e-06, 'epoch': 40.82}


                                                    
 40%|████      | 8009/20000 [02:21<07:51, 25.43it/s]

{'eval_loss': 0.3505323827266693, 'eval_accuracy': 0.87, 'eval_runtime': 0.8045, 'eval_samples_per_second': 6214.964, 'eval_steps_per_second': 776.871, 'epoch': 40.82}


 42%|████▎     | 8500/20000 [02:29<02:55, 65.63it/s]

{'loss': 0.288, 'grad_norm': 0.3813932538032532, 'learning_rate': 5.75e-06, 'epoch': 43.37}


                                                    
 43%|████▎     | 8506/20000 [02:30<09:30, 20.16it/s]

{'eval_loss': 0.34470951557159424, 'eval_accuracy': 0.87, 'eval_runtime': 0.8078, 'eval_samples_per_second': 6189.724, 'eval_steps_per_second': 773.715, 'epoch': 43.37}


 45%|████▌     | 9000/20000 [02:37<02:53, 63.40it/s]

{'loss': 0.2789, 'grad_norm': 0.3590758144855499, 'learning_rate': 5.500000000000001e-06, 'epoch': 45.92}


                                                    
 45%|████▌     | 9005/20000 [02:38<09:12, 19.88it/s]

{'eval_loss': 0.3396010100841522, 'eval_accuracy': 0.8712, 'eval_runtime': 0.7955, 'eval_samples_per_second': 6285.53, 'eval_steps_per_second': 785.691, 'epoch': 45.92}


 48%|████▊     | 9500/20000 [02:46<02:38, 66.39it/s]

{'loss': 0.2735, 'grad_norm': 0.4139081835746765, 'learning_rate': 5.2500000000000006e-06, 'epoch': 48.47}


                                                    
 48%|████▊     | 9510/20000 [02:47<06:40, 26.20it/s]

{'eval_loss': 0.334972083568573, 'eval_accuracy': 0.8724, 'eval_runtime': 0.7984, 'eval_samples_per_second': 6262.834, 'eval_steps_per_second': 782.854, 'epoch': 48.47}


 50%|█████     | 10000/20000 [02:55<02:36, 63.75it/s]

{'loss': 0.2665, 'grad_norm': 0.3797486126422882, 'learning_rate': 5e-06, 'epoch': 51.02}


                                                     
 50%|█████     | 10007/20000 [02:55<08:22, 19.89it/s]

{'eval_loss': 0.3309451639652252, 'eval_accuracy': 0.8726, 'eval_runtime': 0.7964, 'eval_samples_per_second': 6278.556, 'eval_steps_per_second': 784.82, 'epoch': 51.02}


 52%|█████▎    | 10500/20000 [03:03<02:31, 62.88it/s]

{'loss': 0.2611, 'grad_norm': 0.4067634344100952, 'learning_rate': 4.75e-06, 'epoch': 53.57}


                                                     
 53%|█████▎    | 10510/20000 [03:04<06:17, 25.12it/s]

{'eval_loss': 0.32723990082740784, 'eval_accuracy': 0.8732, 'eval_runtime': 0.7934, 'eval_samples_per_second': 6301.741, 'eval_steps_per_second': 787.718, 'epoch': 53.57}


 55%|█████▌    | 11000/20000 [03:12<02:20, 63.93it/s]

{'loss': 0.2562, 'grad_norm': 0.3543435037136078, 'learning_rate': 4.5e-06, 'epoch': 56.12}


                                                     
 55%|█████▌    | 11007/20000 [03:13<07:14, 20.68it/s]

{'eval_loss': 0.3239773213863373, 'eval_accuracy': 0.8738, 'eval_runtime': 0.7781, 'eval_samples_per_second': 6425.901, 'eval_steps_per_second': 803.238, 'epoch': 56.12}


 57%|█████▊    | 11500/20000 [03:20<02:10, 65.23it/s]

{'loss': 0.2509, 'grad_norm': 0.3678228557109833, 'learning_rate': 4.25e-06, 'epoch': 58.67}


                                                     
 58%|█████▊    | 11505/20000 [03:21<07:05, 19.96it/s]

{'eval_loss': 0.32110145688056946, 'eval_accuracy': 0.8754, 'eval_runtime': 0.7928, 'eval_samples_per_second': 6306.993, 'eval_steps_per_second': 788.374, 'epoch': 58.67}


 60%|██████    | 12000/20000 [03:29<02:01, 65.86it/s]

{'loss': 0.2473, 'grad_norm': 0.49710121750831604, 'learning_rate': 4.000000000000001e-06, 'epoch': 61.22}


                                                     
 60%|██████    | 12012/20000 [03:30<05:18, 25.09it/s]

{'eval_loss': 0.31844618916511536, 'eval_accuracy': 0.8754, 'eval_runtime': 0.8052, 'eval_samples_per_second': 6209.337, 'eval_steps_per_second': 776.167, 'epoch': 61.22}


 62%|██████▎   | 12500/20000 [03:37<01:52, 66.60it/s]

{'loss': 0.2427, 'grad_norm': 0.36851271986961365, 'learning_rate': 3.7500000000000005e-06, 'epoch': 63.78}


                                                     
 63%|██████▎   | 12511/20000 [03:38<04:40, 26.67it/s]

{'eval_loss': 0.3160717189311981, 'eval_accuracy': 0.8766, 'eval_runtime': 0.7848, 'eval_samples_per_second': 6371.005, 'eval_steps_per_second': 796.376, 'epoch': 63.78}


 65%|██████▌   | 13000/20000 [03:46<01:47, 65.14it/s]

{'loss': 0.2394, 'grad_norm': 0.4376172721385956, 'learning_rate': 3.5e-06, 'epoch': 66.33}


                                                     
 65%|██████▌   | 13005/20000 [03:47<05:48, 20.05it/s]

{'eval_loss': 0.31387367844581604, 'eval_accuracy': 0.8782, 'eval_runtime': 0.8015, 'eval_samples_per_second': 6238.006, 'eval_steps_per_second': 779.751, 'epoch': 66.33}


 68%|██████▊   | 13500/20000 [03:54<01:42, 63.54it/s]

{'loss': 0.2366, 'grad_norm': 0.32693105936050415, 'learning_rate': 3.2500000000000002e-06, 'epoch': 68.88}


                                                     
 68%|██████▊   | 13508/20000 [03:55<04:11, 25.77it/s]

{'eval_loss': 0.31212118268013, 'eval_accuracy': 0.8782, 'eval_runtime': 0.7696, 'eval_samples_per_second': 6496.975, 'eval_steps_per_second': 812.122, 'epoch': 68.88}


 70%|███████   | 14000/20000 [04:03<01:31, 65.89it/s]

{'loss': 0.2325, 'grad_norm': 0.5341466665267944, 'learning_rate': 3e-06, 'epoch': 71.43}


                                                     
 70%|███████   | 14006/20000 [04:04<04:59, 20.01it/s]

{'eval_loss': 0.3103892505168915, 'eval_accuracy': 0.88, 'eval_runtime': 0.7954, 'eval_samples_per_second': 6285.877, 'eval_steps_per_second': 785.735, 'epoch': 71.43}


 72%|███████▎  | 14500/20000 [04:11<01:20, 67.99it/s]

{'loss': 0.2306, 'grad_norm': 0.5222039222717285, 'learning_rate': 2.7500000000000004e-06, 'epoch': 73.98}


                                                     
 73%|███████▎  | 14510/20000 [04:12<03:36, 25.36it/s]

{'eval_loss': 0.308851957321167, 'eval_accuracy': 0.8802, 'eval_runtime': 0.8058, 'eval_samples_per_second': 6204.711, 'eval_steps_per_second': 775.589, 'epoch': 73.98}


 75%|███████▌  | 15000/20000 [04:20<01:15, 66.05it/s]

{'loss': 0.2274, 'grad_norm': 0.40736937522888184, 'learning_rate': 2.5e-06, 'epoch': 76.53}


                                                     
 75%|███████▌  | 15007/20000 [04:21<04:15, 19.51it/s]

{'eval_loss': 0.3075104057788849, 'eval_accuracy': 0.8802, 'eval_runtime': 0.8147, 'eval_samples_per_second': 6137.33, 'eval_steps_per_second': 767.166, 'epoch': 76.53}


 78%|███████▊  | 15500/20000 [04:28<01:05, 68.46it/s]

{'loss': 0.2264, 'grad_norm': 0.3898400366306305, 'learning_rate': 2.25e-06, 'epoch': 79.08}


                                                     
 78%|███████▊  | 15511/20000 [04:29<02:56, 25.40it/s]

{'eval_loss': 0.30642005801200867, 'eval_accuracy': 0.8806, 'eval_runtime': 0.8191, 'eval_samples_per_second': 6104.578, 'eval_steps_per_second': 763.072, 'epoch': 79.08}


 80%|████████  | 16000/20000 [04:37<01:01, 65.26it/s]

{'loss': 0.2237, 'grad_norm': 0.4143665134906769, 'learning_rate': 2.0000000000000003e-06, 'epoch': 81.63}


                                                     
 80%|████████  | 16012/20000 [04:38<02:32, 26.10it/s]

{'eval_loss': 0.3054101765155792, 'eval_accuracy': 0.8804, 'eval_runtime': 0.8041, 'eval_samples_per_second': 6218.001, 'eval_steps_per_second': 777.25, 'epoch': 81.63}


 82%|████████▎ | 16500/20000 [04:45<00:52, 67.22it/s]

{'loss': 0.2228, 'grad_norm': 0.3824039697647095, 'learning_rate': 1.75e-06, 'epoch': 84.18}


                                                     
 83%|████████▎ | 16512/20000 [04:46<02:13, 26.14it/s]

{'eval_loss': 0.30453288555145264, 'eval_accuracy': 0.8806, 'eval_runtime': 0.8023, 'eval_samples_per_second': 6231.707, 'eval_steps_per_second': 778.963, 'epoch': 84.18}


 85%|████████▌ | 17000/20000 [04:54<00:49, 60.68it/s]

{'loss': 0.2202, 'grad_norm': 0.44040292501449585, 'learning_rate': 1.5e-06, 'epoch': 86.73}


                                                     
 85%|████████▌ | 17009/20000 [04:55<02:04, 24.02it/s]

{'eval_loss': 0.3038578927516937, 'eval_accuracy': 0.8802, 'eval_runtime': 0.8451, 'eval_samples_per_second': 5916.238, 'eval_steps_per_second': 739.53, 'epoch': 86.73}


 88%|████████▊ | 17500/20000 [05:02<00:40, 62.09it/s]

{'loss': 0.2199, 'grad_norm': 0.3494792580604553, 'learning_rate': 1.25e-06, 'epoch': 89.29}


                                                     
 88%|████████▊ | 17503/20000 [05:03<02:08, 19.46it/s]

{'eval_loss': 0.3031989336013794, 'eval_accuracy': 0.8806, 'eval_runtime': 0.8027, 'eval_samples_per_second': 6229.217, 'eval_steps_per_second': 778.652, 'epoch': 89.29}


 90%|█████████ | 18000/20000 [05:11<00:32, 61.69it/s]

{'loss': 0.218, 'grad_norm': 0.3691171705722809, 'learning_rate': 1.0000000000000002e-06, 'epoch': 91.84}


                                                     
 90%|█████████ | 18009/20000 [05:12<01:23, 23.74it/s]

{'eval_loss': 0.3027598261833191, 'eval_accuracy': 0.8804, 'eval_runtime': 0.8316, 'eval_samples_per_second': 6012.501, 'eval_steps_per_second': 751.563, 'epoch': 91.84}


 92%|█████████▎| 18500/20000 [05:20<00:23, 65.03it/s]

{'loss': 0.218, 'grad_norm': 0.3601178824901581, 'learning_rate': 7.5e-07, 'epoch': 94.39}


                                                     
 93%|█████████▎| 18507/20000 [05:21<01:00, 24.85it/s]

{'eval_loss': 0.3024196922779083, 'eval_accuracy': 0.8802, 'eval_runtime': 0.7555, 'eval_samples_per_second': 6618.434, 'eval_steps_per_second': 827.304, 'epoch': 94.39}


 95%|█████████▌| 19000/20000 [05:28<00:14, 68.01it/s]

{'loss': 0.2168, 'grad_norm': 0.34206077456474304, 'learning_rate': 5.000000000000001e-07, 'epoch': 96.94}


                                                     
 95%|█████████▌| 19010/20000 [05:29<00:39, 24.83it/s]

{'eval_loss': 0.3021363914012909, 'eval_accuracy': 0.88, 'eval_runtime': 0.8228, 'eval_samples_per_second': 6076.532, 'eval_steps_per_second': 759.566, 'epoch': 96.94}


 98%|█████████▊| 19500/20000 [05:37<00:07, 64.03it/s]

{'loss': 0.2172, 'grad_norm': 0.48698970675468445, 'learning_rate': 2.5000000000000004e-07, 'epoch': 99.49}


                                                     
 98%|█████████▊| 19510/20000 [05:38<00:19, 25.17it/s]

{'eval_loss': 0.3019624650478363, 'eval_accuracy': 0.88, 'eval_runtime': 0.7901, 'eval_samples_per_second': 6328.522, 'eval_steps_per_second': 791.065, 'epoch': 99.49}


100%|██████████| 20000/20000 [05:45<00:00, 65.01it/s]

{'loss': 0.2163, 'grad_norm': 0.36724454164505005, 'learning_rate': 0.0, 'epoch': 102.04}


                                                     
100%|██████████| 20000/20000 [05:46<00:00, 57.69it/s]

{'eval_loss': 0.3019106090068817, 'eval_accuracy': 0.8802, 'eval_runtime': 0.8354, 'eval_samples_per_second': 5984.855, 'eval_steps_per_second': 748.107, 'epoch': 102.04}
{'train_runtime': 346.6627, 'train_samples_per_second': 7384.701, 'train_steps_per_second': 57.693, 'train_loss': 0.31380066299438475, 'epoch': 102.04}





TrainOutput(global_step=20000, training_loss=0.31380066299438475, metrics={'train_runtime': 346.6627, 'train_samples_per_second': 7384.701, 'train_steps_per_second': 57.693, 'train_loss': 0.31380066299438475, 'epoch': 102.04})

In [25]:
#Save model 
#trainer.save_model("mlp_model")

### 3.3. Evaluation on test set

In [29]:
# Your code to evaluate the final model on the test set here
test_results = trainer.predict(tokenized_data["test"].select(range(5000)))
test_accuracy = compute_accuracy((test_results.predictions, test_results.label_ids))
print(f"Test accuracy: {test_accuracy}")

100%|██████████| 625/625 [00:00<00:00, 768.58it/s]

Test accuracy: {'accuracy': 0.8802}





In [42]:
#Convert the 10 first predictions to labels 
print("Predictions:", "\n", test_results.predictions[:10])
print("Binary predicted labels:", np.argmax(test_results.predictions[:10], axis=-1))


#Print first 10 true labels
true_labels = test_results.label_ids[:10]
print("True labels:", true_labels)


Predictions: 
 [[-1.4783843   1.2700074 ]
 [-0.16219756  0.06732692]
 [-0.01269698 -0.06483601]
 [-1.0602162   0.89951706]
 [ 0.37185746 -0.41186056]
 [-1.1537267   0.98373985]
 [ 0.09102534 -0.15580702]
 [ 1.1402507  -1.1289095 ]
 [ 0.29495418 -0.34836176]
 [-0.6786371   0.54820937]]
Binary predicted labels: [1 1 0 1 0 1 0 0 0 1]
True labels: [1 1 0 1 0 1 1 0 0 1]


---

## 4. Results and summary

### 4.1 Corpus insights

(Briefly discuss what you learned about the corpus and its annotation)

### 4.2 Results

(Briefly summarize your results)

### 4.3 Relation to state of the art

(Compare your results to the state-of-the-art performance)

---

## 5. Bonus Task (optional)

### 5.1. Annotating out-of-domain documents

(Briefly describe the chosen out-of-domain documents)

(Briefly describe the process of annotation)

### 5.2 Conversion into dataset

In [None]:
# Your code to convert the annotations into a dataset here

### 5.3. Model evaluation on out-of-domain test set

In [None]:
# Your code to evaluate the model on the out-of-domain test set here

### 5.4 Bonus task results

(Present the results of the evaluation on the out-of-domain test set)

### 5.5. Annotated data

In [None]:
# Include your annotated out-of-domain data here