# Trees

For our gold labels, we need to recover the node distances from our parse tree. For this we will use the functionality provided by `ete3`, that allows us to compute that directly. I have provided code that transforms a `TokenTree` to a `Tree` in `ete3` format.

In [1]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
# In case you want to transform your conllu tree to an nltk.Tree, for better visualisation

def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"

    return tree_str


def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree

    tree_str = rec_tokentree_to_nltk(tokentree)

    return NLTKTree.fromstring(tree_str)

In [3]:
# !pip install ete3
from ete3 import Tree as EteTree


class FancyTree(EteTree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)


def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)

    return FancyTree(f"{newick_str};")

In [4]:
# Let's check if it works!
# We can read in a corpus using the code that was already provided, and convert it to an ete3 Tree.

def parse_corpus(filename):
    from conllu import parse_incr

    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

As you can see we label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances.

To create the true distances of a parse tree in our treebank, we are going to use the `.get_distance` method that is provided by `ete3`: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#working-with-branch-distances

We will store all these distances in a `torch.Tensor`.

Please fill in the gap in the following method. I recommend you to have a good look at Hewitt's blog post  about these node distances.

In [5]:
from tqdm import tqdm
from torch import Tensor


def create_gold_distances(corpus, max_sen_len):
    all_distances = []

    for item in tqdm(corpus):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)
        
        nodes = ete_tree.search_nodes()
        sen_len = len(nodes)
        distances = torch.full((max_sen_len, max_sen_len), -1, device=device)

        # Your code for computing all the distances comes here.
        for i in range(sen_len):
            node1 = nodes[i]
            for j in range(i, sen_len):
                node2 = nodes[j]
                distance = node1.get_distance(node2)
                distances[int(node1.name)-1][int(node2.name)-1] = distance
                distances[int(node2.name)-1][int(node1.name)-1] = distance

        all_distances.append(distances)

    return torch.stack(all_distances)

The next step is now to do the previous step the other way around. After all, we are mainly interested in predicting the node distances of a sentence, in order to recreate the corresponding parse tree.

Hewitt et al. reconstruct a parse tree based on a _minimum spanning tree_ (MST, https://en.wikipedia.org/wiki/Minimum_spanning_tree). Fortunately for us, we can simply import a method from `scipy` that retrieves this MST.

In [6]:
from scipy.sparse.csgraph import minimum_spanning_tree
import torch


def create_mst(distances):
    distances = torch.triu(distances).detach().cpu().numpy()
    
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    
    return mst

Let's have a look at what this looks like, by looking at a relatively short sentence in the sample corpus.

If your addition to the `create_gold_distances` method has been correct, you should be able to run the following snippet. This then shows you the original parse tree, the distances between the nodes, and the MST that is retrieved from these distances. Can you spot the edges in the MST matrix that correspond to the edges in the parse tree?

Now that we are able to map edge distances back to parse trees, we can create code for our quantitative evaluation. For this we will use the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

To do this, we will need to obtain all the edges from our MST matrix. Note that, since we are using undirected trees, that an edge can be expressed in 2 ways: an edge between node $i$ and node $j$ is denoted by both `mst[i,j] = 1`, or `mst[j,i] = 1`.

You will write code that computes the UUAS score for a matrix of predicted distances, and the corresponding gold distances. I recommend you to split this up into 2 methods: 1 that retrieves the edges that are present in an MST matrix, and one general method that computes the UUAS score.

In [7]:
def edges(mst):
    # Your code for retrieving the edges from the MST matrix
    edges = []
    for i in range(len(mst)):
        for j in (mst[i]==1).nonzero()[0]:
            edges.extend([(i, j), (j, i)])

    return set(edges)

def calc_uuas(pred_distances, gold_distances): 
    # Your code for computing the UUAS score
    pred_mst = create_mst(pred_distances)
    gold_mst = create_mst(gold_distances)
    
    pred_edges = edges(pred_mst)
    gold_edges = edges(gold_mst)
    
    correct = sum(edge in pred_edges for edge in gold_edges)
    total_gold = len(gold_edges)
    if total_gold!=0:
        uuas = correct/total_gold
    elif total_gold==0 and len(pred_edges)==0:
        uuas = 1
    else:
        uuas=0
    
    return uuas

We now have everything in place to start doing the actual exciting stuff: training our structural probe!
    
To make life easier for you, we will simply take the `torch` code for this probe from John Hewitt's repository. This allows you to focus on the training regime from now on.

In [8]:
import torch.nn as nn
import torch


class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device=device):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


