# A4: LSTMs and Transformers for Word Sense Disambiguation

by Nikolai Ilinykh, Adam Ek, and others.

The lab is an exploration and learning exercise to be done in a group and also in discussion with the teachers and other students.

Write all your answers and the code in the appropriate boxes below.


A problem with static distributional vectors is the difficulty of distinguishing between different *word senses*. We will continue our exploration of word vectors by considering *trainable vectors* or *word embeddings* for Word Sense Disambiguation (WSD). We will work with both LSTMs and transformer models, e.g. BERT. The purpose of the assignment is to learn use representations neural models in a downstream task of word sense disambiguation.


## Word Sense Disambiguation Task

The goal of word sense disambiguation is to train a model to find the sense of a word (homonyms of a word-form). For example, the word "bank" can mean "sloping land" or "financial institution". 

(a) "I deposited my money in the **bank**" (financial institution)

(b) "I swam from the river **bank**" (sloping land)

In case a) and b), we can determine the meaning of "bank" based on the *context*. To utilize context in a semantic model, we use *contextualized word representations*.

Previously, we worked with *static word representations*, i.e., the representation does not depend on the context. To illustrate, we can consider sentences (a) and (b), where the word **bank** would have the same static representation in both sentences, which means that it becomes difficult for us to predict its sense. What we want is to create representations that depend on the context, i.e., *contextualized embeddings*.

As we have discussed in the class, contextualized representations can come in the form of pre-training the model for some "general" task and then fine-tuning it for some downstream task. Here we will do the following:

(1) Train and test LSTM model directly for word sense disambiguation. We will learn contextualized representations within this model.

(2) Take BERT that was pre-trained on masked language modeling and next sentence prediction. Fine-tune it on our data and test it for the word sense disambiguation on the task dataset. The idea for you is to explore how pre-trained contextualized representations from BERT can be updated and used for the downstream task of word sense disambiguation.

Your overall task in this lab is to create a neural network model that can disambiguate the word sense of 30 different words.

In [12]:
# first we import some packages that we need

# here add any package that you will need later

import torch
import torch.nn as nn
#import torchtext
import pandas as pd

# our hyperparameters (add more when/if you need them)
device = torch.device('cuda')

batch_size = 8
learning_rate = 0.000001
epochs = 5

# 1. Working with Data

A central part of any machine learning system is the data we're working with.

In this section, we will split the data (the dataset is in `wsd_data.txt`) into a training set and a test set.


## Data

The dataset we will use contains different word senses for 30 different words. The data is organized as follows (values separated by tabs), where each line is a separate item in the dataset:

- Column 1: word-sense, e.g., keep%2:42:07::
- Column 2: word-form, e.g., keep.v
- Column 3: index of word, e.g., 15
- Column 4: white-space tokenized context, e.g., Action by the Committee In pursuance of its mandate , the Committee will continue to keep under review the situation relating to the question of Palestine and participate in relevant meetings of the General Assembly and the Security Council . The Committee will also continue to monitor the situation on the ground and draw the attention of the international community to urgent developments in the Occupied Palestinian Territory , including East Jerusalem , requiring international action .


### Splitting the Data

Your first task is to separate the data into a *training set* and a *test set*.

The training set should contain 80% of the examples, and the test set the remaining 20%.

The examples for the test/training set should be selected **randomly**.

Save each dataset into a .csv file for loading later.

**[2 marks]**

In [13]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import csv

def data_split(dataset_path):
    with open (dataset_path) as file:
        data = file.readlines()
    train_split, test_split = train_test_split(data,test_size=0.2,random_state=42)
    #print(train_split[:20])

    with open('train_data.csv', 'w') as csvfile:
        writer = csv.writer(csvfile, quotechar=" ", quoting=csv.QUOTE_MINIMAL)
        for row in train_split:
            writer.writerow([row])
    with open('test_data.csv', 'w') as csvfile:
        writer = csv.writer(csvfile, quotechar=" ",quoting=csv.QUOTE_MINIMAL)
        for row in test_split:
            writer.writerow([row])

    return train_split,test_split

train_data,test_data = data_split('./wsd_data.txt')    #REMEMBER TO CHANGE BACK!!!!

### Creating a Baseline

Your second task is to create a *baseline* for the task.

