# WSD par fine-tuning d'un modèle *BERT

Import des librairies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import tqdm 
from random import shuffle
import os


## Les données "ASFALDA"


Il s'agit des données d'un FrameNet du français, comprenant environ 16000 annotations, pour environ 100 frames distincts.


### Chargement et Manipulation des données

In [None]:
## Fonction de chargement des données
def load_asfalda_data(gold_data_file, split_info_file, val_proportion=None): 
    s = open(split_info_file)
    lines = [ l[:-1].split('\t') for l in s.readlines() ]
    split_info_dic = { line[0]:line[1] for line in lines }

    sentences = {'dev':[], 'train':[], 'test':[]}
    tg_wrks = {'dev':[], 'train':[], 'test':[]}
    tg_lemmas = {'dev':[], 'train':[], 'test':[]}
    labels = {'dev':[], 'train':[], 'test':[]}

    max_sent_len = {'dev':0, 'train':0, 'test':0}
    max_tg_wrk = {'dev':0, 'train':0, 'test':0}

    stream = open(gold_data_file)
    for line in stream.readlines():
        if line.startswith('#'):
            continue
        line = line.strip()
        (sentid, tg_wrk, frame_name, tg_lemma, tg_pos, rest) = line.split('\t',5)
        sentence = rest.split("\t")[-1].split(' ')
        part = split_info_dic[sentid]
        tg_wrk = int(tg_wrk)

        l = len(sentence)
        sentences[part].append(sentence)
        labels[part].append(frame_name)
        tg_wrks[part].append(tg_wrk)
        tg_lemmas[part].append(tg_lemma)
        if max_sent_len[part] < l: 
            max_sent_len[part] = l 
        if max_tg_wrk[part] < tg_wrk: 
            max_tg_wrk[part] = tg_wrk 
    print("Longueur max des phrases:", max_sent_len)
    print("Rang max du target (en mots):", max_tg_wrk)
    
    if val_proportion:
        for dic in [sentences, tg_wrks, labels, tg_lemmas]:
            (dic['val'], dic['train']) = split_list(dic['train'], proportion=val_proportion)
    return sentences, tg_wrks, tg_lemmas, labels

## Fonction de découpage des données
def split_list(inlist, proportion=0.1, shuffle=False):
     """ partitions the input list of items (of any kind) into 2 lists, 
     the first one representing @proportion of the whole 
     
     If shuffle is not set, the partition takes one item every xxx items
     otherwise, the split is random"""
     n = len(inlist)
     size1 = int(n * proportion)
     if not(size1):
          size1 = 1
     print("SPLIT %d items into %d and %d" % (n, n-size1, size1))
     if shuffle:
          inlist = sample(inlist, n)
          return (inlist[:size1], inlist[size1:])
     else:
          divisor = int(n / size1)
          l1 = []
          l2 = []
          for (i,x) in enumerate(inlist):
               if i % divisor or len(l1) >= size1:
                    l2.append(x)
               else:
                    l1.append(x)
          return (l1,l2)

## Chemin des données
gold_data_file = './asfalda_data_for_wsd/sequoiaftb.asfalda_1_3.gold.uniq.nofullant.txt'
split_info_file = './asfalda_data_for_wsd/sequoiaftb_split_info'

## Chargement des données
sentences, tg_wrks, tg_lemmas, label_strs = load_asfalda_data(gold_data_file,
                                                              split_info_file, 
                                                              val_proportion=0.1)

## Manipulation des données                                                    
all_labels_strs = []
all_lemma_strs = []
for p in sentences.keys():
    all_labels_strs += label_strs[p]
    all_lemma_strs += tg_lemmas[p]
    avgl = sum([len(s) for s in sentences[p]])/len(sentences[p])
    print("%s : %d sentences, average lentgh=%.2f" 
          %(p, len(sentences[p]), avgl))

i2lemma = list(set(all_lemma_strs))
lemma2i = {x:i for i,x in enumerate(i2lemma)}


i2label = list(set(all_labels_strs))
label2i = {x:i for i,x in enumerate(i2label)}
i_OTHER_SENSE = label2i['Other_sense']

