In [1]:
%%capture
#!pip install transformers
#!pip install datasets
#!pip install conllu

In [2]:
import torch
import itertools
import numpy as np
import torch.nn as nn
from functools import partial
import matplotlib.pyplot as plt
from datasets import load_dataset
from collections import defaultdict
from transformers import BertTokenizer, BertModel, AutoModel, AutoTokenizer

# Bert model and other hyperparameters

In [3]:
# Hyperparameters
BERT_MODEL = 'dbmdz/bert-base-italian-xxl-cased'

BATCH_SIZE = 64
LEARNING_RATE = 0.0001
DROPOUT_RATE = 0.3
MLP_SIZE_1 = 500
MLP_SIZE_2 = 150
INPUT_SIZE = 3*768


EPOCHS = 5
SAVE_PATH = "saved_checkpoint.pt"
TOKENIZER_MAX_LEN = 75
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Arc-standard

Recall that a **configuration** of the arc-standard parser is a triple of the form $( \sigma, \beta, A)$
where:

* $\sigma$ is the stack;
* $\beta$ is the input buffer;
* $A$ is a set of arcs constructed so far.

We write $\sigma_i$, $i \geq 1$, for the $i$-th token in the stack; we also write $\beta_i$, $i \geq 1$, for the $i$-th token in the buffer. 

The parser can perform three types of **actions** (transitions):

* **shift**, which removes $\beta_1$ from the buffer and pushes it into the stack;
* **left-arc**, which creates the arc $(\sigma_1 \rightarrow \sigma_2)$, and removes $\sigma_2$ from the stack;
* **right-arc**, which creates the arc $(\sigma_2 \rightarrow \sigma_1)$, and removes $\sigma_1$ from the stack.

Let $w = w_0 w_1 \cdots w_{n}$ be the input sentence, with $w_0$ the special symbol `<ROOT>`.
Stack and buffer are implemented as lists of integers, where `j` represents word $w_j$.  Top-most stack token is at the right-end of the list; first buffer token is at the left-end of the list. 
Set $A$ is implemented as an array `arcs` of size $n+1$ such that if arc $(w_i \rightarrow w_j)$ is in $A$ then `arcs[j]=i`, and if $w_j$ is still missing its head node in the tree under construction, then `arcs[j]=-1`. We always have `arcs[0]=-1`.  We use this representation also for complete dependency trees.


In [4]:
import orig #file orig.py in the same foldr contains original class and function definitions

In [5]:
class ArcStandard (orig.ArcStandard):
    def __init__(self, sentence, label_set):
        self.sentence = sentence
        self.buffer = [i for i in range(len(self.sentence))]
        self.stack = []
        self.arcs = [(-1, -1) for _ in range(len(self.sentence))]# head - relation converted to int

        self.label_set = {y:x for x,y in label_set.items()} #for future reference (printing)
        # three shift moves to initialize the stack
        self.shift()
        self.shift()
        if len(self.sentence) > 2:
            self.shift()

    def left_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o2] = (o1, deprel) #added deprel
        self.stack.append(o1)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def right_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o1] = (o2, deprel) #added deprel
        self.stack.append(o2)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def print_configuration(self):
        s = [self.sentence[i] for i in self.stack]
        b = [self.sentence[i] for i in self.buffer]
        print("STACK: ", s, "BUFFER: ", b) #added indication of stack and buffer
        print([(x[0], self.label_set[x[1]]) for x in self.arcs])

In [6]:
class Oracle (orig.Oracle):
    def __init__(self, parser, gold_tree, gold_labels):
        self.parser = parser
        self.gold = gold_tree
        self.labels = gold_labels #must be integers; see below

    def get_gold_label(self):
        if(self.is_left_arc_gold()):
            o2 = self.parser.stack[len(self.parser.stack)-2]
            return self.labels[o2]
        elif(self.is_right_arc_gold()):
            o1 = self.parser.stack[len(self.parser.stack)-1]
            return self.labels[o1]
        else:
            raise Exception("Action SHIFT does not produce a labeled dependency")
            

## Functions to simplify the processing

In [7]:
"""
dataset should be a Dataset object from HuggingFace
Return a dictionary where arc labels are associate to integer values.
Label '_' is removed (used for composite tokens)
Labels are sorted in alphabetical order because MSE in the neural network will penalize less classes that are "close".
<ROOT> label is given value -1 (label of the <ROOT> node).
"""
def create_deprel_dict(dataset):
    labels = set()
    for x in dataset['deprel']:
        for elem in x:
            if(elem != '_'):
                labels.add(elem)
    labels = sorted(list(labels))
    ret = {labels[i]:i for i in range(len(labels))}
    ret.update({"None":-1})
    return ret

