# A3: Word Embeddings and Language Modelling

Adam Ek

The lab is an exploration and learning exercise to be done in a group and also in discussion with the teachers and other students.

Write all your answers and the code in the appropriate boxes below.

In this lab we will explore constructing *static* word embeddings (i.e. word2vec) and building language models. We'll also evaluate these systems on intermediate tasks, namely word similarity and identifying "good" and "bad" sentences.

* For this we'll use pytorch. Some basic operations that will be useful can be found here: https://jhui.github.io/2018/02/09/PyTorch-Basic-operations
* In general: we are not interested in getting state-of-the-art performance :) focus on the implementation and not results of your model. For this reason, you can use a subset of the dataset: the first 5000-10 000 sentences or so, on linux/mac: ```head -n 10000 inputfile > outputfile```. 
* If possible, use the MLTGpu, it will make everything faster :)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# for gpu, replace "cpu" with "cuda:n" where n is the index of the GPU
device = torch.device('cpu')

In [2]:
from torch.utils.data import Dataset, DataLoader
from torchtext.vocab import Vocab
from collections import Counter, OrderedDict

# Word2Vec embeddings

In this first part we'll construct a word2vec model which will give us *static* word embeddings (that is, they are fixed after training).

After we've trained our model we will evaluate the embeddings obtained on a word similarity task.

## Formatting data


First we need to load some data, you can download the file on Canvas under files/assignments/03-lab-data/wiki-corpus.50000.txt. The file contains 50 000 sentences randomly selected from the complete wikipedia. Each line in the file contains one sentence. The sentences are whitespace tokenized.

Your first task is to create a dataset suitable for word2vec. That is, we define some ```window_size``` then iterate over all sentences in the dataset, putting the center word in one field and the context words in another (separate the fields with ```tab```).

For example, the sentece "this is a lab" with ```window size = 4``` will be formatted as:
```
center, context
---------------------
this    is a
is      this a lab
a       this is lab
lab     is a
```

this will be our training examples when training the word2vec model.

[3 marks]

In [3]:
data_path = '/srv/data/computational-semantics/03-lab-data/wiki-corpus.50000.txt'
WINDOW_SIZE = 4

with open(data_path) as f:
    corpus = [line for line in f]

def corpus_reader(data_path):
    with open(data_path) as f:
        dataset = []
        wz = WINDOW_SIZE  
        for line in f:
            words = line.split()
            for i in range(len(words)):
                center = words[i]
                if i < wz:
                    context = words[0:i]+words[i+1:i+wz+1]
                else:
                    context = words[i-wz:i]+words[i+1:i+wz+1]
                dataset.append(center + '\t' + ' '.join(context))
    return dataset
                
data = corpus_reader(data_path)

