# Dataset analysis



In [1]:
!pip install datasets  # huggingface library with dataset
!pip install conllu    # aux library for processing CoNLL-U format

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Collec

In [2]:
import time
import random
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from datasets import load_dataset

# Description of baseline and BERT model

## Arc-eager

A configuration of the arc-eager parser is a triple of the form $(σ, β, A)$ where:

* $\sigma$ is the stack
* $\beta$ is the buffer
* $A$ is the set of arcs constructed so far

Let:

* $\beta_i$, $i\geq1$ the $i$-th token in the buffer
* $\sigma_i$, $i\geq1$ the $i$-th token in the stack

for the $i$-th configuration

The arc-eager parser can perform four types of actions (transitions):

* **left-arc** (LA): create the arc $(\beta_1 → \sigma_1)$ and remove $\sigma_1$ from the stack. The **preconditions** are: $\sigma_1$ is not the ROOT and $\sigma_1$ does not have already an head

* **right-arc** (RA): create the arc $(\sigma_1 → \beta_1)$ and push $\beta_1$ to the stack

* **reduce** (RE): remove $\sigma_1$ from the stack. The **precondition** is: $\sigma_1$ must have a head

* **shift** (SH): remove $\beta_1$ from the buffer and push it to the stack


Let $w = w_0 w_1 \cdots w_{n}$ be the input sentence, with $w_0$ the special symbol `<ROOT>`.
Stack and buffer are implemented as lists of integers, where `j` represents word $w_j$.  Top-most stack token is at the right-end of the list; first buffer token is at the left-end of the list. 
Set $A$ is implemented as an array `arcs` of size $n+1$ such that if arc $(w_i \rightarrow w_j)$ is in $A$ then `arcs[j]=i`, and if $w_j$ is still missing its head node in the tree under construction, then `arcs[j]=-1`. We always have `arcs[0]=-1`.  We use this representation also for complete dependency trees.

In [3]:
# NOTE: here only the transition methods are implemented
# the preconditions will be checked later with the boolean methods

class ArcEager:
  def __init__(self, sentence):
    self.sentence = sentence
    self.buffer = [i for i in range(len(self.sentence))]  # buffer
    self.stack = []   # empty stack
    self.arcs = [-1 for _ in range(len(self.sentence))]  # non-valid arcs

    # initialization of the stack with one SH operations
    # only the <ROOT> in the stack
    self.shift()

  # left arc
  def left_arc(self):
    b1 = self.buffer[0]
    o1 = self.stack.pop()
    self.arcs[o1] = b1

  # right arc
  def right_arc(self):
    b1 = self.buffer[0]
    o1 = self.stack[-1]
    self.arcs[b1] = o1
    self.buffer = self.buffer[1:]
    self.stack.append(b1)
    
  # reduce
  def reduce(self):
    _ = self.stack.pop()

  # shift
  def shift(self):
    b1 = self.buffer[0]
    self.buffer = self.buffer[1:]
    self.stack.append(b1)

  # print configuration (for debug)
  def print_configuration(self):
    s = [self.sentence[i] for i in self.stack]
    b = [self.sentence[i] for i in self.buffer]
    print(s, b)
    print(self.arcs) 

  def is_tree_final(self):
    return len(self.buffer) == 0
    

## Oracle

The transition actions must follow the preconditions above

In [4]:
class Oracle:
  def __init__(self, parser, gold_tree):
    self.parser = parser
    self.gold = gold_tree

  # left arc?
  def is_left_arc_gold(self):
    # get the two elems the arc should be built on
    b1 = self.parser.buffer[0]
    o1 = self.parser.stack[len(self.parser.stack)-1]

    # check preconditions
    if o1 == -1: return False
    if self.parser.arcs[o1] != -1: return False

    # check if gold move
    if self.gold[o1] != b1:
      return False
    else:
      return True

  # right arc?
  def is_right_arc_gold(self):
    b1 = self.parser.buffer[0]
    o1 = self.parser.stack[len(self.parser.stack)-1]

    # no preconditions?
    # check if gold move
    if self.gold[b1] != o1:
      return False
    else:
      return True

  # reduce?
  def is_reduce_gold(self):

    # precondition
    o1 = self.parser.stack[len(self.parser.stack)-1]
    if self.parser.arcs[o1] == -1:
      return False
    
    # check if exist k < o1 s.t. exist (k,b1) or (b1,k) in Agold
    b1 = self.parser.buffer[0]
    if self.gold[b1] < o1:
      return True
    else:
      return b1 in self.gold[:o1]
    

  


## Test on a sentence

In [5]:
# define the sentence and the gold tree
sentence = ["<ROOT>", "He", "wrote", "her", "a", "letter", "."]
gold = [-1, 2, 0, 2, 5, 2, 2]

# initialize parser and oracle
parser = ArcEager(sentence)
oracle = Oracle(parser, gold)

# print initial configuration
parser.print_configuration()

