<h1>HW4 - Attention and Transformers- Practical Q2</h1>
<h3><font color=yellow>Total Points: 100</font></h3>



<font size=+4>
Introduction
</font>
<br/>
<font size=+2>

<br/>
Part-of-speech tagging (POS tagging) is a natural language processing task where each word in a given text is assigned a specific part-of-speech category, such as noun, verb, adjective, etc. In the context of Roberta, a state-of-the-art language model based on the transformer architecture, POS tagging involves leveraging the model's pre-trained knowledge to accurately predict and label the grammatical roles of words in a sentence. Roberta's extensive training data and attention mechanisms enable it to capture subtle contextual cues, allowing for more nuanced and accurate POS tagging. By understanding the syntactic structure of a sentence, Roberta enhances the efficiency and precision of POS tagging, contributing to more sophisticated language understanding and downstream applications in natural language processing.
</font>


<font size=+2>
In this section we want to fine-tune the Roberta that you trained in previous section for POS-Tagging Task on Parsig dataset
<font/>

In [None]:
!pip install transformers
!pip install datasets
!pip install accelerate -U

import re
import torch
import pandas as pd

import numpy as np
import transformers
import matplotlib.pyplot as plt


In [31]:
# All label that exist
# AAX mean Unknown used for CLS and other symbols
tags = {'AAX', 'B-ADJ', 'I-ADJ', 'B-ADV', 'I-ADV', 'B-N', 'I-N', 'B-V', 'I-V',\
        'B-PRONOUN', 'I-PRONOUN', 'B-NUM', 'I-NUM', 'B-DET', 'I-DET', 'B-PRE',\
        'I-PRE', 'B-POST', 'I-POST', 'B-CONJ', 'I-CONJ', 'B-JUNK', 'I-JUNK',\
        'B-MARKER', 'I-MARKER',}
# Convert lables to speceif numeric id
tag2id = {tag: id for id, tag in enumerate(sorted(list(tags)))}
id2tag = {id: tag for tag, id in tag2id.items()}

In [32]:
######################   TODO 1.1   ########################
# Load pretrained model that trained in previous sectiopn
# Load tokenizer from previous section
# Freeze base model for fine-tuning
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch import nn

tokenizer = AutoTokenizer.from_pretrained("Mahdi-Salahshour/mlm_tokenizer")
pretrained_model = AutoModelForMaskedLM.from_pretrained("Mahdi-Salahshour/mlm")


for param in pretrained_model.parameters():
    param.requires_grad = False
###################### (5 points) ##########################


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [33]:
######################   TODO 1.2  ########################
from sklearn.model_selection import train_test_split

train_df = pd.read_csv("train_df.csv")
test_df = pd.read_csv("test_df.csv")
max_length = max(max(len(seq) for seq in train_df) , max(len(seq) for seq in test_df))
def process_labels(df, tag2id):
    df['labels'] = df['labels'].str.split(',').apply(lambda x : [0] + [tag2id[i] for i in x] + [0])
    return df


train_df = process_labels(train_df, tag2id)
test_df = process_labels(test_df, tag2id)

train_df, valid_df = train_test_split(train_df, test_size=0.2, random_state=42)

###################### (5 points) ##########################


In [34]:
######################   TODO 1.3   ########################
# Complete custom data set
# Use Torcargmaxh dataloader for datasets
###################### (15 points) ##########################

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder


class CustomDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.texts = df['text'].tolist()
        self.labels = df['labels'].tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __getitem__(self, index):
        text = self.texts[index]
        labels = self.labels[index]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        #add padding if the length of seq is less than max length

        labels = labels[:self.max_length] + [0] * max(0, self.max_length - len(labels))

        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()


        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': torch.tensor(labels, dtype=torch.long)
        }

    def __len__(self):
        return len(self.texts)


train_dataset = CustomDataset(train_df, tokenizer, max_length)
test_dataset = CustomDataset(test_df, tokenizer, max_length)
valid_dataset = CustomDataset(valid_df, tokenizer, max_length)