"""
Return (sanitized_tokens, gold tree, gold labels) for a sentence in the dataset, assuming gold tree is in a column labeled 'head' and labels in a column labeled 'deprel'.
Both token indices and labels are converted to integer (in the latter case, according to deprel_dict).
Sanitized tokens is a tokenlist with removal of composite tokens.
"""
def create_gold(sentence, deprel_dict):
    sanitized = ['<ROOT>']
    gold_tree = [-1]
    gold_labels = [-1]
    for i in range(len(sentence['tokens'])):
        if(sentence['head'][i] != 'None') and sentence['deprel'][i] in deprel_dict: #second condition is to avoid unseen or wrong labels
            sanitized.append(sentence['tokens'][i])
            gold_tree.append(int(sentence['head'][i]))
            gold_labels.append(deprel_dict[sentence['deprel'][i]])
    
    return sanitized, gold_tree, gold_labels

# Convert the list of words into a list of sentences to be preprocessed by BERT Tokenizer
def to_list_sentences(sentences):
  list_sentences = []

  for lst in sentences:
    s = ""
    for word in lst:
      s+= (word + " ")
    list_sentences.append(s)
  return list_sentences

## Testing parser and oracle...

In [8]:
%%capture
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")

In [9]:
deprels = create_deprel_dict(train_dataset)
sentence = train_dataset[3]

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)

print(tokens)
print(gold_tree)
print(gold_labels)

['<ROOT>', 'Inconsueto', 'allarme', 'a', 'la', 'Tate', 'Gallery', ':']
[-1, 2, 0, 5, 5, 2, 5, 2]
[-1, 4, 41, 8, 17, 31, 28, 40]


In [10]:
parser = ArcStandard(tokens, deprels)
oracle = Oracle(parser, gold_tree, gold_labels)

parser.print_configuration()

STACK:  ['<ROOT>', 'Inconsueto', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]


In [11]:
while not parser.is_tree_final():  # transition precedence implemented here
    if oracle.is_shift_gold():  
        parser.shift()
    elif oracle.is_left_arc_gold():
        parser.left_arc(oracle.get_gold_label())
    elif oracle.is_right_arc_gold():
        parser.right_arc(oracle.get_gold_label())
    
    parser.print_configuration()

STACK:  ['<ROOT>', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a'] BUFFER:  ['la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la'] BUFFER:  ['Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (5, 'det'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'No

## Testing the whole dataset

In [12]:
# remember to call orig.is_projective() to check for projectivity!!
sentence = train_dataset[10]
deprels = create_deprel_dict(train_dataset)

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)
orig.is_projective(gold_tree)

False

In [13]:
### NOTE: THIS CELL WAS COPIED WITH FEW MODIFICATIONS FROM THE ORIGINAL FILE

# oracle test: run the parser guided by the oracle on the entire training set 

non_projective = 0
correct = 0
wrong = 0

deprels = create_deprel_dict(train_dataset)

for sample in train_dataset:
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)

    if not orig.is_projective(gold_tree):
        non_projective += 1
        continue

    parser = ArcStandard(tokens, deprels)
    oracle = Oracle(parser, gold_tree, gold_labels)

    while not parser.is_tree_final():
        if oracle.is_left_arc_gold(): 
            parser.left_arc(oracle.get_gold_label())
        elif oracle.is_right_arc_gold():
            parser.right_arc(oracle.get_gold_label())
        elif oracle.is_shift_gold(): 
            parser.shift()

    for j in range(len(gold_tree)):  # comparing heads from parser and gold for actual sample
        if gold_tree[j] == parser.arcs[j][0] and gold_labels[j] == parser.arcs[j][1]: 
            correct += 1
        else:
            wrong += 1

print("non projective: ", non_projective)
print("correct: ", correct)
print("wrong: ", wrong)

non projective:  167
correct:  282533
wrong:  0


## Creating samples for the neural oracle