I have provided a rough outline for the training regime that you can use. Note that the hyper parameters that I provide here only serve as an indication, but should be (briefly) explored by yourself.

As can be seen in Hewitt's code above, there exists functionality in the probe to deal with batched input. It is up to you to use that: a (less efficient) method can still incorporate batches by doing multiple forward passes for a batch and computing the backward pass only once for the summed losses of all these forward passes. (_I know, this is not the way to go, but in the interest of time that is allowed ;-), the purpose of the assignment is writing a good paper after all_).

In [9]:
# FETCH SENTENCE REPRESENTATIONS transformer
from torch import Tensor
from typing import List
from conllu import parse_incr, TokenList
from torch.nn.utils.rnn import pad_sequence


# Should return a tensor of shape (num_tokens_in_corpus, representation_size)
# Make sure you correctly average the subword representations that belong to 1 token!
def fetch_sen_reps(ud_parses: List[TokenList], model, tokenizer, model_type, concat) -> Tensor:
    sen_reps = []
    sen_len = []
    for sentence in tqdm(ud_parses):
        if model_type=='TF':
            total_tokens = []
            connected = []
            token_num = 0
            for word in sentence:
                input_ids = tokenizer.encode(word['form'])
                total_tokens.extend(input_ids)
                token_num+=len(input_ids)
                connected.append(token_num)
                    
            input_sen = Tensor(total_tokens).type(torch.long).unsqueeze(0).to(device)
            output = model(input_sen)[0][0].detach()
            
            output_sen = output[0:connected[0]].mean(dim=0).unsqueeze(dim=0)
            for i in range(len(connected)-1):
                part = output[connected[i]:connected[i+1]].mean(dim=0).unsqueeze(dim=0)
                output_sen = torch.cat([output_sen, part], dim=0)
                
            if concat:
                sen_reps.extend(output_sen)
                
            else:
                sen_reps.append(output_sen)
                
            sen_len.append(len(sentence))
                
        elif model_type=='RNN':
            hidden_0 = model.init_hidden(1)
            input_ids = Tensor([tokenizer['<eos>']]).type(torch.long).unsqueeze(0).to(device)
            _, hidden_eos = model(input_ids, hidden_0)
            sen = []
            for word in sentence:
                if word['form'] not in tokenizer:
                    input_ids = Tensor([tokenizer['<unk>']]).type(torch.long).unsqueeze(0).to(device)
                else:
                    input_ids = Tensor([tokenizer[word['form']]]).type(torch.long).unsqueeze(0).to(device)
                  
                output, (hidden, cell) = model(input_ids, hidden_eos)
                
                sen.append(hidden[-1].squeeze().detach())
                
            if concat:
                sen_reps.extend(output_sen)
                
            else:
                sen_reps.append(output_sen)
                
            sen_len.append(len(sentence))
    
    return pad_sequence(sen_reps, batch_first=True, padding_value=0), Tensor(sen_len)

In [10]:
import torch
from transformers import *
from collections import defaultdict
from model import RNNModel

model_TF = GPT2Model.from_pretrained('distilgpt2').to(device=device)
tokenizer_TF = GPT2Tokenizer.from_pretrained('distilgpt2')

model_location = 'RNN/Gulordava.pt'  # <- point this to the location of the Gulordava .pt file
rnn = torch.load(model_location, map_location=device)