batch_size = 64
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
validloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

In [45]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

class PartOfSpeechTagger(nn.Module):
    def __init__(self, size_of_input, number_of_tags, hidden_layer):
        super(PartOfSpeechTagger, self).__init__()
        self.fully_connected_1 = nn.Linear(size_of_input, hidden_layer)
        self.batch_norm = nn.BatchNorm1d(hidden_layer)
        self.relu = nn.ReLU()
        self.fully_connected_2 = nn.Linear(hidden_layer, number_of_tags)

    def forward(self, input_tensor):
        out = self.fully_connected_1(input_tensor)
        out = self.batch_norm(out)
        out = self.relu(out)
        out = self.fully_connected_2(out)
        return out




p_h = PartOfSpeechTagger(pretrained_model.config.hidden_size, len(tags), 1024)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

roberta_model = pretrained_model.roberta

def forward(input_ids, attention_mask):
    outputs = roberta_model(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden_states = outputs.last_hidden_state
    pos_output = p_h(last_hidden_states)

    return pos_output

roberta_model = roberta_model.to(device)
p_h = p_h.to(device)


In [47]:
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AdamW
from tqdm import tqdm
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

roberta_model.to(device)
roberta_model.train()

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = AdamW(roberta_model.parameters(), lr=2e-5)

num_epochs = 50
num_training_steps = num_epochs * len(trainloader)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_training_steps)

for epoch in range(num_epochs):
    total_loss, total_correct, total_samples = 0, 0, 0
    tqdm_dataloader = tqdm(trainloader, desc=f"Epoch {epoch + 1}/{num_epochs}")

    for batch in tqdm_dataloader:
        input_ids, attention_mask, labels = [item.to(device) for item in (batch['input_ids'], batch['attention_mask'], batch['labels'])]
        optimizer.zero_grad()
        outputs = roberta_model(input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        loss = criterion(logits.view(-1, roberta_model.config.num_labels), labels.view(-1))
        total_loss += loss.item()
        _, predicted_labels = torch.max(logits, 2)
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.numel()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    print(f"Epoch {epoch + 1}/{num_epochs} - Average Loss: {total_loss / len(trainloader)}, Accuracy: {total_correct / total_samples * 100:.2f}%")

    roberta_model.eval()
    total_valid_loss, total_valid_correct, total_valid_samples = 0, 0, 0

    with torch.no_grad():
        tqdm_valid_dataloader = tqdm(validloader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation")

        for valid_batch in tqdm_valid_dataloader:
            input_ids_valid, attention_mask_valid, labels_valid = [item.to(device) for item in (valid_batch['input_ids'], valid_batch['attention_mask'], valid_batch['labels'])]
            valid_outputs = roberta_model(input_ids_valid, attention_mask=attention_mask_valid, labels=labels_valid)
            valid_logits = valid_outputs.logits
            valid_loss = F.cross_entropy(valid_logits.view(-1, roberta_model.config.num_labels), labels_valid.view(-1), ignore_index=tokenizer.pad_token_id)
            total_valid_loss += valid_loss.item()
            _, predicted_labels_valid = torch.max(valid_logits, 2)
            total_valid_correct += (predicted_labels_valid == labels_valid).sum().item()
            total_valid_samples += labels_valid.numel()

    print(f"Epoch {epoch + 1}/{num_epochs} - Validation Loss: {total_valid_loss / len(validloader)}, Validation Accuracy: {total_valid_correct / total_valid_samples * 100:.2f}%")

    roberta_model.train()



Epoch 1/50: 100%|██████████| 41/41 [00:04<00:00,  8.96it/s]


Epoch 1/50 - Average Loss: 2.2528280165137313, Accuracy: 36.82%


Epoch 1/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.46it/s]


Epoch 1/50 - Validation Loss: 1.661430142142556, Validation Accuracy: 50.69%


Epoch 2/50: 100%|██████████| 41/41 [00:04<00:00,  9.76it/s]


Epoch 2/50 - Average Loss: 1.5101719338719437, Accuracy: 57.87%


Epoch 2/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.19it/s]