['<ROOT>'] ['He', 'wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [6]:
# until the tree is final, apply the right move

iteration = 0   # keep track of the number of iterations
while not parser.is_tree_final():
  print('Iteration:',iteration)

  # if LA is gold
  if oracle.is_left_arc_gold():
    parser.left_arc()
    parser.print_configuration()
    print('Transition: LA', end='\n\n')

  # elif RA is gold
  elif oracle.is_right_arc_gold():
    parser.right_arc()
    parser.print_configuration()
    print('Transition: RA', end='\n\n')

  # elif RE is gold
  elif oracle.is_reduce_gold():
    parser.reduce()
    parser.print_configuration()
    print('Transition: RE', end='\n\n')

  # else shift
  else:
    parser.shift()
    parser.print_configuration()
    print('Transition: SH', end='\n\n')


  iteration = iteration + 1

# parsing tree completed

Iteration: 0
['<ROOT>', 'He'] ['wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]
Transition: SH

Iteration: 1
['<ROOT>'] ['wrote', 'her', 'a', 'letter', '.']
[-1, 2, -1, -1, -1, -1, -1]
Transition: LA

Iteration: 2
['<ROOT>', 'wrote'] ['her', 'a', 'letter', '.']
[-1, 2, 0, -1, -1, -1, -1]
Transition: RA

Iteration: 3
['<ROOT>', 'wrote', 'her'] ['a', 'letter', '.']
[-1, 2, 0, 2, -1, -1, -1]
Transition: RA

Iteration: 4
['<ROOT>', 'wrote', 'her', 'a'] ['letter', '.']
[-1, 2, 0, 2, -1, -1, -1]
Transition: SH

Iteration: 5
['<ROOT>', 'wrote', 'her'] ['letter', '.']
[-1, 2, 0, 2, 5, -1, -1]
Transition: LA

Iteration: 6
['<ROOT>', 'wrote'] ['letter', '.']
[-1, 2, 0, 2, 5, -1, -1]
Transition: RE

Iteration: 7
['<ROOT>', 'wrote', 'letter'] ['.']
[-1, 2, 0, 2, 5, 2, -1]
Transition: RA

Iteration: 8
['<ROOT>', 'wrote'] ['.']
[-1, 2, 0, 2, 5, 2, -1]
Transition: RE

Iteration: 9
['<ROOT>', 'wrote', '.'] []
[-1, 2, 0, 2, 5, 2, 2]
Transition: RA



In [7]:
# check the arcs obtained are the same as in the gold tree
print('Is tree correct?', oracle.gold == parser.arcs)

Is tree correct? True


# Data set-up and training

## Dataset

## Preprocessing

### Utils for parsing

In [8]:
# function to check if the tree is projective or not
def is_projective(tree):
  for i in range(len(tree)):
    if tree[i] == -1:
      continue
    left = min(i, tree[i])
    right = max(i, tree[i])

    for j in range(0, left):
      if tree[j] > left and tree[j] < right:
        return False
    for j in range(left+1, right):
      if tree[j] < left or tree[j] > right:
        return False
    for j in range(right+1, len(tree)):
      if tree[j] > left and tree[j] < right:
        return False

  return True

def parse_sentence(oracle, show_conf=False):
    if show_conf:
        oracle.parser.print_configuration()
    iteration = 0   # keep track of the number of iterations
    while not oracle.parser.is_tree_final():
        if show_conf:
            print('Iteration:',iteration)

        # if LA is gold
        if oracle.is_left_arc_gold():
            oracle.parser.left_arc()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: LA', end='\n\n')

        # elif RA is gold
        elif oracle.is_right_arc_gold():
            oracle.parser.right_arc()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: RA', end='\n\n')

        # elif RE is gold
        elif oracle.is_reduce_gold():
            oracle.parser.reduce()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: RE', end='\n\n')

        # else shift
        else:
            oracle.parser.shift()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: SH', end='\n\n')


        iteration = iteration + 1
    return oracle.parser.arcs

# get a random sentence to test the parsing algorith to be correct
def parse_dataset(dataset):
    for sample in dataset:
        sentence = ['<ROOT>'] + sample['tokens']
        gold_tree = sample['head']
        gold_tree = [-1] + [int(key) for key in gold_tree]
        # initialize parser and oracle
        parser = ArcEager(sentence)
        oracle = Oracle(parser, gold_tree)
        parse_tree = parse_sentence(oracle)
        if parse_tree != gold_tree:
            return False
    return True

### Test the oracle and the parsing functions on the training set

In [9]:
# get english sentences from the dataset
train_dataset = load_dataset('universal_dependencies', 'en_lines', split="train")
# remove non-projective sentences: heads in the gold tree are strings, we convert them to int
train_dataset = [sample for sample in train_dataset if is_projective([-1] + [int(head) for head in sample["head"]])]

# parse the whole dataset $train_dataset using the oracle to check if it works correctly
print(parse_dataset(train_dataset))

Downloading builder script:   0%|          | 0.00/87.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.33M [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/191k [00:00<?, ?B/s]

Downloading and preparing dataset universal_dependencies/en_lines to /root/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/580k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/199k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/181k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/3176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1032 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1035 [00:00<?, ? examples/s]

Dataset universal_dependencies downloaded and prepared to /root/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7. Subsequent calls will reuse this data.
True


### Pre-processing utils
- the $create\_token\_indices$ function creates a vocabulary $vocab$ containing all tokens in $dataset$ (given as input argument) appearing at least $threshold$ (given as input argument, default value 3) times. $vocab[token]$ is the index assigned to $token;

- the $process\_sample$ function is used to process our data and create the actual training samples.
For each sentence in $train\_dataset$, we use our oracle to compute the canonical path followed to extract the gold tree. We then pair each configuration to the golden transition selected by the oracle.
Because of the structure of Arc-Eager parser, we encode a $configuration$ with only two words: $\sigma_1$ and $\beta_1$ (i.e. the topmost element on the stack and the first buffer element, respectiveley);

- $prepare\_batch$ function pre-processes a batch of samples $batch\_data$ using as indices of tokens the ones contained in the vocabulary $tokens\_indices\_voc$. The pre-processing is done by applying function $process\_sample$ to each sample in $batch\_data$.

In [10]:
# the function returns a dictionary containing the vocabulary embedding indices for the tokens in $dataset
# i.e. an index is associated to each token in $dataset
# $threshold is the minimum number of appearance for a token in $dataset to be included in the dictionary
def create_token_indices(dataset, threshold=3):
  # $dic has the tokens as keys. Given a token, $dic[token] is the number of occurrences of $token in $dataset
  dic = {}
  for sample in dataset:
    for token in sample['tokens']:
      if token in dic:
        dic[token] += 1
      else:
        dic[token] = 1 

  # vocab["token"] is an integer representing the index of token "token"
  vocab = {}
  # indices for some special tokens
  vocab["<pad>"] = 0
  vocab["<ROOT>"] = 1 
  vocab["<unk>"] = 2

  next_ind = 3
  for token in dic.keys():
    if dic[token] >= threshold:
      vocab[token] = next_ind
      next_ind += 1

  return vocab

# creates training instances from one sample if $get_gold_path is $True, otherwise does only a simple pre-process
# $sample is a sample of our dataset
def process_sample(tokens_indices_voc, sample, get_gold_path = False):

  # add the root token to the sentence and its head (-1) to the gold head list
  sentence = ["<ROOT>"] + sample["tokens"]
  gold = [-1] + [int(i) for i in sample["head"]]
  
  # sentence representation with each token represented by its index in the vocabulary
  sentence_repr = [tokens_indices_voc[token] if token in tokens_indices_voc else tokens_indices_voc["<unk>"] for token in sentence]

  # $gold_configurations and $gold_transitions are parallel arrays whose elements refer to parsing steps
  # $gold_configurations[i] records configuration at step $i, i.e. topmost stack token and first buffer token for current step
  gold_configurations = []
  # $gold_transitions[i] contains oracle (canonical) transition for step $i: 0 is left_arc, 1 right_arc, 2 reduce, 3 shift
  gold_transitions = []

  # only for training
  if get_gold_path:
    parser = ArcEager(sentence)
    oracle = Oracle(parser, gold)

    while not parser.is_tree_final():
      
      # save configuration
      configuration = [parser.stack[-1]]
      if len(parser.buffer) == 0:
        configuration.append(-1)
      else:
        configuration.append(parser.buffer[0])  
      gold_configurations.append(configuration)

      # save gold transition
      if oracle.is_left_arc_gold():
        gold_transitions.append(0)
        oracle.parser.left_arc()
      elif oracle.is_right_arc_gold():
        gold_transitions.append(1)
        oracle.parser.right_arc()
      elif oracle.is_reduce_gold():
        gold_transitions.append(2)
        oracle.parser.reduce()
      else:
        gold_transitions.append(3)
        oracle.parser.shift()

  # $sentence_repr is a list containing representations of tokens in $sample
  # $gold is a list containning the representation of the gold tree of $sample
  # $gold_configurations is a list containing gold configurations
  # $gold_transitions is a list containing gold transitions
  return sentence_repr, gold_configurations, gold_transitions, gold

# this function is used to pre-process a batch of samples $batch_data from the original dataset
# applying function $process_sample to each sample in $batch_data
def prepare_batch(tokens_indices_voc, batch_data, get_gold_path=False):
  processed_batch_data = [process_sample(tokens_indices_voc, sample, get_gold_path=get_gold_path) for sample in batch_data]
  
  sentences_repr = [] 
  paths = [] 
  moves = [] 
  gold_trees = []
  for sample in processed_batch_data:
    sentences_repr.append(sample[0])
    paths.append(sample[1])
    moves.append(sample[2])
    gold_trees.append(sample[3])
    
    # sentences_repr, paths, moves, gold_trees are parallel lists
    # element in position $i of each of the above lists refers to the same sentence $i 
  return sentences_repr, paths, moves, gold_trees

### Datasets loading and pre-processing

In [11]:
# the training set has already been loaded, then load also development set and test set
dev_dataset = load_dataset('universal_dependencies', 'en_lines', split="validation")
test_dataset = load_dataset('universal_dependencies', 'en_lines', split="test")

# set the number of samples per batch
BATCH_SIZE = 32

# create the dictionary with token indices that is the embedding dictionary
tokens_indices = create_token_indices(train_dataset)

# create dataloaders to batch each dataset and apply function $prepare_batch to each batch
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, tokens_indices, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, tokens_indices))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))