labels = {}
for p in label_strs.keys():
    labels[p] = [label2i[x] for x in label_strs[p]]





### Baseline MFS ("most frequent sense")

##### Calcul du dictionnaire des sens les plus communs

In [None]:
from collections import defaultdict
## Fonction de calcul de la fréquence des sens
def frequence(tg_lemmas,label_strs):
  dict_lemmes = defaultdict(lambda: defaultdict(int))
  dict_most_frequent = {}
  for lemme, sens in zip(tg_lemmas,label_strs):
      dict_lemmes[lemme][sens]+=1
  for key, value in dict_lemmes.items():
    max_item = max(value, key=lambda k: value[k])
    dict_most_frequent[key] = max_item
  return dict_lemmes, dict_most_frequent


## Calcul de la fréquence des sens pour chaque ensemble
dict_lemmes_train, dict_most_frequent_train = frequence(tg_lemmas['train']+tg_lemmas['val'],label_strs['train']+label_strs['val'])
dict_lemmes_val, dict_most_frequent_val = frequence(tg_lemmas['val'],label_strs['val'])
dict_lemmes_dev, dict_most_frequent_dev = frequence(tg_lemmas['dev'],label_strs['dev'])


##### Evaluation du modèle MFS

In [None]:
## Fonction de calcul de l'accuracy du modèle de base
def baseline(dict_most_frequent, tg_lemmas,label_strs,train_or_dev = 'train+val'):
  amount_correct = 0
  for lemme, sens in zip(tg_lemmas,label_strs):
    if sens == dict_most_frequent[lemme]:
      amount_correct +=1
  return 'Accuracy of baseline model on ' + train_or_dev + ' is:',amount_correct/len(tg_lemmas)*100


In [None]:
## Calcul de l'accuracy du modèle de base pour chaque ensemble
baseline(dict_most_frequent_train, tg_lemmas['train']+tg_lemmas['val'],label_strs['train']+label_strs['val'])
baseline(dict_most_frequent_dev, tg_lemmas['dev'],label_strs['dev'], 'dev')

Now we need to find the lemmas, targets and the pairs of target + lemma that are not in train+val sets but are present in dev set.

In [None]:
from itertools import chain

lemmes_inconnus = []
sens_inconnus = []

for lemma in tg_lemmas['dev']:
  if lemma not in chain(tg_lemmas['train'], tg_lemmas['val']):
    lemmes_inconnus.append(lemma)
for sens in label_strs['dev']:
  if sens not in chain(label_strs['train'], label_strs['val']):
    sens_inconnus.append(sens)



In [None]:
'The number of unknown lemmas in dev is', len(set(lemmes_inconnus))

In [None]:
'The number of unknown sens in dev is', len(set(sens_inconnus))

In [None]:
unknown_associations = []
for lemma in dict_lemmes_dev.keys():
  if lemma in dict_lemmes_train:
    associations_possibles_train = dict_lemmes_train[lemma].keys()
    associations_possibles_val = dict_lemmes_dev[lemma].keys()
    for association_val in associations_possibles_val:
      if association_val not in associations_possibles_train:
         unknown_associations.append((lemma,association_val))




In [None]:
'The number of unknown associations in dev is', len(set(unknown_associations))

## Modèle et tokenization de type *BERT

On va utiliser un modèle pré-entraîné de type *BERT, en passant par le module "transformers" d'huggingface.

In [None]:
## Installation des librairies nécessaires
try:
  import transformers
except ImportError:
  !pip install transformers

try:
    import sacremoses
except ImportError:
    !pip install sacremoses
    

from transformers import AutoModel, AutoTokenizer, AutoConfig

In [None]:
## Importation du modèle Flaubert
flaubert_tokenizer = AutoTokenizer.from_pretrained("flaubert/flaubert_base_cased")
flaubert_config = AutoConfig.from_pretrained("flaubert/flaubert_base_cased")

### Encodage des données (correspondance entre rang de mot et rang de token 




