In [1]:
%%capture
!pip install transformers
!pip install datasets
!pip install conllu

In [3]:
import torch
import itertools
import numpy as np
import seaborn as sns
import torch.nn as nn
from functools import partial
import matplotlib.pyplot as plt
from datasets import load_dataset
from collections import defaultdict
from transformers import BertTokenizer, BertModel

# Bert model and other hyperparameters

In [4]:
# Hyperparameters
BERT_MODEL = 'dbmdz/bert-base-italian-xxl-cased'
DROPOUT_RATE = 0.2
MLP_SIZE = 300
INPUT_SIZE = 3*768
LEARNING_RATE = 0.001
EPOCHS = 30
SAVE_PATH = "saved_checkpoint.pt"
TOKENIZER_MAX_LEN = 150
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Arc-standard

Recall that a **configuration** of the arc-standard parser is a triple of the form $( \sigma, \beta, A)$
where:

* $\sigma$ is the stack;
* $\beta$ is the input buffer;
* $A$ is a set of arcs constructed so far.

We write $\sigma_i$, $i \geq 1$, for the $i$-th token in the stack; we also write $\beta_i$, $i \geq 1$, for the $i$-th token in the buffer. 

The parser can perform three types of **actions** (transitions):

* **shift**, which removes $\beta_1$ from the buffer and pushes it into the stack;
* **left-arc**, which creates the arc $(\sigma_1 \rightarrow \sigma_2)$, and removes $\sigma_2$ from the stack;
* **right-arc**, which creates the arc $(\sigma_2 \rightarrow \sigma_1)$, and removes $\sigma_1$ from the stack.

Let $w = w_0 w_1 \cdots w_{n}$ be the input sentence, with $w_0$ the special symbol `<ROOT>`.
Stack and buffer are implemented as lists of integers, where `j` represents word $w_j$.  Top-most stack token is at the right-end of the list; first buffer token is at the left-end of the list. 
Set $A$ is implemented as an array `arcs` of size $n+1$ such that if arc $(w_i \rightarrow w_j)$ is in $A$ then `arcs[j]=i`, and if $w_j$ is still missing its head node in the tree under construction, then `arcs[j]=-1`. We always have `arcs[0]=-1`.  We use this representation also for complete dependency trees.


In [5]:
import orig #file orig.py in the same foldr contains original class and function definitions

In [6]:
class ArcStandard (orig.ArcStandard):
    def __init__(self, sentence, label_set):
        self.sentence = sentence
        self.buffer = [i for i in range(len(self.sentence))]
        self.stack = []
        self.arcs = [(-1, -1) for _ in range(len(self.sentence))]# head - relation converted to int

        self.label_set = {y:x for x,y in label_set.items()} #for future reference (printing)
        # three shift moves to initialize the stack
        self.shift()
        self.shift()
        if len(self.sentence) > 2:
            self.shift()

    def left_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o2] = (o1, deprel) #added deprel
        self.stack.append(o1)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def right_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o1] = (o2, deprel) #added deprel
        self.stack.append(o2)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def print_configuration(self):
        s = [self.sentence[i] for i in self.stack]
        b = [self.sentence[i] for i in self.buffer]
        print("STACK: ", s, "BUFFER: ", b) #added indication of stack and buffer
        print([(x[0], self.label_set[x[1]]) for x in self.arcs])

In [7]:
class Oracle (orig.Oracle):
    def __init__(self, parser, gold_tree, gold_labels):
        self.parser = parser
        self.gold = gold_tree
        self.labels = gold_labels #must be integers; see below

    def get_gold_label(self):
        if(self.is_left_arc_gold()):
            o2 = self.parser.stack[len(self.parser.stack)-2]
            return self.labels[o2]
        elif(self.is_right_arc_gold()):
            o1 = self.parser.stack[len(self.parser.stack)-1]
            return self.labels[o1]
        else:
            raise Exception("Action SHIFT does not produce a labeled dependency")
            

## Functions to simplify the processing