#### Now everything is ready to create a BiLSTM and extract embeddings from configurations in order to then train a classifier based on gold trees

# Baseline model

In this section, we will create the baseline Bi-LSTM model to predict the next move of the acr-eager parser

Definition of some parameters (da testare diversi valori)

In [12]:
EMBEDDING_SIZE = 200
LSTM_SIZE = 200
LSTM_LAYERS = 1
MLP_SIZE = 200        # multi perceptron layer size
DROPOUT = 0.2
EPOCHS = 15
LR = 0.001            # learning rate

Now we create the model

Here some comments about **how to select** the **next transition** given the **predicted scores** from the model:

* $\color{red}{\text{If the best move is LA}}$: we need to check if it is a feasible move, that is:
  * $\sigma_1$ does not have already a head (*a token cannot have two heads*)
  * $\sigma_1$ is not the ROOT (*the ROOT cannot be a depentent*)

  If this is not the case all $\color{yellow}{\text{RA}}$, $\color{yellow}{\text{RE}}$, $\color{yellow}{\text{SH}}$ are theoretically feasible moves. So we need to check the second-best predicted move:
  
  * $\color{yellow}{\text{If the second-best move is RA}}$: we need to check if it is a feasible move:
    * if $\sigma_1$ is not the ROOT then RA is ok (*we'll never have $\beta_1$ with a head already, because when we do RA we push to the stack*)
    * if $\sigma_1$ is the ROOT then:
      
      In this case RE not feasible (*we cannot remove the ROOT*)we can do RA only if buffer has length = 1, otherwise we need SH. This is because if buffer has length > 1, we create an arc $(ROOT → \beta_i)$ but then for sure we'll need to do another arc $(ROOT → \beta_j)$ to reach the final configuration. If this happens, the ROOT will have 2 dependents, that is not feasible. Therefore:
      * if len(buffer) > 1:
        * then SH
      
      * Otherwise RA is ok
  * $\color{yellow}{\text{If the second-best move is RE}}$: we need to check if it is a feasible move:
    * if $\sigma_1$ is not the ROOT, then RE is ok
    * otherwise at this point we can do theoretically both $\color{cyan}{\text{RA}}$ or $\color{cyan}{\text{SH}}$, so we need to check the third-best predicted move. In this case we have $\sigma_1$ the ROOT:

      * $\color{cyan}{\text{If the third-best move is RA}}$: For the same reason as before, we can do RA if len(buffer) == 1:
        * If len(buffer) == 1:
          * RA is ok

        * otherwise SH

      * $\color{cyan}{\text{If the third-best move is SH}}$: SH is always ok
  * $\color{yellow}{\text{If the second-best move is SH}}$: SH is always ok


NOTE: At this point the considerations already done above will be skipped to simplify the notation

* $\color{red}{\text{If the best move is RA}}$:
  * if $\sigma_1$ not the ROOT then: RA is ok
  * if $\sigma_1$ is the ROOT then:
    * if len(buffer) == 1 then: RA is ok
    * otherwise we have the following situation: [ROOT][$\beta_1$,$\beta_2$,$\beta_3$,...]. We cannot do LA (*the ROOT cannot be a dependent*), neither RE (*we cannot remove the ROOT*). The only other option is SH

* $\color{red}{\text{If the best move is RE}}$:
  * if if $\sigma_1$ not the ROOT then RE is ok
  * if $\sigma_1$ is the ROOT then we can do theoretically both $\color{yellow}{\text{SH}}$ or $\color{yellow}{\text{RA}}$, while LA is not possible (*otherwise the ROOT becomes a dependent*) 

  * $\color{yellow}{\text{If the second-best move is RA}}$:
    * if len(buffer) == 1 then: RA is ok
    * otherwise SH is ok  

  * $\color{yellow}{\text{If the second-best move is SH}}$: SH is always ok


* $\color{red}{\text{If the best move is SH}}$: SH is always ok

In [13]:
class Net(nn.Module):

  def __init__(self, device):
    super(Net, self).__init__()
    self.device = device
    self.embeddings = nn.Embedding(len(tokens_indices), EMBEDDING_SIZE, padding_idx=tokens_indices["<pad>"])
    
    # initialize bi-LSTM
    self.lstm = nn.LSTM(EMBEDDING_SIZE, LSTM_SIZE, num_layers = LSTM_LAYERS, bidirectional=True, dropout=DROPOUT)

    # initialize feedforward
    self.w1 = torch.nn.Linear(4*LSTM_SIZE, MLP_SIZE, bias=True)
    self.activation = torch.nn.Tanh()
    self.w2 = torch.nn.Linear(MLP_SIZE, 4, bias=True)
    self.softmax = torch.nn.Softmax(dim=-1)

    self.dropout = torch.nn.Dropout(DROPOUT)
  
  
  def forward(self, x, paths):
    # get the embeddings 
    x = [self.dropout(self.embeddings(torch.tensor(i).to(self.device))) for i in x]

    # run the bi-lstm
    h = self.lstm_pass(x)

    # for each parser configuration that we need to score we arrange from the
    # output of the bi-lstm the correct input for the feedforward
    mlp_input = self.get_mlp_input(paths, h)

    # run the feedforward and get the scores for each possible action
    out = self.mlp(mlp_input)

    return out

  def lstm_pass(self, x):
    x = torch.nn.utils.rnn.pack_sequence(x, enforce_sorted=False)
    h, (h_0, c_0) = self.lstm(x)
    h, h_sizes = torch.nn.utils.rnn.pad_packed_sequence(h) # size h: (length_sentences, batch, output_hidden_units)
    return h

  def get_mlp_input(self, configurations, h):
    mlp_input = []
    zero_tensor = torch.zeros(2*LSTM_SIZE, requires_grad=False).to(self.device)
    for i in range(len(configurations)): # for every sentence in the batch
      for j in configurations[i]: # for each configuration of a sentence 
        mlp_input.append(torch.cat([zero_tensor if j[0]==-1 else h[j[0]][i],
                                    zero_tensor if j[1]==-1 else h[j[1]][i]]))
    mlp_input = torch.stack(mlp_input).to(self.device)
    return mlp_input

  def mlp(self, x):
    return self.softmax(self.w2(self.dropout(self.activation(self.w1(self.dropout(x))))))

  # we use this function at inference time. We run the parser and at each step 
  # we pick as next move the one with the highest score assigned by the model
  def infere(self, x):

    parsers = [ArcEager(i) for i in x]

    x = [self.embeddings(torch.tensor(i).to(self.device)) for i in x]

    h = self.lstm_pass(x)

    while not self.parsed_all(parsers):
      # get the current configuration and score next moves
      configurations = self.get_configurations(parsers)
      mlp_input = self.get_mlp_input(configurations, h)
      mlp_out = self.mlp(mlp_input)
      # take the next parsing step
      self.parse_step(parsers, mlp_out)

    # return the predicted dependency tree
    return [parser.arcs for parser in parsers]

  def get_configurations(self, parsers):
    configurations = []

    for parser in parsers:
      if parser.is_tree_final():
        conf = [-1, -1]
      else:
        conf = [parser.stack[len(parser.stack)-1]]
        if len(parser.buffer) == 0:
          conf.append(-1)
        else:
          conf.append(parser.buffer[0])  
      configurations.append([conf])

    return configurations

  def parsed_all(self, parsers):
    for parser in parsers:
      if not parser.is_tree_final():
        return False
    return True

  # In this function we select and perform the next move according to the scores obtained.
  # We need to be careful and select correct moves, e.g. don't do a shift if the buffer
  # is empty or a left arc if σ2 is the ROOT. For clarity sake we didn't implement
  # these checks in the parser so we must do them here. This renders the function quite ugly
  def parse_step(self, parsers, moves):


    # get the move with the highest score
    best_moves = torch.topk(moves,3,dim=1).indices
    best_predicted_move = best_moves[:,0]
    second_best_predicted_move = best_moves[:,1]
    third_best_predicted_move = best_moves[:,2]

    for idx in range(len(parsers)):

      # if tree is final no need to do next move
      if parsers[idx].is_tree_final():
        continue
      else:

        #=========================
        # if best move is LEFT ARC
        #=========================
        if best_predicted_move[idx] == 0:

          # NOTE: buffer for sure not empty ALWAYS
            
          # check sigma_1 is not the root and sigma_1 doeas not have a head already
          if parsers[idx].stack[-1] != 0 and parsers[idx].arcs[parsers[idx].stack[-1]] == -1:
            parsers[idx].left_arc()

          # theoretically I can do RA, RE or SH, check the one with highest prob
          else:
            
            # if second best is right arc
            if second_best_predicted_move[idx] == 1:

              # at this point we have a) sigma1 is the ROOT or b) sigma 1 already a head
              # not both since the ROOT cannot have a head

              # if sigma1 not the ROOT we can do right arc
              if parsers[idx].stack[-1] != 0:
                parsers[idx].right_arc()

              # if sigma1 is the ROOT
              # for sure reduce not possible
              else:
                
                # DA RIVEDERE: NON SONO SICURO DI QUESTA COSA, FORSE IL RIGHT ARC SI PUÒ
                # FARE IN OGNI CASO (aggiornamento: penso sia ok)

                # if len buffer is > 1 shift
                if len(parsers[idx].buffer) > 1:
                  parsers[idx].shift()

                else:
                  parsers[idx].right_arc()


            # if second best is reduce
            elif second_best_predicted_move[idx] == 2:
              
              # if sigma1 not the root we can reduce since for sure it has already
              # a head
              if parsers[idx].stack[-1] != 0:
                parsers[idx].reduce()

              # if sigma1 is the root
              # at this point we can do both RA or SH
              else:

                # if shift has higher score or buffer length > 1 then shift
                if third_best_predicted_move[idx] != 1 or len(parsers[idx].buffer) > 1:
                  parsers[idx].shift()
                
                else:
                  parsers[idx].right_arc()

            # if second best is shift do it, no problems
            else:
              parsers[idx].shift()

        #=========================
        # if best move is RIGHT ARC
        #=========================
        elif best_predicted_move[idx] == 1:

          # if sigma1 not the ROOT then we can right arc
          if parsers[idx].stack[-1] != 0:
            parsers[idx].right_arc()

          else:
            
            if len(parsers[idx].buffer) == 1:
              parsers[idx].right_arc()

            else:
              parsers[idx].shift()


        #=========================
        # if best move is REDUCE
        #=========================
        elif best_predicted_move[idx] == 2:
          
          # if sigma1 not the ROOT then RE is ok
            if parsers[idx].stack[-1] != 0:
              parsers[idx].reduce()

            else:
              # at this point we can do both RA or SH
              
              # if second best move right arc
              if second_best_predicted_move[idx] == 1:

                if len(parser[idx].buffer) == 1:
                  parsers[idx].right_arc()

                else:
                  parsers[idx].shift()

              # if second best move is SH
              else:
                parsers[idx].shift()


        #=========================
        # if best move is SHIFT
        #=========================
        else:

          # for shift no problems
          parsers[idx].shift()

### Train and test of BiLSTM model

In [14]:
def evaluate(gold, preds): 
  total = 0
  correct = 0

  for g, p in zip(gold, preds):
    for i in range(1,len(g)):
      total += 1
      if g[i] == p[i]:
        correct += 1

  return correct/total

In [15]:
def train(model, dataloader, criterion, optimizer):
  model.train()
  total_loss = 0
  count = 0

  for batch in dataloader:
    optimizer.zero_grad()
    sentences, paths, moves, trees = batch

    out = model(sentences, paths)
    labels = torch.tensor(sum(moves, [])).to(device) #sum(moves, []) flatten the array
    loss = criterion(out, labels)

    count +=1
    total_loss += loss.item()

    loss.backward()
    optimizer.step()
  
  return total_loss/count

def test(model, dataloader):
  model.eval()

  gold = []
  preds = []

  for batch in dataloader:
    sentences, paths, moves, trees = batch
    with torch.no_grad():
      pred = model.infere(sentences)

      gold += trees
      preds += pred
  
  return evaluate(gold, preds)

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model = Net(device)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)