A baseline is a "reality check" for a model. Given a very simple heuristic/algorithmic/model solution to the problem, can our neural network perform better than this? Baselines are important as they give us a point of comparison for the actual models. They are commonly used in NLP. Sometimes baseline models are not simple models but previous state-of-the-art.

In this exercise, we will have a simple baseline model that is the "most common sense" (MCS) baseline. For each word form, find the most commonly assigned sense to the word and label a word with that sense. In a fictional dataset, "bank" has two senses: "financial institution," which occurs 5 times, and "side of the river," which occurs 3 times. Thus, all 8 occurrences of "bank" are labeled "financial institution," yielding an MCS accuracy of 5/8 = 62.5%. If a model obtains a higher score than this, we can conclude that the model *at least* is better than selecting the most frequent word sense.

Your task is to write the code for this baseline, train, and test it. The baseline has the knowledge about labels and their frequency only from the train data. You evaluate it on the test data by comparing the ground-truth sense with the one that the model predicts. A good "dumb" baseline in this case is the one that performs quite badly. Expect the model to perform around 0.30 in terms of accuracy. You should use accuracy as your main metric; you can also compute the F1-score.

**[2 marks]**


In [22]:
import pandas as pd
import re

def mcs_baseline(data):
    wordform_dict = {}
    final_dict = {}
    train = pd.read_csv('train_data.csv',delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])
    for index, row in train.iterrows():
        if row["word_sense"] not in wordform_dict.keys():
            wordform_dict[row["word_sense"]] = 0
        wordform_dict[row["word_sense"]] += 1

    for index, row in train.iterrows():
        senses_list = []
        for word_sense,frequency in wordform_dict.items():
            if word_sense[:3] in row["word_form"]:
                senses_list.append((word_sense,frequency))
        final_dict[row["word_form"]] = senses_list

    for word_form,word_sense in final_dict.items():
        most_common_sense = 0
        most_common_word = " "
        for most_common in word_sense:
            if most_common[1] > most_common_sense:
                most_common_sense = most_common[1]
                most_common_word = most_common[0]
        final_dict[word_form] = most_common_word

    #getting the accuracy with always predicting the most common label 

    test = pd.read_csv('test_data.csv',delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])
    corrects = 0
    for index, row in test.iterrows():
        if row["word_sense"] in final_dict.values():
            corrects+=1
    return corrects/len(test["word_form"]) 

mcs_baseline(train_data)

  train = pd.read_csv('train_data.csv',delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])
  test = pd.read_csv('test_data.csv',delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])


0.30723208415516107

### Creating Data Iterators

To train a neural network, we first need to prepare the data. This involves converting words (and labels) to a number and organizing the data into batches. We also want the ability to shuffle the examples such that they appear in a random order.

Your task is to create a dataloader for the training and test set you created previously.

You are encouraged to adjust your own dataloader you built for previous assignments. Some things to take into account:

1. Tokenize inputs, keep a dictionary of word-to-IDs and IDs-to-words (vocabulary), fix paddings. You might need to consider doing these for each of the four fields in the dataset.
2. Your dataloader probably has a function to process data. Process each column in the dataset.
3. You might want to clean the data a bit. For example, the first column has some symbols, which might be unnecessary. It is up to you whether you want to remove them and clean this column or keep labels the way they are. In any case, you must provide an explanation of your decision and how you think it will affect the performance of your model. Data and its preprocessing matters, so motivate your decisions.
4. Organize your dataset into batches and shuffle them. You should have something akin to data iterators so that your model can take them.

Implement the dataloader and perform necessary preprocessings.

[**2 marks**]

In [14]:
from torch.utils.data import Dataset
from tqdm import tqdm

