# A4: LSTMs and Transformers for Word Sense Disambiguation

by Nikolai Ilinykh, Adam Ek, and others.

The lab is an exploration and learning exercise to be done in a group and also in discussion with the teachers and other students.

Write all your answers and the code in the appropriate boxes below.


A problem with static distributional vectors is the difficulty of distinguishing between different *word senses*. We will continue our exploration of word vectors by considering *trainable vectors* or *word embeddings* for Word Sense Disambiguation (WSD). We will work with both LSTMs and transformer models, e.g. BERT. The purpose of the assignment is to learn use representations neural models in a downstream task of word sense disambiguation.


## Word Sense Disambiguation Task

The goal of word sense disambiguation is to train a model to find the sense of a word (homonyms of a word-form). For example, the word "bank" can mean "sloping land" or "financial institution". 

(a) "I deposited my money in the **bank**" (financial institution)

(b) "I swam from the river **bank**" (sloping land)

In case a) and b), we can determine the meaning of "bank" based on the *context*. To utilize context in a semantic model, we use *contextualized word representations*.

Previously, we worked with *static word representations*, i.e., the representation does not depend on the context. To illustrate, we can consider sentences (a) and (b), where the word **bank** would have the same static representation in both sentences, which means that it becomes difficult for us to predict its sense. What we want is to create representations that depend on the context, i.e., *contextualized embeddings*.

As we have discussed in the class, contextualized representations can come in the form of pre-training the model for some "general" task and then fine-tuning it for some downstream task. Here we will do the following:

(1) Train and test LSTM model directly for word sense disambiguation. We will learn contextualized representations within this model.

(2) Take BERT that was pre-trained on masked language modeling and next sentence prediction. Fine-tune it on our data and test it for the word sense disambiguation on the task dataset. The idea for you is to explore how pre-trained contextualized representations from BERT can be updated and used for the downstream task of word sense disambiguation.

Your overall task in this lab is to create a neural network model that can disambiguate the word sense of 30 different words.

In [None]:
# first we import some packages that we need

# here add any package that you will need later

import torch
import torch.nn as nn
import torchtext

# our hyperparameters (add more when/if you need them)
device = torch.device('cuda:1')

batch_size = 8
learning_rate = 0.001
epochs = 3

# 1. Working with Data

A central part of any machine learning system is the data we're working with.

In this section, we will split the data (the dataset is in `wsd_data.txt`) into a training set and a test set.


## Data

The dataset we will use contains different word senses for 30 different words. The data is organized as follows (values separated by tabs), where each line is a separate item in the dataset:

- Column 1: word-sense, e.g., keep%2:42:07::
- Column 2: word-form, e.g., keep.v
- Column 3: index of word, e.g., 15
- Column 4: white-space tokenized context, e.g., Action by the Committee In pursuance of its mandate , the Committee will continue to keep under review the situation relating to the question of Palestine and participate in relevant meetings of the General Assembly and the Security Council . The Committee will also continue to monitor the situation on the ground and draw the attention of the international community to urgent developments in the Occupied Palestinian Territory , including East Jerusalem , requiring international action .


### Splitting the Data

Your first task is to separate the data into a *training set* and a *test set*.

The training set should contain 80% of the examples, and the test set the remaining 20%.

The examples for the test/training set should be selected **randomly**.

Save each dataset into a .csv file for loading later.

**[2 marks]**

In [2]:
import pandas as pd
import random

def data_split(dataset_path):
    # Read the dataset assuming it's tab-separated
    data = pd.read_csv(dataset_path, sep='\t', header=None,
                       names=["word_sense", "word_form", "word_index", "context"])
    
    # Shuffle the data
    data = data.sample(frac=1, random_state=42).reset_index(drop=True)
    
    # Calculate split index
    split_idx = int(0.8 * len(data))
    
    # Split the data
    train_split = data.iloc[:split_idx]
    test_split = data.iloc[split_idx:]
    
    # Save to CSV for later use (optional)
    train_split.to_csv("train_split.csv", index=False)
    test_split.to_csv("test_split.csv", index=False)

    return train_split, test_split