Epoch 2/50 - Validation Loss: 1.2018236897208474, Validation Accuracy: 64.30%


Epoch 3/50: 100%|██████████| 41/41 [00:04<00:00,  9.74it/s]


Epoch 3/50 - Average Loss: 1.161164533801195, Accuracy: 65.21%


Epoch 3/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 21.29it/s]


Epoch 3/50 - Validation Loss: 0.984381231394681, Validation Accuracy: 69.31%


Epoch 4/50: 100%|██████████| 41/41 [00:04<00:00,  9.23it/s]


Epoch 4/50 - Average Loss: 0.9865582439957595, Accuracy: 69.51%


Epoch 4/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.06it/s]


Epoch 4/50 - Validation Loss: 0.8737964034080505, Validation Accuracy: 72.24%


Epoch 5/50: 100%|██████████| 41/41 [00:04<00:00,  9.62it/s]


Epoch 5/50 - Average Loss: 0.8815629903863116, Accuracy: 72.25%


Epoch 5/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 24.55it/s]


Epoch 5/50 - Validation Loss: 0.7926277897574685, Validation Accuracy: 75.24%


Epoch 6/50: 100%|██████████| 41/41 [00:04<00:00,  8.95it/s]


Epoch 6/50 - Average Loss: 0.7967978192538749, Accuracy: 74.40%


Epoch 6/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 21.13it/s]


Epoch 6/50 - Validation Loss: 0.7129558053883639, Validation Accuracy: 77.07%


Epoch 7/50: 100%|██████████| 41/41 [00:04<00:00,  9.38it/s]


Epoch 7/50 - Average Loss: 0.7230055070504909, Accuracy: 76.91%


Epoch 7/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.59it/s]


Epoch 7/50 - Validation Loss: 0.7012083042751659, Validation Accuracy: 78.14%


Epoch 8/50: 100%|██████████| 41/41 [00:04<00:00,  9.50it/s]


Epoch 8/50 - Average Loss: 0.6694178901067595, Accuracy: 78.20%


Epoch 8/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.49it/s]


Epoch 8/50 - Validation Loss: 0.6635909622365778, Validation Accuracy: 78.52%


Epoch 9/50: 100%|██████████| 41/41 [00:04<00:00,  9.04it/s]


Epoch 9/50 - Average Loss: 0.6214566259849362, Accuracy: 79.37%


Epoch 9/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.10it/s]


Epoch 9/50 - Validation Loss: 0.6185390244830739, Validation Accuracy: 79.41%


Epoch 10/50: 100%|██████████| 41/41 [00:04<00:00,  9.50it/s]


Epoch 10/50 - Average Loss: 0.5805561164530312, Accuracy: 80.40%


Epoch 10/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 25.25it/s]


Epoch 10/50 - Validation Loss: 0.6009299944747578, Validation Accuracy: 79.64%


Epoch 11/50: 100%|██████████| 41/41 [00:04<00:00,  9.55it/s]


Epoch 11/50 - Average Loss: 0.5377700583236974, Accuracy: 81.85%


Epoch 11/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 20.85it/s]


Epoch 11/50 - Validation Loss: 0.5902896306731484, Validation Accuracy: 80.48%


Epoch 12/50: 100%|██████████| 41/41 [00:04<00:00,  9.08it/s]


Epoch 12/50 - Average Loss: 0.4991384010489394, Accuracy: 82.73%


Epoch 12/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 25.89it/s]


Epoch 12/50 - Validation Loss: 0.5838824537667361, Validation Accuracy: 81.09%


Epoch 13/50: 100%|██████████| 41/41 [00:04<00:00,  9.71it/s]


Epoch 13/50 - Average Loss: 0.4661375015247159, Accuracy: 83.58%


