In [2]:
import torch
import torch.nn as nn
from functools import partial
from datasets import load_dataset

## Arc-standard

Recall that a **configuration** of the arc-standard parser is a triple of the form $( \sigma, \beta, A)$
where:

* $\sigma$ is the stack;
* $\beta$ is the input buffer;
* $A$ is a set of arcs constructed so far.

We write $\sigma_i$, $i \geq 1$, for the $i$-th token in the stack; we also write $\beta_i$, $i \geq 1$, for the $i$-th token in the buffer. 

The parser can perform three types of **actions** (transitions):

* **shift**, which removes $\beta_1$ from the buffer and pushes it into the stack;
* **left-arc**, which creates the arc $(\sigma_1 \rightarrow \sigma_2)$, and removes $\sigma_2$ from the stack;
* **right-arc**, which creates the arc $(\sigma_2 \rightarrow \sigma_1)$, and removes $\sigma_1$ from the stack.

Let $w = w_0 w_1 \cdots w_{n}$ be the input sentence, with $w_0$ the special symbol `<ROOT>`.
Stack and buffer are implemented as lists of integers, where `j` represents word $w_j$.  Top-most stack token is at the right-end of the list; first buffer token is at the left-end of the list. 
Set $A$ is implemented as an array `arcs` of size $n+1$ such that if arc $(w_i \rightarrow w_j)$ is in $A$ then `arcs[j]=i`, and if $w_j$ is still missing its head node in the tree under construction, then `arcs[j]=-1`. We always have `arcs[0]=-1`.  We use this representation also for complete dependency trees.


In [3]:
import orig #file orig.py in the same foldr contains original class and function definitions

In [4]:
class ArcStandard (orig.ArcStandard):
    def __init__(self, sentence, label_set):
        self.sentence = sentence
        self.buffer = [i for i in range(len(self.sentence))]
        self.stack = []
        self.arcs = [(-1, -1) for _ in range(len(self.sentence))]

        self.label_set = {y:x for x,y in label_set.items()} #for future reference (printing)
        # three shift moves to initialize the stack
        self.shift()
        self.shift()
        if len(self.sentence) > 2:
            self.shift()

    def left_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o2] = (o1, deprel) #added deprel
        self.stack.append(o1)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def right_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o1] = (o2, deprel) #added deprel
        self.stack.append(o2)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def print_configuration(self):
        s = [self.sentence[i] for i in self.stack]
        b = [self.sentence[i] for i in self.buffer]
        print("STACK: ", s, "BUFFER: ", b) #added indication of stack and buffer
        print([(x[0], self.label_set[x[1]]) for x in self.arcs])

In [5]:
class Oracle (orig.Oracle):
    def __init__(self, parser, gold_tree, gold_labels):
        self.parser = parser
        self.gold = gold_tree
        self.labels = gold_labels #must be integers; see below

    def get_gold_label(self):
        if(self.is_left_arc_gold()):
            o2 = self.parser.stack[len(self.parser.stack)-2]
            return self.labels[o2]
        elif(self.is_right_arc_gold()):
            o1 = self.parser.stack[len(self.parser.stack)-1]
            return self.labels[o1]
        else:
            raise Exception("Action SHIFT does not produce a labeled dependency")
            

## Functions to simplify the processing

In [6]:
"""
Return a dictionary where arc labels are associate to integer values.
Label '_' is removed (used for composite tokens)
Labels are sorted in alphabetical order because MSE in the neural network will penalize less classes that are "close".
<ROOT> label is given value -1 (label of the <ROOT> node).
"""
def create_deprel_dict(dataset):
    import pandas as pd
    df = dataset.to_pandas()
    labels = set()
    for x in df['deprel']:
        for elem in x:
            if(elem != '_'):
                labels.add(elem)
    labels = sorted(list(labels))
    ret = {labels[i]:i for i in range(len(labels))}
    ret.update({"None":-1})
    return ret

"""
Return (sanitized_tokens, gold tree, gold labels) for a sentence in the dataset, assuming gold tree is in a column labeled 'head' and labels in a column labeled 'deprel'.
Both token indices and labels are converted to integer (in the latter case, according to deprel_dict).
Sanitized tokens is a tokenlist with removal of composite tokens.
"""
def create_gold(sentence, deprel_dict):
    sanitized = ['<ROOT>']
    gold_tree = [-1]
    gold_labels = [-1]
    for i in range(len(sentence['tokens'])):
        if(sentence['head'][i] != 'None'):
            sanitized.append(sentence['tokens'][i])
            gold_tree.append(int(sentence['head'][i]))
            gold_labels.append(deprel_dict[sentence['deprel'][i]])
    
    return sanitized, gold_tree, gold_labels 

## Testing parser and oracle...

In [8]:
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")

Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)


In [9]:
deprels = create_deprel_dict(train_dataset)
sentence = train_dataset[3]

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)

print(gold_tree)
print(gold_labels)

[-1, 2, 0, 5, 5, 2, 5, 2]
[-1, 4, 41, 8, 17, 31, 28, 40]


In [10]:
parser = ArcStandard(tokens, deprels)
oracle = Oracle(parser, gold_tree, gold_labels)

parser.print_configuration()

STACK:  ['<ROOT>', 'Inconsueto', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]