class WSD_Dataset(Dataset):
    
    def __init__(self, csv_file):
        
        self.file = pd.read_csv(csv_file, delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])

        self.senses = [sense.replace(":", "").replace("%", "") for sense in self.file['word_sense']]
        self.words = [line for line in self.file['word_form']]
        self.indices = [int(x) for x in self.file['index']]
        self.contexts = [line.lower().split() for line in self.file['context']]


        # vocab for senses
        self.sense_vocab_s2i = {s: num for num, s in enumerate(list(set(self.senses)))}
        self.sense_vocab_i2s = {num: s for num, s in enumerate(list(set(self.senses)))}
        
        
        # padding
        self.context_lengths = []
        for c in self.contexts:
            self.context_lengths.append(len(c))
        pad_number = max(self.context_lengths)
        
        self.padded_contexts = []
        for c in tqdm(self.contexts):
            while len(c) < pad_number:
                c.append('<pad>')
            self.padded_contexts.append(c)

        # create vocabulary for contexts
        self.tokens = list(set([token for line in self.padded_contexts for token in line]))
        # vocab for context
        self.context_w2i = {t: num for num, t in enumerate(self.tokens)}
        self.context_i2w = {num: t for num, t in enumerate(self.tokens)}
        
        # encode input contexts

        self.int_contexts = []
        for c in self.padded_contexts:
            encoded_c = []
            for w in c:
                encoded_c.append(self.context_w2i[w])
            self.int_contexts.append(encoded_c)

        self.int_senses = []
        for s in self.senses:
            self.int_senses.append([self.sense_vocab_s2i[s]])
            
            
        
    def __getitem__(self, idx):
        
        #print(self.int_contexts[idx])
        
        return ([self.int_contexts[idx]], self.int_senses[idx], self.indices[idx])
    
    def __len__(self):
        return len(self.int_senses)



In [15]:
dataset = WSD_Dataset('./train_data.csv')
test_dataset = WSD_Dataset('./test_data.csv')

  self.file = pd.read_csv(csv_file, delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])
100%|██████████████████████████████████████████████████████████████████████████| 60839/60839 [00:00<00:00, 79214.39it/s]
  self.file = pd.read_csv(csv_file, delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])
100%|██████████████████████████████████████████████████████████████████████████| 15210/15210 [00:00<00:00, 82335.83it/s]


In [64]:
len(dataset.int_contexts)

60839

In [16]:
from torch.utils.data import DataLoader


dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=8, shuffle=True)


In [67]:
len(dataset.tokens)

70237

# 2.1 LSTM for Word Sense Disambiguation

In this section, we will train an LSTM model to predict word senses based on *contextualized representations*.

You can read more about LSTMs [here](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).


### Model

We will use a **bidirectional** Long Short-Term Memory (LSTM) network to create a representation for the sentences and a **linear** classifier to predict the sense of each word.

As we discussed in the lecture, bidirectional LSTM is using **two** hidden states: one that goes in the left-to-right direction, and another one that goes in the right-to-left direction. PyTorch documentation on LSTMs can be found [here](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html). It says that if the bidirectional parameter is set to True, then "h_n will contain a concatenation of the final forward and reverse hidden states, respectively." Keep it in mind because you will have to ensure that your linear layer for prediction takes input of that size.

When we initialize the model, we need a few things:

1) An embedding layer: a dictionary from which we can obtain word embeddings
2) A LSTM-module to obtain contextual representations
3) A classifier that computes scores for each word-sense given *some* input

The general procedure is the following:

1) For each word in the sentence, obtain word embeddings
2) Run the embedded sentences through the LSTM
3) Select the appropriate hidden state
4) Predict the word-sense 

**Suggestion for efficiency:** *Use a low dimensionality (32) for word embeddings and the LSTM when developing and testing the code, then scale up when running the full training/tests*

Your tasks will be to create **two different models** (both follow the two outlines described above).

-----

Your first model should make a prediction from the LSTM's representation of the target word.

In particular, you run your LSTM on the context in which the target word is used. LSTM will produce a sequence of hidden states. Each hidden state corresponds to a single word from the input context. For example, you should be able to get 37 hidden states for a context that has 37 words/elements in it. Next, take the LSTM's representation of the target word. For example, it can be hidden state number 5, because the fifth word in your context is the target word that you want to predict the meaning for. This target's word representation is the input to your linear layer that makes the final prediction.

**[5 marks]**

In [17]:
class WSDModel_approach1(nn.Module):
    
    def __init__(self, vocab_size, hidden_size, embedding_dim, vocab_sense_size): #hidden size 512
        
        super(WSDModel_approach1,self).__init__()
        
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        
        self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers=1, bidirectional=True) #, batch_first=True)
        
        self.classifier = nn.Linear(in_features=hidden_size*2, out_features=vocab_sense_size)
        
    def forward(self, input_seqs):
        # Translate batches into embeddings
        
        embedded_context = self.embeddings(input_seqs.to(device))
                
        lstm_out, (h, c) = self.rnn(embedded_context.to(device))
        #print(lstm_out.shape)     
        target_hidden = torch.zeros((8, 1, 1024)) # batch_size, context, hidden_size
        for num, ind in enumerate(indices.to(device)):            
            target_hidden[num, :, :] = lstm_out[num, ind, :]    #problem possibly here?
        #print(target_hidden.shape)
        predictions = self.classifier(target_hidden.to(device))
              
        return predictions