Pour pouvoir utiliser un modèle *BERT pré-entraîné, il faut
- utiliser la même tokenisation en tokens (potentiellement des sous-mots) que celle utilisée à l'entraînement du modèle
- convertir les séquences de tokens en séquences d'ids de tokens 
- et maintenir un lien entre les rangs de mot dans la phrase (dont le rang du target) et les rangs de tokens

In [None]:
## Création de la classe WSDEncoder
from nltk.tokenize import WordPunctTokenizer
class WSDEncoder:
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config 
    
    ## Création des phrases et des positions des mots cibles
    def new_ranks(self, sentences, tg_works):
      tg_trks = []
      phrases = []
      for sent, rank in zip(sentences, tg_works):
        has_seen = False
       
        word_gold = sent[rank]
        phrase = []
        for index, word in enumerate(sent):
          if word == word_gold and has_seen == False:
            tg_trks.append(len(phrase)+1)
            has_seen = True
          phrase.extend(flaubert_tokenizer.tokenize(word))
        phrases.append(phrase)
      
      return phrases, tg_trks
 
    ## Encodage des phrases
    def encode(self, sentences, tg_wrks, max_length=350, verbose=False, is_split_into_words=True):
      if is_split_into_words == False:
        sentences = [sentence.split() for sentence in sentences]
      phrases, first_trk_of_targets = self.new_ranks(sentences, tg_wrks)
      tokenized = []
      for phrase in phrases:
         tokenized.append(flaubert_tokenizer.encode(phrase,add_special_tokens = True,truncation = True, max_length = max_length, padding = 'max_length', pad_to_max_length = True))
      return tokenized, first_trk_of_targets

encoder = WSDEncoder(flaubert_tokenizer, flaubert_config)

test_sents = ["Conséquemment leurs codes comprendraient des erreurs .",
            "J' essaie de comprendre les transformers .",  
            "Il n' a pas bien compris le code !"]
test_tg_wrks = [3, 3, 5] 

print("Not add padding and not add special tokens : ")
tid_seqs, first_trk_of_targets = encoder.encode(test_sents, test_tg_wrks, max_length=100,is_split_into_words=False)
for tid_seq, ft in zip(tid_seqs, first_trk_of_targets):
    print("Len = %d target token rank = %d tid_seq = %s" % (len(tid_seq), ft, str(tid_seq))) 

#### Test encodage

In [None]:
encoder = WSDEncoder(flaubert_tokenizer, flaubert_config)
test_sents = [["Conséquemment", "leurs", "codes", "comprendraient", "des", "erreurs", "."]]
test_tg_wrks = [3]
tid_seqs, first_trk_of_targets = encoder.encode(test_sents, test_tg_wrks, max_length=20, verbose=True, is_split_into_words = True)

print('trgs',first_trk_of_targets)

for tid_seq, ft in zip(tid_seqs, first_trk_of_targets):
  print(flaubert_tokenizer.convert_ids_to_tokens(tid_seq))
  print("Len = %d target token rank = %d tid_seq = %s" % (len(tid_seq), ft, str(tid_seq)))

encoder.tokenizer.convert_ids_to_tokens(6090)


### Classe WSDData: encodage complet des données asfalda