In [8]:
"""
dataset should be a Dataset object from HuggingFace
Return a dictionary where arc labels are associate to integer values.
Label '_' is removed (used for composite tokens)
Labels are sorted in alphabetical order because MSE in the neural network will penalize less classes that are "close".
<ROOT> label is given value -1 (label of the <ROOT> node).
"""
def create_deprel_dict(dataset):
    labels = set()
    for x in dataset['deprel']:
        for elem in x:
            if(elem != '_'):
                labels.add(elem)
    labels = sorted(list(labels))
    ret = {labels[i]:i for i in range(len(labels))}
    ret.update({"None":-1})
    return ret

"""
Return (sanitized_tokens, gold tree, gold labels) for a sentence in the dataset, assuming gold tree is in a column labeled 'head' and labels in a column labeled 'deprel'.
Both token indices and labels are converted to integer (in the latter case, according to deprel_dict).
Sanitized tokens is a tokenlist with removal of composite tokens.
"""
def create_gold(sentence, deprel_dict):
    sanitized = ['<ROOT>']
    gold_tree = [-1]
    gold_labels = [-1]
    for i in range(len(sentence['tokens'])):
        if(sentence['head'][i] != 'None') and sentence['deprel'][i] in deprel_dict: #second condition is to avoid unseen or wrong labels
            sanitized.append(sentence['tokens'][i])
            gold_tree.append(int(sentence['head'][i]))
            gold_labels.append(deprel_dict[sentence['deprel'][i]])
    
    return sanitized, gold_tree, gold_labels 

## Testing parser and oracle...

In [9]:
%%capture
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")

In [10]:
deprels = create_deprel_dict(train_dataset)
sentence = train_dataset[3]

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)

print(tokens)
print(gold_tree)
print(gold_labels)

['<ROOT>', 'Inconsueto', 'allarme', 'a', 'la', 'Tate', 'Gallery', ':']
[-1, 2, 0, 5, 5, 2, 5, 2]
[-1, 4, 41, 8, 17, 31, 28, 40]


In [11]:
parser = ArcStandard(tokens, deprels)
oracle = Oracle(parser, gold_tree, gold_labels)

parser.print_configuration()

STACK:  ['<ROOT>', 'Inconsueto', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]


In [12]:
while not parser.is_tree_final():  # transition precedence implemented here
    if oracle.is_shift_gold():  
        parser.shift()
    elif oracle.is_left_arc_gold():
        parser.left_arc(oracle.get_gold_label())
    elif oracle.is_right_arc_gold():
        parser.right_arc(oracle.get_gold_label())
    
    parser.print_configuration()

STACK:  ['<ROOT>', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a'] BUFFER:  ['la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la'] BUFFER:  ['Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (5, 'det'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'No

## Testing the whole dataset

In [13]:
# remember to call orig.is_projective() to check for projectivity!!
sentence = train_dataset[10]
deprels = create_deprel_dict(train_dataset)

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)
orig.is_projective(gold_tree)

False

In [14]:
### NOTE: THIS CELL WAS COPIED WITH FEW MODIFICATIONS FROM THE ORIGINAL FILE

# oracle test: run the parser guided by the oracle on the entire training set 

non_projective = 0
correct = 0
wrong = 0

deprels = create_deprel_dict(train_dataset)

for sample in train_dataset:
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)

    if not orig.is_projective(gold_tree):
        non_projective += 1
        continue

    parser = ArcStandard(tokens, deprels)
    oracle = Oracle(parser, gold_tree, gold_labels)

    while not parser.is_tree_final():
        if oracle.is_left_arc_gold(): 
            parser.left_arc(oracle.get_gold_label())
        elif oracle.is_right_arc_gold():
            parser.right_arc(oracle.get_gold_label())
        elif oracle.is_shift_gold(): 
            parser.shift()

    for j in range(len(gold_tree)):  # comparing heads from parser and gold for actual sample
        if gold_tree[j] == parser.arcs[j][0] and gold_labels[j] == parser.arcs[j][1]: 
            correct += 1
        else:
            wrong += 1

print("non projective: ", non_projective)
print("correct: ", correct)
print("wrong: ", wrong)

non projective:  167
correct:  282533
wrong:  0


## Creating samples for the neural oracle

