# Dataset analysis



In [1]:
!pip install datasets  # huggingface library with dataset
!pip install conllu    # aux library for processing CoNLL-U format

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
Collec

In [2]:
import time
import random
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from datasets import load_dataset

# Description of baseline and BERT model

## Arc-eager

A configuration of the arc-eager parser is a triple of the form $(σ, β, A)$ where:

* $\sigma$ is the stack
* $\beta$ is the buffer
* $A$ is the set of arcs constructed so far

Let:

* $\beta_i$, $i\geq1$ the $i$-th token in the buffer
* $\sigma_i$, $i\geq1$ the $i$-th token in the stack

for the $i$-th configuration

The arc-eager parser can perform four types of actions (transitions):

* **left-arc** (LA): create the arc $(\beta_1 → \sigma_1)$ and remove $\sigma_1$ from the stack. The **preconditions** are: $\sigma_1$ is not the ROOT and $\sigma_1$ does not have already an head

* **right-arc** (RA): create the arc $(\sigma_1 → \beta_1)$ and push $\beta_1$ to the stack

* **reduce** (RE): remove $\sigma_1$ from the stack. The **precondition** is: $\sigma_1$ must have a head

* **shift** (SH): remove $\beta_1$ from the buffer and push it to the stack


Let $w = w_0 w_1 \cdots w_{n}$ be the input sentence, with $w_0$ the special symbol `<ROOT>`.
Stack and buffer are implemented as lists of integers, where `j` represents word $w_j$.  Top-most stack token is at the right-end of the list; first buffer token is at the left-end of the list. 
Set $A$ is implemented as an array `arcs` of size $n+1$ such that if arc $(w_i \rightarrow w_j)$ is in $A$ then `arcs[j]=i`, and if $w_j$ is still missing its head node in the tree under construction, then `arcs[j]=-1`. We always have `arcs[0]=-1`.  We use this representation also for complete dependency trees.

In [3]:
# NOTE: here only the transition methods are implemented
# the preconditions will be checked later with the boolean methods

class ArcEager:
  def __init__(self, sentence):
    self.sentence = sentence
    self.buffer = [i for i in range(len(self.sentence))]  # buffer
    self.stack = []   # empty stack
    self.arcs = [-1 for _ in range(len(self.sentence))]  # non-valid arcs

    # initialization of the stack with one SH operations
    # only the <ROOT> in the stack
    self.shift()

  # left arc
  def left_arc(self):
    b1 = self.buffer[0]
    o1 = self.stack.pop()
    self.arcs[o1] = b1

  # right arc
  def right_arc(self):
    b1 = self.buffer[0]
    o1 = self.stack[-1]
    self.arcs[b1] = o1
    self.buffer = self.buffer[1:]
    self.stack.append(b1)
    
  # reduce
  def reduce(self):
    _ = self.stack.pop()

  # shift
  def shift(self):
    b1 = self.buffer[0]
    self.buffer = self.buffer[1:]
    self.stack.append(b1)

  # print configuration (for debug)
  def print_configuration(self):
    s = [self.sentence[i] for i in self.stack]
    b = [self.sentence[i] for i in self.buffer]
    print(s, b)
    print(self.arcs) 

  def is_tree_final(self):
    return len(self.buffer) == 0
    

## Oracle

The transition actions must follow the preconditions above

In [4]:
class Oracle:
  def __init__(self, parser, gold_tree):
    self.parser = parser
    self.gold = gold_tree

  # left arc?
  def is_left_arc_gold(self):
    # get the two elems the arc should be built on
    b1 = self.parser.buffer[0]
    o1 = self.parser.stack[len(self.parser.stack)-1]

    # check preconditions
    if o1 == -1: return False
    if self.parser.arcs[o1] != -1: return False

    # check if gold move
    if self.gold[o1] != b1:
      return False
    else:
      return True

  # right arc?
  def is_right_arc_gold(self):
    b1 = self.parser.buffer[0]
    o1 = self.parser.stack[len(self.parser.stack)-1]

    # no preconditions?
    # check if gold move
    if self.gold[b1] != o1:
      return False
    else:
      return True

  # reduce?
  def is_reduce_gold(self):

    # precondition
    o1 = self.parser.stack[len(self.parser.stack)-1]
    if self.parser.arcs[o1] == -1:
      return False
    
    # check if exist k < o1 s.t. exist (k,b1) or (b1,k) in Agold
    b1 = self.parser.buffer[0]
    if self.gold[b1] < o1:
      return True
    else:
      return b1 in self.gold[:o1]
    

  