In [None]:
import random
## Création de la classe WSDData
class WSDData:
    def __init__(self, corpus_type, sentences, tg_wrks, tg_lemmas, labels, encoder, max_length=350):
        
        self.corpus_type = corpus_type 
        self.size = len(sentences)
        self.encoder = encoder

        self.labels = labels       
        self.sentences = sentences 
        self.tg_lemmas = tg_lemmas 
        
        tid_seqs, first_trk_of_targets = encoder.encode(sentences, tg_wrks, max_length)
        self.tg_lemma_indexes = [lemma2i[lemma]for lemma in self.tg_lemmas]

        self.tid_seqs = tid_seqs  
        self.first_trk_of_targets  = first_trk_of_targets
        
        
    ## Mélange des données
    def shuffle(self):
      z = list(zip(self.labels, self.sentences, self.tg_lemma_indexes, self.tid_seqs, self.first_trk_of_targets))
      random.shuffle(z)
      labels, sentences, tg_lemma_indexes, tid_seqs,first_trk_of_targets = zip(*z)
      
      return labels, sentences, tg_lemma_indexes, tid_seqs,first_trk_of_targets

 
    ## Création des batches sous forme de générateur
    def make_batches(self, batch_size, shuffle_data=False):

        bstart = 0
        if shuffle_data:
          self.labels, self.sentences, self.tg_lemmas, self.tid_seqs, self.first_trk_of_targets = self.shuffle()
        N = len(self.labels)
        while bstart < len(self.labels):
          bend = min(bstart+batch_size,N)
          b_labels, b_tid_seqs, b_tg_trks, b_lemmas = self.labels[bstart:bend] , self.tid_seqs[bstart:bend] , self.first_trk_of_targets[bstart:bend], self.tg_lemma_indexes[bstart:bend] 
          assert(len(b_labels)==len(b_tid_seqs))
          yield (b_tid_seqs, b_tg_trks, b_labels, b_lemmas)
        
          bstart += batch_size

In [None]:
## Création des données d'entraînement, de validation et de test
MAX_LENGTH = 300
wsd_data = {}
for p in sentences.keys():
    print("Encodage de la partie %s ..." % p)
    wsd_data[p] = WSDData(p, sentences[p], tg_wrks[p], tg_lemmas[p], labels[p], 
                          encoder, max_length=MAX_LENGTH)
    for i, s in enumerate(wsd_data[p].tid_seqs):
        if len(s) != MAX_LENGTH:
            print("Size bug:", i, s)

## Classe WSDClassifier : le réseau de neurones pour la WSD

Architecture de base = 
- le modèle *BERT (ici FlauBERT)
- puis une couche linéaire + softmax

### Le réseau : architecture, propagation avant, évaluation

In [None]:
## Création de la classe MLP
class MLP(nn.Module):
  def __init__(self,input_size,output_size,hidden_size):
    super(MLP, self).__init__()
    self.encoder = nn.Linear(input_size,hidden_size)
    self.decoder = nn.Linear(hidden_size,output_size)
    self.activation = nn.Tanh()
  def forward(self,xinput):
    h = self.activation(self.encoder(xinput))
    return self.decoder(h)