In [15]:
# Modified from the original
"""
This function processes a single sample, which is one sentence provided as an element of a Dataset object and returns 
    enc_sentence : a list of integers encoding the phrase #???? is this even correct for BERT???
    gold_path : a list of configurations
    gold_moves : a list of (move, label)
"""
def process_sample(sample, emb_dictionary, deprels, get_gold_path = False): #emb_dictionary and deprels are dictionaries of words and of dependency relations
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)
    #print("DEBUG:", tokens)
    enc_sentence = [emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"] for word in tokens]
    #print("DEBUG:", enc_sentence)
    
    # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
    gold_path = []   # record two topmost stack token and first buffer token for current step
    gold_moves = []  # oracle (canonical) move for current step: 100 is left, 0 right, -100 shift #motivations provided above TODO

    if get_gold_path:  # only for training
        parser = ArcStandard(enc_sentence, deprels)
        oracle = Oracle(parser, gold_tree, gold_labels)

        while not parser.is_tree_final():
            configuration = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])

            gold_path.append(configuration)
            
            if oracle.is_left_arc_gold():
                label_code = oracle.get_gold_label() #deprels[oracle.get_gold_label()]
                parser.left_arc(label_code) #labels 1-45 mean leftarc
                gold_moves.append(1+label_code) #note: I switched the instructions here for symmetry. There was no comment or good reason not to.
            elif oracle.is_right_arc_gold():
                label_code = oracle.get_gold_label() #deprels[oracle.get_gold_label()]
                parser.right_arc(1 + len(deprels) + label_code) #labels 46-90 mean rightarc
                gold_moves.append(1 + len(deprels) + label_code)
            elif oracle.is_shift_gold():
                parser.shift()
                gold_moves.append(0)
            else:
                print("**** AN ERROR OCCURRED ****")
                #print(tokens)
                #print(enc_sentence)
                #print(gold_tree)
                parser.print_configuration()
                raise Exception("No action identified as gold!! Please make sure your dataset doesn't contain nonprojective trees.")
                
            #print("DEBUG: ", parser.stack)
    
    #print("DEBUG: length of path: ", len(gold_path), "length of moves: ", len(gold_moves))
    return enc_sentence, gold_path, gold_moves, gold_tree, gold_labels


In [16]:
### Code to test process_sample
"""
test_sentence = train_dataset[88]
emb_dictionary = create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

process_sample(test_sentence, emb_dictionary, deprels, get_gold_path = True)
"""

'\ntest_sentence = train_dataset[88]\nemb_dictionary = create_dict(train_dataset)\ndeprels = create_deprel_dict(train_dataset)\n\nprocess_sample(test_sentence, emb_dictionary, deprels, get_gold_path = True)\n'

In [17]:
"""
Function that prepares the data for the model.
@return sentences : list of sentences in the dataset encoded according to emb_dictionary
@return paths : list of lists of configurations visited during the oracle-guided parsing (on training data) | empty list if get_gold_path is False
@return moves : list of lists of moves as a number performed during the oracle-guided parsing (on training data) | empty if get_gold_path is False
                                N.B. 0 = shift
                                        1-45 = leftarc + label
                                        46-90 = rightarc + label
@return trees : ground truth tree
@return labels : ground truth lables
"""
#modified form orig
def prepare_batch(batch_data, emb_dictionary, deprels, get_gold_path=False):
    data = [process_sample(s, emb_dictionary, deprels, get_gold_path=get_gold_path) for s in batch_data]
    # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
    sentences = [s[0] for s in data]
    paths = [s[1] for s in data]
    moves = [s[2] for s in data]
    trees = [s[3] for s in data]
    labels = [s[4] for s in data]
    return sentences, paths, moves, trees, labels

In [18]:
### code to test prepare_batch
"""
emb_dictionary = create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

_ = prepare_batch(iter(train_dataset), emb_dictionary, deprels, True)
"""

'\nemb_dictionary = create_dict(train_dataset)\ndeprels = create_deprel_dict(train_dataset)\n\n_ = prepare_batch(iter(train_dataset), emb_dictionary, deprels, True)\n'