for epoch in range(EPOCHS):
  avg_train_loss = train(model, train_dataloader, criterion, optimizer)
  val_uas = test(model, dev_dataloader)

  print("Epoch: {:3d} | avg_train_loss: {:5.3f} | dev_uas: {:5.3f} |".format( epoch, avg_train_loss, val_uas))



Device: cuda
Epoch:   0 | avg_train_loss: 1.080 | dev_uas: 0.551 |
Epoch:   1 | avg_train_loss: 0.951 | dev_uas: 0.611 |
Epoch:   2 | avg_train_loss: 0.921 | dev_uas: 0.626 |
Epoch:   3 | avg_train_loss: 0.901 | dev_uas: 0.648 |
Epoch:   4 | avg_train_loss: 0.887 | dev_uas: 0.660 |
Epoch:   5 | avg_train_loss: 0.875 | dev_uas: 0.666 |
Epoch:   6 | avg_train_loss: 0.866 | dev_uas: 0.674 |
Epoch:   7 | avg_train_loss: 0.858 | dev_uas: 0.671 |
Epoch:   8 | avg_train_loss: 0.850 | dev_uas: 0.671 |
Epoch:   9 | avg_train_loss: 0.844 | dev_uas: 0.677 |
Epoch:  10 | avg_train_loss: 0.837 | dev_uas: 0.680 |
Epoch:  11 | avg_train_loss: 0.833 | dev_uas: 0.689 |
Epoch:  12 | avg_train_loss: 0.828 | dev_uas: 0.681 |
Epoch:  13 | avg_train_loss: 0.824 | dev_uas: 0.684 |
Epoch:  14 | avg_train_loss: 0.821 | dev_uas: 0.686 |


# BERT model
In this section, we will use BERT to extract a feature vector for each token $\sigma_1$ and $\beta_1$ in $configuration$.

In [17]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.30.1-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m61.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m70.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m80.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, safetensors, transformers
Successfully installed safetensors-0.3.1 tokenizers-0.13.3 transformer

#### Pre-processing
We pre-process a batch of samples