In [None]:
## création de la classe WSDClassifier
class WSDClassifier(nn.Module):
    def __init__(self, num_labels, device='cpu', bert_model_name="flaubert/flaubert_base_cased", freeze_bert = True, use_mlp = False, hidden_size = 100,nbr_lemmas = len(lemma2i), lemma_embedding_size = 518, add_lemmas = False):
        super(WSDClassifier, self).__init__()
        self.device = device
        self.use_mlp = use_mlp
        self.add_lemmas = add_lemmas
        self.bert_layer = AutoModel.from_pretrained(bert_model_name,
                                                   ).to(self.device)
        self.bert_config = AutoConfig.from_pretrained(bert_model_name)
        if self.add_lemmas:
          self.hidden_size_bert = int(self.bert_config.hidden_size)+int(lemma_embedding_size)
          print('lemmahidden', self.hidden_size_bert)
        else:
          self.hidden_size_bert = int(self.bert_config.hidden_size)
        if add_lemmas:
          self.lemma_embedding = nn.Embedding(nbr_lemmas, lemma_embedding_size).to(self.device)
        if freeze_bert:
          for param in self.bert_layer.parameters():
            param.requires_grad = False
        if self.use_mlp:
          self.mlp = MLP(self.hidden_size_bert,num_labels,100).to(self.device)

        else:
          self.linear = torch.nn.Linear(self.hidden_size_bert,num_labels).to(self.device)
        self.softmax = torch.nn.LogSoftmax(dim = 1).to(self.device)
    
    ## Fonction de prédictions
    def forward(self, b_tid_seq, b_tg_trk, b_tg_lemma_indexes = None):
        embeddings_bert = self.bert_layer(b_tid_seq, return_dict = True).last_hidden_state.to(self.device)
        
        embeddings_bert_tgt = embeddings_bert[torch.arange(embeddings_bert.size(0)), b_tg_trk].to(self.device) #last hidden state, ranks. We extract the 
        if self.add_lemmas: #add lemma information
          lemma_embeddings = self.lemma_embedding(b_tg_lemma_indexes).to(self.device)
          embeddings_bert_tgt = torch.cat((embeddings_bert_tgt,lemma_embeddings), dim = 1).to(self.device)#If we use lemmas, we concat bert embeddings and information about lemmas
        
        if self.use_mlp:
          linear_tgt = self.mlp(embeddings_bert_tgt).to(self.device)
        else:
          linear_tgt = self.linear(embeddings_bert_tgt).to(self.device)
        soft_tgt = self.softmax(linear_tgt).to(self.device)
       
        return soft_tgt
    
    ## Fonction d'entraînement
    def run_on_dataset(self, wsd_data, optimizer, batch_size=32, validation_use = False):
        pred_labels = []
        val_losses = []
        batch_acc = []
        
        loss_function = nn.NLLLoss()
        self.eval()
        for b_tid_seqs, b_tg_trks, b_labels, b_lemma_idx in wsd_data.make_batches(32, shuffle_data=False):
          with torch.no_grad():
            b_tid_seqs = torch.tensor(b_tid_seqs, device=self.device).to(self.device)
            
            b_tg_trks = torch.tensor(b_tg_trks, device=self.device).to(self.device)
            b_labels = torch.tensor(b_labels, device=self.device).to(self.device)
            b_lemma_idx = torch.tensor(b_lemma_idx, device=self.device).to(self.device)
            log_probs = classifier(b_tid_seqs, b_tg_trks, b_lemma_idx).to(self.device)
        
            log_probs = self(b_tid_seqs, b_tg_trks,b_lemma_idx)
            b_pred_labels = torch.argmax(log_probs, dim=1).to(self.device)
            
            pred_labels.extend(b_pred_labels)
            
            if validation_use:
            
              loss = loss_function(log_probs,b_labels)
    
              val_losses.append(loss.item())
          
          batch_acc.append(self.evaluate(b_labels,b_pred_labels))
          


        
     
        return pred_labels, val_losses, mean(batch_acc), b_labels, batch_acc

    ## Fonction d'évaluation
    def evaluate(self, gold_labels, pred_labels):
        acc = float(torch.sum(gold_labels == pred_labels))/len(gold_labels)
        return acc



        


In [None]:
num_labels = len(i2label)
classifier = WSDClassifier(num_labels)

for name, param in classifier.named_parameters():
  print("PARAM named %s, of shape %s" % (name, str(param.shape)))
  print(param.requires_grad)
    


#### Test de la propagation avant

In [None]:


with torch.no_grad():
    classifier.eval()
    for b_tid_seqs, b_tg_trks, b_labels, _ in wsd_data['dev'].make_batches(32, shuffle_data=True):
        b_tid_seqs = torch.tensor(b_tid_seqs, device=classifier.device)
        b_tg_trks = torch.tensor(b_tg_trks, device=classifier.device)
        b_labels = torch.tensor(b_labels, device=classifier.device).to(classifier.device)
        print('input size : ',b_tid_seqs.size()," reference size : ",b_tg_trks.size())
        log_probs = classifier(b_tid_seqs, b_tg_trks)
        print('output size : ',log_probs.size()," reference size : ",b_labels.size())
       
        gold = b_labels[0] #.item()
        print("GOLD LABEL of first ex %d ( = %s)" % (gold, i2label[gold]))
        print("LOG_PROBS before training: %s\n\n" % str(log_probs[0]))
        break
        

        



### Entraînement : fine-tuning sur la tâche de WSD

In [145]:
from statistics import mean
import numpy as np

BATCH_SIZE = 32
LR = 0.0005
n_epochs = 20
import numpy as np
stop_early = 0
early_stopping_patience = 6


loss_function = nn.NLLLoss() 
optimizer = optim.Adam(classifier.parameters(), lr=LR)