with open('RNN/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}
    
vocab_dict = defaultdict(lambda: w2i["<unk>"])
vocab_dict.update(w2i)



In [15]:
from torch import optim

'''
Similar to the `create_data` method of the previous notebook, I recommend you to use a method 
that initialises all the data of a corpus. Note that for your embeddings you can use the 
`fetch_sen_reps` method again. However, for the POS probe you concatenated all these representations into 
1 big tensor of shape (num_tokens_in_corpus, model_dim). 

The StructuralProbe expects its input to contain all the representations of 1 sentence, so I recommend you
to update your `fetch_sen_reps` method in a way that it is easy to retrieve all the representations that 
correspond to a single sentence.
''' 

def init_corpus(path, model, tokenizer, model_type, concat=False, cutoff=None):
    """ Initialises the data of a corpus.
    
    Parameters
    ----------
    path : str
        Path to corpus location
    concat : bool, optional
        Optional toggle to concatenate all the tensors
        returned by `fetch_sen_reps`.
    cutoff : int, optional
        Optional integer to "cutoff" the data in the corpus.
        This allows only a subset to be used, alleviating 
        memory usage.
    """
    corpus = parse_corpus(path)[:cutoff]
    
    
    embs, sen_len = fetch_sen_reps(corpus, model, tokenizer, model_type, concat=concat)    
    gold_distances = create_gold_distances(corpus, embs.size(1))
    
    return gold_distances, embs, sen_len


# I recommend you to write a method that can evaluate the UUAS & loss score for the dev (& test) corpus.
# Feel free to alter the signature of this method.
def evaluate_probe(probe, data, loss_function):
    for i in range(0, len(data['train']), batch_size):
        pred = probe(data[1])
        labels = data[0].to(device)
        sen_len = data[2].to(device)

        loss_score, total_sents = loss_function(pred, labels, sen_len)
        loss_score/=total_sents

        uuas_score = 0
        for i in range(len(pred)):
            pred_slice = pred[i][:int(sen_len[i]), :int(sen_len[i])]
            label_slice = labels[i][:int(sen_len[i]), :int(sen_len[i])]
            uuas_score += calc_uuas(pred_slice, label_slice)

        uuas_score/=len(pred)
        
    return loss_score, uuas_score


# Feel free to alter the signature of this method.
def train(data):
    emb_dim = 768
    rank = 64
    lr = 10e-4
    batch_size = 24
    epochs = 100

    probe = StructuralProbe(emb_dim, rank)
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()
        
    for epoch in range(epochs):
        print("\n---------------------------------------------------------") 
        print("epoch: " + str(epoch+1))
        print("---------------------------------------------------------") 
        
        print("\n---------------------------------------------------------")
        for i in range(0, len(data['train']), batch_size):
            optimizer.zero_grad()

            # YOUR CODE FOR DOING A PROBE FORWARD PASS
            pred = probe(data['train'][1][i:i+batch_size])
            labels = data['train'][0][i:i+batch_size].to(device)
            sen_len = data['train'][2][i:i+batch_size].to(device)
            
            batch_loss, total_sents = loss_function(pred, labels, sen_len)
            batch_loss/=total_sents
            
            if i%100:
                print("train set loss: " + str(batch_loss.item()) + ", iter: " + str(i) + '/' + str(len(corpus)))
            
            batch_loss.backward()
            optimizer.step()

        dev_loss, dev_uuas = evaluate_probe(probe, data['dev'], loss_function)
        
        print("\n---------------------------------------------------------")    
        print("dev set loss: " + str(dev_loss.item()) + ", dev-uuas: " + str(dev_uuas))
        print("---------------------------------------------------------")        
        

        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        scheduler.step(dev_loss)

    test_loss, test_uuas = evaluate_probe(probe, data['test'], loss_function)
    
    print("\n---------------------------------------------------------")    
    print("test set loss: " + str(test_loss.item()) + ", test-uuas: " + str(test_uuas))
    print("---------------------------------------------------------")

In [12]:
data = {}
data['train'] = init_corpus('data/en_ewt-ud-train.conllu', model_TF, tokenizer_TF, 'TF', cutoff=5000)
data['dev'] = init_corpus('data/en_ewt-ud-dev.conllu', model_TF, tokenizer_TF, 'TF', cutoff=1000)
data['test'] = init_corpus('data/en_ewt-ud-test.conllu', model_TF, tokenizer_TF, 'TF', cutoff=1000)

100%|██████████| 5000/5000 [00:58<00:00, 85.87it/s]
100%|██████████| 5000/5000 [02:29<00:00, 33.55it/s] 
100%|██████████| 1000/1000 [00:11<00:00, 89.80it/s]
100%|██████████| 1000/1000 [00:25<00:00, 39.92it/s]
100%|██████████| 1000/1000 [00:11<00:00, 87.89it/s]
100%|██████████| 1000/1000 [00:24<00:00, 40.81it/s]


In [17]:
train(data)


---------------------------------------------------------
epoch: 1
---------------------------------------------------------

---------------------------------------------------------


RuntimeError: CUDA out of memory. Tried to allocate 1.34 GiB (GPU 0; 3.95 GiB total capacity; 2.69 GiB already allocated; 513.06 MiB free; 52.39 MiB cached)