In [14]:
# Modified from the original
"""
This function processes a single sample, which is one sentence provided as an element of a Dataset object and returns 
    enc_sentence : a list of integers encoding the phrase #???? is this even correct for BERT???
    gold_path : a list of configurations
    gold_moves : a list of (move, label)
"""
def process_sample(sample, deprels, get_gold_path = False): #emb_dictionary and deprels are dictionaries of words and of dependency relations
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)
    #print("DEBUG:", tokens)
    #enc_sentence = [word if word in emb_dictionary else "<unk>" for word in tokens] #NOT NEEDED! BERT takes care of unknown words
    #print("DEBUG:", enc_sentence)
    
    # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
    gold_path = []   # record two topmost stack token and first buffer token for current step
    gold_moves = []  # oracle (canonical) move for current step: 100 is left, 0 right, -100 shift #motivations provided above TODO

    if get_gold_path:  # only for training
        parser = ArcStandard(tokens, deprels)
        oracle = Oracle(parser, gold_tree, gold_labels)

        while not parser.is_tree_final():
            configuration = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])

            gold_path.append(configuration)
            
            if oracle.is_left_arc_gold():
                label_code = oracle.get_gold_label() #deprels[oracle.get_gold_label()]
                parser.left_arc(label_code) #labels 1-45 mean leftarc
                gold_moves.append(1+label_code) #note: I switched the instructions here for symmetry. There was no comment or good reason not to.
            elif oracle.is_right_arc_gold():
                label_code = oracle.get_gold_label() #deprels[oracle.get_gold_label()]
                parser.right_arc(1 + len(deprels) + label_code) #labels 46-90 mean rightarc
                gold_moves.append(1 + len(deprels) + label_code)
            elif oracle.is_shift_gold():
                parser.shift()
                gold_moves.append(0)
            else:
                print("**** AN ERROR OCCURRED ****")
                #print(tokens)
                #print(enc_sentence)
                #print(gold_tree)
                parser.print_configuration()
                raise Exception("No action identified as gold!! Please make sure your dataset doesn't contain nonprojective trees.")
                
            #print("DEBUG: ", parser.stack)
    
    #print("DEBUG: length of path: ", len(gold_path), "length of moves: ", len(gold_moves))
    return tokens, gold_path, gold_moves, gold_tree, gold_labels


In [15]:
### Code to test process_sample
"""
test_sentence = train_dataset[88]
emb_dictionary = create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

process_sample(test_sentence, emb_dictionary, deprels, get_gold_path = True)
"""

'\ntest_sentence = train_dataset[88]\nemb_dictionary = create_dict(train_dataset)\ndeprels = create_deprel_dict(train_dataset)\n\nprocess_sample(test_sentence, emb_dictionary, deprels, get_gold_path = True)\n'

In [16]:
"""
Function that prepares the data for the model.
@return sentences : list of sentences in the dataset encoded according to emb_dictionary
@return paths : list of lists of configurations visited during the oracle-guided parsing (on training data) | empty list if get_gold_path is False
@return moves : list of lists of moves as a number performed during the oracle-guided parsing (on training data) | empty if get_gold_path is False
                                N.B. 0 = shift
                                        1-45 = leftarc + label
                                        46-90 = rightarc + label
@return trees : ground truth tree
@return labels : ground truth lables
"""
#modified form orig
def prepare_batch(batch_data, deprels, get_gold_path=False):
    data = [process_sample(s, deprels, get_gold_path=get_gold_path) for s in batch_data]
    # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
    sentences = [s[0] for s in data]
    paths = [s[1] for s in data]
    moves = [s[2] for s in data]
    trees = [s[3] for s in data]
    labels = [s[4] for s in data]
    return sentences, paths, moves, trees, labels

In [17]:
### code to test prepare_batch
"""
emb_dictionary = create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

_ = prepare_batch(iter(train_dataset), emb_dictionary, deprels, True)
"""

'\nemb_dictionary = create_dict(train_dataset)\ndeprels = create_deprel_dict(train_dataset)\n\n_ = prepare_batch(iter(train_dataset), emb_dictionary, deprels, True)\n'

In [18]:
"""
Copied from original.
Modified casing of pad and unk.
"""
def create_dict(dataset, threshold=3):
  dic = {}  # dictionary of word counts
  for sample in dataset:
    for word in sample['tokens']:
      if word in dic:
        dic[word] += 1
      else:
        dic[word] = 1 

  map = {}  # dictionary of word/index pairs
  map["[PAD]"] = 0
  map["<ROOT>"] = 1
  map["<unk>"] = 2

  next_indx = 3
  for word in dic.keys():
    if dic[word] >= threshold:
      map[word] = next_indx
      next_indx += 1

  return map

In [19]:
"""
Return the indices of the element to mantain in the dataset. This procedure was
created because some sentences present some dependecies never seen in the training
and will not processed correctly.
In our case the worng sentences have the following idx element: 10_new-83
"""
def remove_sentences(dataset, idx):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if x["idx"] != idx:
            ret.append(i)
    return ret

"""
Return indices of projective sentences in the dataset. Must be used with select() to filter out nonprojective trees.
"""
def find_projective_idx(dataset):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if orig.is_projective([-1] + [int(head) for head in x["head"] if head!='None']):
            ret.append(i)
    return ret

"""
Return indices of sentences shorter than TOKENIZER_MAX_LEN/2 (to account for subtokens)
"""
def find_long_idx(dataset):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if len(x["tokens"]) < int(TOKENIZER_MAX_LEN*0.66): 
            ret.append(i)
    return ret