### Creating a Baseline

Your second task is to create a *baseline* for the task.

A baseline is a "reality check" for a model. Given a very simple heuristic/algorithmic/model solution to the problem, can our neural network perform better than this? Baselines are important as they give us a point of comparison for the actual models. They are commonly used in NLP. Sometimes baseline models are not simple models but previous state-of-the-art.

In this exercise, we will have a simple baseline model that is the "most common sense" (MCS) baseline. For each word form, find the most commonly assigned sense to the word and label a word with that sense. In a fictional dataset, "bank" has two senses: "financial institution," which occurs 5 times, and "side of the river," which occurs 3 times. Thus, all 8 occurrences of "bank" are labeled "financial institution," yielding an MCS accuracy of 5/8 = 62.5%. If a model obtains a higher score than this, we can conclude that the model *at least* is better than selecting the most frequent word sense.

Your task is to write the code for this baseline, train, and test it. The baseline has the knowledge about labels and their frequency only from the train data. You evaluate it on the test data by comparing the ground-truth sense with the one that the model predicts. A good "dumb" baseline in this case is the one that performs quite badly. Expect the model to perform around 0.30 in terms of accuracy. You should use accuracy as your main metric; you can also compute the F1-score.

**[2 marks]**


In [3]:
# def mcs_baseline(data):
    
#     # your code goes here
    
#     return


from collections import defaultdict, Counter
from sklearn.metrics import accuracy_score

def mcs_baseline(train_data, test_data):
    # Step 1: Count the most common sense for each word form in training data
    sense_counts = defaultdict(Counter)
    for _, row in train_data.iterrows():
        word_form = row['word_form']
        word_sense = row['word_sense']
        sense_counts[word_form][word_sense] += 1

    # Step 2: Build the MCS dictionary
    mcs_dict = {word: senses.most_common(1)[0][0] for word, senses in sense_counts.items()}

    # Step 3: Predict the most common sense for each word form in test data
    y_true = []
    y_pred = []

    for _, row in test_data.iterrows():
        word_form = row['word_form']
        true_sense = row['word_sense']
        predicted_sense = mcs_dict.get(word_form, None)  # Fallback could be random or majority class
        if predicted_sense is not None:
            y_true.append(true_sense)
            y_pred.append(predicted_sense)

    # Step 4: Compute accuracy
    accuracy = accuracy_score(y_true, y_pred)
    return accuracy


In [21]:
# using MCS baseline
train_data, test_data = data_split("wsd_data.txt")
mcs_accuracy = mcs_baseline(train_data, test_data)
print(f"MCS Baseline Accuracy: {mcs_accuracy:.2%}")

MCS Baseline Accuracy: 31.93%


### Creating Data Iterators

To train a neural network, we first need to prepare the data. This involves converting words (and labels) to a number and organizing the data into batches. We also want the ability to shuffle the examples such that they appear in a random order.

Your task is to create a dataloader for the training and test set you created previously.

You are encouraged to adjust your own dataloader you built for previous assignments. Some things to take into account:

1. Tokenize inputs, keep a dictionary of word-to-IDs and IDs-to-words (vocabulary), fix paddings. You might need to consider doing these for each of the four fields in the dataset.
2. Your dataloader probably has a function to process data. Process each column in the dataset.
3. You might want to clean the data a bit. For example, the first column has some symbols, which might be unnecessary. It is up to you whether you want to remove them and clean this column or keep labels the way they are. In any case, you must provide an explanation of your decision and how you think it will affect the performance of your model. Data and its preprocessing matters, so motivate your decisions.
4. Organize your dataset into batches and shuffle them. You should have something akin to data iterators so that your model can take them.

Implement the dataloader and perform necessary preprocessings.

[**2 marks**]

In [4]:
# def dataloader(path):

#     # your code goes here
#     # below are only some examples!
    
#     def __getitem__(self, idx):
        
#         return Y

#     def __len__(self):
        