with torch.no_grad():
    classifier.eval()
    for b_tid_seqs, b_tg_trks, b_labels, _ in wsd_data['dev'].make_batches(32, shuffle_data=True):
        b_tid_seqs = torch.tensor(b_tid_seqs, device=classifier.device)
        b_tg_trks = torch.tensor(b_tg_trks, device=classifier.device)
        b_labels = torch.tensor(b_labels, device=classifier.device).to(classifier.device)
        print('input size : ',b_tid_seqs.size()," reference size : ",b_tg_trks.size())
        log_probs = classifier(b_tid_seqs, b_tg_trks)
        print('output size : ',log_probs.size()," reference size : ",b_labels.size())
       
        gold = b_labels[0] #.item()
        print("GOLD LABEL of first ex %d ( = %s)" % (gold, i2label[gold]))
        print("LOG_PROBS before training: %s\n\n" % str(log_probs[0]))
        break
        

        

 = 'sequoiaftb.asfalda_1_3.wsd.lr' + 'Adam' + str(LR) + '_bs' + str(BATCH_SIZE)
out_model_file = './' + config_name + '.model'
out_log_file = './' + config_name + '.log'


train_losses = []
val_losses = []
val_accs = []
min_val_loss = None
epoch_losses = []
train_data = wsd_data['train']
val_data = wsd_data['val']
dev_data = wsd_data['dev']
classifier.train()
early_stopping_patience = 0 
stop_early = 6

print('Training.....')
acc = 0
for epoch in range(n_epochs):
  classifier.train()
  
  print('Training..... epoch nr: ', epoch)
  for b_tid_seqs, b_tg_trks, b_labels, _ in train_data.make_batches(64, shuffle_data=True):
    
  
    optimizer.zero_grad()
    b_tid_seqs = torch.tensor(b_tid_seqs, device=classifier.device).to(classifier.device)
    b_tg_trks = torch.tensor(b_tg_trks, device=classifier.device).to(classifier.device)
    b_labels = torch.tensor(b_labels, device=classifier.device).to(classifier.device)
    log_probs = classifier(b_tid_seqs, b_tg_trks).to(classifier.device)
  
    loss = loss_function(log_probs,  b_labels)
    
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())
  epoch_losses.append(sum(train_losses)/len(train_losses))
  
  pred_labels, val_losses, val_acc, _, _ = classifier.run_on_dataset(val_data, optimizer, batch_size=32, validation_use=True)
  print('--------')
  print('train loss: ',epoch_losses[-1],'val accuracy: ', val_acc)
  print('--------')
  if val_acc>=acc:
    acc=val_acc
  else:
    early_stopping_patience +=1
    if early_stopping_patience == stop_early:
      print('Stopping early...')
      break

print("train losses: %s" % ' / '.join([ "%.4f" % x for x in epoch_losses]))
print("val   losses: %s" % ' / '.join([ "%.4f" % x for x in val_losses]))



In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_losses)
plt.ylabel(n_epochs)
plt.show()

In [None]:
import matplotlib.pyplot as plt
plt.plot(val_accs)
plt.ylabel(n_epochs)
plt.show()

We then added option lemmas, MLP and weights to balance weights.
The results are seen below 

In [None]:
classifier_mlp_weights_lemmas = WSDClassifier(num_labels, use_mlp = True, hidden_size = 100,nbr_lemmas = len(lemma2i), lemma_embedding_size = 518, add_lemmas = True)

In [None]:
from statistics import mean
import numpy as np
stop_early = 0
early_stopping_patience = 6

BATCH_SIZE = 32
LR = 0.0005
n_epochs = 30
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
class_weight = compute_class_weight("balanced", np.unique(list(label2i.values())),list(label2i.values()))

class_weight = torch.tensor(class_weight, dtype=torch.float).to(classifier.device)

loss_function = nn.NLLLoss(weight = class_weight) 
optimizer = optim.Adam(classifier_mlp_weights_lemmas.parameters(), lr=LR)