In [18]:
model1 = WSDModel_approach1(len(dataset.tokens), 512, 512, len(dataset.sense_vocab_s2i))

In [19]:
#model 1

import numpy as np

for batch in dataloader:
    
    input_seqs = [torch.stack(elem) for elem in batch[0]]

        
    gold_labels = batch[1][0].to(device)
    indices = batch[2].long().to(device)
    #print(indices)
    #print("indicesshape",indices.shape)
    input_seqs = input_seqs[0].to(device)
    
    #print(input_seqs.shape)
    input_seqs = input_seqs.permute(1, 0).to(device)
    out = model1(input_seqs).squeeze(1)#.detach().numpy()
    print('out',out.shape)
    break

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Your second model should make a prediction from the final hidden state of your LSTM.

In particular, do the same first steps as in the first approach. But then to make a prediction with your linear layer, you will need to take the last hidden state that your LSTM produces for the whole sequence.

**[5 marks]**

In [22]:
class WSDModel_approach2(nn.Module):
    
    def __init__(self, vocab_size, hidden_size, embedding_dim, vocab_sense_size):
        
        super(WSDModel_approach2,self).__init__()
        
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        
        self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers=1, bidirectional=True, batch_first=True)
        
        self.classifier = nn.Linear(in_features=hidden_size*2, out_features=vocab_sense_size)
        
    def forward(self, input_seqs):
        # Translate batches into embeddings
        
        embedded_context = self.embeddings(input_seqs)
                
        lstm_out, (h, c) = self.rnn(embedded_context)
        #print("this is h.shape", h.shape)
        
        h = h.permute(1, 0, 2)
        h = h.reshape(8, 1024)   
        #print(h.shape)

        predictions = self.classifier(h)
              
        return predictions

In [23]:
model2 = WSDModel_approach2(len(dataset.tokens), 512, 512, len(dataset.sense_vocab_s2i))

In [31]:
# model 2 

import numpy as np

for batch in dataloader:
    
    input_seqs = [torch.stack(elem) for elem in batch[0]]

        
    gold_labels = batch[1][0]
    indices = batch[2].long()
    input_seqs = input_seqs[0]
    #print(len(input_seqs))
    #print(indices)
    input_seqs = input_seqs.permute(1, 0)
    
    
    out = model2(input_seqs).squeeze(1).detach().numpy()
    #print("outshape",out.shape)
    
    #print(out[0])

    max_prediction = np.argmax(out, axis=1)
    #print(max_prediction)
    
    #print(dataset.sense_vocab_i2s[max_prediction[0]])
    
    
    break
    

### Training and Testing the Model

Now we are ready to train and test our model. What we need now is a loss function, an optimizer, and our data. 

- First, create the loss function and the optimizer.
- Next, iterate over the number of epochs (i.e., how many times we let the model see our data). 
- For each epoch, iterate over the dataset to obtain batches. Use the batch as input to the model, and let the model output scores for the different word senses.
- For each model output, calculate the loss (and print the loss) on the output and update the model parameters.
- Reset the gradients and repeat.
- After all epochs are done, test your trained model on the test set and calculate the total and per-word-form accuracy of your model.

Implement the training and testing of the model.

**[4 marks]**

**Suggestion for efficiency:** *When developing your model, try training and testing the model on one or two batches (for each epoch) of data to make sure everything works! It's very annoying if you train for N epochs to find out that something went wrong when testing the model, or to find that something goes wrong when moving from epoch 0 to epoch 1.*

Do not forget to save your best models as .pickle files. The results should be reproducible for us to evaluate your models.


In [21]:
import torch.optim as optim
from torch.nn import CrossEntropyLoss
device = torch.device("cuda")
loss_function = CrossEntropyLoss()
model1 = model1.to(device)
optimizer = optim.AdamW(model1.parameters(),lr=0.00001)