#         return X
    
# def data_load(something):
    
#     return dataloader_train, dataloader_test


import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from collections import defaultdict
from torch.nn.utils.rnn import pad_sequence

class dataloader(Dataset):
    def __init__(self, dataframe, word2idx, label_encoder):
        self.data = dataframe
        self.word2idx = word2idx
        self.label_encoder = label_encoder

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        context_words = row['context'].split()
        word_ids = [self.word2idx.get(w, self.word2idx['<unk>']) for w in context_words]
        word_tensor = torch.tensor(word_ids, dtype=torch.long)

        target_index = int(row['word_index'])
        label = self.label_encoder.transform([row['word_sense']])[0]

        return word_tensor, target_index, label

    def __len__(self):
        return len(self.data)

def collate_batch(batch):
    contexts, indices, labels = zip(*batch)
    padded_contexts = pad_sequence(contexts, batch_first=True, padding_value=0)
    indices = torch.tensor(indices, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)
    return padded_contexts, indices, labels

def build_vocab(df):
    word2idx = {'<pad>': 0, '<unk>': 1}
    idx = 2
    for text in df['context']:
        for word in text.split():
            if word not in word2idx:
                word2idx[word] = idx
                idx += 1
    return word2idx

def data_load(train_df, test_df, batch_size=8):
    # Build vocabulary
    word2idx = build_vocab(train_df)

    # Encode labels
    label_encoder = LabelEncoder()
    label_encoder.fit(train_df['word_sense'])

    # Create datasets
    train_dataset = dataloader(train_df, word2idx, label_encoder)
    test_dataset = dataloader(test_df, word2idx, label_encoder)

    # Create DataLoaders
    dataloader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    dataloader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

    return dataloader_train, dataloader_test, word2idx, label_encoder




# 2.1 LSTM for Word Sense Disambiguation

In this section, we will train an LSTM model to predict word senses based on *contextualized representations*.

You can read more about LSTMs [here](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).


### Model

We will use a **bidirectional** Long Short-Term Memory (LSTM) network to create a representation for the sentences and a **linear** classifier to predict the sense of each word.

As we discussed in the lecture, bidirectional LSTM is using **two** hidden states: one that goes in the left-to-right direction, and another one that goes in the right-to-left direction. PyTorch documentation on LSTMs can be found [here](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html). It says that if the bidirectional parameter is set to True, then "h_n will contain a concatenation of the final forward and reverse hidden states, respectively." Keep it in mind because you will have to ensure that your linear layer for prediction takes input of that size.

When we initialize the model, we need a few things:

1) An embedding layer: a dictionary from which we can obtain word embeddings
2) A LSTM-module to obtain contextual representations
3) A classifier that computes scores for each word-sense given *some* input

The general procedure is the following:

1) For each word in the sentence, obtain word embeddings
2) Run the embedded sentences through the LSTM
3) Select the appropriate hidden state
4) Predict the word-sense 

**Suggestion for efficiency:** *Use a low dimensionality (32) for word embeddings and the LSTM when developing and testing the code, then scale up when running the full training/tests*

Your tasks will be to create **two different models** (both follow the two outlines described above).

-----

Your first model should make a prediction from the LSTM's representation of the target word.

In particular, you run your LSTM on the context in which the target word is used. LSTM will produce a sequence of hidden states. Each hidden state corresponds to a single word from the input context. For example, you should be able to get 37 hidden states for a context that has 37 words/elements in it. Next, take the LSTM's representation of the target word. For example, it can be hidden state number 5, because the fifth word in your context is the target word that you want to predict the meaning for. This target's word representation is the input to your linear layer that makes the final prediction.

**[5 marks]**

In [6]:
# class WSDModel_approach1(nn.Module):
#     def __init__(self, ...):
        
#         # your code goes here
#         self.embeddings = ...
#         self.rnn = ...
#         self.classifier = ...
    
#     def forward(self, batch):
#         # your code goes here
        
#         return predictions

import torch
import torch.nn as nn