We sampled 50 000 senteces completely random from the *whole* wikipedia for our training data. Give some reasons why this is good, and why it might be bad. (*note*: We'll have a few questions like these, one or two reasons for and against is sufficient)

[2 marks]

### Loading the data

We now need to load the data in an appropriate format for torchtext (https://torchtext.readthedocs.io/en/latest/). We'll use PyText for this and it'll follow the same structure as I showed you in the lecture (remember to lower-case all tokens). Create a function which returns a (bucket)iterator of the training data, and the vocabulary object (```Field```). 

(*hint1*: you can format the data such that the center word always is first, then you only need to use one field)

(*hint2*: the code I showed you during the leture is available in /files/pytorch_tutorial/ on canvas)

[4 marks]

In [4]:
class CBOWDataset(Dataset):
    def __init__(self, data,
                 window_size=4,
                 unk_token='UNK',
                 pad_token='PAD'):
        
        self.window_size = window_size
        self.unk_token = unk_token
        self.pad_token = pad_token
        
        self.text_data = [line.lower().split('\t') for line in data]
        self.vocab = self.create_vocab()
        self.samples = self.create_samples()
        
        self.pad_index = self.vocab[self.pad_token]
        self.unk_index = self.vocab[self.unk_token]

    def create_vocab(self):
        tokens = set([self.pad_token, self.unk_token]+[line[0] for line in self.text_data])
        vocab = Vocab(OrderedDict([(token, 1) for token in tokens]), specials=[self.unk_token])
        return vocab
    
    def word2index(self,line):
        target, context = line
        t_index = self.vocab[target]
        c_indices = [self.vocab[word] for word in context.split()]
        padding = [self.vocab[self.pad_token]]*(self.window_size*2-len(c_indices)) 
        c_indices += padding
        return t_index, c_indices
    
    def create_samples(self):
        return [self.word2index(line) for line in self.text_data]
    
    
    def __getitem__(self, idx):
        
       # target_word, context = self.samples[idx]
       # context = torch.Tensor(context).long()
        return self.samples[idx]

    def __len__(self):
        return len(self.samples)
    

In [5]:
def get_data(data, batch_size):
    dataset = CBOWDataset(data)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=lambda x: x)
    return dataloader, dataset.vocab

In [6]:
# Check that it is working and the format is right:
dataloader, vocab = get_data(data, 4)
print('pad_index =', vocab['PAD'])
for i, batch in enumerate(dataloader):
    print('Batch size:', len(batch)) 
    for example in batch:
        print(example)
  # Only look at first batch.
    break

pad_index = 8982
Batch size: 4
(222, [50624, 51758, 59660, 78953, 8982, 8982, 8982, 8982])
(72373, [29207, 40930, 37960, 53652, 74678, 47651, 8982, 8982])
(59844, [22566, 76, 32390, 72373, 53652, 21729, 78737, 12959])
(72551, [78284, 39517, 72373, 15566, 18448, 14837, 32636, 222])


We lower-cased all tokens above; give some reasons why this is a good idea, and why it may be harmful to our embeddings.

[2 marks]

In [7]:
# no difference between sentence initial and others (probably good). 
# No difference between proper names and other words spelled the same, eg bill and Bill (probably bad, since they are different words).

## Word Embeddings Model

We will implement the CBOW model for constructing word embedding models.

In [8]:
import torch.optim as optim

In the CBOW model we try to predict the center word based on the context. That is, we take as input ```n``` context words, encode them as vectors, then combine them by summation. This will give us one embedding. We then use this embedding to predict *which* word in our vocabuary is the most likely center word. 

Implement this model 

[7 marks]

In [9]:
class CBOWModel(nn.Module):
    def __init__(self, vocab_dims, embedding_dims, pad_index):
        super(CBOWModel, self).__init__()
        
        if pad_index > 0:
            vocab_dims += 1
            
        self.embeddings = nn.Embedding(num_embeddings=vocab_dims,                                          
                                       embedding_dim=embedding_dims,
                                       padding_idx=pad_index)
        
        self.prediction = torch.nn.Linear(in_features=embedding_dims,
                                             out_features=vocab_dims, bias=False)
        
    
    def forward(self, context):
        embedded_context = self.embeddings(context)
        projection = self.projection_function(embedded_context)
        predictions = self.prediction(projection)
        
        return predictions
        
    def projection_function(self, xs):
        """
        This function will take as input a tensor of size (B, S, D)
        where B is the batch_size, S the window size, and D the dimensionality of embeddings
        this function should compute the sum over the embedding dimensions of the input, 
        that is, we transform (B, S, D) to (B, 1, D) or (B, D) 
        """
        xs_sum = xs.sum(dim=-2)
        return xs_sum

Now we need to train the models. First we define which hyperparameters to use. (You can change these, for example when *developing* your model you can use a batch size of 2 and a very low dimensionality (say 10), just to speed things up). When actually training your model *fo real*, you can use a batch size of [8,16,32,64], and embedding dimensionality of [128,256].

In [10]:
# you can change these numbers to suit your needs :)
word_embeddings_hyperparameters = {'epochs':3,
                                   'batch_size':16,
                                   'embedding_size':128,
                                   'learning_rate':0.001,
                                   'embedding_dim':10}

Train your model. Iterate over the dataset, get outputs from your model, calculate loss and backpropagate.

We mentioned in the lecture that we use Negative Log Likelihood (https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html) loss to train Word2Vec model. In this lab we'll take a shortcut when *training* and use Cross Entropy Loss (https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html), basically it combines ```log_softmax``` and ```NLLLoss```. So what your model should output is a *score* for each word in our vocabulary. The ```CrossEntropyLoss``` will then assign probabilities and calculate the negative log likelihood loss.

[3 marks]

In [11]:
# load data
dataset, vocab = get_data(data, word_embeddings_hyperparameters['batch_size'])

In [12]:
from tqdm.auto import tqdm

In [13]:
%%time

# build model and construct loss/optimizer
cbow_model = CBOWModel(len(vocab), word_embeddings_hyperparameters['embedding_dim'], vocab['PAD'])
cbow_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(cbow_model.parameters(), lr=word_embeddings_hyperparameters['learning_rate'])

CPU times: user 17.1 ms, sys: 2.07 ms, total: 19.1 ms
Wall time: 18.5 ms


In [15]:
# start training loop
total_loss = 0
for epoch in tqdm(range(word_embeddings_hyperparameters['epochs'])):
    for i, batch in tqdm(enumerate(dataset)):
        if i < 100:
        
            context = torch.Tensor([sample[1] for sample in batch]).long().to(device)
            target_word = torch.Tensor([sample[0] for sample in batch]).to(device)

            # send your batch of sentences to the model
            output = cbow_model(context)

            # compute the loss, you'll need to reshape the input
            # you can read more about this is the documentation for
            # CrossEntropyLoss
            loss = loss_fn(output, target_word.long())
            total_loss += loss.item()

            # print average loss for the epoch
            print(total_loss/(i+1), end='\r') 

            # compute gradients
            loss.backward()

            # update parameters
            optimizer.step()

            # reset gradients
            optimizer.zero_grad()
        print()


  0%|          | 0/3 [00:00<?, ?it/s]

0it [00:00, ?it/s]

12.76150894165039
12.89038372039795
12.808883984883627
12.688961505889893
12.659097480773926
12.522737503051758
12.46296228681292
12.5095055103302
12.519189304775661
12.506017875671386
12.56643971529874
12.5242018699646
12.513431989229643
12.516163212912423
12.588139152526855
12.57363760471344
12.530387485728545
12.54788260989719
12.516509658411929
12.55718445777893
12.539267040434337
12.525052070617676
12.522322737652322
12.538971106211344
12.55010684967041
12.528712896200327
12.498394153736255
12.51034699167524
12.503967778436069
12.508590507507325
12.48351078648721
12.484714686870575
12.481674136537494
12.477242273442885
12.47478471483503
12.47191940413581
12.451408798630172
12.447333963293778
12.441386760809483
12.433452463150024
12.426455125576112
12.41473320552281
12.398072553235432
12.386731386184692
12.387579239739312
12.396709545798924
12.406506274608855
12.399909655253092
12.404591540901029
12.395525436401368
12.396466853571873
12.386433674738957
12.379960653916845
12.3711693













































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































0it [00:00, ?it/s]

1234.4056339263916
623.2740316390991
419.40164915720624
317.5694308280945
256.4083881378174
215.63388061523438
186.46745954241072
164.66267585754395
147.72845522562662
134.07937068939208
122.9325662092729
113.72338724136353
105.90702409010667
99.16900525774274
93.38683414459229
88.28268456459045
83.75635814666748
79.75862084494696
76.18654175808555
72.9862072467804
70.08754807426816
67.44664356925271
65.00923011613929
62.78539458910624
60.74698097229004
58.86963756267841
57.13206955238625
55.50801512173244
54.00139127928635
52.59889316558838
51.297164855464814
50.04740393161774
48.88260269165039
47.80089698118322
46.77060579572405
45.798543479707504
44.87023113869332
43.986843033840785
43.13541889190674
42.349047493934634
41.5810759241988
40.863445213862825
40.18656433460324
39.531474178487606
38.905494414435495
38.32620222672172
37.76040706228703
37.202589650948845
36.680906723956674
36.169712696075436
35.68227958679199
35.215380943738495
34.764287210860346
34.33628223560475
33.922215




















































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































0it [00:00, ?it/s]

2391.055438041687
1201.3141465187073
804.5422312418619
606.1875610351562
487.12693786621094
407.80213753382367
351.28464344569613
308.76049077510834
275.7265728844537
249.3480320930481
227.7177384116433
209.710346698761
194.41666610424335
181.31590979439872
170.02719033559163
160.0765809416771
151.28684767554788
143.51578813129
136.54397914284155
130.26591720581055
124.62580717177619
119.4478127306158
114.74658854111381
110.44650340080261
106.46741374969483
102.80721220603355
99.43288114335802
96.24927864755902
93.29363359253982
90.5554706255595
87.97945419434578
85.58652153611183
83.29950953974868
81.1950987928054
79.20877947126117
77.3176121711731
75.53550421225059
73.82191708213405
72.21777077210255
70.68975975513459
69.23215124083728
67.83984558922904
66.532979122428
65.28833070668307
64.1000493367513
62.92992923570716
61.83009273447889
60.766290028889976
59.753580638340544
58.78713533401489
57.858485577153225
56.95397191781264
56.08240294906328
55.24044568450363
54.44418482346968




























































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































## Evaluating the model

We will evaluate the model on a dataset of word similarities, WordSim353 (http://alfonseca.org/eng/research/wordsim353.html , also avalable on Canvas under files/assignments/03-lab-data). The first thing we need to do is read the dataset and translate it to integers. What we'll do is to reuse the ```Field``` that records word indexes (the second output of ```get_data()```) and use it to parse the file.

The wordsim data is structured as follows:

```
word1 word2 score
...
```


The ```Field``` we got from ```read_data()``` has two built-in functions, ```stoi``` which maps a string to an integer and ```itos``` which maps an integer to a string. 

What our datareader needs to do is: 

```
for line in file:
    word1, word2, score = file.split()
    # encode word1 and word2 as integers
    word1_idx = vocab.vocab.stoi[word1]
    word2_idx = vocab.vocab.stoi[word2]
```

when we have the integers for ```word_1``` and ```word2``` we'll compute the similarity between their word embeddings with *cosine simlarity*. We can obtain the embeddings by querying the embedding layer of the model.

We calculate the cosine similarity for each word pair in the dataset, then compute the pearson correlation between the similarities we obtained with the scores given in the dataset. 

[4 marks]

In [19]:
# your code goes here

def read_wordsim(path, vocab, embeddings):
    dataset_sims = []
    model_sims = []
    with open(path) as f:
        for line in f:
            word1, word2, score = line.split()
            
            score = float(score)
            dataset_sims.append(score)
            
            # get the index for the word
            word1_idx = vocab.stoi[word1]
            word2_idx = vocab.stoi[word2]
            #embeddings = embeddings.to("cpu")
            
            # get the embedding of the word
            word1_emb = embeddings(torch.Tensor(word1_idx).long().to(device))
            word2_emb = embeddings(torch.Tensor(word2_idx).long().to(device))
            
            # compute cosine similarity, we'll use the version included in pytorch functional
            # https://pytorch.org/docs/master/generated/torch.nn.functional.cosine_similarity.html
            cosine_similarity = F.cosine_similarity(word1_emb, word2_emb)
            
            model_sims.append(cosine_similarity.item())
    
    return dataset_sims, model_sims

path = '/srv/data/computational-semantics/03-lab-data/wordsim_similarity_goldstandard.txt'
data, model = read_wordsim(path, vocab, cbow_model)
pearson_correlation = np.corrcoef(data, model)
            
# the non-diagonals give the pearson correlation,
print(pearson_correlation)

IndexError: index out of range in self

Do you think the model performs good or bad? Why?

[3 marks]

Select the 10 best and 10 worst performing word pairs, can you see any patterns that explain why *these* are the best and worst word pairs?

[3 marks]

Suggest some ways of improving the model we apply to WordSim353.

[3 marks]

If we consider a scenario where we use these embeddings in a downstream task, for example sentiment analysis (roughly: determining whether a sentence is positive or negative). 

Give some examples why the sentiment analysis model would benefit from our embeddnings and one examples why our embeddings could hur the performance of the sentiment model.

[3 marks]

# Language modeling

In this second part we'll build a simple LSTM language model. Your task is to construct a model which takes a sentence as input and predict the next word for each word in the sentence. For this you'll use the ```LSTM``` class provided by PyTorch (https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html). You can read more about the LSTM here: https://colah.github.io/posts/2015-08-Understanding-LSTMs/

NOTE!!!: Use the same dataset (wiki-corpus.50000.txt) as before.

Our setup is similar to before, we first encode the words as distributed representations then pass these to the LSTM and for each output we predict the next word.

For this we'll build a new dataloader with torchtext, the file we pass to the dataloader should contain one sentence per line, with words separated by whitespace.

```
word_1, ..., word_n
word_1, ..., word_k
...
```

in this dataloader you want to make sure that each sentence begins with a ```<start>``` token and ends with a ```<end>``` token, there is a keyword argument in ```Field``` for this :). But other than that, as before you read the dataset and output a iterator over the dataset and a vocabulary. 

Implement the dataloader, language model and the training loop (the training loop will basically be the same as for word2vec).

[12 marks]

In [None]:
# you can change these numbers to suit your needs as before :)
lm_hyperparameters = {'epochs':3,
                      'batch_size':16,
                      'learning_rate':0.001,
                      'embedding_dim':128,
                      'output_dim':128}

In [None]:
data_path = 'wiki-corpus.txt'
def get_data():
    # your code here, roughly the same as for the word2vec dataloader

In [None]:
class LM_withLSTM(nn.Module):
    def __init__(...):
        super(LM_withLSTM, self).__init__()
        self.embeddings = ...
        self.LSTM = ...
        self.predict_word = ...
    
    def forward(self, seq):
        embedded_seq = ...
        timestep_reprentation, *_ = ...
        predicted_words = ...
        
        return predicted_words

In [None]:
# load data
dataset, vocab = get_data(...)

# build model and construct loss/optimizer
lm_model = LM_withLSTM(len(vocab), 
                       lm_hyperparameters['embedding_dim'],
                       lm_hyperparameters['output_dim'])
lm_model.to(device)

loss_fn = CrossEntropyLoss()
optimizer = optim.Adam(cbow_model.parameters(), lr=lm_hyperparameters['lr'])

# start training loop
total_loss = 0
for epoch in range(lm_hyperparameters['epochs']):
    for i, batch in enumerate(dataset):
        
        # the strucure for each BATCH is:
        # <start>, w0, ..., wn, <end>
        sentence = batch.sentence
        
        # when training the model, at each input we predict the *NEXT* token
        # consequently there is nothing to predict when we give the model 
        # <end> as input. 
        # thus, we do not want to give <end> as input to the model, select 
        # from each batch all tokens except the last. 
        # tip: use pytorch indexing/slicing (same as numpy) 
        # (https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html#operations-on-tensors)
        # (https://jhui.github.io/2018/02/09/PyTorch-Basic-operations/)
        input_sentence = ...
        
        # send your batch of sentences to the model
        output = lm_model(input_sentence)
        
        # for each output, the model predict the NEXT token, so we have to reshape 
        # our dataset again. On timestep t, we evaluate on token t+1. That is,
        # we never predict the <start> token ;) so this time, we select all but the first 
        # token from sentences (that is, all the tokens that we predict)
        gold_data = ...
        
        # the shape of the output and sentence variable need to be changed,
        # for the loss function. Details are in the documentation.
        # You can use .view(...,...) to reshape the tensors  
        loss = loss_fn(...)
        total_loss += loss.item()
        
        # print average loss for the epoch
        print(total_loss/(i+1), end='\r') 
        
        # compute gradients
        ...
        
        # update parameters
        ...
        
        # reset gradients
        ...
    print()

### Evaluating the language model

We'll evaluate our model using the BLiMP dataset (https://github.com/alexwarstadt/blimp). The BLiMP dataset contains sets of linguistic minimal pairs for various syntactic and semantic phenomena, We'll evaluate our model on *existential quantifiers* (link: https://github.com/alexwarstadt/blimp/blob/master/data/existential_there_quantifiers_1.jsonl). This data, as the name suggests, investigate whether language models assign higher probability to *correct* usage of there-quantifiers. 

An example entry in the dataset is: 

```
{"sentence_good": "There was a documentary about music irritating Allison.", "sentence_bad": "There was each documentary about music irritating Allison.", "field": "semantics", "linguistics_term": "quantifiers", "UID": "existential_there_quantifiers_1", "simple_LM_method": true, "one_prefix_method": false, "two_prefix_method": false, "lexically_identical": false, "pairID": "0"}
```

Download the dataset and build a datareader (similar to what you did for word2vec). The dataset structure you should aim for is (you don't need to worry about the other keys for this assignment):

```
good_sentence_1, bad_sentence_1
...
```

your task now is to compare the probability assigned to the good sentence with to the probability assigned to the bad sentence. To compute a probability for a sentence we consider the product of the probabilities assigned to the *gold* tokens, remember, at timestep ```t``` we're predicting which token comes *next* e.g. ```t+1``` (basically, you do the same thing as you did when training).

In rough pseudo code what your code should do is:

```
accuracy = []
for good_sentence, bad_sentence in dataset:
    gs_lm_output = LanguageModel(good_sentence)
    gs_token_probabilities = softmax(gs_lm_output)
    gs_sentence_probability = product(gs_token_probabilities[GOLD_TOKENS])

    bs_lm_output = LanguageModel(bad_sentence)
    bs_token_probabilities = softmax(bs_lm_output)
    bs_sentence_probability = product(bs_token_probabilities[GOLD_TOKENS])

    # int(True) = 1 and int(False) = 0
    is_correct = int(gs_sentence_probability > bs_sentence_probability)
    accuracy.append(is_correct)

print(numpy.mean(accuracy))
    
```

[6 marks]

In [None]:
# your code goes here
import json

def evaluate_model(path, vocab, model):
    
    accuracy = []
    with open(path) as f:
        # iterate over one pair of sentences at a time
        for line in f:
            # load the data
            data = json.loads(line)
            good_s = data['sentence_good']
            bad_s = data['sentence_bad']
            
            # the data is tokenized as whitespace
            tok_good_s = ...
            tok_bad_s = ...
            
            # encode your words as integers using the vocab from the dataloader, size is (S)
            # we use unsqueeze to create the batch dimension 
            # in this case our input is only ONE batch, so the size of the tensor becomes: 
            # (S) -> (1, S) as the model expects batches
            enc_good_s = torch.tensor([_ for x in tok_good_s], device=device).unsqueeze(0)
            enc_bad_s = torch.tensor([_ for x in tok_bad_s], device=device).unsqueeze(0)
            
            # pass your encoded sentences to the model and predict the next tokens
            good_s = LM_withLSTM(enc_good_s)
            bad_s = LM_withLSTM(enc_bad_s)
            
            # get probabilities with softmax
            gs_probs = F.softmax(...)
            bs_probs = F.softmax(...)
            
            # select the probability of the gold tokens
            gs_sent_prob = find_token_probs(gs_probs, enc_good_s)
            bs_sent_prob = find_token_probs(bs_probs, enc_bad_s)
            
            accuracy.append(int(gs_sent_prob>bs_sent_prob))
            
    return accuracy
            
def find_token_probs(model_probs, encoded_sentece):
    probs = []

    # iterate over the tokens in your encoded sentence
    for token, gold_token in enumerate(encoded_sentece):
        # select the probability of the gold tokens and save
        # hint: pytorch indexing is helpful here ;)
        prob = ...
        probs.append(prob)
    sentence_prob = ...
    return sentence_prob

path = 'existential_there_quantifiers_1.jsonl'
accuracy = evaluate_model(path, ..., ...)

print('Final accuracy:')
print(np.round(np.mean(accuracy), 3))


### Analysis

Our model get some score, say, 55% correct predictions. Is this good? Suggest some *baseline* (i.e. a stupid "model" we hope ours is better than) we can compare the model against.

[3 marks]

Suggest some improvements you could make to your language model.

[3 marks]

Suggest some other metrics we can use to evaluate our system

[2 marks]

# Literature


Neural architectures:

[1] Y. Bengio, R. Ducharme, P. Vincent, and C. Janvin. A neural probabilistic language model. (Links to an external site.) Journal of Machine Learning Research, 3(6):1137–1155, 2003. (Sections 3 and 4 are less relevant today and hence you can glance through them quickly. Instead, look at the Mikolov papers where they describe training word embeddings with the current neural network architectures.)

[2] T. Mikolov, K. Chen, G. Corrado, and J. Dean. Efficient estimation of word representations in vector space. arXiv preprint arXiv:1301.3781, 2013.

[3] T. Mikolov, I. Sutskever, K. Chen, G. S. Corrado, and J. Dean. Distributed representations of words and phrases and their compositionality. In Advances in neural information processing systems, pages 3111–3119, 2013.
    


## Statement of contribution

Briefly state how many times you have met for discussions, who was present, to what degree each member contributed to the discussion and the final answers you are submitting.

## Marks

This assignment has a total of 63 marks.