Epoch 13/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.23it/s]


Epoch 13/50 - Validation Loss: 0.5895793329585682, Validation Accuracy: 81.48%


Epoch 14/50: 100%|██████████| 41/41 [00:04<00:00,  9.32it/s]


Epoch 14/50 - Average Loss: 0.43708802795991664, Accuracy: 84.88%


Epoch 14/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 20.60it/s]


Epoch 14/50 - Validation Loss: 0.5729551884261045, Validation Accuracy: 81.76%


Epoch 15/50: 100%|██████████| 41/41 [00:04<00:00,  9.70it/s]


Epoch 15/50 - Average Loss: 0.4094806880485721, Accuracy: 85.47%


Epoch 15/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.58it/s]


Epoch 15/50 - Validation Loss: 0.5562875758517872, Validation Accuracy: 81.91%


Epoch 16/50: 100%|██████████| 41/41 [00:04<00:00,  9.86it/s]


Epoch 16/50 - Average Loss: 0.37943305620333045, Accuracy: 86.31%


Epoch 16/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.52it/s]


Epoch 16/50 - Validation Loss: 0.5928777022795244, Validation Accuracy: 81.88%


Epoch 17/50: 100%|██████████| 41/41 [00:04<00:00,  9.25it/s]


Epoch 17/50 - Average Loss: 0.3559106777353985, Accuracy: 87.04%


Epoch 17/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.03it/s]


Epoch 17/50 - Validation Loss: 0.6042602062225342, Validation Accuracy: 81.83%


Epoch 18/50: 100%|██████████| 41/41 [00:04<00:00,  9.82it/s]


Epoch 18/50 - Average Loss: 0.32586128828002187, Accuracy: 88.11%


Epoch 18/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.48it/s]


Epoch 18/50 - Validation Loss: 0.5928192192857916, Validation Accuracy: 81.53%


Epoch 19/50: 100%|██████████| 41/41 [00:04<00:00,  9.81it/s]


Epoch 19/50 - Average Loss: 0.3074340962055253, Accuracy: 88.69%


Epoch 19/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 23.47it/s]


Epoch 19/50 - Validation Loss: 0.5858129940249703, Validation Accuracy: 82.26%


Epoch 20/50: 100%|██████████| 41/41 [00:04<00:00,  9.17it/s]


Epoch 20/50 - Average Loss: 0.2889953175695931, Accuracy: 89.06%


Epoch 20/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.03it/s]


Epoch 20/50 - Validation Loss: 0.5810276188633658, Validation Accuracy: 82.21%


Epoch 21/50: 100%|██████████| 41/41 [00:04<00:00,  9.87it/s]


Epoch 21/50 - Average Loss: 0.26682043439004477, Accuracy: 89.87%


Epoch 21/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.25it/s]


Epoch 21/50 - Validation Loss: 0.5975109067830172, Validation Accuracy: 82.34%


Epoch 22/50: 100%|██████████| 41/41 [00:04<00:00,  9.58it/s]


Epoch 22/50 - Average Loss: 0.2563132807248976, Accuracy: 90.22%


Epoch 22/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 19.94it/s]


Epoch 22/50 - Validation Loss: 0.5998231768608093, Validation Accuracy: 82.01%


Epoch 23/50: 100%|██████████| 41/41 [00:04<00:00,  9.52it/s]


Epoch 23/50 - Average Loss: 0.2375691013365257, Accuracy: 90.72%


Epoch 23/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.39it/s]


Epoch 23/50 - Validation Loss: 0.6177080598744479, Validation Accuracy: 82.67%


Epoch 24/50: 100%|██████████| 41/41 [00:04<00:00,  9.73it/s]


Epoch 24/50 - Average Loss: 0.21849045629908398, Accuracy: 91.11%


Epoch 24/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.20it/s]


Epoch 24/50 - Validation Loss: 0.630093124779788, Validation Accuracy: 82.16%