config_name = 'sequoiaftb.asfalda_1_3.wsd.lr' + 'Adam' + str(LR) + '_bs' + str(BATCH_SIZE)
out_model_file = './' + config_name + '.model'
out_log_file = './' + config_name + '.log'


train_losses = []
val_losses = []
val_accs = []
min_val_loss = None
epoch_losses = []
train_data = wsd_data['train']
val_data = wsd_data['val']
dev_data = wsd_data['dev']

val_accs = []

print('Training.....')
acc = 0
for epoch in range(n_epochs):  
  classifier_mlp_weights_lemmas.train()
  
  print('acc',acc)
  print('Training..... epoch nr: ', epoch)
  for b_tid_seqs, b_tg_trks, b_labels, b_lemma_idx in train_data.make_batches(64, shuffle_data=True):
    
  
    optimizer.zero_grad()
    b_tid_seqs = torch.tensor(b_tid_seqs, device=classifier_mlp_weights_lemmas.device).to(classifier_mlp_weights_lemmas.device)
    b_tg_trks = torch.tensor(b_tg_trks, device=classifier_mlp_weights_lemmas.device).to(classifier_mlp_weights_lemmas.device)
    b_labels = torch.tensor(b_labels, device=classifier_mlp_weights_lemmas.device).to(classifier_mlp_weights_lemmas.device)
    b_lemma_idx = torch.tensor(b_lemma_idx, device=classifier_mlp_weights_lemmas.device).to(classifier_mlp_weights_lemmas.device)
    log_probs = classifier_mlp_weights_lemmas(b_tid_seqs, b_tg_trks, b_lemma_idx).to(classifier_mlp_weights_lemmas.device)
  
    loss = loss_function(log_probs,  b_labels)
    
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())
  epoch_losses.append(sum(train_losses)/len(train_losses))
 
  pred_labels, val_losses, val_acc, _, _ = classifier_mlp_weights_lemmas.run_on_dataset(val_data, optimizer, batch_size=32, validation_use=True)
  val_accs.append(val_acc)
  print('--------')
  print('train loss: ',epoch_losses[-1],'val accuracy: ', val_acc)
  print('--------')
 
  if val_acc>=acc:
    acc=val_acc
  else:
    early_stopping_patience +=1
    if early_stopping_patience == stop_early:
      print('Stopping early...')
      break

print("train losses: %s" % ' / '.join([ "%.4f" % x for x in epoch_losses]))
print("val   losses: %s" % ' / '.join([ "%.4f" % x for x in val_losses]))



In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_losses)
plt.ylabel(n_epochs)
plt.show()

In [None]:
import matplotlib.pyplot as plt
plt.plot(val_accs)
plt.ylabel(n_epochs)
plt.show()

Below is a variant of classifier using class weight and MLP layer but NOT lemmas

In [None]:
classifier_mlp_weights = WSDClassifier(num_labels, use_mlp = True, hidden_size = 100,nbr_lemmas = len(lemma2i), lemma_embedding_size = 518, add_lemmas = False)

In [None]:
from statistics import mean
import numpy as np
stop_early = 0
early_stopping_patience = 6

BATCH_SIZE = 32
LR = 0.0005
n_epochs = 25
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
class_weight = compute_class_weight("balanced", np.unique(list(label2i.values())),list(label2i.values()))

class_weight = torch.tensor(class_weight, dtype=torch.float).to(classifier.device)

loss_function = nn.NLLLoss(weight = class_weight) 
optimizer = optim.Adam(classifier_mlp_weights.parameters(), lr=LR)

config_name = 'sequoiaftb.asfalda_1_3.wsd.lr' + 'Adam' + str(LR) + '_bs' + str(BATCH_SIZE)
out_model_file = './' + config_name + '.model'
out_log_file = './' + config_name + '.log'


train_losses = []
val_losses = []
val_accs = []
min_val_loss = None
epoch_losses = []
train_data = wsd_data['train']
val_data = wsd_data['val']
dev_data = wsd_data['dev']
classifier_mlp_weights.train()