## Test on a sentence

In [5]:
# define the sentence and the gold tree
sentence = ["<ROOT>", "He", "wrote", "her", "a", "letter", "."]
gold = [-1, 2, 0, 2, 5, 2, 2]

# initialize parser and oracle
parser = ArcEager(sentence)
oracle = Oracle(parser, gold)

# print initial configuration
parser.print_configuration()

['<ROOT>'] ['He', 'wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [6]:
# until the tree is final, apply the right move

iteration = 0   # keep track of the number of iterations
while not parser.is_tree_final():
  print('Iteration:',iteration)

  # if LA is gold
  if oracle.is_left_arc_gold():
    parser.left_arc()
    parser.print_configuration()
    print('Transition: LA', end='\n\n')

  # elif RA is gold
  elif oracle.is_right_arc_gold():
    parser.right_arc()
    parser.print_configuration()
    print('Transition: RA', end='\n\n')

  # elif RE is gold
  elif oracle.is_reduce_gold():
    parser.reduce()
    parser.print_configuration()
    print('Transition: RE', end='\n\n')

  # else shift
  else:
    parser.shift()
    parser.print_configuration()
    print('Transition: SH', end='\n\n')


  iteration = iteration + 1

# parsing tree completed

Iteration: 0
['<ROOT>', 'He'] ['wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]
Transition: SH

Iteration: 1
['<ROOT>'] ['wrote', 'her', 'a', 'letter', '.']
[-1, 2, -1, -1, -1, -1, -1]
Transition: LA

Iteration: 2
['<ROOT>', 'wrote'] ['her', 'a', 'letter', '.']
[-1, 2, 0, -1, -1, -1, -1]
Transition: RA

Iteration: 3
['<ROOT>', 'wrote', 'her'] ['a', 'letter', '.']
[-1, 2, 0, 2, -1, -1, -1]
Transition: RA

Iteration: 4
['<ROOT>', 'wrote', 'her', 'a'] ['letter', '.']
[-1, 2, 0, 2, -1, -1, -1]
Transition: SH

Iteration: 5
['<ROOT>', 'wrote', 'her'] ['letter', '.']
[-1, 2, 0, 2, 5, -1, -1]
Transition: LA

Iteration: 6
['<ROOT>', 'wrote'] ['letter', '.']
[-1, 2, 0, 2, 5, -1, -1]
Transition: RE

Iteration: 7
['<ROOT>', 'wrote', 'letter'] ['.']
[-1, 2, 0, 2, 5, 2, -1]
Transition: RA

Iteration: 8
['<ROOT>', 'wrote'] ['.']
[-1, 2, 0, 2, 5, 2, -1]
Transition: RE

Iteration: 9
['<ROOT>', 'wrote', '.'] []
[-1, 2, 0, 2, 5, 2, 2]
Transition: RA



In [7]:
# check the arcs obtained are the same as in the gold tree
print('Is tree correct?', oracle.gold == parser.arcs)

Is tree correct? True


# Data set-up and training

## Dataset

## Preprocessing

### Utils for parsing

In [8]:
# function to check if the tree is projective or not
def is_projective(tree):
  for i in range(len(tree)):
    if tree[i] == -1:
      continue
    left = min(i, tree[i])
    right = max(i, tree[i])

    for j in range(0, left):
      if tree[j] > left and tree[j] < right:
        return False
    for j in range(left+1, right):
      if tree[j] < left or tree[j] > right:
        return False
    for j in range(right+1, len(tree)):
      if tree[j] > left and tree[j] < right:
        return False

  return True

def parse_sentence(oracle, show_conf=False):
    if show_conf:
        oracle.parser.print_configuration()
    iteration = 0   # keep track of the number of iterations
    while not oracle.parser.is_tree_final():
        if show_conf:
            print('Iteration:',iteration)

        # if LA is gold
        if oracle.is_left_arc_gold():
            oracle.parser.left_arc()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: LA', end='\n\n')

        # elif RA is gold
        elif oracle.is_right_arc_gold():
            oracle.parser.right_arc()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: RA', end='\n\n')

        # elif RE is gold
        elif oracle.is_reduce_gold():
            oracle.parser.reduce()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: RE', end='\n\n')

        # else shift
        else:
            oracle.parser.shift()
            if show_conf:
                oracle.parser.print_configuration()
                print('Transition: SH', end='\n\n')


        iteration = iteration + 1
    return oracle.parser.arcs

# get a random sentence to test the parsing algorith to be correct
def parse_dataset(dataset):
    for sample in dataset:
        sentence = ['<ROOT>'] + sample['tokens']
        gold_tree = sample['head']
        gold_tree = [-1] + [int(key) for key in gold_tree]
        # initialize parser and oracle
        parser = ArcEager(sentence)
        oracle = Oracle(parser, gold_tree)
        parse_tree = parse_sentence(oracle)
        if parse_tree != gold_tree:
            return False
    return True

### Test the oracle and the parsing functions on the training set

In [9]:
# get english sentences from the dataset
train_dataset = load_dataset('universal_dependencies', 'en_lines', split="train")
# remove non-projective sentences: heads in the gold tree are strings, we convert them to int
train_dataset = [sample for sample in train_dataset if is_projective([-1] + [int(head) for head in sample["head"]])]

# parse the whole dataset $train_dataset using the oracle to check if it works correctly
print(parse_dataset(train_dataset))

Downloading builder script:   0%|          | 0.00/87.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.33M [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/191k [00:00<?, ?B/s]

Downloading and preparing dataset universal_dependencies/en_lines to /root/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/580k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/199k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/181k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/3176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1032 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1035 [00:00<?, ? examples/s]

Dataset universal_dependencies downloaded and prepared to /root/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7. Subsequent calls will reuse this data.
True


### Pre-processing utils
- the $create\_token\_indices$ function creates a vocabulary $vocab$ containing all tokens in $dataset$ (given as input argument) appearing at least $threshold$ (given as input argument, default value 3) times. $vocab[token]$ is the index assigned to $token;

- the $process\_sample$ function is used to process our data and create the actual training samples.
For each sentence in $train\_dataset$, we use our oracle to compute the canonical path followed to extract the gold tree. We then pair each configuration to the golden transition selected by the oracle.
Because of the structure of Arc-Eager parser, we encode a $configuration$ with only two words: $\sigma_1$ and $\beta_1$ (i.e. the topmost element on the stack and the first buffer element, respectiveley);

- $prepare\_batch$ function pre-processes a batch of samples $batch\_data$ using as indices of tokens the ones contained in the vocabulary $tokens\_indices\_voc$. The pre-processing is done by applying function $process\_sample$ to each sample in $batch\_data$.

In [10]:
# the function returns a dictionary containing the vocabulary embedding indices for the tokens in $dataset
# i.e. an index is associated to each token in $dataset
# $threshold is the minimum number of appearance for a token in $dataset to be included in the dictionary
def create_token_indices(dataset, threshold=3):
  # $dic has the tokens as keys. Given a token, $dic[token] is the number of occurrences of $token in $dataset
  dic = {}
  for sample in dataset:
    for token in sample['tokens']:
      if token in dic:
        dic[token] += 1
      else:
        dic[token] = 1 

  # vocab["token"] is an integer representing the index of token "token"
  vocab = {}
  # indices for some special tokens
  vocab["<pad>"] = 0
  vocab["<ROOT>"] = 1 
  vocab["<unk>"] = 2

  next_ind = 3
  for token in dic.keys():
    if dic[token] >= threshold:
      vocab[token] = next_ind
      next_ind += 1

  return vocab

# creates training instances from one sample if $get_gold_path is $True, otherwise does only a simple pre-process
# $sample is a sample of our dataset
def process_sample(tokens_indices_voc, sample, get_gold_path = False):

  # add the root token to the sentence and its head (-1) to the gold head list
  sentence = ["<ROOT>"] + sample["tokens"]
  gold = [-1] + [int(i) for i in sample["head"]]
  
  # sentence representation with each token represented by its index in the vocabulary
  sentence_repr = [tokens_indices_voc[token] if token in token_indices_voc else token_indices_voc["<unk>"] for token in sentence]

  # $gold_configurations and $gold_transitions are parallel arrays whose elements refer to parsing steps
  # $gold_configurations[i] records configuration at step $i, i.e. topmost stack token and first buffer token for current step
  gold_configurations = []
  # $gold_transitions[i] contains oracle (canonical) transition for step $i: 0 is left_arc, 1 right_arc, 2 reduce, 3 shift
  gold_transitions = []

  # only for training
  if get_gold_path:
    parser = ArcEager(sentence)
    oracle = Oracle(parser, gold)

    while not parser.is_tree_final():
      
      # save configuration
      configuration = [parser.stack[-1]]
      if len(parser.buffer) == 0:
        configuration.append(-1)
      else:
        configuration.append(parser.buffer[0])  
      gold_configurations.append(configuration)

      # save gold transition
      if oracle.is_left_arc_gold():
        gold_transitions.append(0)
        oracle.parser.left_arc()
      elif oracle.is_right_arc_gold():
        gold_transitions.append(1)
        oracle.parser.right_arc()
      elif oracle.is_reduce_gold():
        gold_transitions.append(2)
        oracle.parser.reduce()
      else:
        gold_transitions.append(3)
        oracle.parser.shift()

  # $sentence_repr is a list containing representations of tokens in $sample
  # $gold is a list containning the representation of the gold tree of $sample
  # $gold_configurations is a list containing gold configurations
  # $gold_transitions is a list containing gold transitions
  return sentence_repr, gold_configurations, gold_transitions, gold

# this function is used to pre-process a batch of samples $batch_data from the original dataset
# applying function $process_sample to each sample in $batch_data
def prepare_batch(tokens_indices_voc, batch_data, get_gold_path=False):
  processed_batch_data = [process_sample(tokens_indices_voc, sample, get_gold_path=get_gold_path) for sample in batch_data]
  
  sentences_repr, paths, moves, gold_trees = []
  for sample in processed_batch_data:
    sentences_repr.append(sample[0])
    paths.append(sample[1])
    moves.append(sample[2])
    gold_trees.append(sample[3])
    
    # sentences_repr, paths, moves, gold_trees are parallel lists
    # element in position $i of each of the above lists refers to the same sentence $i 
  return sentences_repr, paths, moves, gold_trees

### Datasets loading and pre-processing

In [11]:
# the training set has already been loaded, then load also development set and test set
dev_dataset = load_dataset('universal_dependencies', 'en_lines', split="validation")
test_dataset = load_dataset('universal_dependencies', 'en_lines', split="test")

# set the number of samples per batch
BATCH_SIZE = 32

# create the dictionary with token indices 
tokens_indices = create_token_indices(train_dataset)

# create dataloaders to batch each dataset and apply function $prepare_batch to each batch
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, tokens_indices, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))



#### Now everything is ready to create a BiLSTM and extract embeddings from configurations in order to then train a classifier based on gold trees

### BERT model
In this section, we will use BERT to extract a feature vector for each token $\sigma_2$, $\sigma_1$, $\beta_1$ in $configuration$.

In [12]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m122.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.3 transformers-4.29.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━

#### Pre-processing
We pre-process a batch of samples

In [13]:
# creates training instances from one sample
# $sample is a sample of our dataset
def process_sample_bert(sample, get_gold_path=False):

  # add the root token to the sentence and its head (-1) to the gold head list
  sentence = ["<ROOT>"] + sample["tokens"]
  gold = [-1] + [int(i) for i in sample["head"]]
  
  # save the tokens
  sample_tokens = []
  for token in sample["tokens"]:
    sample_tokens.append(token)

  # $gold_configurations and $gold_transitions are parallel arrays whose elements refer to parsing steps
  # $gold_configurations[i] records configuration at step $i, i.e. topmost stack token and first buffer token for current step
  gold_configurations = []
  # $gold_transitions[i] contains oracle (canonical) transition for step $i: 0 is left_arc, 1 right_arc, 2 reduce, 3 shift
  gold_transitions = []

  # only for training
  if get_gold_path:
    parser = ArcEager(sentence)
    oracle = Oracle(parser, gold)

    while not parser.is_tree_final():
      
      # save configuration
      configuration = [parser.stack[-1]]
      if len(parser.buffer) == 0:
        configuration.append(-1)
      else:
        configuration.append(parser.buffer[0])  
      gold_configurations.append(configuration)

      # save gold transition
      if oracle.is_left_arc_gold():
        gold_transitions.append(0)
        oracle.parser.left_arc()
      elif oracle.is_right_arc_gold():
        gold_transitions.append(1)
        oracle.parser.right_arc()
      elif oracle.is_reduce_gold():
        gold_transitions.append(2)
        oracle.parser.reduce()
      else:
        gold_transitions.append(3)
        oracle.parser.shift()

  # $sample_tokens is a list containing the tokens in $sample
  # $gold is a list containning the representation of the gold tree of $sample
  # $gold_configurations is a list containing gold configurations
  # $gold_transitions is a list containing gold transitions
  return sample_tokens, gold_configurations, gold_transitions, gold

# this function is used to pre-process a batch of samples $batch_data from the original dataset
# applying function $process_sample to each sample in $batch_data
def prepare_batch_bert(batch_data, get_gold_path=False):
  processed_batch_data = [process_sample_bert(sample, get_gold_path=get_gold_path) for sample in batch_data]
  
  samples_tokens, paths, moves, gold_trees = []
  for sample in processed_batch_data:
    samples_tokens.append(sample[0])
    paths.append(sample[1])
    moves.append(sample[2])
    gold_trees.append(sample[3])
    
    # $samples_tokens, $paths, $moves, $gold_trees are parallel lists
    # element in position $i of each of the above lists refers to the same sentence $i 
  return samples_tokens, paths, moves, gold_trees

#### Data loaders
We load the sets preparing batches according to the function $prepare\_batch\_bert$.
In this way, we will be able to train our model in parallel in the samples belonging to the same batch.

In [14]:
# load train set, development set and test set
train_dataset = load_dataset('universal_dependencies', 'en_lines', split="train")
dev_dataset = load_dataset('universal_dependencies', 'en_lines', split="validation")
test_dataset = load_dataset('universal_dependencies', 'en_lines', split="test")

# set the number of samples per batch
BATCH_SIZE = 32

# create dataloaders to batch each dataset and apply function $prepare_batch to each batch
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch_bert, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch_bert))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch_bert))