In [20]:
# creates training instances from one sample
# $sample is a sample of our dataset
def process_sample_bert(sample, get_gold_path=False):

    # add the root token to the sentence and its head (-1) to the gold head list
    sentence = ["<ROOT>"] + sample["tokens"]
    gold = [-1] + [int(i) for i in sample["head"]]
    
    # save the tokens
    sample_tokens = []
    sample_tokens.append("<ROOT>")
    for token in sample["tokens"]:
        sample_tokens.append(token)

    # $gold_configurations and $gold_transitions are parallel arrays whose elements refer to parsing steps
    # $gold_configurations[i] records configuration at step $i, i.e. topmost stack token and first buffer token for current step
    gold_configurations = []
    # $gold_transitions[i] contains oracle (canonical) transition for step $i: 0 is left_arc, 1 right_arc, 2 reduce, 3 shift
    gold_transitions = []

    # only for training
    if get_gold_path:
        parser = ArcEager(sentence)
        oracle = Oracle(parser, gold)

        while not parser.is_tree_final():
        
            # save configuration
            configuration = [parser.stack[-1]]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])  
            gold_configurations.append(configuration)

            # save gold transition
            if oracle.is_left_arc_gold():
                gold_transitions.append(0)
                oracle.parser.left_arc()
            elif oracle.is_right_arc_gold():
                gold_transitions.append(1)
                oracle.parser.right_arc()
            elif oracle.is_reduce_gold():
                gold_transitions.append(2)
                oracle.parser.reduce()
            else:
                gold_transitions.append(3)
                oracle.parser.shift()

    # $sample_tokens is a list containing the tokens in $sample
    # $gold is a list containning the representation of the gold tree of $sample
    # $gold_configurations is a list containing gold configurations
    # $gold_transitions is a list containing gold transitions
    return sample_tokens, gold_configurations, gold_transitions, gold

# this function is used to pre-process a batch of samples $batch_data from the original dataset
# applying function $process_sample to each sample in $batch_data
def prepare_batch_bert(batch_data, get_gold_path=False):
    processed_batch_data = [process_sample_bert(sample, get_gold_path=get_gold_path) for sample in batch_data]
    
    samples_tokens = []
    paths = []
    moves = []
    gold_trees = []

    for sample in processed_batch_data:
        samples_tokens.append(sample[0])
        paths.append(sample[1])
        moves.append(sample[2])
        gold_trees.append(sample[3])
        
    # $samples_tokens, $paths, $moves, $gold_trees are parallel lists
    # element in position $i of each of the above lists refers to the same sentence $i 
    return samples_tokens, paths, moves, gold_trees

#### Data loaders
We load the sets preparing batches according to the function $prepare\_batch\_bert$.
In this way, we will be able to train our model in parallel in the samples belonging to the same batch.

In [19]:
# load train set, development set and test set
train_dataset = load_dataset('universal_dependencies', 'en_lines', split="train")
validation_dataset = load_dataset('universal_dependencies', 'en_lines', split="validation")
test_dataset = load_dataset('universal_dependencies', 'en_lines', split="test")

# set the number of samples per batch
BATCH_SIZE = 32

# create dataloaders to batch each dataset and apply function $prepare_batch to each batch
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch_bert, get_gold_path=True))
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch_bert))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch_bert))



#### Model creation
Define the model to be used.

In [21]:
# Model design parameters
BERT_TOKEN_EMBEDDING_SIZE = 768
CLASSIFICATION_LABELS = 4
WORD_PER_CONFIGURATION = 2
# Model hyperparameters
LINEAR_LAYER_SIZE = 200
DROPOUT = 0.2


from transformers import BertModel
from transformers import BertTokenizer