#   training loop for the first model
for epoch in range(epochs):
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        input_seqs = [torch.stack(elem) for elem in batch[0]]  # contexts
        gold_labels = batch[1][0].to(device)   # gold label in the batch
        indices = batch[2].long().to(device)   # indices of the target words??
        input_seqs = input_seqs[0].to(device) # contexts
        input_seqs = input_seqs.permute(1, 0).to(device)
        outputs = model1(input_seqs)#.squeeze(1).to(device)
        outputs = outputs.view(outputs.size(0), -1)
        if gold_labels.shape[0] != outputs.shape[0]:
            print(f"Skipping batch due to shape mismatch: {gold_labels.shape} vs {outputs.shape}") #why is there a mismatch we are using dataloaders
            continue
        loss = loss_function(outputs.to(device), gold_labels.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss : {total_loss/len(dataloader)}')


# testing the model after all epochs are completed

correct = 0
with torch.no_grad():
    for batch in dataloader_test:
        input_seqs = [torch.stack(elem) for elem in batch[0]]   # contexts
        gold_labels = batch[1][0].to(device)  # gold label in the batch
        indices = batch[2].to(device).long()  # indices of the target words
        input_seqs = input_seqs[0].to(device)  # contexts
        input_seqs = input_seqs.permute(1, 0).to(device)
        outputs = model1(input_seqs).squeeze(1)    
        max_prediction = torch.argmax(outputs, dim=1)#.cpu().numpy()
        if max_prediction.size() != gold_labels.size():
            print(f"Skipping batch due to shape mismatch: {max_prediction.size()} vs {gold_labels.size()}") #also why is there mismatch here
            continue
        correct+=torch.sum(max_prediction == gold_labels).item()
    print("correct:",correct)
print("accuracy:", correct/len(dataloader_test))

    

Skipping batch due to shape mismatch: torch.Size([7]) vs torch.Size([8, 222])
Epoch 1, Loss : 1.786867834724618
Skipping batch due to shape mismatch: torch.Size([7]) vs torch.Size([8, 222])
Epoch 2, Loss : 1.5498475718545255
Skipping batch due to shape mismatch: torch.Size([7]) vs torch.Size([8, 222])
Epoch 3, Loss : 1.540198791563942
Skipping batch due to shape mismatch: torch.Size([7]) vs torch.Size([8, 222])
Epoch 4, Loss : 1.5327540545460427
Skipping batch due to shape mismatch: torch.Size([7]) vs torch.Size([8, 222])
Epoch 5, Loss : 1.5270262987577625
Skipping batch due to shape mismatch: torch.Size([8]) vs torch.Size([2])
correct: 214
accuracy: 0.11251314405888538


In [25]:
loss_function = CrossEntropyLoss()
model2 = model2.to(device)
optimizer = optim.AdamW(model2.parameters(),lr=0.0001)
#print(device)
#   training loop for the second model

for epoch in range(epochs):
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        input_seqs = [torch.stack(elem) for elem in batch[0]]   # contexts
        gold_labels = batch[1][0].to(device)   # gold label in the batch
        #print("gold_labels", gold_labels)
        #indices = batch[2].long()   # indices of the target words??
        input_seqs = input_seqs[0].to(device)  # contexts
        input_seqs = input_seqs.permute(1, 0).to(device)
        try:
            outputs = model2(input_seqs)
            #print("outputs:", outputs)
            loss = loss_function(outputs, gold_labels)
        #print("loss",loss)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        except:
            continue
    print(f'Epoch {epoch + 1}, Loss : {total_loss/len(dataloader)}')
    
# testing the model after all epochs are completed

correct = 0
with torch.no_grad():
    for batch in dataloader_test:
        input_seqs = [torch.stack(elem) for elem in batch[0]]   # contexts
        gold_labels = batch[1][0].to(device)   # gold label in the batch
        #print("gold:",dataset.sense_vocab_i2s[gold_labels[0].item()])
        indices = batch[2].long().to(device)   # indices of the target words??
        input_seqs = input_seqs[0].to(device)  # contexts
        input_seqs = input_seqs.permute(1, 0).to(device)
        try:
            outputs = model2(input_seqs)  
            max_prediction_tensor = torch.argmax(outputs, dim=1).cpu().numpy()
        #print("prediction",dataset.sense_vocab_i2s[max_prediction[0].item()])
            max_prediction_tensor = torch.tensor(max_prediction).to(device)
            correct+=int(torch.eq(max_prediction_tensor,gold_labels).sum().item())
        except:
            continue
    print("correct:",correct)
print("accuracy:", correct/len(dataloader_test))

Epoch 1, Loss : 4.660387977497269
Epoch 2, Loss : 4.159341004426192
Epoch 3, Loss : 3.5921727566480794
Epoch 4, Loss : 2.8516270564466146
Epoch 5, Loss : 2.022744782898945


  max_prediction_tensor = torch.tensor(max_prediction).to(device)


correct: 76
accuracy: 0.03995793901156677


# 2.2 Fine-tuning and Testing BERT for Word Sense Disambiguation

In this section of the lab, you'll try out the transformer, specifically the BERT model. For this, we'll use the Hugging Face library ([https://huggingface.co/](https://huggingface.co/)).

You can find the documentation for the BERT model [here](https://huggingface.co/transformers/model_doc/bert.html) and a general usage guide [here](https://huggingface.co/transformers/v2.9.1/quickstart.html).

What we're going to do is *fine-tune* the BERT model, i.e., update the weights of a pre-trained model. That is, we have a model that is pre-trained on masked language modeling and next sentence prediction (kind of basic, general tasks which are useful for a lot of more specific tasks), but now we apply it to word sense disambiguation with the word representations it has learned.

We'll use the same data splits for training and testing as before, but this time you will use a different dataloader.

Now you create an iterator that collects N sentences (where N is the batch size) then use the BertTokenizer to transform the sentence into integers. For your dataloader, remember to:
* Shuffle the data in each batch
* Make sure you get a new iterator for each *epoch*
* Create a vocabulary of *sense-labels* so you can calculate accuracy 

We then pass this batch into the BERT model (you must have pre-loaded its weights) and update the weights (fine-tune). The BERT model will encode the sentence, then we send this encoded sentence into a prediction layer and collect what it outputs.

As input to the prediction layer, you are free to play with different types of information. For example, the expected way would be to use CLS representation. You can also use other representations and compare them.

About the hyperparameters and training:
* For BERT, usually a lower learning rate works best, between 0.0001-0.000001.
* BERT takes a lot of resources, running it on CPU will take ages, utilize the GPUs :)
* Since BERT takes a lot of resources, use a small batch size (4-8)
* Computing the BERT representation, make sure you pass the mask

**[12 marks]**

In [7]:
from transformers import BertTokenizer


from torch.utils.data import Dataset
from tqdm import tqdm

class Bert_Dataset(Dataset):
    
    def __init__(self, csv_file):
        
        self.file = pd.read_csv(csv_file, delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])

        self.senses = [sense.replace(":", "").replace("%", "") for sense in self.file['word_sense']]
        self.words = [line for line in self.file['word_form']]
        self.indices = [int(x) for x in self.file['index']]
        self.contexts = [line for line in self.file['context']]
       

        # vocab for senses
        self.sense_vocab_s2i = {s: num for num, s in enumerate(list(set(self.senses)))}
        self.sense_vocab_i2s = {num: s for num, s in enumerate(list(set(self.senses)))}
        
        self.int_senses = []
        for s in self.senses:
           self.int_senses.append([self.sense_vocab_s2i[s]])

        self.unique_senses = []
        for s in self.senses:   
            if s not in self.unique_senses:
                self.unique_senses.append(s)
            
        
    def __getitem__(self, idx):
        context = self.contexts[idx]
        word_index = self.indices[idx]
        word_sense = self.int_senses[idx]
        
        return context ,word_index,word_sense
    
    def __len__(self):
        return len(self.contexts)
    
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')    
def bert_collate(data):
    contexts =[x[0] for x in data]
    word_indices= [x[1] for x in data]
    senses = [x[2] for x  in data]

    seqs = tokenizer.batch_encode_plus(
        contexts,
        add_special_tokens = True,
        padding = 'longest',
        return_tensors = 'pt'
        )
    adjusted_word_indices = []
    for idx in word_indices:
        adjusted_word_indices.append(idx+1)
    return seqs['input_ids'], seqs['attention_mask'], torch.tensor(adjusted_word_indices), torch.tensor(senses)



In [8]:
def dataloader_for_bert(path_to_file, batch_size, tokenizer):
    bert_dataset = Bert_Dataset(path_to_file)
    bert_dataloader = DataLoader(bert_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=bert_collate)
    return bert_dataset, bert_dataloader


In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_dataset,bert_dataloader = dataloader_for_bert('train_data.csv',batch_size=8,tokenizer=tokenizer)
bert_dataset_test,bert_dataloader_test = dataloader_for_bert('test_data.csv',batch_size=8,tokenizer=tokenizer)
    

  self.file = pd.read_csv(csv_file, delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])
  self.file = pd.read_csv(csv_file, delimiter='\\t',header=None,names=["word_sense", "word_form","index", "context"])


In [10]:
from transformers import BertModel
import random

#initializing the model with pretrained weights

class BERT_WSD(nn.Module):
    def __init__(self, no_of_labels):
        super(BERT_WSD,self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        self.classifier = nn.Linear(self.bert.config.hidden_size,no_of_labels)
    
    def forward(self,input_ids,mask,indices):
        outputs = self.bert(input_ids.to(device),mask.to(device))
        #So the hidden state size is (batch_size, seq_len, hidden_size) 
        seq_output = outputs.last_hidden_state.to(device)
        batch_size = seq_output.size(0)
        indices = indices.view(batch_size) 
        #print(indices)
        target = seq_output[torch.arange(batch_size),indices]
        
        predictions = self.classifier(target)
        return predictions

In [11]:
from torch.nn import CrossEntropyLoss
import torch.optim as optim
from sklearn.metrics import accuracy_score
loss_function = CrossEntropyLoss()
model = BERT_WSD(len(bert_dataset.unique_senses)).to(device)
optimizer = optim.AdamW(model.parameters(),lr=0.000001)

epochs = 5

model.train()
for epoch in range(epochs):
    total_loss = 0
    total_correct = 0
    total = 0

    for batch in bert_dataloader:
        input_ids,mask,indices,senses = batch
        input_ids.to(device)
        mask.to(device)
        indices.to(device)
        senses = torch.tensor(senses).view(-1)
        senses.to(device)
        optimizer.zero_grad()
        
            
        outputs = model(input_ids,mask,indices).to(device)
        #print(outputs)
        loss = loss_function(outputs.to(device),senses.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss : {total_loss/len(dataloader)}')

total_correct = 0
with torch.no_grad():
    for batch in bert_dataloader_test:
        input_ids,mask,indices,senses = batch
        input_ids.to(device)
        mask.to(device)
        indices.to(device)
        senses = torch.tensor(senses).view(-1)
        senses.to(device)
        outputs = model(input_ids,mask,indices).to(device)
        _,predicted = torch.max(outputs.to(device),1)
        total_correct += (predicted.to(device) == senses.to(device)).sum().item()
        total += senses.size(0)
    accuracy = total_correct/total
    print("correct:",total_correct)
    print(accuracy)
    

  senses = torch.tensor(senses).view(-1)


Epoch 1, Loss : 4.733702104300123
Epoch 2, Loss : 3.405372674335076
Epoch 3, Loss : 2.1753532283012995
Epoch 4, Loss : 1.5793246047227183
Epoch 5, Loss : 1.2809866848762375


  senses = torch.tensor(senses).view(-1)


correct: 5536
0.36397107166337933


# 3. Evaluation

Explain the difference between the two LSTMs that you have implemented for word sense disambiguation.

Important note: your LSTMs should be nearly the same, but your linear layer must take different inputs. Describe why and how you think this difference will affect the performance of different LSTMs. How does the contextual representation of the whole sequence perform? How does the representation of the target word perform? What is better and for what situations? Why do we observe these differences?

What kind of representations are the different approaches using to predict word senses?

**[4 marks]**

In [None]:
# LSTM MODEL 1 RESULTS: 
# ___________________________________
# Epoch 1, Loss : 1.786867834724618
# Epoch 2, Loss : 1.5498475718545255
# Epoch 3, Loss : 1.540198791563942
# Epoch 4, Loss : 1.5327540545460427
# Epoch 5, Loss : 1.5270262987577625

# correct: 214
# accuracy: 0.11251314405888538 = ~ 11 %

# *******************************************************

# LSTM MODEL 2 RESULTS:
# ____________________________________
# Epoch 1, Loss : 4.660387977497269
# Epoch 2, Loss : 4.159341004426192
# Epoch 3, Loss : 3.5921727566480794
# Epoch 4, Loss : 2.8516270564466146
# Epoch 5, Loss : 2.022744782898945

# correct: 76
# accuracy: 0.03995793901156677 = ~ 4 %

# ********************************************************

# DISCUSSION :
# ____________________________________
# Both of our LSTM models perform quite poorly, reaching at best an accuracy of 11%. We were somewhat surprised of the low performance!
# Nevertheless, the first model performs (quite significantly) better than the second one. Could the reason be that the overall context-representation that the second model uses can be too ambiguous / general?
# Also, it might be that specifying the target word's position in the first model aids the model to perform better.

# EXPLANATIONS : 
# _____________________________________
# Model 1 captures the context around the target word by using the hidden state from the LSTM at the target word's position. The LSTM processes the entire sentence, and the hidden state at the target word's index captures its context-specific representation. This hidden state is then passed to a linear classifier to predict the word sense, emphasizing the immediate surrounding context of the word.

# Model 2 captures the overall sentence context by using the final hidden state from the LSTM, which represents the entire input sequence. The LSTM processes the whole sentence, and the final hidden state reflects the global context. This final hidden state is then passed to a linear classifier to predict the word sense, focusing on the broader context of the entire sentence.

# Difference:
# Model 1 focuses on the hidden state at the target word's position, reflecting the immediate context around the target word. Model 2 uses the final hidden state of the LSTM, reflecting the context of the entire sentence.

# Performance (implications):
# Model 1: Better for tasks where the immediate context around the word is sufficient for disambiguation.
# Model 2: Better for tasks requiring an understanding of the entire sentence to disambiguate the word sense.

Evaluate your model with per-word form *accuracy* and comment on the results you get. How does the model perform in comparison to the baseline, and how do the models compare to each other? 

Expand on the evaluation by sorting the word-forms by the number of senses they have. Are word forms with fewer senses easier to predict? Give a short explanation of the results you get based on the number of senses per word.

**[4 marks]**

In [None]:
# Unfortunately we were not able to calculate per-word form accuracies.
# The amount of training data lead to heavy computations that we were forced to run on the server


How do the LSTMs perform in comparison to BERT? What's the difference between representations obtained by the LSTMs and BERT?

**[4 marks]**

In [None]:
#   BERT MODEL RESULTS (RUN IN THE MLTSERVER ):
# _________________________________________________
# Epoch 1, Loss : 4.733702104300123
# Epoch 2, Loss : 3.405372674335076
# Epoch 3, Loss : 2.1753532283012995
# Epoch 4, Loss : 1.5793246047227183
# Epoch 5, Loss : 1.2809866848762375
# correct: 5536
# ACCURACY : 0.36397107166337933 = ~ 36 %

# BERT's perfomance is significantly better than the two LSTM models'. This can be due to the fact that we are using a pre-trained model with pre-trained word embeddings.
# Differences in representations compared to LSTMs: fine-tuning representations that already exist, unsequential way BERT deals with the contexts, the concept of attention.
# We noticed that the BERT model performed better when more epochs are added and the learning rate is decreased.

What could we do to improve all WSD models that we have worked with in this assignment?

**[2 marks]**

In [None]:
# -Use pre-trained word embeddings instead of creating them from scratch.
# -Add more epochs.
# -Explore using dropout and mask (?)

# Readings

[1] Kågebäck, M., & Salomonsson, H. (2016). Word Sense Disambiguation using a Bidirectional LSTM. arXiv preprint arXiv:1606.03568.

[2] ON WSD: https://web.stanford.edu/~jurafsky/slp3/slides/Chapter18.wsd.pdf

## Statement of contribution

Briefly state how many times you have met for discussions, who was present, to what degree each member contributed to the discussion and the final answers you are submitting.

## Marks

This assignment has a total of 46 marks.

In [None]:
#   We worked on this assignment for two weeks almost every day. Most of our calls / meetings were ~4-5 hours.
#   Contributions were equal among the groupmembers, although Eleni Fysikoudi was the leading person of the group, doing a great deal of the troubleshooting and debugging.  