#### Model creation
Define the model to be used.

In [None]:
# Model design parameters
BERT_TOKEN_EMBEDDING_SIZE = 768
CLASSIFICATION_LABELS = 4
WORD_PER_CONFIGURATION = 2
# Model hyperparameters
LINEAR_LAYER_SIZE = 300
DROPOUT = 0.3


from transformers import BertModel
from transformers import BertTokenizer

# Custom Bert model
class CustomBertBasedModel(nn.Module):
    

    # CONSTRUCTORS


    # Constructor specifications
    def __init__(self, device):
        
        # inherit the constructor of $nn.Module
        super(CustomBertBasedModel, self).__init__()

        # name of the Bert model to be used
        bert_model_name = "bert-base-uncased"

        # tokenizer used for Bert
        self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)

        # Bert model
        self.bert = BertModel.from_pretrained(bert_model_name)

        # feedforward head for classification
        self.droput = nn.Dropout(DROPOUT)
        self.linear1 = nn.Linear(BERT_TOKEN_EMBEDDING_SIZE * WORD_PER_CONFIGURATION, LINEAR_LAYER_SIZE, bias=True)
        self.activation1 = nn.Tanh()
        self.linear2 = nn.Linear(LINEAR_LAYER_SIZE, CLASSIFICATION_LABELS, bias=True)
        self.softmax = nn.Softmax(dim=-1) # the dimension where the softmax has to be applied is the last one

        # device
        self.device = device


    # UTILITY FUNCTIONS: ARCEAGER PARSE


    # Check if all parsers $parsers have finished parsing their sentences.
    # Returns:
    #   ~ True if all parsers have finished;
    #   ~ False otherwise.
    def parsed_all(self, parsers):
        for parser in parsers:
            if not parser.is_tree_final():
                return False
        return True


    def get_configurations(self, parsers):
        
        # list of current configurations: $configurations[i] is the current configuration for parser $parsers[i]
        configurations = []

        for parser in parsers:
            # current configuration of $parser
            configuration = []

            if parser.is_tree_final():
                # empty configuration
                configuration = [-1, -1]

            else:
                # append the first element of the configuration: the first element on top of the stack
                configuration.append(parser.stack[-1])
                if len(parser.buffer) == 0:
                    # no element in the buffer, so append $-1
                    configuration.append(-1)
                else:
                    # append the second element of the configuration: the first element in the buffer
                    configuration.append(parser.buffer[0])
            
            # append the current configuration for the current parser $parser
            configurations.append(configuration)

        return configurations


    # Function that selects and performs the next ArcEager transition according to the scores in $moves.
    # $parsers is a list of ArcEager parsers: $parsers[i] is parsing sentence $i in the batch.
    # $fnn_output is the output tensor of the fnn predicting scores for the $CLASSIFICATION_LABELS possible transitions.
    # $fnn_output has size (# sentences in batch, $CLASSIFICATION_LABELS).
    # We need to be careful and select correct moves, e.g. don't do a shift if the buffer
    # is empty or a left arc if σ2 is the ROOT. For clarity sake we didn't implement
    # these checks in the parser so we must do them here. This makes the function quite ugly, but necessary.
    def parse_step(self, parsers, fnn_output):
        
        # get the transitions with highest probability: $highest_prob_transitions[i] is the transition with highest predicted
        # probability for parser $parsers[i] (which is parsing sentence $i)
        highest_prob_transitions = fnn_output.argmax(-1) # the last dimension of the tensor contains the labels' prediction scores
        
        # get the best transition for each parser in $parsers
        for i in range(len(parsers)):
            
            if parsers[i].is_tree_final():
                # do nothing if the current parser has already parsed the whole sentence
                continue

            else:

                # label 0: left-arc
                if highest_prob_transitions[i] == 0:
                    if parsers[i].stack[len(parsers[i].stack)-2] != 0:
                        parsers[i].left_arc()
                    else:
                        if len(parsers[i].buffer) > 0:
                            parsers[i].shift()
                        else:
                            parsers[i].right_arc()

                # label 1: right-arc                                        
                elif highest_prob_transitions[i] == 1:
                    if parsers[i].stack[len(parsers[i].stack)-2] == 0 and len(parsers[i].buffer)>0:
                        parsers[i].shift()
                    else:
                        parsers[i].right_arc()

                # label 2: reduce
                elif highest_prob_transitions[i] == 2:
                    if len(parsers[i].buffer) > 0:
                        parsers[i].shift()
                    else:
                        if fnn_output[i][0] > fnn_output[i][1]:
                            if parsers[i].stack[len(parsers[i].stack)-2] != 0:
                                parsers[i].left_arc()
                            else:
                                parsers[i].right_arc()
                        else:
                            parsers[i].right_arc()
                
                # label 3: shift
                elif highest_prob_transitions[i] == 2:
                    return


    # UTILITY FUNCTIONS: MODEL


    # Given a list of tokens $list_tokens representing words in a sentence, the function reconstructs the
    # string representing the sentence.
    def sentence_reconstruction(list_tokens):
        sentence = ""
        for token in list_tokens:
            sentence = sentence + " " + token
        return sentence


    # The function averages tokens hidden representations $hidden_vectors in order to
    # obtain a representation for each token in $desired_tokens starting from tokens in $current_tokens.
    # More precisely, it averages representations of tokens in $current_tokens that are subtokens of the
    # same token in $desired_tokens.
    # $hidden_vectors contains representations for tokens in $current_tokens. It is a torch tensor of size
    # (len($current_tokens), $BERT_TOKEN_EMBEDDING_SIZE).
    # $current_tokens contains the tokens for which we have a corresponding representation in $hidden_vectors.
    # $desired_tokens contains the tokens for which we want a representation.
    # returns the tensor $torch.stack(new_hidden_vectors, dim=0) of size (len($desired_tokens), $BERT_TOKEN_EMBEDDING_SIZE), where
    # torch.stack(new_hidden_vectors, dim=0)[i] is the hidden vector representing $desired_tokens[i] as the mean
    # of the hidden vectors in $hidden_vectors representing tokens in $current_tokens that are subtokens of $desired_tokens[i].
    def average_layer(hidden_vectors, current_tokens, desired_tokens):
        
        # list of hidden vectors representing $desired_tokens. Actually, it is a list of torch tensors.
        new_hidden_vectors = []

        # index of the current token in $current_tokens
        current_token_index = -1
        # build one hidden vector per token in $desired_tokens
        for word in desired_tokens:
            current_token_index = current_token_index + 1
            # current subword of $word
            current_word = current_tokens[current_token_index]
            # list of hidden vectors in $hidden_vectors representing tokens that are subword tokens of $word
            sub_word_vectors = [hidden_vectors[current_token_index]]
            # append hidden vectors representing subword tokens of $word until $word and $current_word are equal
            while not current_word.lower() == word.lower():
                # we add to $current_word the next subword token of $word. With [2:], we remove "##".
                current_token_index = current_token_index + 1
                current_word = current_word + current_tokens[current_token_index][2:]
                sub_word_vectors.append(hidden_vectors[current_token_index])
            # build a tensor of size (len($sub_word_vectors), $BERT_TOKEN_EMBEDDING_SIZE) from the list $sub_word_vectors by stacking subword
            # vectors along the first dimension
            stacked_tensor = torch.stack(sub_word_vectors, dim=0)
            # append to the list $new_hidden_vectors the hidden vector representing $word, which is computed as the mean
            # of the hidden vectors in $hidden_vectors representing its subword tokens
            new_hidden_vectors.append(torch.mean(stacked_tensor, dim=0))

        # we use $torch.stack to build a tensor of size (len($desired_tokens), $BERT_TOKEN_EMBEDDING_SIZE) from the
        # list of tensors $new_hidden_vectors
        return torch.stack(new_hidden_vectors, dim=0)
    

    # The function, given the outputs of Bert applied to a batch of sentences, returns the inputs for the FNN.
    # $bert_outputs_list contains Bert representations of sentences in a batch.
    # $bert_outputs_list[i] is the Bert representation for sentence $i in the batch.
    # $sentences_words[i] is the list of words in sentence $i in the batch.
    # $sentences_bert_tokens[i] is the list of tokens provided by Bert tokenizer applied to sentence $i in the batch.
    # $configurations[i] contains the configurations of sentence $i in the batch.
    # returns $fnn_inputs: it is a tensor of features representing configurations. Its size is (# configurations, 2 * $BERT_TOKEN_EMBEDDING_SIZE).
    def prepare_inputs_fnn(self, bert_outputs_list, sentences_words, sentences_bert_tokens, configurations):
        
        # $fnn_inputs[i] will contain the input for the FNN related to sample $i in the batch
        fnn_inputs = []
        
        # iterate over the number of samples in the batch, i.e. the number of sentences processed by Bert
        for i in range(len(configurations)):

            # embedding representation of current sentence provided by Bert
            # hidden_repr_sentence is a tensor of size (1, # tokens in the sentence, $BERT_TOKEN_EMBEDDING_SIZE)
            hidden_repr_sentence = bert_outputs_list[i].last_hidden_state
            # remove special token [CLS] at the beginning of the sentence
            hidden_repr_sentence = hidden_repr_sentence[0, 1:, :]
            # average embeddings of tokens belonging to the same word
            words_bert_embeddings = self.average_layer(hidden_repr_sentence, sentences_bert_tokens, sentences_words)

            # iterate over the configurations of the current sentence and append the representation of each configuration to $fnn_inputs
            for configuration in configurations[i]:
                
                # if $configuration[0] == -1, then we have no word on the stack, so the representation for the token on the stack is zero
                tensor_sigma = torch.zeros(BERT_TOKEN_EMBEDDING_SIZE, requires_grad=False).to(self.device)
                # else, if there is something on top of the stack, then we take its Bert representation $words_bert_embeddings[configuration[0]]
                if configuration[0] != -1:
                    tensor_sigma = words_bert_embeddings[configuration[0]].to(self.device)
                
                # if $configuration[1] == -1, then we have no word in the buffer, so the representation for the token in the stack is zero
                tenesor_beta = torch.zeros(BERT_TOKEN_EMBEDDING_SIZE, requires_grad=False).to(self.device)
                # else, if there is something at the beginning of the buffer, then its Bert representation is
                # $words_bert_embeddings[configuration[1]]
                if configuration[1] != -1:
                    tensor_beta = words_bert_embeddings[configuration[1]].to(self.device)
                
                # we concatenate the representation of the current configuration $configuration as the concatenation of the embeddings
                # of, respectively, the first token on top of the stack and the first token in the buffer
                fnn_inputs.append(torch.cat(tensor_sigma, tensor_beta))
        
        # the inputs for the fnn are the representations of all configurations of all sentences in the batch
        # we stack them into  single tensor
        fnn_inputs = torch.stack(fnn_inputs).to(self.device)
            
        return fnn_inputs


    # For each list of tokens in $sentences_tokens, the function reconstructs the string representing the
    # original sentence. Moreover, it applies the Bert tokenizer to each sentence and prepares each
    # sentence in the batch for being processed as input of the Bert model.
    # Returns:
    #   ~ $bert_inputs: inputs for Bert model;
    #   ~ $sentences_tokenizers: list containing the list of tokens returned by Bert tokenizer for each sentence.
    def bert_tokenization_and_input(self, sentences_tokens):
        
        # prepare the inputs for Bert: we apply the tokenizer to each sentence in the batch
        bert_inputs = []
        # $sentence_tokenizers[i] contains tokens provided by Bert tokenizer of sentence $i
        sentences_tokenizers = []

        # iterate over the samples in the batch
        for sentence_tokens in sentences_tokens:
            # reconstruct the original sentence string by the word tokens $sentence_tokens
            sentence = self.sentence_reconstruction(sentence_tokens)
            # get the input representation to feed Bert with $sentence
            bert_inputs.append(self.bert_tokenizer(sentence, truncation=True, add_special_tokens=True, return_tensors="pt"))
            # store the tokens provided by Bert tokenizer
            sentences_tokenizers.append(self.bert_tokenizer.tokenize(sentence))

        return bert_inputs, sentences_tokenizers


    # IMPORTANT INSTANCE FUNCTIONS


    # Function that applies the feedforward neural network to the input tensor $input.
    # It returns the output tensor of the fnn.
    def fnn(self, input):

        # first layer of the fnn
        output_layer1 = self.activation1(self.linear1(self.dropout(input)))
        # second layer of the fnn
        output_layer2 = self.softmax(self.linear2(self.dropout(output_layer1)))
        
        return output_layer2


    # Forward method that we use at training time. It does the forward pass for a batch of samples at the time.
    # $sentences_tokens[i] is the list of tokens of sentence $i in the batch.
    # $configurations contains the configurations of all samples in the batch.
    # $configurations[i] is the list of configurations related to $sentences[i].
    # $configurations[i][i] is the configuration at step $i when parsing sentence $sentences[i].
    def forward(self, sentences_tokens, configurations):
        
        # $bert_inputs contains the inputs for the bert model, after having applied $self.bert_tokenizer
        # to each sentence in the batch.
        # $bert_tokenized_sentences is a list containing tokenized sentences of the batch.
        bert_inputs, bert_tokenized_sentences = self.bert_tokenization_and_input(sentences_tokens)

        # Bert outputs: one per sample in the batch
        bert_outputs = []
        for bert_input in bert_inputs:
            bert_outputs.append(self.bert(**bert_input).to(self.deivce))

        # prepare the inputs for the feedforward neural network
        fnn_inputs = self.prepare_inputs_fnn(bert_outputs, sentences_tokens, bert_tokenized_sentences, configurations)

        # run the feedforward neural network with all parsing configurations of all sentences in the batch as inputs
        fnn_outputs = self.fnn(fnn_inputs)

        # return the output of the forward pass in the whole architecture
        return fnn_outputs


    # Forward method that we use at inference time.
    # It processes a batch of samples at the time.
    # It infers a list of parsing trees: one per sample in the batch.
    def infere(self, sentences_tokens):
        
        # construct a parser for each sentence in the batch
        parsers = [ArcEager(sentence_tokens) for sentence_tokens in sentences_tokens]

        # this part is the same as the first part of the $forward function
        bert_inputs, bert_tokenized_sentences = self.bert_tokenization_and_input(sentences_tokens)
        bert_outputs = []
        for bert_input in bert_inputs:
            bert_outputs.append(self.bert(**bert_input).to(self.deivce))

        # now we do not have the gold configurations, so we need to compute and score them step by step according to the model's predictions
        while not self.parsed_all(parsers):

            # get the current configurations of the parsers: $configurations[i] is the current configuration of parser $parsers[i]
            configurations = self.get_configurations(parsers)
            fnn_inputs = self.prepare_inputs_fnn(bert_outputs, sentences_tokens, bert_tokenized_sentences, configurations)
            fnn_outputs = self.fnn(fnn_inputs)
            
            # given the current configuration, select the correct transition based on the scores given by the fnn
            self.parse_step(parsers, fnn_outputs)

        # return the predicted dependency trees: one per sentence in the batch
        return [parser.arcs for parser in parsers]

# Evaluation

# SotA discussion