# Custom Bert model
class CustomBertBasedModel(nn.Module):
    

    # CONSTRUCTORS


    # Constructor specifications
    def __init__(self, device):
        
        # inherit the constructor of $nn.Module
        super(CustomBertBasedModel, self).__init__()

        # name of the Bert model to be used
        bert_model_name = "bert-base-uncased"

        # tokenizer used for Bert
        self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)

        # Bert model
        self.bert = BertModel.from_pretrained(bert_model_name)

        # feedforward head for classification
        self.dropout = nn.Dropout(DROPOUT)
        self.linear1 = nn.Linear(BERT_TOKEN_EMBEDDING_SIZE * WORD_PER_CONFIGURATION, LINEAR_LAYER_SIZE, bias=True)
        self.activation1 = nn.Tanh()
        self.linear2 = nn.Linear(LINEAR_LAYER_SIZE, CLASSIFICATION_LABELS, bias=True)
        self.softmax = nn.Softmax(dim=-1) # the dimension where the softmax has to be applied is the last one

        # device
        self.device = device


    # UTILITY FUNCTIONS: ARCEAGER PARSE


    # Check if all parsers $parsers have finished parsing their sentences.
    # Returns:
    #   ~ True if all parsers have finished;
    #   ~ False otherwise.
    def parsed_all(self, parsers):
        for parser in parsers:
            if not parser.is_tree_final():
                return False
        return True


    def get_configurations(self, parsers):
        
        # list of current configurations: $configurations[i] is the current configuration for parser $parsers[i]
        configurations = []

        for parser in parsers:
            # current configuration of $parser
            configuration = []

            if parser.is_tree_final():
                # empty configuration
                configuration = [-1, -1]

            else:
                # append the first element of the configuration: the first element on top of the stack
                configuration.append(parser.stack[-1])
                if len(parser.buffer) == 0:
                    # no element in the buffer, so append $-1
                    configuration.append(-1)
                else:
                    # append the second element of the configuration: the first element in the buffer
                    configuration.append(parser.buffer[0])
            
            # append the current configuration for the current parser $parser
            # we add $[] in order to append a list with only one configuration: necessary for $prepare_input_fnn
            configurations.append([configuration])

        return configurations


    # Function that selects and performs the next ArcEager transition for each sentence $i in the batch,
    # according to the scores in $fnn_output.
    # $parsers is a list of ArcEager parsers: $parsers[i] is parsing sentence $i in the batch.
    # $fnn_output is the output tensor of the fnn predicting scores for the $CLASSIFICATION_LABELS possible transitions.
    # $fnn_output has size (# sentences in batch, $CLASSIFICATION_LABELS).
    def parse_step(self, parsers, fnn_output):
        
        # get the indices of the top 3 transition scores
        # $top_k_transitions_indices[i] contains the indices in $fnn_output[i] (i.e. for sentence $i) sorted by predicted transition score
        top_k_transitions_indices = torch.topk(fnn_output, 4, dim=-1).indices
        # transitions with highest score: $best_transitions_indices[i] is the transition with highest score for sentence $i
        best_transitions = top_k_transitions_indices[:, 0]
        second_best_transitions = top_k_transitions_indices[:, 1]
        third_best_transitions = top_k_transitions_indices[:, 2]

        # get the best transition for each parser in $parsers
        for i in range(len(parsers)):
            
            # do nothing if the parser $parsers[i] has already parsed the whole sentence
            if parsers[i].is_tree_final():
                continue
            
            else:

                # transition with highest probability: left-arc
                if best_transitions[i] == 0:

                    # note that the buffer is never empty for sure
                    # check that sigma_1 is not the root and sigma_1 doeas not already have a head
                    if parsers[i].stack[-1] != 0 and parsers[i].arcs[parsers[i].stack[-1]] == -1:
                        parsers[i].left_arc()

                    # if the preconditions for the left-arc are not satisfied, I can do right-arc, reduce or shift without any problem,
                    # so pick the transition with highest probability
                    else:
                        # the second best transition is right-arc
                        if second_best_transitions[i] == 1:

                            # at this point two things may happen: sigma_1 is the ROOT or sigma_1 already has a head
                            # the two things cannot happen together since the ROOT cannot have a head
                            # let's check in which of the two cases we are

                            # sigma_1 is not the ROOT: we can do right-arc
                            if parsers[i].stack[-1] != 0:
                                parsers[i].right_arc()

                            # if sigma1 is the ROOT
                            # for sure reduce not possible
                            else:
                                
                                # NON SI PUO' FARE IL RIGHT-ARC IN OGNI CASO?

                                # if len buffer is > 1 shift
                                if len(parsers[i].buffer) > 1:
                                    parsers[i].shift()

                                else:
                                    parsers[i].right_arc()


                        # the second best transition is reduce
                        elif second_best_transitions[i] == 2:
                        
                            # if sigma_1 is not the root, then we can reduce, because for sure sigma_1 already has a head
                            # this is because only one of the two preconditions of left-arc is not satisfied (never both)
                            if parsers[i].stack[-1] != 0:
                                parsers[i].reduce()

                            # sigma_1 is the root: we cannot do a reduce
                            # at this point we can do only right-arc or shift: no problems for both of them
                            else:
                                
                                # PERCHE' SCEGLIERE SHIFT SE NEL BUFFER CI SONO ALMENO DUE ELEMENTI, MA RIGHT-ARC E' MIGLIORE?

                                # if shift has higher score or buffer length > 1 then shift
                                if third_best_transitions[i] != 1 or len(parsers[i].buffer) > 1:
                                    parsers[i].shift()
                                else:
                                    parsers[i].right_arc()

                        # if the second best transition is shift, then do it without any problem
                        else:
                            parsers[i].shift()

                # transition with highest probability: right-arc
                elif best_transitions[i] == 1:

                    # if sigma_1 is not the ROOT, then we can do right-arc
                    if parsers[i].stack[-1] != 0:
                        parsers[i].right_arc()

                    else:
                        
                        if len(parsers[i].buffer) == 1:
                            parsers[i].right_arc()

                        else:
                            parsers[i].shift()


                # transition with highest probability: reduce
                elif best_transitions[i] == 2:
                
                    # if sigma_1 is not the ROOT, then we can do reduce
                    if parsers[i].stack[-1] != 0:
                        parsers[i].reduce()

                    # sigma_1 is the root, so we cannot do reduce: we can do right-arc or shift
                    else:
                    
                        # the second best transition is right-arc
                        if second_best_transitions[i] == 1:

                            if len(parsers[i].buffer) == 1:
                                parsers[i].right_arc()

                            else:
                                parsers[i].shift()

                        # the second best transition is shift
                        else:
                            parsers[i].shift()


                # transition with highest probability: shift
                else:
                    # we can do the shift without any problem
                    parsers[i].shift()


    # UTILITY FUNCTIONS: MODEL


    # Given a list of tokens $list_tokens representing words in a sentence, the function reconstructs the
    # string representing the sentence.
    def sentence_reconstruction(self, list_tokens):
        sentence = ""
        for token in list_tokens:
            sentence = sentence + " " + token
        return sentence


    # The function averages tokens hidden representations $hidden_vectors in order to
    # obtain a representation for each token in $desired_tokens starting from tokens in $current_tokens.
    # More precisely, it averages representations of tokens in $current_tokens that are subtokens of the
    # same token in $desired_tokens.
    # $hidden_vectors contains representations for tokens in $current_tokens. It is a torch tensor of size
    # (max sentence length in batch, $BERT_TOKEN_EMBEDDING_SIZE).
    # $current_tokens contains the tokens for which we have a corresponding representation in $hidden_vectors.
    # $desired_tokens contains the tokens for which we want a representation.
    # returns the tensor $torch.stack(new_hidden_vectors, dim=0) of size (len($desired_tokens), $BERT_TOKEN_EMBEDDING_SIZE), where
    # torch.stack(new_hidden_vectors, dim=0)[i] is the hidden vector representing $desired_tokens[i] as the mean
    # of the hidden vectors in $hidden_vectors representing tokens in $current_tokens that are subtokens of $desired_tokens[i].
    def average_layer(self, hidden_vectors, current_tokens, desired_tokens):
        
        # list of hidden vectors representing $desired_tokens. Actually, it is a list of torch tensors.
        new_hidden_vectors = []

        # index of the current token in $current_tokens
        current_token_index = -1

        # build one hidden vector per token in $desired_tokens
        for word in desired_tokens:

            # update the index
            current_token_index = current_token_index + 1

            # current subword of $word
            current_word = current_tokens[current_token_index]

            # list of hidden vectors in $hidden_vectors representing tokens that are subword tokens of $word
            sub_word_vectors = [hidden_vectors[current_token_index]]

            # append hidden vectors representing subword tokens of $word until $word and $current_word are equal
            while not len(current_word) == len(word):

                # we add to $current_word the next subword token of $word
                current_token_index = current_token_index + 1
                if current_tokens[current_token_index].startswith("##"):
                    current_word = current_word + current_tokens[current_token_index][2:]
                else:
                    current_word = current_word + current_tokens[current_token_index]
                sub_word_vectors.append(hidden_vectors[current_token_index])

            # build a tensor of size (len($sub_word_vectors), $BERT_TOKEN_EMBEDDING_SIZE) from the list $sub_word_vectors by stacking subword
            # vectors along the first dimension
            stacked_tensor = torch.stack(sub_word_vectors, dim=0)

            # append to the list $new_hidden_vectors the hidden vector representing $word, which is computed as the mean
            # of the hidden vectors in $hidden_vectors representing its subword tokens
            new_hidden_vectors.append(torch.mean(stacked_tensor, dim=0))

        # we use $torch.stack to build a tensor of size (len($desired_tokens), $BERT_TOKEN_EMBEDDING_SIZE) from the
        # list of tensors $new_hidden_vectors
        return torch.stack(new_hidden_vectors, dim=0)
    

    # The function, given the outputs of Bert applied to a batch of sentences, returns the inputs for the FNN.
    # $bert_output contains Bert representation of the batch of sentences.
    # $bert_output.last_hidden_state is a tensor of size (# sentences in batch, # tokens in max sentence length, $BERT_TOKEN_EMBEDDING_SIZE)
    # $sentences_words[i] is the list of words in sentence $i in the batch.
    # $sentences_bert_tokens[i] is the list of tokens provided by Bert tokenizer applied to sentence $i in the batch.
    # $configurations[i] contains the configurations of sentence $i in the batch.
    # returns $fnn_input: it is a tensor of features representing configurations. Its size is (# configurations, 2 * $BERT_TOKEN_EMBEDDING_SIZE).
    def prepare_input_fnn(self, bert_output, sentences_words, sentences_bert_tokens, configurations):

        # $fnn_input[i] will contain the input for the FNN related to sample $i in the batch
        fnn_input = []
        
        # iterate over the number of samples in the batch, i.e. the number of sentences processed by Bert
        for i in range(len(sentences_words)):

            # embedding representation of current sentence provided by Bert
            # hidden_repr_sentence is a tensor of size (# tokens in the sentence, $BERT_TOKEN_EMBEDDING_SIZE)
            hidden_repr_sentence = bert_output.last_hidden_state[i]
            # average embeddings of tokens belonging to the same word
            words_bert_embeddings = self.average_layer(hidden_repr_sentence, sentences_bert_tokens[i], sentences_words[i])

            # iterate over the configurations of the current sentence and append the representation of each configuration to $fnn_inputs
            for configuration in configurations[i]:
                
                # if $configuration[0] == -1, then we have no word on the stack, so the representation for the token on the stack is zero
                tensor_sigma = torch.zeros(BERT_TOKEN_EMBEDDING_SIZE, requires_grad=False).to(self.device)
                # else, if there is something on top of the stack, then we take its Bert representation $words_bert_embeddings[configuration[0]]
                if configuration[0] != -1:
                    tensor_sigma = words_bert_embeddings[configuration[0]].to(self.device)
                
                # if $configuration[1] == -1, then we have no word in the buffer, so the representation for the token in the stack is zero
                tensor_beta = torch.zeros(BERT_TOKEN_EMBEDDING_SIZE, requires_grad=False).to(self.device)
                # else, if there is something at the beginning of the buffer, then its Bert representation is
                # $words_bert_embeddings[configuration[1]]
                if configuration[1] != -1:
                    tensor_beta = words_bert_embeddings[configuration[1]].to(self.device)
                
                # we concatenate the representation of the current configuration $configuration as the concatenation of the embeddings
                # of, respectively, the first token on top of the stack and the first token in the buffer
                fnn_input.append(torch.cat([tensor_sigma, tensor_beta]))
        
        # the input for the fnn is the representations of all configurations of all sentences in the batch
        # we stack them into  single tensor
        fnn_input = torch.stack(fnn_input).to(self.device)
            
        return fnn_input


    # For each list of tokens in $sentences_tokens, the function reconstructs the string representing the
    # original sentence. Moreover, it applies the Bert tokenizer to the batch of reconstructed sentences and prepares the
    # batch for being processed as input of the Bert model.
    # Returns:
    #   ~ $bert_input: input for Bert model;
    #   ~ $sentences_tokenizers: list containing the list of tokens returned by Bert tokenizer for each sentence.
    def bert_tokenization_and_input(self, sentences_tokens):
        
        # list containing all sentences in the batch
        sentences = []

        # iterate over the samples in the batch
        for sentence_tokens in sentences_tokens:

            # reconstruct the original sentence string by the word tokens $sentence_tokens
            sentences.append(self.sentence_reconstruction(sentence_tokens))

        # get the input representation to feed Bert with the batch of sentences $sentences
        bert_input = self.bert_tokenizer(sentences, padding=True, truncation=True, add_special_tokens=False, return_tensors="pt")

        # $sentence_tokenizers[i] contains tokens provided by Bert tokenizer of sentence $i
        sentences_tokenizers = []
        # get the ids of the tokens of the sentences in the batch
        input_ids = bert_input["input_ids"]
        # Convert each row in the input_ids tensor to a list of tokens
        for row in input_ids:
            sentences_tokenizers.append(self.bert_tokenizer.convert_ids_to_tokens(row))

        return bert_input, sentences_tokenizers


    # IMPORTANT INSTANCE FUNCTIONS


    # Function that applies the feedforward neural network to the input tensor $input.
    # It returns the output tensor of the fnn.
    def fnn(self, input):

        # first layer of the fnn
        output_layer1 = self.activation1(self.linear1(self.dropout(input)))
        # second layer of the fnn
        output_layer2 = self.softmax(self.linear2(self.dropout(output_layer1)))
        
        return output_layer2


    # Forward method that we use at training time. It does the forward pass for a batch of samples at the time.
    # $sentences_tokens[i] is the list of tokens of sentence $i in the batch.
    # $configurations contains the configurations of all samples in the batch.
    # $configurations[i] is the list of configurations related to $sentences[i].
    # $configurations[i][i] is the configuration at step $i when parsing sentence $sentences[i].
    def forward(self, sentences_tokens, configurations):

        # $bert_input contains the input for the bert model, after having applied $self.bert_tokenizer the batch
        # $bert_tokenized_sentences is a list containing tokenized sentences of the batch.
        bert_input, bert_tokenized_sentences = self.bert_tokenization_and_input(sentences_tokens)
        bert_input.to(self.device)

        # get the Bert output
        bert_output = self.bert(**bert_input)

        # prepare the inputs for the feedforward neural network
        fnn_input = self.prepare_input_fnn(bert_output, sentences_tokens, bert_tokenized_sentences, configurations)

        # run the feedforward neural network with all parsing configurations of all sentences in the batch as input
        fnn_output = self.fnn(fnn_input)

        # return the output of the forward pass in the whole architecture
        return fnn_output


    # Forward method that we use at inference time.
    # It processes a batch of samples at the time.
    # It infers a list of parsing trees: one per sample in the batch.
    def infere(self, sentences_tokens):
        
        # construct a parser for each sentence in the batch
        parsers = [ArcEager(sentence_tokens) for sentence_tokens in sentences_tokens]

        # this part is the same as the first part of the $forward function
        bert_input, bert_tokenized_sentences = self.bert_tokenization_and_input(sentences_tokens)
        bert_input.to(self.device)
        bert_output = self.bert(**bert_input)

        # now we do not have the gold configurations, so we need to compute and score them step by step according to the model's predictions
        while not self.parsed_all(parsers):

            # get the current configurations of the parsers: $configurations[i] is the current configuration of parser $parsers[i]
            configurations = self.get_configurations(parsers)

            fnn_input = self.prepare_input_fnn(bert_output, sentences_tokens, bert_tokenized_sentences, configurations)
            fnn_output = self.fnn(fnn_input)
            
            # given the current configuration, select the correct transition based on the scores given by the fnn
            self.parse_step(parsers, fnn_output)

        # return the predicted dependency trees: one per sentence in the batch
        return [parser.arcs for parser in parsers]