In [19]:
"""
Copied from original.
Modified casing of pad and unk.
"""
def create_dict(dataset, threshold=3):
  dic = {}  # dictionary of word counts
  for sample in dataset:
    for word in sample['tokens']:
      if word in dic:
        dic[word] += 1
      else:
        dic[word] = 1 

  map = {}  # dictionary of word/index pairs
  map["[PAD]"] = 0
  map["<ROOT>"] = 1
  map["<unk>"] = 2

  next_indx = 3
  for word in dic.keys():
    if dic[word] >= threshold:
      map[word] = next_indx
      next_indx += 1

  return map

In [20]:
"""
Return the indices of the element to mantain in the dataset. This procedure was
created because some sentences present some dependecies never seen in the training
and will not processed correctly.
In our case the worng sentences have the following idx element: 10_new-83
"""
def remove_sentences(dataset, idx):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if x["idx"] != idx:
            ret.append(i)
    return ret

"""
Return indices of projective sentences in the dataset. Must be used with select() to filter out nonprojective trees.
"""
def find_projective_idx(dataset):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if orig.is_projective([-1] + [int(head) for head in x["head"] if head!='None']):
            ret.append(i)
    return ret

"""
Return indices of sentences shorter than TOKENIZER_MAX_LEN
"""
def find_long_idx(dataset):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if len(x["tokens"]) <= TOKENIZER_MAX_LEN: 
            ret.append(i)
    return ret

In [21]:
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")
dev_dataset = load_dataset('universal_dependencies', 'it_isdt', split="validation")
test_dataset = load_dataset('universal_dependencies', 'it_isdt', split="test")

train_dataset = train_dataset.select(find_projective_idx(train_dataset)) #to remove nonprojective trees
train_dataset = train_dataset.select(find_long_idx(train_dataset)) #to remove sentences too long

dev_dataset = dev_dataset.select(find_projective_idx(dev_dataset)) #to remove nonprojective trees
dev_dataset = dev_dataset.select(find_long_idx(dev_dataset)) #to remove sentences too long

test_dataset = test_dataset.select(remove_sentences(test_dataset, "10_new-83")) #to remove sentences that have labels that do not appear in the training set
test_dataset = test_dataset.select(find_projective_idx(test_dataset)) #to remove nonprojective trees
test_dataset = test_dataset.select(find_long_idx(test_dataset)) #to remove sentences too long

Reusing dataset universal_dependencies (/root/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/root/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/root/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)


In [22]:
emb_dictionary = create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

In [23]:
BATCH_SIZE = 32

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, emb_dictionary = emb_dictionary, deprels = deprels, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, emb_dictionary = emb_dictionary, deprels = deprels, get_gold_path=True))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, emb_dictionary = emb_dictionary, deprels = deprels, get_gold_path=True))

## Ideas for the future
* Use mean square error to predict pairs (MOVE, LABEL) in the training set
* We need to rescale the data in an acceptable range. Probably should make the MOVE have greater variation, e.g LeftArc=100 RightArc=0 Shift=-100
* BERT layer must process each sentence separately to produce word embeddings, but it should also be finetuned alongside the fully connected layer

# Neural Oracle