Epoch 25/50: 100%|██████████| 41/41 [00:04<00:00,  9.14it/s]


Epoch 25/50 - Average Loss: 0.2107001633905783, Accuracy: 91.51%


Epoch 25/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 25.08it/s]


Epoch 25/50 - Validation Loss: 0.6598588986830278, Validation Accuracy: 82.37%


Epoch 26/50: 100%|██████████| 41/41 [00:04<00:00,  9.75it/s]


Epoch 26/50 - Average Loss: 0.19667472890237483, Accuracy: 91.83%


Epoch 26/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.06it/s]


Epoch 26/50 - Validation Loss: 0.6423645723949779, Validation Accuracy: 82.72%


Epoch 27/50: 100%|██████████| 41/41 [00:04<00:00,  9.65it/s]


Epoch 27/50 - Average Loss: 0.183810557534055, Accuracy: 92.34%


Epoch 27/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 24.71it/s]


Epoch 27/50 - Validation Loss: 0.6311048133806749, Validation Accuracy: 82.26%


Epoch 28/50: 100%|██████████| 41/41 [00:04<00:00,  8.40it/s]


Epoch 28/50 - Average Loss: 0.17598628270916822, Accuracy: 92.68%


Epoch 28/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 14.30it/s]


Epoch 28/50 - Validation Loss: 0.6532352702184157, Validation Accuracy: 82.65%


Epoch 29/50: 100%|██████████| 41/41 [00:06<00:00,  6.76it/s]


Epoch 29/50 - Average Loss: 0.16423158456639544, Accuracy: 93.03%


Epoch 29/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 19.50it/s]


Epoch 29/50 - Validation Loss: 0.6425862258130853, Validation Accuracy: 82.21%


Epoch 30/50: 100%|██████████| 41/41 [00:06<00:00,  5.89it/s]


Epoch 30/50 - Average Loss: 0.1573446571100049, Accuracy: 93.10%


Epoch 30/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 11.72it/s]


Epoch 30/50 - Validation Loss: 0.6765247149900957, Validation Accuracy: 82.26%


Epoch 31/50: 100%|██████████| 41/41 [00:05<00:00,  7.28it/s]


Epoch 31/50 - Average Loss: 0.15373476067694222, Accuracy: 93.18%


Epoch 32/50: 100%|██████████| 41/41 [00:04<00:00,  9.20it/s]


Epoch 32/50 - Average Loss: 0.14974678544009604, Accuracy: 93.41%


Epoch 32/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.14it/s]


Epoch 32/50 - Validation Loss: 0.6694235043092207, Validation Accuracy: 82.24%


Epoch 33/50: 100%|██████████| 41/41 [00:04<00:00,  9.88it/s]


Epoch 33/50 - Average Loss: 0.14306828906623328, Accuracy: 93.35%


Epoch 33/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.71it/s]


Epoch 33/50 - Validation Loss: 0.6702261729673906, Validation Accuracy: 83.03%


Epoch 34/50: 100%|██████████| 41/41 [00:04<00:00,  9.71it/s]


Epoch 34/50 - Average Loss: 0.13791526708661056, Accuracy: 93.84%


Epoch 34/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 21.70it/s]


Epoch 34/50 - Validation Loss: 0.704160210761157, Validation Accuracy: 82.77%


Epoch 35/50: 100%|██████████| 41/41 [00:04<00:00,  9.42it/s]


Epoch 35/50 - Average Loss: 0.13491959433730055, Accuracy: 93.70%


Epoch 35/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.60it/s]


Epoch 35/50 - Validation Loss: 0.6712120500477877, Validation Accuracy: 82.44%


Epoch 36/50: 100%|██████████| 41/41 [00:04<00:00,  9.71it/s]


Epoch 36/50 - Average Loss: 0.13062788782323279, Accuracy: 94.03%


Epoch 36/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.54it/s]