#### Training of the Bert-based model
Some considerations:
- training the whole model (both pre-trained Bert and classification head) is considered fine-tuning, because we are fine-tuning weights to better suit our task;
- training the whole model directly could cause catastrophic forgetting: since the classification head has weights that are randomly initialized, it could cause a sensible change in the weights of the Bert model, causing loss of pre-trained knowledge. Hence, a good idea could be to divide the training into two phases:
    - keep Bert pre-trained weights frozen and update only the classification head's ones. In this way, we make the classification head adapt to the pre-trained weights of Bert (as well as to our specific task);
    - do a second training by updating also Bert weights. The risk of catastrophic forgetting is now mitigated by the fact that in the previous training phase, the weights of the classification head have adapted to Bert ones and are no more random.

Class for agglomerating everything needed for training the model.

In [22]:
# hyperparameters
EPOCHS = 15
LEARNING_RATE_BERT = 1e-4
LEARNING_RATE_FNN = 1e-3

# Class for training a model in a custom way defined inside the class itself.
class CustomTrain:


    # Constructor function.
    # $model is the model to be trained.
    # $device is the device to be used for training the model.
    def __init__(self, model, device, epochs=EPOCHS, loss_function=nn.CrossEntropyLoss(), bert_freezed=False):
        
        # set the model
        self.model = model

        # set the loss function: cross entropy loss is the default
        self.loss = loss_function

        # set the deviced to be used for training the model
        self.device = device

        # set the number of epochs for the training
        self.number_of_epochs = epochs

        # set the learning rates
        self.learning_rate_bert = LEARNING_RATE_BERT
        self.learning_rate_fnn = LEARNING_RATE_FNN
        
        # set the optimizer: Adam
        if bert_freezed:

            # the weights of the Bert moel are not updated during training
            self.optimizer = torch.optim.Adam([{"params" : model.linear1.parameters(), "lr" : LEARNING_RATE_FNN},
                                               {"params" : model.linear2.parameters(), "lr" : LEARNING_RATE_FNN}])
                                               
                                               
        else:
            # the weights of the Bert moel are updated during training
            self.optimizer = torch.optim.Adam([{"params" : model.bert.parameters(), "lr" : LEARNING_RATE_BERT},
                                               {"params" : model.linear1.parameters(), "lr" : LEARNING_RATE_FNN},
                                               {"params" : model.linear2.parameters(), "lr" : LEARNING_RATE_FNN}])


    # Function to train model on $dataloader based on the loss function $loss_function.
    # It does only one epoch.
    # $model is the model to be trained.
    # $dataloader is the dataloader containing the preprocessed batches.
    # $loss_function is the loss function to be used for error and gradient computation.
    # $optimizer is the optimizer to be used to update the learnable parameters according to
    # the gradient of $loss_function.
    def single_epoch_train(self, dataloader):

        # set the model to train phase
        self.model.train()

        # initial loss value
        total_loss = 0
        # initial number of batches used for training
        count_batches = 0

        # iterate over the batches in the dataset
        for batch in dataloader:

            # zero the gradients: we are at the beginning of a new batch
            self.optimizer.zero_grad()
            
            # load the current batch $batch
            samples_tokens, configurations, transitions, gold_trees = batch

            # forward step of the model applied to $batch
            model_output = self.model(samples_tokens, configurations)

            # sum(moves, []) flattens the list of lists $transitions into a single list where
            # the single $transitions[i] are concatenated
            labels = torch.tensor(sum(transitions, [])).to(self.device)
            
            # compute the loss of the current batch $batch w.r.t. the gold labels $labels
            loss = self.loss(model_output, labels)

            # update the number of batches for which the loss has been computed
            count_batches += 1
            # update the total loss, i.e. the sum of the losses of the single considered batches
            # by adding the loss of the current batch $batch
            total_loss += loss.item()

            # backpropagation step to compute the gradient of the loss w.r.t. the learnable parameters of the model
            loss.backward()
            
            # optimization of the learnable parameters of the model
            self.optimizer.step()
        
        # return the mean batch loss
        return total_loss / count_batches


    # Trains the model on the training set provided by $train_dataloader for the given number of epochs.
    # $train_dataloader is the dataloader providing the batched training set.
    # $validation_dataloader is the dataloader providing the batched validation set.
    # $loss_function is the loss function to be used for training the model.
    # $optimizer is the optimizer to be used to train the model.
    # After the application of this function, the models' parameters are updated according to the training procedure.
    # Moreover, the function prints the average training loss and the performances of the models on the validation set
    # at each epoch.
    def custom_train(self, train_dataloader, validation_dataloader):

        # iterate over the desired number of epochs
        for epoch in range(self.number_of_epochs):
            
            # single training step
            average_training_loss = self.single_epoch_train(train_dataloader)
            
            # current performances of the model on the validation set
            validation_performances = self.performance_evaluation(validation_dataloader)

            # print both average batch loss and validation performances w.r.t. the current training epoch
            print("Epoch: {:3d} | Average batch loss training set: {:5.3f} | UAS validation set : {:5.3f} |"
                  .format(epoch, average_training_loss, validation_performances))

    
    # Unlabeled accuracy score: percentage of correctly predicted arcs w.r.t. the total number of arcs.
    # This evaluation metric is computed for all samples in a batch.
    # $gold_arcs_batch[i] is the list of gold arcs for sample $i in the batch.
    # $predicted_arcs_batch[i] is the list of predicted arcs for sample $i in the batch.
    # The function returns the ratio (# correctly predicted arcs in the batch / total # gold arcs in the batch).
    def unlabeled_accuracy_score(self, gold_arcs_batch, predicted_arcs_batch): 
        
        # total number of checked arcs: initialized to 0
        total_number_of_arcs = 0
        # number of correctly predicted arcs: initialized to 0
        correctly_predicted_arcs = 0

        # iterate over the samples in the batch
        for gold_arcs_sample, predicted_arcs_sample in zip(gold_arcs_batch, predicted_arcs_batch):
            
            # iterate over the gold arcs of each sample
            # we start from $1 because the root is never dependent of any head
            for i in range(1, len(gold_arcs_sample)):

                total_number_of_arcs += 1
                
                # if the current gold arc of the current sample is equal to the current predicted arc of the current sample,
                # then update the number of correctly predicted arcs
                if gold_arcs_sample[i] == predicted_arcs_sample[i]:
                    correctly_predicted_arcs += 1

        return correctly_predicted_arcs / total_number_of_arcs


    # Function that computes a measure of the performances of the models in the evaluation set
    # provided by $dataloader.
    # The function used to compute performances values is $unlabeled_accuracy_score.
    # Returns $unlabeled_accuracy_score computed over $dataloader.
    def performance_evaluation(self, dataloader):
        
        # set the model to evaluation time
        self.model.eval()

        # list of gold arcs: initialized as empty
        gold_arcs_list = []
        # list of predicted arcs: initialized as empty
        predicted_arcs_list = []

        # iterate over the batches in the loaded dataloader
        for batch in dataloader:

            # load the current batch $batch
            samples_tokens, configurations, transitions, gold_trees = batch

            # disable torch gradient computation in order not to waste resources
            with torch.no_grad():

                # use the model to predict the arcs
                predicted_arcs_batch = model.infere(samples_tokens)

                # update the lists with checked arcs
                gold_arcs_list += gold_trees
                predicted_arcs_list += predicted_arcs_batch
        
        # return the value returned by the evaluation function defined in the class
        return self.unlabeled_accuracy_score(gold_arcs_list, predicted_arcs_list)