class WSDModel_approach1(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, padding_idx):
        super(WSDModel_approach1, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, batch):
        text, word_indices = batch  # text: [batch_size, seq_len], word_indices: [batch_size]
        embedded = self.embeddings(text)  # [batch_size, seq_len, embedding_dim]
        outputs, _ = self.lstm(embedded)  # [batch_size, seq_len, hidden_dim * 2]

        # Gather hidden states corresponding to target word indices
        batch_size = outputs.size(0)
        target_outputs = outputs[torch.arange(batch_size), word_indices]  # [batch_size, hidden_dim*2]

        predictions = self.classifier(target_outputs)  # [batch_size, output_dim]
        return predictions


Your second model should make a prediction from the final hidden state of your LSTM.

In particular, do the same first steps as in the first approach. But then to make a prediction with your linear layer, you will need to take the last hidden state that your LSTM produces for the whole sequence.

**[5 marks]**

In [7]:
# class WSDModel_approach2(nn.Module):
#     def __init__(self, ...):
#         # your code goes here
    
#     def forward(self, ...):
#         # your code goes here
        
#         return predictions


class WSDModel_approach2(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, padding_idx):
        super(WSDModel_approach2, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, batch):
        text, _ = batch  # word index not used here
        embedded = self.embeddings(text)
        _, (hidden, _) = self.lstm(embedded)

        # hidden shape: [2, batch_size, hidden_dim] for bidirectional
        # Concatenate forward and backward hidden states
        final_hidden = torch.cat((hidden[0], hidden[1]), dim=1)  # [batch_size, hidden_dim * 2]

        predictions = self.classifier(final_hidden)
        return predictions


### Training and Testing the Model

Now we are ready to train and test our model. What we need now is a loss function, an optimizer, and our data. 

- First, create the loss function and the optimizer.
- Next, iterate over the number of epochs (i.e., how many times we let the model see our data). 
- For each epoch, iterate over the dataset to obtain batches. Use the batch as input to the model, and let the model output scores for the different word senses.
- For each model output, calculate the loss (and print the loss) on the output and update the model parameters.
- Reset the gradients and repeat.
- After all epochs are done, test your trained model on the test set and calculate the total and per-word-form accuracy of your model.

Implement the training and testing of the model.

**[4 marks]**

**Suggestion for efficiency:** *When developing your model, try training and testing the model on one or two batches (for each epoch) of data to make sure everything works! It's very annoying if you train for N epochs to find out that something went wrong when testing the model, or to find that something goes wrong when moving from epoch 0 to epoch 1.*

Do not forget to save your best models as .pickle files. The results should be reproducible for us to evaluate your models.


In [8]:
# train_iter, test_iter, vocab, labels = dataloader(path_to_folder)

# loss_function = ...
# optimizer = ...
# model = ...

# for _ in range(epochs):
#     # train model
#     ...
    
# # test model after all epochs are completed



import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
import pickle
import numpy

def train_and_evaluate(path_to_folder, model_class, epochs=3, lr=0.001, batch_size=4, save_path="best_model.pkl"):
    # Load data
    # train_iter, test_iter, vocab, label_encoder = dataloader(path_to_folder)
    train_data, test_data = data_split(path_to_folder)

    # make sure the data is loaded correctly
    # print(train_data.columns)
    # print(train_data.head())

    train_iter, test_iter, vocab, label_encoder = data_load(train_data, test_data, batch_size=batch_size)


    # Model, loss, optimizer
    model = model_class(
        vocab_size=len(vocab),
        embedding_dim=32,
        hidden_dim=32,
        output_dim=len(label_encoder.classes_),
        padding_idx=vocab['<pad>']
    )

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    loss_function.to(device)

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in train_iter:
            inputs, word_indices, labels = [x.to(device) for x in batch]

            optimizer.zero_grad()
            outputs = model((inputs, word_indices))  # shape: [batch_size, num_classes]
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f}")

    # Evaluation on test data
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_iter:
            inputs, word_indices, labels = [x.to(device) for x in batch]
            outputs = model((inputs, word_indices))
            
            # Test for debugging, numpy error
            # print(type(outputs))
            # print(type(torch.argmax(outputs, dim=1)))
            '''
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            '''
            preds = torch.argmax(outputs, dim=1).cpu().tolist()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().tolist())
            


    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy:.2%}")

    # Save model
    with open(save_path, "wb") as f:
        pickle.dump(model.state_dict(), f)

    return model