Epoch 36/50 - Validation Loss: 0.6906456784768538, Validation Accuracy: 82.77%


Epoch 37/50: 100%|██████████| 41/41 [00:04<00:00,  9.30it/s]


Epoch 37/50 - Average Loss: 0.12659618676435658, Accuracy: 94.08%


Epoch 37/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 19.16it/s]


Epoch 37/50 - Validation Loss: 0.6800582788207314, Validation Accuracy: 82.75%


Epoch 38/50: 100%|██████████| 41/41 [00:04<00:00,  9.59it/s]


Epoch 38/50 - Average Loss: 0.12189932058497173, Accuracy: 94.19%


Epoch 38/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.07it/s]


Epoch 38/50 - Validation Loss: 0.6699885725975037, Validation Accuracy: 82.70%


Epoch 39/50: 100%|██████████| 41/41 [00:04<00:00,  9.80it/s]


Epoch 39/50 - Average Loss: 0.1180894442084359, Accuracy: 94.32%


Epoch 39/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.17it/s]


Epoch 39/50 - Validation Loss: 0.6908670176159252, Validation Accuracy: 82.88%


Epoch 40/50: 100%|██████████| 41/41 [00:04<00:00,  9.21it/s]


Epoch 40/50 - Average Loss: 0.11714442910217657, Accuracy: 94.30%


Epoch 40/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.77it/s]


Epoch 40/50 - Validation Loss: 0.6836603419347242, Validation Accuracy: 82.47%


Epoch 41/50: 100%|██████████| 41/41 [00:04<00:00,  9.81it/s]


Epoch 41/50 - Average Loss: 0.11645356729263212, Accuracy: 94.40%


Epoch 41/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.87it/s]


Epoch 41/50 - Validation Loss: 0.6957695484161377, Validation Accuracy: 82.75%


Epoch 42/50: 100%|██████████| 41/41 [00:04<00:00,  9.69it/s]


Epoch 42/50 - Average Loss: 0.1174961746465869, Accuracy: 94.25%


Epoch 42/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 21.37it/s]


Epoch 42/50 - Validation Loss: 0.6857142936099659, Validation Accuracy: 82.67%


Epoch 43/50: 100%|██████████| 41/41 [00:04<00:00,  9.25it/s]


Epoch 43/50 - Average Loss: 0.11377442900727434, Accuracy: 94.24%


Epoch 43/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.31it/s]


Epoch 43/50 - Validation Loss: 0.6839263439178467, Validation Accuracy: 82.82%


Epoch 44/50: 100%|██████████| 41/41 [00:04<00:00,  9.76it/s]


Epoch 44/50 - Average Loss: 0.11806708537950748, Accuracy: 94.26%


Epoch 44/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.67it/s]


Epoch 44/50 - Validation Loss: 0.6870280829342928, Validation Accuracy: 82.82%


Epoch 45/50: 100%|██████████| 41/41 [00:04<00:00,  9.43it/s]


Epoch 45/50 - Average Loss: 0.10946939885616302, Accuracy: 94.59%


Epoch 45/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 21.19it/s]


Epoch 45/50 - Validation Loss: 0.6819972206245769, Validation Accuracy: 82.82%


Epoch 46/50: 100%|██████████| 41/41 [00:04<00:00,  9.63it/s]


Epoch 46/50 - Average Loss: 0.10863110550293108, Accuracy: 94.69%


Epoch 46/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 27.04it/s]


Epoch 46/50 - Validation Loss: 0.6894840679385446, Validation Accuracy: 82.77%


Epoch 47/50: 100%|██████████| 41/41 [00:04<00:00,  9.77it/s]


Epoch 47/50 - Average Loss: 0.10964054686994087, Accuracy: 94.63%


Epoch 47/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.00it/s]


Epoch 47/50 - Validation Loss: 0.7088442336429249, Validation Accuracy: 82.77%


Epoch 48/50: 100%|██████████| 41/41 [00:04<00:00,  9.15it/s]