Actual training of the model

In [23]:
# set the device to be used for training the model $model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# set the model to be trained
model = CustomBertBasedModel(device)
# save the model to the desired device
model.to(device)

# Bert freezed
bert_freezed = True
# initialize the CustomTrain object to train the model
trainer = CustomTrain(model, device, bert_freezed=True, epochs=2)
# actual training of the model
trainer.custom_train(train_dataloader, validation_dataloader)

# Bert unfreezed
# initialize the CustomTrain object to train the model
trainer = CustomTrain(model, device, epochs=15)
# actual training of the model
trainer.custom_train(train_dataloader, validation_dataloader)

Device: cuda


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch:   0 | Average batch loss training set: 1.089 | UAS validation set : 0.572 |
Epoch:   1 | Average batch loss training set: 1.004 | UAS validation set : 0.611 |
Epoch:   0 | Average batch loss training set: 0.891 | UAS validation set : 0.797 |
Epoch:   1 | Average batch loss training set: 0.829 | UAS validation set : 0.813 |
Epoch:   2 | Average batch loss training set: 0.816 | UAS validation set : 0.814 |
Epoch:   3 | Average batch loss training set: 0.804 | UAS validation set : 0.835 |
Epoch:   4 | Average batch loss training set: 0.796 | UAS validation set : 0.843 |
Epoch:   5 | Average batch loss training set: 0.795 | UAS validation set : 0.838 |
Epoch:   6 | Average batch loss training set: 0.791 | UAS validation set : 0.830 |


#### Test of Bert model
Test the trained Bert-based model on the test set.

In [24]:
# evaluate UAS performances of the trained model on the test set
test_uas = trainer.performance_evaluation(test_dataloader)
# print the UAS score
print("UAS test set: {:5.3f}".format(test_uas))

UAS test set: 0.842


# Evaluation

# SotA discussion