In [15]:
# using model approach 1
trained_model = train_and_evaluate("wsd_data.txt", WSDModel_approach1, epochs=5, lr=0.001, batch_size=4, save_path="WSDModel_approach1_best_model.pkl")

Epoch 1/5 | Loss: 22824.9954
Epoch 2/5 | Loss: 14674.2469
Epoch 3/5 | Loss: 12468.6999
Epoch 4/5 | Loss: 10776.5447
Epoch 5/5 | Loss: 9302.1645
Test Accuracy: 67.97%


In [16]:
# using model approach 2
trained_model2 = train_and_evaluate("wsd_data.txt", WSDModel_approach2, epochs=5, lr=0.001, batch_size=4, save_path="WSDModel_approach2_best_model.pkl")

Epoch 1/5 | Loss: 73954.5119
Epoch 2/5 | Loss: 67911.8845
Epoch 3/5 | Loss: 60084.2378
Epoch 4/5 | Loss: 51437.4014
Epoch 5/5 | Loss: 42160.7167
Test Accuracy: 34.40%


# 2.2 Fine-tuning and Testing BERT for Word Sense Disambiguation

In this section of the lab, you'll try out the transformer, specifically the BERT model. For this, we'll use the Hugging Face library ([https://huggingface.co/](https://huggingface.co/)).

You can find the documentation for the BERT model [here](https://huggingface.co/transformers/model_doc/bert.html) and a general usage guide [here](https://huggingface.co/transformers/v2.9.1/quickstart.html).

What we're going to do is *fine-tune* the BERT model, i.e., update the weights of a pre-trained model. That is, we have a model that is pre-trained on masked language modeling and next sentence prediction (kind of basic, general tasks which are useful for a lot of more specific tasks), but now we apply it to word sense disambiguation with the word representations it has learned.

We'll use the same data splits for training and testing as before, but this time you will use a different dataloader.

Now you create an iterator that collects N sentences (where N is the batch size) then use the BertTokenizer to transform the sentence into integers. For your dataloader, remember to:
* Shuffle the data in each batch
* Make sure you get a new iterator for each *epoch*
* Create a vocabulary of *sense-labels* so you can calculate accuracy 

We then pass this batch into the BERT model (you must have pre-loaded its weights) and update the weights (fine-tune). The BERT model will encode the sentence, then we send this encoded sentence into a prediction layer and collect what it outputs.

As input to the prediction layer, you are free to play with different types of information. For example, the expected way would be to use CLS representation. You can also use other representations and compare them.

About the hyperparameters and training:
* For BERT, usually a lower learning rate works best, between 0.0001-0.000001.
* BERT takes a lot of resources, running it on CPU will take ages, utilize the GPUs :)
* Since BERT takes a lot of resources, use a small batch size (4-8)
* Computing the BERT representation, make sure you pass the mask

**[12 marks]**

In [32]:
# def dataloader_for_bert(path_to_file, batch_size):
#     ...

import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import pandas as pd

class WSD_BERT_Dataset(Dataset):
    def __init__(self, dataframe, tokenizer, label_encoder):
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.contexts = dataframe['context'].tolist()
        self.labels = label_encoder.transform(dataframe['word_sense'])

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(self.contexts[idx],
                                  padding='max_length',
                                  truncation=True,
                                  max_length=128,
                                  return_tensors='pt')
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        # label = self.labels[idx]
        # label = int(self.labels[idx][0])
        label = int(self.labels[idx])
        return input_ids, attention_mask, label


def dataloader_for_bert(path_to_file, batch_size):
    df = pd.read_csv(path_to_file, sep='\t', header=None,
                     names=['word_sense', 'word_form', 'word_index', 'context'])

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    label_encoder = LabelEncoder()
    label_encoder.fit(df['word_sense'])

    dataset = WSD_BERT_Dataset(df, tokenizer, label_encoder)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return dataloader, label_encoder


In [33]:
# class BERT_WSD(nn.Module):
#     def __init__(self, ...):
#         # your code goes here
#         self.bert = ...
#         self.classifier = ...
    
#     def forward(self, batch):
#         # your code goes here
        
#         return predictions


from transformers import BertModel, BertForSequenceClassification

class BERT_WSD(nn.Module):
    def __init__(self, num_classes):
        super(BERT_WSD, self).__init__()
        # self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.bert = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels=num_classes)
        
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, batch):
        input_ids, attention_mask = batch
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # Use CLS token
        predictions = self.classifier(cls_output)
        return predictions


In [34]:
# loss_function = ...
# optimizer = ...
# model = ...

# for _ in range(epochs):
#     # train model
#     ...
    
# # test model after all epochs are completed

from sklearn.metrics import accuracy_score, f1_score
import torch
import torch.nn as nn
import torch.optim as optim
import os

def evaluate_bert_model(model, dataloader, device):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for input_ids, attention_mask, labels in dataloader:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            outputs = model((input_ids, attention_mask))
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            y_pred.extend(preds)
            y_true.extend(labels.numpy())

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    return acc, f1


def train_bert_model(path_to_file, batch_size=8, epochs=3, lr=1e-5, save_path="bert_wsd_best.pt"):
    from transformers import BertTokenizer
    dataloader, label_encoder = dataloader_for_bert(path_to_file, batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BERT_WSD(num_classes=len(label_encoder.classes_)).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_accuracy = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for input_ids, attention_mask, labels in dataloader:
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model((input_ids, attention_mask))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

        # Evaluate after each epoch
        acc, f1 = evaluate_bert_model(model, dataloader, device)
        print(f" → Eval Acc: {acc:.2%}, F1: {f1:.4f}")

        # Save best model
        if acc > best_accuracy:
            best_accuracy = acc
            torch.save(model.state_dict(), save_path)
            print(f" ✓ Best model saved to {save_path}")

    print(f"\nBest Accuracy: {best_accuracy:.2%}")
    return model, label_encoder




In [None]:
trained_model_bert = train_bert_model("wsd_data.txt", batch_size=4, epochs=1, lr=1e-5)

# 3. Evaluation

Explain the difference between the two LSTMs that you have implemented for word sense disambiguation.

Important note: your LSTMs should be nearly the same, but your linear layer must take different inputs. Describe why and how you think this difference will affect the performance of different LSTMs. How does the contextual representation of the whole sequence perform? How does the representation of the target word perform? What is better and for what situations? Why do we observe these differences?

What kind of representations are the different approaches using to predict word senses?

**[4 marks]**

Evaluate your model with per-word form *accuracy* and comment on the results you get. How does the model perform in comparison to the baseline, and how do the models compare to each other? 

Expand on the evaluation by sorting the word-forms by the number of senses they have. Are word forms with fewer senses easier to predict? Give a short explanation of the results you get based on the number of senses per word.

**[4 marks]**

How do the LSTMs perform in comparison to BERT? What's the difference between representations obtained by the LSTMs and BERT?

**[4 marks]**

What could we do to improve all WSD models that we have worked with in this assignment?

**[2 marks]**

# Readings

[1] Kågebäck, M., & Salomonsson, H. (2016). Word Sense Disambiguation using a Bidirectional LSTM. arXiv preprint arXiv:1606.03568.

[2] ON WSD: https://web.stanford.edu/~jurafsky/slp3/slides/Chapter18.wsd.pdf

## Statement of contribution

Briefly state how many times you have met for discussions, who was present, to what degree each member contributed to the discussion and the final answers you are submitting.

## Marks

This assignment has a total of 46 marks.