Epoch 48/50 - Average Loss: 0.1101971661172262, Accuracy: 94.48%


Epoch 48/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.25it/s]


Epoch 48/50 - Validation Loss: 0.6791129952127283, Validation Accuracy: 82.77%


Epoch 49/50: 100%|██████████| 41/41 [00:04<00:00,  9.68it/s]


Epoch 49/50 - Average Loss: 0.11025161595969665, Accuracy: 94.48%


Epoch 49/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 26.70it/s]


Epoch 49/50 - Validation Loss: 0.7115216065536846, Validation Accuracy: 82.75%


Epoch 50/50: 100%|██████████| 41/41 [00:05<00:00,  7.83it/s]


Epoch 50/50 - Average Loss: 0.10962375135319989, Accuracy: 94.59%


Epoch 50/50 - Validation: 100%|██████████| 11/11 [00:00<00:00, 20.37it/s]

Epoch 50/50 - Validation Loss: 0.6903644339604811, Validation Accuracy: 82.75%





In [48]:
######################   TODO 2.2   ########################
# Report best model accuarcy on validation dataset
# Accuracy below 85% will not be graded
# accuracy_valid
roberta_model.eval()
total_valid_correct = 0
total_valid_samples = 0
total_valid_loss = 0

with torch.no_grad():
    t = tqdm(testloader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation")

    for valid_batch in t:
        input_ids_valid, attention_mask_valid, labels_valid = valid_batch['input_ids'], valid_batch['attention_mask'], valid_batch['labels']

        input_ids_valid, attention_mask_valid, labels_valid = input_ids_valid.to(device), attention_mask_valid.to(device), labels_valid.to(device)

        valid_outputs = roberta_model(input_ids_valid, attention_mask=attention_mask_valid, labels=labels_valid)
        valid_logits = valid_outputs.logits

        valid_loss = F.cross_entropy(valid_logits.view(-1, roberta_model.config.num_labels), labels_valid.view(-1), ignore_index=tokenizer.pad_token_id)

        total_valid_loss += valid_loss.item()

        _, predicted_labels_valid = torch.max(valid_logits, 2)
        correct_valid = (predicted_labels_valid == labels_valid).sum().item()
        total_valid_correct += correct_valid
        total_valid_samples += labels_valid.numel()

average_valid_loss = total_valid_loss / len(validloader)
accuracy_valid = total_valid_correct / total_valid_samples
print(f"Epoch {epoch + 1}/{num_epochs} - Validation Loss: {average_valid_loss}, Validation Accuracy: {accuracy_valid * 100:.2f}%")

###################### (5 points) ##########################

Epoch 50/50 - Validation: 100%|██████████| 7/7 [00:00<00:00, 19.53it/s]

Epoch 50/50 - Validation Loss: 0.4731152057647705, Validation Accuracy: 82.27%





In [51]:
######################   TODO 2.1   ########################
# Now implement it with huggingface trainer

from transformers import AdamW, Trainer, TrainingArguments
import torch.nn as nn

loss_function = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs, labels=labels)
        logits = outputs.logits
        loss = loss_function(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

training_arguments = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_steps=200,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    evaluation_strategy="steps",
    eval_steps=50,
    metric_for_best_model="eval_loss",
)

my_trainer = MyTrainer(
    model=roberta_model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

my_trainer.train()

###################### (40 points) ##########################

Step,Training Loss,Validation Loss
50,0.0422,1.0247
100,0.022,1.077994
150,0.0225,1.144989
200,0.0254,1.156328
250,0.0253,1.114691
300,0.0184,1.12859
350,0.0121,1.146529
400,0.0085,1.145519


TrainOutput(global_step=410, training_loss=0.021715456974215622, metrics={'train_runtime': 55.6664, 'train_samples_per_second': 470.302, 'train_steps_per_second': 7.365, 'total_flos': 80181698058000.0, 'train_loss': 0.021715456974215622, 'epoch': 10.0})