"""
Remove trailing apostrophes in tokens from the dataset
This function will be applied by map() on each row of the dataset separately
"""
def remove_apostrophe(example):
    example['tokens'] = [ example['tokens'][i][:-1] if (example['tokens'][i][-1]=="'" and example['tokens'][i][0]!="'") else  example['tokens'][i] for i in range(len(example['tokens']))]
    return example

## Text preprocessing

In [20]:
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")
train_lite_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train[:10%]")
dev_dataset = load_dataset('universal_dependencies', 'it_isdt', split="validation")
test_dataset = load_dataset('universal_dependencies', 'it_isdt', split="test")

train_dataset = train_dataset.select(find_projective_idx(train_dataset)) #to remove nonprojective trees
train_dataset = train_dataset.select(find_long_idx(train_dataset)) #to remove sentences too long
train_dataset = train_dataset.map(remove_apostrophe)

train_lite_dataset = train_lite_dataset.select(find_projective_idx(train_lite_dataset)) #to remove nonprojective trees
train_lite_dataset = train_lite_dataset.select(find_long_idx(train_lite_dataset)) #to remove sentences too long
train_lite_dataset = train_lite_dataset.map(remove_apostrophe)

dev_dataset = dev_dataset.select(find_projective_idx(dev_dataset)) #to remove nonprojective trees
dev_dataset = dev_dataset.select(find_long_idx(dev_dataset)) #to remove sentences too long
dev_dataset = dev_dataset.map(remove_apostrophe)

test_dataset = test_dataset.select(remove_sentences(test_dataset, "10_new-83")) #this specific sentence has a label that never appears in the training set
test_dataset = test_dataset.select(find_projective_idx(test_dataset)) #to remove nonprojective trees
test_dataset = test_dataset.select(find_long_idx(test_dataset)) #to remove sentences too long
test_dataset = test_dataset.map(remove_apostrophe)

Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Loading cached processed dataset at /home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d/cache-1548aef782f705ca.arrow
Loading cached processed dataset at /home/filippo/.

In [21]:
#emb_dictionary = create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

In [22]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, deprels = deprels, get_gold_path=True))
train_lite_dataloader = torch.utils.data.DataLoader(train_lite_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, deprels = deprels, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, deprels = deprels, get_gold_path=True))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, deprels = deprels, get_gold_path=True))

## Ideas for the future
* Use mean square error to predict pairs (MOVE, LABEL) in the training set
* We need to rescale the data in an acceptable range. Probably should make the MOVE have greater variation, e.g LeftArc=100 RightArc=0 Shift=-100
* BERT layer must process each sentence separately to produce word embeddings, but it should also be finetuned alongside the fully connected layer

# Neural Oracle