In [11]:
while not parser.is_tree_final():  # transition precedence implemented here
    if oracle.is_shift_gold():  
        parser.shift()
    elif oracle.is_left_arc_gold():
        parser.left_arc(oracle.get_gold_label())
    elif oracle.is_right_arc_gold():
        parser.right_arc(oracle.get_gold_label())
    
    parser.print_configuration()

STACK:  ['<ROOT>', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a'] BUFFER:  ['la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la'] BUFFER:  ['Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (5, 'det'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'No

## Testing the whole dataset

In [12]:
# remember to call orig.is_projective() to check for projectivity!!

In [15]:
### NOTE: THIS CELL WAS COPIED WITH FEW MODIFICATIONS FROM THE ORIGINAL FILE

# oracle test: run the parser guided by the oracle on the entire training set 

non_projective = 0
correct = 0
wrong = 0

deprels = create_deprel_dict(train_dataset)

for sample in train_dataset:
    tokens, gold_tree, gold_labels = create_gold(sentence, deprels)

    if not orig.is_projective(gold_tree):
        non_projective += 1
        continue

    parser = ArcStandard(tokens, deprels)
    oracle = Oracle(parser, gold_tree, gold_labels)

    while not parser.is_tree_final():
        if oracle.is_left_arc_gold(): 
            parser.left_arc(oracle.get_gold_label())
        elif oracle.is_right_arc_gold():
            parser.right_arc(oracle.get_gold_label())
        elif oracle.is_shift_gold(): 
            parser.shift()

    for j in range(len(gold_tree)):  # comparing heads from parser and gold for actual sample
        if gold_tree[j] == parser.arcs[j][0] and gold_labels[j] == parser.arcs[j][1]: 
            correct += 1
        else:
            wrong += 1

print("non projective: ", non_projective)
print("correct: ", correct)
print("wrong: ", wrong)

non projective:  0
correct:  104968
wrong:  0


## Creating samples for the neural oracle

In [30]:
# Modified from the original

#TODO work in progress
"""
This function processes a single sample, which is one sentence provided as an element of a Dataset object and returns 
    enc_sentence : a list of integers encoding the phrase #???? is this even correct for BERT???
    gold_path : a list of configurations
    gold_moves : a list of (move, label)
"""
def process_sample(sample, emb_dictionary, deprels, get_gold_path = False): #emb_dictionary and deprels are dictionaries of words and of dependency relations
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)
    enc_sentence = [emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"] for word in sentence]

    # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
    gold_path = []   # record two topmost stack token and first buffer token for current step
    gold_moves = []  # oracle (canonical) move for current step: 100 is left, 0 right, -100 shift #motivations provided above TODO

    if get_gold_path:  # only for training
        parser = ArcStandard(sentence, deprel)
        oracle = Oracle(parser, gold_tree, gold_labels)

        while not parser.is_tree_final():
            configuration = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])  

            gold_path.append(configuration)

            if oracle.is_left_arc_gold():  
                label_code = deprels[oracle.get_gold_label()]
                parser.left_arc(label_code)
                gold_moves.append((100, label_code)) #note: I switched the instructions here for symmetry. There was no comment or good reason not to.
            elif oracle.is_right_arc_gold():
                label_code = deprels[oracle.get_gold_label()]
                parser.right_arc(label_code)
                gold_moves.append((0, label_code))
            elif oracle.is_shift_gold():
                parser.shift()
                gold_moves.append((-100, -100))

    return enc_sentence, gold_path, gold_moves, gold_tree, gold_labels


In [32]:
"""
Function that prepares the data for the model.
@return sentences : list of sentences in the dataset encoded according to emb_dictionary
@return paths : list of lists of configurations visited during the oracle-guided parsing (on training data) | empty list if get_gold_path is False
@return moves : list of lists of (MOVE, LABEL) performed during the oracle-guided parsing (on training data) | empty if get_gold_path is False
@return trees : ground truth tree
@return labels : ground truth lables

Note: the last two returned values may eventually be zipped together for faster loss calculation.
"""
#modified form orig
def prepare_batch(batch_data, emb_dictionary, deprels, get_gold_path=False):
    data = [process_sample(s, emb_dictionary, deprels, get_gold_path=get_gold_path) for s in batch_data]
    # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
    sentences = [s[0] for s in data]
    paths = [s[1] for s in data]
    moves = [s[2] for s in data]
    trees = [s[3] for s in data]
    labels = [s[4] for s in data]
    return sentences, paths, moves, trees, labels

In [35]:
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")
train_dataset.filter(lambda x : orig.is_projective([-1] + [int(head) for head in x["head"] if head!='None'])) #to remove nonprojective trees
dev_dataset = load_dataset('universal_dependencies', 'it_isdt', split="validation")
test_dataset = load_dataset('universal_dependencies', 'it_isdt', split="test")

emb_dictionary = orig.create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)


  0%|          | 0/14 [00:00<?, ?ba/s]

Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)


In [45]:
BATCH_SIZE = 32

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))

## Ideas for the future
* Use mean square error to predict pairs (MOVE, LABEL) in the training set
* We need to rescale the data in an acceptable range. Probably should make the MOVE have greater variation, e.g LeftArc=100 RightArc=0 Shift=-100
* BERT layer must process each sentence separately to produce word embeddings, but it should also be finetuned alongside the fully connected layer
* 