Neural oracle implemented using BERT fine-tunned with a simple classifiers.
The references are: [Towards Data Science](https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f) and [Text Classification | Sentiment Analysis with BERT using huggingface, PyTorch and Python Tutorial](https://www.youtube.com/watch?v=8N-nM3QW7O0).

The hyperparameters and the training tweaks are taken from the original BERT paper.

In [24]:
# load the BERT tokenizer with pretrained weights for the Italian language
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

Downloading:   0%|          | 0.00/230k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

In this section we have loaded the BERT tokenizer and try it to see how a sentence is processed and which kind of objects the tokenizer returns to be used in BERT.

In [None]:
example_text = '''Con il termine interfluvio in geologia si indica la porzione di superﬁcie più elevata
                  che separa due valli ﬂuviali adiacenti, che può essere una cresta oppure un' area ampia,
                  comunque non coinvolta dal movimento delle acque'''


bert_input = tokenizer(example_text, padding='max_length', max_length = 15, 
                       truncation=True, return_tensors="pt")

print(bert_input['input_ids'])      # words to integer
print(bert_input['token_type_ids']) # binary mask that identifies whether a token is in the first sentence (before [SEP]) or the second (after [SEP])
print(bert_input['attention_mask']) # binary mask that identifies whether a token is a real word or just padding

decoded_text = tokenizer.decode(bert_input.input_ids[0])
print(decoded_text)

tensor([[  102,   401,   162,  1909,   532, 27948,  2369,   139,  5101,  1783,
           223,  1731,   146, 16711,   103]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
[CLS] Con il termine interfluvio in geologia si indica la porzione [SEP]


  ## Oracle model using BERT
For the oracle we decided to use and fine-tune bert in a regression task.
The main idea is:

*   given a sentence get the **'Bert input tokens**' and the **'attention mask**' from the BERT tokenizer above;
*   pass the 'input tokens' and the 'attention mask' to the BERTOracle that returns a tensor that contains a list of configurations, one for each sentence.

In this setting a configuration is a pair of numbers. For this reason we use MSE loss to create a **multi-ouptut regressor** task.

One configuration is in the form (**ACTION, LABELS**) where for ACTION we have:

* 100 for **LeftArc** action;
* 0 for **RightArc** action;
* -100 for **Shift** action.

For the LABELS we have decided to order them in alphabetical order and convert them into consecutive numbers.
The main idea is: 

the labels represent the relation between words.<br>
In the used treebank there are many such relation that are very similar like *obj* and *iobj* so by ordering them and giving consecutive numbers we have that *obj* will be encoded, for instance, as 12 and *iobj* with 13. This means that, even if the oracle infere 13 instead of 12 we have that the prediction is not as wrong as if it will have predicte 50.<br>
The main assumption is that closer numbers represent more similar relation.

To perfrom the actual multi-output regression we used a simple linear layer that outputs this two values.

In [25]:
class BERTOracle(nn.Module):
    def __init__(self, input_size, deprels):
      super(BERTOracle, self).__init__()
      self.bert = BertModel.from_pretrained(BERT_MODEL)
      self.dropout = nn.Dropout(DROPOUT_RATE)
      self.w1 = nn.Linear(input_size, MLP_SIZE) #for classification: we need 2*sizeof(label_set)+1
      self.activation = torch.nn.Tanh()
      self.w2 = nn.Linear(MLP_SIZE, 2*len(deprels)+1)
      self.softmax = torch.nn.Softmax(dim=-1)

    #Substitute the "tokenized words" in the paths with the embeddings
    def substitute_embeddings(self, word_embeddings, paths):
      ffn_input = []
      zero_tensor = torch.zeros(768, requires_grad=False).to(DEVICE)
      for sentence_index in range(len(paths)): # one path for each sentence so len(paths) = number of sentences
        for configuration in paths[sentence_index]: # take the sentence in number 'sentence_index'
          new_tensor = torch.cat(
              [
                zero_tensor if configuration[0]==-1 else word_embeddings[sentence_index][configuration[0]], 
                zero_tensor if configuration[1]==-1 else word_embeddings[sentence_index][configuration[1]], 
                zero_tensor if configuration[2]==-1 else word_embeddings[sentence_index][configuration[2]]
              ]
            )
          ffn_input.append(new_tensor)
      ffn_input = torch.stack(ffn_input).to(DEVICE)
      return ffn_input

    #Since each sentence have a different number of configuration I pad the paths to have all the same number of configurations
    #Maybe its not required
    def pad_paths(arr):
      pad_token = [-1,-1,-1]
      padded = zip(*itertools.zip_longest(*arr, fillvalue=pad_token))
      return list(padded)

    def ffn_pass(self, configuration):
      out = self.dropout(configuration)
      out = self.w1(out)
      out = self.activation(out)
      out = self.dropout(out)
      out = self.w2(out)
      return self.softmax(out)

    def forward(self, input_tokens, attention_mask, paths):
      #Compute the BERT embeddings and retrieve them from the last hidden layer
      out = self.bert(input_tokens, attention_mask)    
      word_embeddings = out.last_hidden_state # [sentence_index, tokens, embedding_of_token]

      #Substitute each word in a configuration with the corrisponding embedding.
      #Be aware of: we need padding to have the configurations all the same size, maybe we need to reshape them to match (batch_size x config_per_sentence x 768*3)
      configurations = self.substitute_embeddings(word_embeddings, paths)

      #Pass through FFN
      return self.ffn_pass(configurations)
    ##################################################################################################################
    
    def infere(self, sentences):
        parsers = [ArcStandard(i, deprels) for i in sentences]
        #copied from below...
        list_sentences = [str(x) for x in sentences] 
        preprocessed_text = tokenizer(list_sentences, padding='max_length', max_length = TOKENIZER_MAX_LEN, truncation=True, return_tensors="pt")
        input_tokens = preprocessed_text["input_ids"].to(DEVICE)
        attention_mask = preprocessed_text["attention_mask"].to(DEVICE)
        
        h = self.bert(input_tokens, attention_mask).last_hidden_state

        while not self.parsed_all(parsers):
            configurations = self.get_configurations(parsers)
            mlp_input = self.substitute_embeddings(h, configurations) 
            mlp_out = self.ffn_pass(mlp_input)
            self.parse_step(parsers, mlp_out, deprels)

        return [parser.arcs for parser in parsers]
    
    """
    This function was copied from the original file
    """
    def get_configurations(self, parsers):
        configurations = []

        for parser in parsers:
          if parser.is_tree_final():
            conf = [-1, -1, -1]
          else:
            conf = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
            if len(parser.buffer) == 0:
              conf.append(-1)
            else:
              conf.append(parser.buffer[0])  
          configurations.append([conf])

        return configurations

    """
    This function was copied from the original file
    """
    def parsed_all(self, parsers):
        for parser in parsers:
            if not parser.is_tree_final():
                return False
        return True
    
    """
    Remember we use the convention
    0 = shift
    [1, len(deprels)] = left-arc
    [len(deprels)+1, 2*len(deprels)+1] = right-arc
    
    This function was adapted from the original to remove padding from the buffer avoiding all edge cases
    that complicated the original code.
    """
    def parse_step(self, parsers, moves, labels):
      moves_argm = moves.argmax(-1) #vector of argmaxes
      moves_second_argm = moves[:, 1:].argmax(-1) + 1 #needed when buffer is empty but argmax is 0 (shift). Contains the argmax for each sentence excluding index 0. 
                                                      #1 is added to account for slicing
      
      for i in range(len(parsers)):
          kind_of_move = 0 if moves_argm[i] == 0 else (1 + int(moves_argm[i] / len(labels))) #0 if shift, 1 if leftarc, 2 is rightarc
          label = moves_argm[i] % len(labels) #label of the arc, ignored if shift
          
          while parsers[i].buffer and parsers[i].buffer[-1] == 0: #while buffer not empty and we have padding in the buffer: remove it.
              parsers[i].buffer.pop()
              continue #a new prediction is needed

          if parsers[i].is_tree_final():
              continue

          if kind_of_move == 1: #predicted: leftarc
              parsers[i].left_arc(label)

          elif kind_of_move == 2: #predicted: rightarc
              parsers[i].right_arc(label)

          elif moves_argm[i] == 0: #predicted: shift
              if parsers[i].buffer: 
                  parsers[i].shift()
              else: #in case buffer is empty
                  if moves_second_argm[i]/len(labels) == 0:
                      parsers[i].left_arc(moves_second_argm % len(labels))
                  elif moves_second_argm[i]/len(labels) == 1:
                      parsers[i].right_arc(moves_second_argm % len(labels))
                  else:
                      raise Exception("moves_second_argm not in range [1, 2*len(labels)]")
                      
          else:
              raise Exception("argmax not in range [0, 2*len(labels)]")

In [26]:
#Initialize the BERTOracle model
model = BERTOracle(INPUT_SIZE, deprels)
model.to(DEVICE)

Downloading:   0%|          | 0.00/425M [00:00<?, ?B/s]

Some weights of the model checkpoint at dbmdz/bert-base-italian-xxl-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BERTOracle(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32102, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

## Training and evaluation functions

In [30]:
# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

#Define the loss as cross entropy
loss_fn = nn.CrossEntropyLoss().to(DEVICE) 

## Train and evaluation function for BERTOracle model

In [34]:
def train_epoch(model, data_loader, loss_fn, optimizer):
  model = model.train()
  losses = []
  for batch in data_loader:
    sentences, paths, moves, trees, labels = batch
    optimizer.zero_grad()                                             # initialize gradient to zeros
    list_sentences = [str(x) for x in sentences]                      # list of sentences in string format

    preprocessed_text = tokenizer(list_sentences, padding='max_length', max_length = TOKENIZER_MAX_LEN, 
                       truncation=True, return_tensors="pt")          # tokenize input using BERT tokenizer

    input_tokens = preprocessed_text["input_ids"].to(DEVICE)
    attention_mask = preprocessed_text["attention_mask"].to(DEVICE)
  
    outputs = model(input_tokens, attention_mask, paths)

    # Compute the loss for each configurations for each sentence in parallel
    tensor_moves = torch.tensor(sum(moves, [])).to(DEVICE)
    loss = loss_fn(outputs, tensor_moves)
    losses.append(loss.item())

    #backpropagation routine
    loss.backward()
    optimizer.step()

  mean_loss = np.mean(losses)

  return mean_loss

In [35]:
def eval_model(model, data_loader, loss_fn, optimzier):
  model = model.eval()                                                # dropout and batch_norm are not enabled

  losses = []
  
  with torch.no_grad():
    for batch in data_loader:
      sentences, paths, moves, trees, labels = batch
      optimizer.zero_grad()                                             # initialize gradient to zeros
      list_sentences = [str(x) for x in sentences]                      # list of sentences in string format

      preprocessed_text = tokenizer(list_sentences, padding='max_length', max_length = TOKENIZER_MAX_LEN, 
                        truncation=True, return_tensors="pt")           # tokenize input using BERT tokenizer

      input_tokens = preprocessed_text["input_ids"].to(DEVICE)
      attention_mask = preprocessed_text["attention_mask"].to(DEVICE)

      outputs = model(input_tokens, attention_mask, paths)

      # Compute the loss for each configurations for each sentence in parallel
      tensor_moves = torch.tensor(sum(moves, [])).to(DEVICE)
      loss = loss_fn(outputs, tensor_moves)
      losses.append(loss.item())

  mean_loss = np.mean(losses)

  return mean_loss

## Training loop
Training loop that for the given number of epoch preform the training step and the evaluation step for each sentence in the train_dataloader and val_dataloader.
All the results are krept inside a dictionary of history.

In [None]:
import gc
gc.collect()

history = defaultdict(list) # store training and validation losses in a dictionary 
epoch_to_save = 5           # save model every 5 epochs

for epoch in range(EPOCHS):
  print(f'Epoch {epoch + 1} / {EPOCHS}')

  #Training step
  train_loss = train_epoch(model, test_dataloader, loss_fn, optimizer)

  print(f"Train loss {train_loss}")

  #Evaluation step
  val_loss = eval_model(model, dev_dataloader, loss_fn, optimizer)

  print(f"Val loss {val_loss}")
  print('-'*10)

  history['train_loss'].append(train_loss)

  history['val_loss'].append(val_loss)

  if epoch % epoch_to_save == 0:
    torch.save(model, SAVE_PATH)
    print("|MODEL SAVED|")

## Evaluation

In [None]:
#Load the saved model's weights
model = torch.load(SAVE_PATH)
model.to(DEVICE)

In [None]:
"""
Functions were adapted from the ones in the original notebook
preds is a list of *couples*

evaluate() returns a couple of values (UAS, LAS)
"""
def evaluate(gold, preds): 
  total = 0
  correct_unlabeled = 0
  correct_labeled = 0

  for g, p in zip(gold, preds):
    for i in range(1,len(g)):
      total += 1
      if g[i][0] == p[i][0]:
        correct_unlabeled += 1
        if g[i][1] == p[i][1]:
          correct_labeled += 1
  return correct_unlabeled/total, correct_labeled/total

def validation(model, dataloader):
  model.eval()
  gold = []
  preds = []

  for batch in dataloader:
    sentences, paths, moves, trees, labels = batch
    with torch.no_grad():
      pred = model.infere(sentences) #now pred is made of (move, label) pairs
      
      gold += trees
      preds += pred
          
  return evaluate(gold, preds)


In [None]:
uas, las = validation(model, test_dataloader)