Neural oracle implemented using BERT fine-tunned with a simple classifiers.
The references are: [Towards Data Science](https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f) and [Text Classification | Sentiment Analysis with BERT using huggingface, PyTorch and Python Tutorial](https://www.youtube.com/watch?v=8N-nM3QW7O0).

The hyperparameters and the training tweaks are taken from the original BERT paper.

In [23]:
# load the BERT tokenizer with pretrained weights for the Italian language
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

In this section we have loaded the BERT tokenizer and try it to see how a sentence is processed and which kind of objects the tokenizer returns to be used in BERT.

In [24]:
example_text = '''Con il termine interfluvio in geologia si indica la porzione di superﬁcie più elevata
                  che separa due valli ﬂuviali adiacenti, che può essere una cresta oppure un' area ampia,
                  comunque non coinvolta dal movimento delle acque'''


bert_input = tokenizer(example_text, padding='max_length', max_length = 15, 
                       truncation=True, return_tensors="pt")

print(bert_input['input_ids'])      # words to integer
print(bert_input['token_type_ids']) # binary mask that identifies whether a token is in the first sentence (before [SEP]) or the second (after [SEP])
print(bert_input['attention_mask']) # binary mask that identifies whether a token is a real word or just padding

decoded_text = tokenizer.decode(bert_input.input_ids[0])
print(decoded_text)

tensor([[  102,   401,   162,  1909,   532, 27948,  2369,   139,  5101,  1783,
           223,  1731,   146, 16711,   103]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
[CLS] Con il termine interfluvio in geologia si indica la porzione [SEP]


  ## Oracle model using BERT
For the oracle we decided to use and fine-tune bert in a regression task.
The main idea is:

*   given a sentence get the **'Bert input tokens**' and the **'attention mask**' from the BERT tokenizer above;
*   pass the 'input tokens' and the 'attention mask' to the BERTOracle that returns a tensor that contains a list of configurations, one for each sentence.

In this setting a configuration is a pair of numbers. For this reason we use MSE loss to create a **multi-ouptut regressor** task.

One configuration is in the form (**ACTION, LABELS**) where for ACTION we have:

* 100 for **LeftArc** action;
* 0 for **RightArc** action;
* -100 for **Shift** action.

For the LABELS we have decided to order them in alphabetical order and convert them into consecutive numbers.
The main idea is: 

the labels represent the relation between words.<br>
In the used treebank there are many such relation that are very similar like *obj* and *iobj* so by ordering them and giving consecutive numbers we have that *obj* will be encoded, for instance, as 12 and *iobj* with 13. This means that, even if the oracle infere 13 instead of 12 we have that the prediction is not as wrong as if it will have predicte 50.<br>
The main assumption is that closer numbers represent more similar relation.

To perfrom the actual multi-output regression we used a simple linear layer that outputs this two values.

In [25]:
class BERTOracle(nn.Module):
    def __init__(self, input_size, deprels, DROPOUT_RATE, MLP_SIZE_1, MLP_SIZE_2):
      super(BERTOracle, self).__init__()
      self.tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
      self.bert = BertModel.from_pretrained(BERT_MODEL, output_hidden_states=True)
      self.dropout = nn.Dropout(DROPOUT_RATE)
      self.w1 = nn.Linear(input_size, MLP_SIZE_1) 
      self.activation = torch.nn.Tanh()
      self.w2 = nn.Linear(MLP_SIZE_1, MLP_SIZE_2)
      self.w3 = nn.Linear(MLP_SIZE_2, 2*len(deprels)+1) #for classification: we need 2*sizeof(label_set)+1
      self.softmax = torch.nn.Softmax(dim=-1) #32xN_classes
    
    """
    We select an average of the 4 last hidden states in BERT as an embedding representation for each subtoken
    """
    def get_embeddings(self, output):
      layers = [-4, -3, -2, -1]
      # Get all hidden states
      states = output.hidden_states
      # Stack and sum all requested layers
      output = torch.stack([states[i] for i in layers])
      return output.mean(dim=0)

    """
    Substitute the word indices in the paths with the embeddings
    Since the BERT tokenizer performs subword tokenization but we have to build configurations with the embedding of WORDS
    """
    def substitute_embeddings(self, avg_of_last_states, tok_output, paths):
      ffn_input = []
      zero_tensor = torch.zeros(768, requires_grad=False).to(DEVICE)
      for sentence_index in range(len(paths)): # one path for each sentence so len(paths) = number of sentences
        for configuration in paths[sentence_index]: # take the sentence in number 'sentence_index' 
            new_tensor = torch.cat(
              [
                zero_tensor if configuration[0]==-1 else self.get_embedding_for_word_j(sentence_index, configuration[0], avg_of_last_states, tok_output), 
                zero_tensor if configuration[1]==-1 else self.get_embedding_for_word_j(sentence_index, configuration[1], avg_of_last_states, tok_output), 
                zero_tensor if configuration[2]==-1 else self.get_embedding_for_word_j(sentence_index, configuration[2], avg_of_last_states, tok_output)
              ]
            )
            ffn_input.append(new_tensor)
      ffn_input = torch.stack(ffn_input).to(DEVICE)
      #print("DEBUG substitute_embeddings: ffn_input.size()", ffn_input.size())
      return ffn_input

    """
    We choose as embedding for a full word the AVERAGE of the embedding of each of its subtokens
    """
    def get_embedding_for_word_j(self, i, j, hidd_states, tok_output):
        indices = np.where(np.array(tok_output.word_ids(i)) == j)[0] #retrieve indices of subtokens belonging to word j
        #print(indices)
        if len(indices) == 0:
            print("DEBUG: tok_output.word_ids(i)) :", tok_output.word_ids(i))
            print("DEBUG: hidd_states[i].size() :", hidd_states[i].size())
            raise Exception("The requested word index ("+str(j)+") in sentence "+str(i)+" is not present")
        #print(torch.stack([hidd_states[i][k] for k in indices]).size()) #OK!!
        return torch.stack([hidd_states[i][k] for k in indices]).mean(dim=0)
    

    def ffn_pass(self, configuration):
      out = self.dropout(configuration)
      out = self.w1(out)
      out = self.activation(out)
      out = self.dropout(out)
      out = self.w2(out)
      out = self.activation(out)
      return self.w3(out)
      #return self.softmax(out)

    def forward(self, sentences, paths):
      #Compute the BERT embeddings and retrieve them from the last hidden layer
      enc = self.tokenizer(sentences, padding='max_length', max_length = TOKENIZER_MAX_LEN, 
                        truncation=True, return_tensors="pt", is_split_into_words=True) 
      out = self.bert(enc['input_ids'].to(DEVICE), enc['attention_mask'].to(DEVICE))

      avg_of_last_states = self.get_embeddings(out)
      configurations = self.substitute_embeddings(avg_of_last_states, enc, paths)

      #Pass through FFN
      return self.ffn_pass(configurations)
    ##################################################################################################################

    def infere(self, sentences):
        #print("DEBUG infere: sentences==", sentences)
        parsers = [ArcStandard(i, deprels) for i in sentences]
    
        enc = self.tokenizer(sentences, padding='max_length', max_length = TOKENIZER_MAX_LEN, 
                        truncation=True, return_tensors="pt", is_split_into_words=True) 
        out = self.bert(enc['input_ids'].to(DEVICE), enc['attention_mask'].to(DEVICE))

        h = self.get_embeddings(out)
        #print("DEBUG infere: h.size()", h.size()) #32x50x768 batch_size;TOKENIZER_MAX_LEN;len(bert_embedding)

        while not self.parsed_all(parsers):
            configurations = self.get_configurations(parsers)
            #print("DEBUG infere: configurations", configurations)
            mlp_input = self.substitute_embeddings(h, enc, configurations) #32x(768*3)
            mlp_out = self.ffn_pass(mlp_input)
            self.parse_step(parsers, mlp_out, deprels)
        
        #print("DEBUG: parsers[13].arcs==", parsers[13].arcs)
        #print("DEBUG: parsers[13]==", parsers[13])
        #print("DEBUG infere: parsers[10].arcs ==", parsers[10].arcs)
        return [parser.arcs for parser in parsers] 
    
    """
    This function was copied from the original file
    """
    def get_configurations(self, parsers):
        configurations = []

        for parser in parsers:
          if parser.is_tree_final():
            conf = [-1, -1, -1]
          else:
            conf = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
            if len(parser.buffer) == 0:
              conf.append(-1)
            else:
              conf.append(parser.buffer[0])  
          configurations.append([conf])

        return configurations

    """
    This function was copied from the original file
    """
    def parsed_all(self, parsers):
        for parser in parsers:
            if not parser.is_tree_final():
                return False
        return True
    
    """
    Remember we use the convention
    0 = shift
    [1, len(deprels)] = left-arc
    [len(deprels)+1, 2*len(deprels)+1] = right-arc
    
    This function was adapted from the original to remove padding from the buffer avoiding all edge cases
    that complicated the original code.
    """
    def parse_step(self, parsers, moves, labels):
      moves_argm = moves.argmax(-1) #vector of argmaxes
      #print("DEBUG parse_step: moves_argm.size()==", moves_argm.size()) #size is 32, which is correct (one for each parser in the batch)
      moves_second_argm = moves[:, 1:].argmax(-1) + 1 #needed when buffer is empty but argmax is 0 (shift). Contains the argmax for each sentence excluding index 0. 
                                                      #1 is added to account for slicing
      
      for i in range(len(parsers)):
          kind_of_move = 0 if moves_argm[i] == 0 else (1 + int(moves_argm[i] / len(labels))) #0 if shift, 1 if leftarc, 2 is rightarc
          label = (moves_argm[i] % len(labels)).item() #label of the arc, ignored if shift
          #print("DEBUG parse_step: type(label)")
          #print("DEBUG: parse_step: label.size()", label.size())

          if parsers[i].is_tree_final():
              continue
          #print("DEBUG parse_step: configuration of the parser", parsers[i].print_configuration())
          if kind_of_move == 1: #predicted: leftarc
            if parsers[i].stack[-2] != 0: #If the second element in the stack is not ROOT
              parsers[i].left_arc(label)
            else:
              if len(parsers[i].buffer) > 0: #if buffer is not empty
                parsers[i].shift()
              else:
                parsers[i].right_arc(label) #if buffer is empty: do rightarc

          elif kind_of_move == 2: #predicted: rightarc
            if parsers[i].stack[-2] == 0 and len(parsers[i].buffer)>0: #if there is ROOT in the 2nd position in the stack and buffer is not empty
              parsers[i].shift
            else:
              parsers[i].right_arc(label)

          elif moves_argm[i] == 0: #predicted: shift
              if parsers[i].buffer: #if the buffer is not empty 
                  parsers[i].shift()
              else: #in case buffer is empty
                #note integer division
                  if (moves_second_argm[i]//len(labels)) == 0:
                      parsers[i].left_arc((moves_second_argm[i] % len(labels)).item())
                  elif (moves_second_argm[i]//len(labels)) == 1:
                      parsers[i].right_arc((moves_second_argm[i] % len(labels)).item())
                  else:
                      raise Exception("moves_second_argm not in range [1, 2*len(labels)]")
                      
          else:
              raise Exception("argmax not in range [0, 2*len(labels)]")

## Training and evaluation functions

## Grid search for BERTOracle model

In [26]:
def train_epoch(model, data_loader, loss_fn, optimizer):
  model = model.train()
  losses = []
  for batch in data_loader:
    sentences, paths, moves, trees, labels = batch

    outputs = model(sentences, paths)

    # Compute the loss for each configurations for each sentence in parallel
    tensor_moves = torch.tensor(sum(moves, [])).to(DEVICE)
    loss = loss_fn(outputs, tensor_moves)
    losses.append(loss.item())

    #backpropagation routine
    optimizer.zero_grad()                                             # initialize gradient to zeros
    loss.backward()
    optimizer.step()

  mean_loss = np.mean(losses)

  return mean_loss

In [27]:
def eval_model(model, data_loader, loss_fn, optimzier):
  model = model.eval()                                                # dropout and batch_norm are not enabled

  losses = []
  
  with torch.no_grad():
    for batch in data_loader:
      sentences, paths, moves, trees, labels = batch

      outputs = model(sentences, paths)

      # Compute the loss for each configurations for each sentence in parallel
      tensor_moves = torch.tensor(sum(moves, [])).to(DEVICE)
      loss = loss_fn(outputs, tensor_moves)
      losses.append(loss.item())

      optimizer.zero_grad()                                             # initialize gradient to zeros

  mean_loss = np.mean(losses)

  return mean_loss

## Grid search
Training loop that for the given number of epoch preform the training step and the evaluation step for each sentence in the train_dataloader and val_dataloader.
All the results are krept inside a dictionary of history.

Grid search commented out to avoid disasters

In [28]:
lr = [5e-04]
dr = [0.3]
mlp_size = [150]

import itertools
hyperparams = itertools.product(lr, dr, mlp_size)

In [29]:
#torch.cuda.empty_cache()

In [30]:
"""
grid_search_results = []
n = 13 #keep track of the number of grid search

for hp in hyperparams:
    n += 1
    LEARNING_RATE = hp[0]
    DROPOUT_RATE = hp[1]
    MLP_SIZE_2 = hp[2]
    SAVE_PATH = "saved_checkpoint_grid_"+str(n)+".pt"
    #Initialize the BERTOracle model
    model = BERTOracle(INPUT_SIZE, deprels, DROPOUT_RATE, MLP_SIZE_1, MLP_SIZE_2)
    model.to(DEVICE)

    # Adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    #Define the loss as cross entropy
    loss_fn = nn.CrossEntropyLoss().to(DEVICE)

    import gc
    gc.collect()

    import time

    history = defaultdict(list) # store training and validation losses in a dictionary 
    epoch_to_save = 5           # save model every 5 epochs

    for epoch in range(EPOCHS):
      print(f'Grid search number {n}: LR={LEARNING_RATE}, DROPOUT={DROPOUT_RATE}, MLP_SIZE_2={MLP_SIZE_2}')
      print(f'Epoch {epoch + 1} / {EPOCHS}')

      #Training step
      start_time = time.time()
      train_loss = train_epoch(model, train_lite_dataloader, loss_fn, optimizer)
      end_time = time.time()
      print(f"Train loss {train_loss}")
      print(f"Time elapsed : {end_time-start_time} seconds")
      #Evaluation step
      val_loss = eval_model(model, dev_dataloader, loss_fn, optimizer)

      print(f"Val loss {val_loss}")
      print('-'*10)

      history['train_loss'].append(train_loss)

      history['val_loss'].append(val_loss)

      if epoch % epoch_to_save == 0:
        torch.save(model, SAVE_PATH)
        print("|MODEL SAVED|")
        
      if epoch == EPOCHS-1:
        grid_search_results.append({'train_loss':train_loss, 'val_loss':val_loss})
"""

'\ngrid_search_results = []\nn = 13 #keep track of the number of grid search\n\nfor hp in hyperparams:\n    n += 1\n    LEARNING_RATE = hp[0]\n    DROPOUT_RATE = hp[1]\n    MLP_SIZE_2 = hp[2]\n    SAVE_PATH = "saved_checkpoint_grid_"+str(n)+".pt"\n    #Initialize the BERTOracle model\n    model = BERTOracle(INPUT_SIZE, deprels, DROPOUT_RATE, MLP_SIZE_1, MLP_SIZE_2)\n    model.to(DEVICE)\n\n    # Adam optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n\n    #Define the loss as cross entropy\n    loss_fn = nn.CrossEntropyLoss().to(DEVICE)\n\n    import gc\n    gc.collect()\n\n    import time\n\n    history = defaultdict(list) # store training and validation losses in a dictionary \n    epoch_to_save = 5           # save model every 5 epochs\n\n    for epoch in range(EPOCHS):\n      print(f\'Grid search number {n}: LR={LEARNING_RATE}, DROPOUT={DROPOUT_RATE}, MLP_SIZE_2={MLP_SIZE_2}\')\n      print(f\'Epoch {epoch + 1} / {EPOCHS}\')\n\n      #Training step\n

In [31]:
#grid_search_results

## Training on optimal hyperparameters

In [32]:
LEARNING_RATE = 1e-04
MLP_SIZE_2 = 300
DROPOUT_RATE = 0.1
SAVE_PATH = "final_model_checkpoint.pt"

model = BERTOracle(INPUT_SIZE, deprels, DROPOUT_RATE, MLP_SIZE_1, MLP_SIZE_2)
model.to(DEVICE)

# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

#Define the loss as cross entropy
loss_fn = nn.CrossEntropyLoss().to(DEVICE)

import gc
gc.collect()

import time

history = defaultdict(list) # store training and validation losses in a dictionary 

for epoch in range(EPOCHS):
  print(f'LR={LEARNING_RATE}, DROPOUT={DROPOUT_RATE}, MLP_SIZE_2={MLP_SIZE_2}')
  print(f'Epoch {epoch + 1} / {EPOCHS}')

  #Training step
  start_time = time.time()
  train_loss = train_epoch(model, train_dataloader, loss_fn, optimizer)
  end_time = time.time()
  print(f"Train loss {train_loss}")
  print(f"Time elapsed : {end_time-start_time} seconds")
  #Evaluation step
  val_loss = eval_model(model, dev_dataloader, loss_fn, optimizer)

  print(f"Val loss {val_loss}")
  print('-'*10)

  history['train_loss'].append(train_loss)

  history['val_loss'].append(val_loss)

  if len(history['val_loss'])==1 or history['val_loss'][-1] < min(history['val_loss'][:-1]): #implement some kind of early stopping
    torch.save(model, SAVE_PATH)
    print("|MODEL SAVED|")

Some weights of the model checkpoint at dbmdz/bert-base-italian-xxl-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


LR=0.0001, DROPOUT=0.1, MLP_SIZE_2=300
Epoch 1 / 5
Train loss 1.5189408444343728
Time elapsed : 843.9847373962402 seconds
Val loss 0.5122224357393053
----------
|MODEL SAVED|
LR=0.0001, DROPOUT=0.1, MLP_SIZE_2=300
Epoch 2 / 5
Train loss 0.3418819362971377
Time elapsed : 843.7815444469452 seconds
Val loss 0.2207983268631829
----------
|MODEL SAVED|
LR=0.0001, DROPOUT=0.1, MLP_SIZE_2=300
Epoch 3 / 5
Train loss 0.1678984811569148
Time elapsed : 844.8818018436432 seconds
Val loss 0.17463983429802787
----------
|MODEL SAVED|
LR=0.0001, DROPOUT=0.1, MLP_SIZE_2=300
Epoch 4 / 5
Train loss 0.10714307435332461
Time elapsed : 847.8088767528534 seconds
Val loss 0.148846252510945
----------
|MODEL SAVED|
LR=0.0001, DROPOUT=0.1, MLP_SIZE_2=300
Epoch 5 / 5
Train loss 0.07510739346926516
Time elapsed : 846.1845502853394 seconds
Val loss 0.15122611737913555
----------


In [33]:
print(min(history['val_loss']))

0.148846252510945


## Evaluation

In [46]:
#Load the saved model's weights
old_path = "../gridsearch_new/final/final_model_epoch_1.pt"
model = torch.load(SAVE_PATH)
model.to(DEVICE)

BERTOracle(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32102, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

In [47]:
"""
Functions were adapted from the ones in the original notebook
preds is a list of *couples*

evaluate() returns a couple of values (UAS, LAS)
"""
def evaluate(gold, preds): 
  total = 0
  correct = 0

  for g, p in zip(gold, preds):
    for i in range(1,len(g)):
      total += 1
      if g[i] == p[i]:
        correct += 1
  return correct/total

def validation(model, dataloader):
  model.eval()
  tree_preds = []
  label_preds = []
  tree_gold = []
  label_gold = []

  for batch in dataloader:
    sentences, paths, moves, trees, labels = batch
    #trees is a list of lists
    with torch.no_grad():
        pred = model.infere(sentences) #now pred is a list of lists of (move, label) pairs (1 lists for each sentence) 
        for lst in pred:
            tree_preds.append([x[0] for x in lst])
            label_preds.append([x[1] for x in lst])
        tree_gold += trees
        label_gold += labels
    
  return (evaluate(tree_gold, tree_preds), evaluate(label_gold, label_preds))


In [49]:
uas, las = validation(model, test_dataloader)

In [50]:
print(uas)
print(las)

0.9526222447428426
0.003040283759817583


0.027489232328350648