print('Training.....')
acc = 0
for epoch in range(n_epochs):
  classifier_mlp_weights.train()
  
  print('acc',acc)
  print('Training..... epoch nr: ', epoch)
  for b_tid_seqs, b_tg_trks, b_labels, b_lemma_idx in train_data.make_batches(64, shuffle_data=True):
    
  
    optimizer.zero_grad()
    b_tid_seqs = torch.tensor(b_tid_seqs, device=classifier_mlp_weights.device).to(classifier_mlp_weights.device)
    b_tg_trks = torch.tensor(b_tg_trks, device=classifier_mlp_weights.device).to(classifier_mlp_weights.device)
    b_labels = torch.tensor(b_labels, device=classifier_mlp_weights.device).to(classifier_mlp_weights.device)
    b_lemma_idx = torch.tensor(b_lemma_idx, device=classifier_mlp_weights.device).to(classifier_mlp_weights.device)
    log_probs = classifier_mlp_weights(b_tid_seqs, b_tg_trks, b_lemma_idx).to(classifier_mlp_weights.device)
  
    loss = loss_function(log_probs,  b_labels)
    
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())
  epoch_losses.append(sum(train_losses)/len(train_losses))
  
  pred_labels, val_losses, val_acc, _, _ =classifier_mlp_weights.run_on_dataset(val_data, optimizer, batch_size=32, validation_use=True)
  val_accs.append(val_acc)
  print('--------')
  print('train loss: ',epoch_losses[-1],'val accuracy: ', val_acc)
  print('--------')
 
  if val_acc>=acc:
    acc=val_acc
  else:
    early_stopping_patience +=1
    if early_stopping_patience == stop_early:
      print('Stopping early...')
      break

print("train losses: %s" % ' / '.join([ "%.4f" % x for x in epoch_losses]))
print("val   losses: %s" % ' / '.join([ "%.4f" % x for x in val_losses]))



In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_losses)
plt.ylabel(n_epochs)
plt.show()

In [None]:
import matplotlib.pyplot as plt
plt.plot(val_accs)
plt.ylabel(n_epochs)
plt.show()

### Evaluation

In [None]:
test_data = wsd_data['test']

**Normal version of classifier**

In [None]:
pred_labels, val_losses, dev_acc, b_labels, _ = classifier.run_on_dataset(dev_data, optimizer, batch_size=32, validation_use=False)
pred_labels, val_losses, test_acc, b_labels, _ = classifier.run_on_dataset(test_data, optimizer, batch_size=32, validation_use=False)

In [None]:
print('dev acc: ', dev_acc, 'test acc:', test_acc)

**Classifier with mlp, lemmas and weights**

In [None]:
pred_labels, val_losses, dev_acc1, b_labels, _ = classifier_mlp_weights_lemmas.run_on_dataset(dev_data, optimizer, batch_size=32, validation_use=False)
pred_labels, val_losses, test_acc1, b_labels, _ = classifier_mlp_weights_lemmas.run_on_dataset(test_data, optimizer, batch_size=32, validation_use=False)

In [None]:
print('dev acc: ', dev_acc1, 'test acc:', test_acc1)

**Classifier with weights**

In [None]:
pred_labels, val_losses, dev_acc2, b_labels, _ = classifier_weights.run_on_dataset(dev_data, optimizer, batch_size=32, validation_use=False)
pred_labels, val_losses, test_acc2, b_labels, _ = classifier_weights.run_on_dataset(test_data, optimizer, batch_size=32, validation_use=False)

In [None]:
print('dev acc: ', dev_acc2, 'test acc:' , test_acc2)

**Classifier mlp and weights**

In [None]:
pred_labels, val_losses, dev_acc3, b_labels, _ = classifier_mlp_weights.run_on_dataset(dev_data, optimizer, batch_size=32, validation_use=False)
pred_labels, val_losses, test_acc3, b_labels, _ = classifier_mlp_weights.run_on_dataset(test_data, optimizer, batch_size=32, validation_use=False)

In [None]:
print('dev acc: ', dev_acc3, 'test acc:', test_acc3)