# Probing Language Models for Structure

## 1. Imports <a id="imports"></a>

In [1]:
import numpy as np
import pickle
from tqdm import tqdm
import os
import gdown
from collections import defaultdict
from lstm.model import RNNModel
from typing import List, Dict, Tuple, Optional
from conllu import parse_incr, TokenList
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from ete3 import Tree as EteTree
from scipy.sparse.csgraph import minimum_spanning_tree

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


  from .autonotebook import tqdm as notebook_tqdm


## 2. Models <a id="models"></a>

### Transformer
For the Transformer models you are advised to make use of the `transformers` library of Huggingface: https://github.com/huggingface/transformers
Their library is well documented, and they provide great tools to easily load in pre-trained models.

### LSTM
We will use the Gulordava LSTM from the Colorless Green RNNs paper: https://arxiv.org/pdf/1803.11138.pdf. The weigths are available at https://drive.google.com/file/d/19Lp3AM4NEPycp_IBgoHfLc_V456pmUom/view?usp=sharing. The original code is available at https://github.com/facebookresearch/colorlessgreenRNNs/blob/master/src/language_models/model.py. The code has been altered to only output the hidden states that we are interested in. For further experiments, have a look at the original code.

In [2]:
# load transformer model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('distilgpt2')

# load LSTM model and vocab
# run this to download the saved lstm model
lstm_path = 'lstm/state_dict.pt'  # path to saved lstm model
if not os.path.exists(lstm_path):
    lstm_model_url = 'https://drive.google.com/u/0/uc?id=19Lp3AM4NEPycp_IBgoHfLc_V456pmUom'
    gdown.download(lstm_model_url, lstm_path, quiet=False)
    
lstm_model = RNNModel('LSTM', 50001, 650, 650, 2)
lstm_model.load_state_dict(torch.load(lstm_path))

# the LSTM uses a vocab dict that maps a token to an id, instead of a tokenizer
with open('lstm/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}

vocab = defaultdict(lambda: w2i["<unk>"])
vocab.update(w2i)

It is a good idea that before you move on, you try to feed some text to your LMs; and check if everything works accordingly. 

## 3. Data <a id="data"></a>

For this assignment you will train your probes on __treebank__ corpora. A treebank is a corpus that has been *parsed*, and stored in a representation that allows the parse tree to be recovered. Next to a parse tree, treebanks also often contain information about part-of-speech tags, which is exactly what we are after now.

The treebank you will use for now is part of the Universal Dependencies project. I provide a sample of this treebank as well, so you can test your setup on that before moving on to larger amounts of data.

Make sure you accustom yourself to the format that is created by the `conllu` library that parses the treebank files before moving on. For example, make sure you understand how you can access the pos tag of a token, or how to cope with the tree structure that is formed using the `to_tree()` functionality.

In [3]:
# read data
def parse_corpus(filename: str) -> List[TokenList]:
    data_file = open(filename, encoding="utf-8")
    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

ud_parses = parse_corpus('data/sample/en_ewt-ud-train.conllu')

## 4. Generating Representations

We now have our data all set, our models are running and we are good to go!

The next step is now to create the model representations for the sentences in our corpora. Once we have generated these representations we can store them, and train additional diagnostic (/probing) models on top of the representations.

There are a few things you should keep in mind here. Read these carefully, as these tips will save you a lot of time in your implementation.
1. Transformer models make use of Byte-Pair Encodings (BPE), that chunk up a piece of text in subword pieces. For example, a word such as "largely" could be chunked up into "large" and "ly". We are interested in probing linguistic information on the __word__-level. Therefore, we will follow the suggestion of Hewitt et al. (2019a, footnote 4), and create the representation of a word by averaging over the representations of its subwords. So the representation of "largely" becomes the average of that of "large" and "ly".


2. Subword chunks never overlap multiple tokens. In other words, say we have a phrase like "None of the", then the tokenizer might chunk that into "No"+"ne"+" of"+" the", but __not__ into "No"+"ne o"+"f the", as those chunks overlap multiple tokens. This is great for our setup! Otherwise it would have been quite challenging to distribute the representation of a subword over the 2 tokens it belongs to.


3. **Important**: If you closely examine the provided treebank, you will notice that some tokens are split up into multiple pieces, that each have their own POS-tag. For example, in the first sentence the word "Al-Zaman" is split into "Al", "-", and "Zaman". In such cases, the conllu `TokenList` format will add the following attribute: `('misc', OrderedDict([('SpaceAfter', 'No')]))` to these tokens. Your model's tokenizer does not need to adhere to the same tokenization. E.g., "Al-Zaman" could be split into "Al-"+"Za"+"man", making it hard to match the representations with their correct pos-tag. Therefore I recommend you to not tokenize your entire sentence at once, but to do this based on the chunking of the treebank. <br /><br />
Make sure to still incoporate the spaces in a sentence though, as these are part of the BPE of the tokenizer. That is, the tokenizer uses a different token id for `"man"`, than it does for `" man"`: the former could be part of `" woman"`=`" wo`"+`"man"`, whereas the latter would be the used in case *man* occurs at the start of a word. The tokenizer for GPT-2 adds spaces at the start of a token (represented as a `Ġ` symbol). This means that you should keep track whether the previous token had the `SpaceAfter` attribute set to `'No'`: in case it did not, you should manually prepend a `" "` ahead of the token.


4. The LSTM LM does not have the issues related to subwords, but is far more restricted in its vocabulary. Make sure you keep the above points in mind though, when creating the LSTM representations. You might want to write separate functions for the LSTM, but that is up to you.


5. The huggingface transformer models don't return the hidden state by default. To achieve this you can pass `output_hidden_states=True` to a model forward pass. The hidden states are then returned for all intermediate layers as well, the latest entry in this list corresponds to the top layer.


6. **N.B.**: Make sure that when you run a sentence through your model, you do so within a `with torch.no_grad():` block, and that you have run `model.eval()` beforehand as well (to disable dropout).


7. **N.B.**: Make sure to use a token's ``["form"]`` attribute, and not the ``["lemma"]``, as the latter will stem any relevant morphological information from the token. We don't want this, because we want to feed well-formed, grammatical sentences to our model.


I would like to stress that if you feel hindered in any way by the simple code structure that is presented here, you are free to modify it :-) Just make sure it is clear to an outsider what you're doing, some helpful comments never hurt.

In [4]:
# fetch sentence representations
def fetch_sen_reps(ud_parses: List[TokenList], model, concat=True) -> torch.Tensor:
    '''
    returns sentence representations (embeddings) for a list of sentences, by first tokenizing them and then passing them through the model
    inputs:
        ud_parses: list of sentences, each sentence is a list of tokens, each token is a dictionary (conllu format)
        model: the model (encoder) to use for sentence representation, either an LSTM or a transformer (GPT2)
        tokenizer: either the GPT2 tokenizer or the LSTM vocab
        rep_size: the size of the sentence representations (embeddings)
    returns:
        sent_reps: a tensor of shape (num_tokens_in_corpus, representation_size), containing the sentence representations (embeddings) for all sentences in the corpus
    '''
    model.eval()    # set model to evaluation mode
    sent_reps = []
    for sent in tqdm(ud_parses):
        # LSTM
        if isinstance(model, RNNModel):
            # tokenize
            sent_tokenized = torch.tensor([vocab[token['form']] for token in sent if token["upostag"] != "_"])
            # get sentence representation
            with torch.no_grad():
                out_rep = model(sent_tokenized.unsqueeze(0), model.init_hidden(1)).squeeze(0)
        # GPT
        elif isinstance(model, GPT2LMHeadModel):
            token_ids, att_masks = [], []
            add_space = False   # whether to add a space before the token
            for token in sent:
                if token["upostag"] == "_": # skip invalid/multiword tokens
                    continue
                # tokenize
                tokenized = tokenizer(" " + token['form'], return_tensors='pt') if add_space else tokenizer(token['form'], return_tensors='pt')
                token_ids.append(tokenized['input_ids'][0])
                att_masks.append(tokenized['attention_mask'][0])
                # check whether to add a space before the next token
                add_space = False if token['misc'] is not None and token['misc'].get('SpaceAfter', '') == 'No' else True
                
            # get sentence representation
            out = model(input_ids=torch.hstack(token_ids), attention_mask=torch.hstack(att_masks), output_hidden_states=True).hidden_states[-1]

            # average over parts belonging to the same token
            out_rep = torch.zeros(len(token_ids), out.shape[-1])
            num_sub_tokens = 0
            for i in range(out_rep.shape[0]):
                out_rep[i] = out[i + num_sub_tokens: i + num_sub_tokens + len(token_ids[i])].mean(0)
                num_sub_tokens += len(token_ids[i]) - 1
                
        else :
            raise ValueError('model should be either an LSTM or a transformer (GPT2)')       
        sent_reps += out_rep if concat else [out_rep]
    
    # stack token representations of entire corpus
    if concat:
        sent_reps = torch.vstack(sent_reps)
    
    return sent_reps

To validate your activation extraction procedure I have set up the following assertion function as a sanity check. It compares your representation of the first sentence in the corpus against a pickled version of mine. 

For this I used `distilgpt2`.

In [5]:
# corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')
# concat = False
# emb_gpt = fetch_sen_reps(corpus, gpt_model, tokenizer, concat=concat)
# if concat:
#     print(f'shape of gpt sentence representations: {emb_gpt.shape}') # (num_tokens_in_corpus, representation_size)
# else:
#     print(f'num sentences: {len(emb_gpt)}') # num_sentences
#     print(f'shape of embeddings of 1st sentence: {emb_gpt[0].shape}') # (num_tokens_in_sentence, representation_size)

In [6]:
def error_msg(model_name, gold_embs, embs, i2w):
    with open(f'{model_name}_tokens1.pickle', 'rb') as f:
        sen_tokens = pickle.load(f)
        
    diff = torch.abs(embs - gold_embs)
    max_diff = torch.max(diff)
    avg_diff = torch.mean(diff)
    
    print(f"{model_name} embeddings don't match!")
    print(f"Max diff.: {max_diff:.4f}\nMean diff. {avg_diff:.4f}")

    print("\nCheck if your tokenization matches with the original tokenization:")
    for idx in sen_tokens.squeeze():
        if isinstance(i2w, list):
            token = i2w[idx]
        else:
            token = i2w.convert_ids_to_tokens(idx.item())
        print(f"{idx:<6} {token}")


def assert_sen_reps(model, tokenizer, lstm, vocab):
    with open('distilgpt2_emb1.pickle', 'rb') as f:
        distilgpt2_emb1 = pickle.load(f)
        
    with open('lstm_emb1.pickle', 'rb') as f:
        lstm_emb1 = pickle.load(f)
    
    corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')[:1]
    
    own_distilgpt2_emb1 = fetch_sen_reps(corpus, model, tokenizer)
    own_lstm_emb1 = fetch_sen_reps(corpus, lstm, vocab)
    
    assert distilgpt2_emb1.shape == own_distilgpt2_emb1.shape, \
        f"Distilgpt2 shape mismatch: {distilgpt2_emb1.shape} (gold) vs. {own_distilgpt2_emb1.shape} (yours)"
    assert lstm_emb1.shape == own_lstm_emb1.shape, \
        f"LSTM shape mismatch: {lstm_emb1.shape} (gold) vs. {own_lstm_emb1.shape} (yours)"

    if not torch.allclose(distilgpt2_emb1, own_distilgpt2_emb1, rtol=1e-3, atol=1e-3):
        error_msg("distilgpt2", distilgpt2_emb1, own_distilgpt2_emb1, tokenizer)
    if not torch.allclose(lstm_emb1, own_lstm_emb1, rtol=1e-3, atol=1e-3):
        error_msg("lstm", lstm_emb1, own_lstm_emb1, list(vocab.keys()))


assert_sen_reps(gpt_model, tokenizer, lstm_model, vocab)

100%|██████████| 1/1 [00:00<00:00,  9.13it/s]
100%|██████████| 1/1 [00:00<00:00, 12.79it/s]


Next, we should define a function that extracts the corresponding POS labels for each activation, which we do based on the **``"upostag"``** attribute of a token (so not the ``xpostag`` attribute). These labels will be transformed to a tensor containing the label index for each item.

In [7]:
# fetch POS tags
def fetch_pos_tags(ud_parses: List[TokenList], pos_vocab: Optional[Dict[str, int]] = None) -> Tuple[torch.Tensor, Dict[str, int]]:
    '''
    return the POS tags for all tokens in the corpus
    inputs:
        ud_parses: list of sentences, each sentence is a list of tokens, each token is a dictionary (conllu format)
        pos_vocab: a dictionary mapping POS tags to integers (optional)
    returns:
        pos_tags: a tensor of shape (num_tokens_in_corpus,) containing the POS tags for all tokens in the corpus
    '''
    if pos_vocab is None:
        pos_vocab = defaultdict(int)
        for sent in ud_parses:
            for token in sent:
                if token["upostag"] == "_": # skip invalid/multiword tokens
                    continue
                # add new POS tags to vocab 
                if token["upostag"] not in pos_vocab:
                    pos_vocab[token["upostag"]] = len(pos_vocab) + 1

    pos_tags = [torch.tensor(pos_vocab[token["upostag"]]) for sent in ud_parses for token in sent if token["upostag"] != "_"]
    pos_tags = torch.vstack(pos_tags).squeeze()

    return pos_tags, pos_vocab


Finally, we can combine all these methods to retrieve the representations (`fetch_sen_reps`) and the corresponding labels (`fetch_pos_tags`). If you are still debugging and testing your setup you can set the `use_sample` variable to `True`, and once everything works you can extract the full corpus by setting it to `False`.

The reason we pass the `train_vocab` to the data creation of the `dev` and `test` data is that we want to use the same label vocabulary across the different train/dev/test splits.

In [None]:
%%time
# create 2 tensors for a .conllu file: 1 containing the token representations, and 1 containing the (tokenized) pos_tags
def create_data(ud_parses, filename: str, lm, pos_vocab=None):
    ud_parses = parse_corpus(filename)
    sen_reps = fetch_sen_reps(ud_parses, lm)
    pos_tags, pos_vocab = fetch_pos_tags(ud_parses, pos_vocab=pos_vocab)
    
    return sen_reps, pos_tags, pos_vocab

lm = gpt_model  # language model, either `gpt_model` or `lstm_model`
# w2i = tokenizer if lm==gpt_model else vocab if lm==lstm_model else None  # word-to-index mapping, either `tokenizer` (gpt) or `vocab` (lstm)
use_sample = True   # use a small sample of the data for faster debugging

# create train/dev/test data
train_path = 'data/sample/en_ewt-ud-train.conllu' if use_sample else 'data/en_ewt-ud-train.conllu'
dev_path = 'data/sample/en_ewt-ud-dev.conllu' if use_sample else 'data/en_ewt-ud-dev.conllu'
test_path = 'data/sample/en_ewt-ud-test.conllu' if use_sample else 'data/en_ewt-ud-test.conllu'

train_x, train_y, train_vocab = create_data(ud_parses, train_path, lm) 
dev_x, dev_y, _ = create_data(ud_parses, dev_path, lm, pos_vocab=train_vocab) 
test_x, test_y, _ = create_data(ud_parses, test_path, lm, pos_vocab=train_vocab)

In [97]:
# create datasets and dataloaders
# define a custom PyTorch dataset
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)

# create custom datasets
train_data = MyDataset(train_x.detach(), train_y.detach())
val_data = MyDataset(dev_x, dev_y)
test_data = MyDataset(test_x, test_y)
print(f'size of train set: {len(train_data)}')
print(f'size of dev set: {len(val_data)}')
print(f'size of test set: {len(test_data)}')

size of train set: 2301
size of dev set: 1166
size of test set: 1153


## 5. Diagnostic Classification <a name="dc"></a>

We now have our models, our data, _and_ our representations all set! Hurray, well done. We can finally move onto the cool stuff, i.e. training the diagnostic models (DCs).

DCs are simple in their complexity on purpose. To read more about why this is the case you could already have a look at the "Designing and Interpreting Probes with Control Tasks" by Hewitt and Liang (esp. Sec. 3.2).

A simple linear model will suffice for now, don't bother with adding fancy non-linearities to it.

I am personally a fan of the `skorch` library, that provides `sklearn`-like functionalities for training `torch` models, but you are free to train your dc using whatever method you prefer.

As this is an Artificial Intelligence master and you have all done ML1 + DL, I expect you to use your train/dev/test splits correctly ;-)

In [102]:
# Diagnostic classifier/probe
# class to store training parameters
class TrainingParams:
    def __init__(self, lr=1e-3, batch_size=256, num_epochs=1000, patience=10):
        self.lr = lr
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.patience = patience
        
def train_pos_probe(model, train_data, val_data, params, model_dir='models', print_every=10):
    # create dataloaders
    train_loader = DataLoader(train_data, batch_size=params.batch_size, shuffle=False)
    val_loader = DataLoader(val_data, batch_size=params.batch_size, shuffle=False)
    # define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
    val_losses, val_accs = [], []

    # training/val loop
    print('training probe...')
    for epoch in range(params.num_epochs):
        # train
        model.train()
        for train_x, train_y in train_loader:
            out = model(train_x)
            loss = criterion(out, train_y)
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
        # validate
        model.eval()
        val_losses_epoch, val_accs_epoch = [], []
        for val_x, val_y in val_loader:
            with torch.no_grad():
                out = model(val_x)
                loss = criterion(out, val_y)
                val_losses_epoch.append(loss.item())
                preds_val = torch.argmax(out, dim=1)
                acc = (preds_val == val_y).sum().item() / len(val_y)
                val_accs_epoch.append(acc)
                
        val_loss_epoch = np.mean(val_losses_epoch)
        val_acc_epoch = np.mean(val_accs_epoch)
        val_losses.append(val_loss_epoch)
        val_accs.append(val_acc_epoch)
        
        if epoch % print_every == 0:
            print(f'epoch: {epoch} | val loss: {val_loss_epoch:.3f} | val acc: {val_acc_epoch:.3f}')
        
        # early stopping
        if epoch >= params.patience and val_loss_epoch > val_losses[-params.patience]:
            print(f'val loss did not improve for {params.patience} epochs, stopping training')
            break
        
    # save model
    suffix = 'gpt' if isinstance(lm, GPT2LMHeadModel) else 'lstm' if isinstance(lm, RNNModel) else ''
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = f'{model_dir}/linear_pos_{suffix}.pt'
    torch.save(model, model_path)
        
    return model, val_losses, val_accs


In [98]:
# train linear probe
params = TrainingParams()
linear_probe_model = nn.Linear(train_x.shape[1], len(train_vocab))
linear_probe_model, _, _ = train_pos_probe(linear_probe_model, train_data, val_data, params)
# test
out_test = linear_probe_model(test_x)
preds_test = torch.argmax(out_test, dim=1)
test_acc = (preds_test == test_y).sum().item() / len(test_y)
print(f'test accuracy of linear pos probe: {test_acc:.3f}')

training probe...
epoch: 0 | val loss: 2.514 | val acc: 0.421
epoch: 10 | val loss: 1.408 | val acc: 0.773
epoch: 20 | val loss: 0.980 | val acc: 0.827
epoch: 30 | val loss: 0.783 | val acc: 0.843
epoch: 40 | val loss: 0.676 | val acc: 0.846
epoch: 50 | val loss: 0.610 | val acc: 0.854
epoch: 60 | val loss: 0.566 | val acc: 0.856
epoch: 70 | val loss: 0.535 | val acc: 0.861
epoch: 80 | val loss: 0.512 | val acc: 0.866
epoch: 90 | val loss: 0.494 | val acc: 0.868
epoch: 100 | val loss: 0.481 | val acc: 0.871
epoch: 110 | val loss: 0.470 | val acc: 0.870
epoch: 120 | val loss: 0.462 | val acc: 0.871
epoch: 130 | val loss: 0.456 | val acc: 0.871
epoch: 140 | val loss: 0.451 | val acc: 0.870
epoch: 150 | val loss: 0.447 | val acc: 0.871
epoch: 160 | val loss: 0.444 | val acc: 0.872
epoch: 170 | val loss: 0.442 | val acc: 0.871
epoch: 180 | val loss: 0.440 | val acc: 0.872
epoch: 190 | val loss: 0.439 | val acc: 0.869
epoch: 200 | val loss: 0.439 | val acc: 0.871
epoch: 210 | val loss: 0.43

## 6. Trees <a name="trees"></a>

For our gold labels, we need to recover the node distances from our parse tree. For this we will use the functionality provided by `ete3`, that allows us to compute that directly. I have provided code that transforms a `TokenTree` to a `Tree` in `ete3` format.

In [75]:
# In case you want to transform your conllu tree to an nltk.Tree, for better visualisation

def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"
    return tree_str


def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree
    tree_str = rec_tokentree_to_nltk(tokentree)
    return NLTKTree.fromstring(tree_str)

In [76]:
class FancyTree(EteTree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)


def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)

    return FancyTree(f"{newick_str};")

In [77]:
# read in a corpus and convert it to an ete3 Tree
def parse_corpus(filename):
    data_file = open(filename, encoding="utf-8")
    ud_parses = list(parse_incr(data_file))
    return ud_parses

corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')
item = corpus[0]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree)


   /-2
  |
  |--3
  |
  |--4
  |
  |   /6 /-5
  |  |
  |  |   /-9
  |  |  |
  |  |  |--10
  |  |  |
  |  |  |--11
  |  |-8|
  |  |  |--12
  |-7|  |
  |  |  |--13
  |  |  |
  |  |   \15/-14
-1|  |
  |  |   /-16
  |  |  |
  |  |  |--17
  |  |  |
  |   \18   /-19
  |     |  |
  |     |  |--20
  |     |  |
  |     |  |-23/-22
  |      \21
  |        |--24
  |        |
  |        |   /-25
  |        |  |
  |         \28--26
  |           |
  |            \-27
  |
   \-29


As you can see we label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances.

To create the true distances of a parse tree in our treebank, we are going to use the `.get_distance` method that is provided by `ete3`: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#working-with-branch-distances

We will store all these distances in a `torch.Tensor`.

Please fill in the gap in the following method. I recommend you to have a good look at Hewitt's blog post  about these node distances.

In [78]:
def create_gold_distances(corpus):
    all_distances = []

    for item in (corpus):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)

        sen_len = len(ete_tree.search_nodes())
        distances = torch.zeros((sen_len, sen_len))

        for node1 in ete_tree.search_nodes():
            for node2 in ete_tree.search_nodes():
                distances[int(node1.name)-1][int(node2.name)-1] = node1.get_distance(node2)

        all_distances.append(distances)

    return all_distances

The next step is now to do the previous step the other way around. After all, we are mainly interested in predicting the node distances of a sentence, in order to recreate the corresponding parse tree.

Hewitt et al. reconstruct a parse tree based on a _minimum spanning tree_ (MST, https://en.wikipedia.org/wiki/Minimum_spanning_tree). Fortunately for us, we can simply import a method from `scipy` that retrieves this MST.

In [80]:
def create_mst(distances):
    distances = torch.triu(distances).detach().numpy()
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    
    return mst

Let's have a look at what this looks like, by looking at a relatively short sentence in the sample corpus.

If your addition to the `create_gold_distances` method has been correct, you should be able to run the following snippet. This then shows you the original parse tree, the distances between the nodes, and the MST that is retrieved from these distances. Can you spot the edges in the MST matrix that correspond to the edges in the parse tree?

In [81]:
item = corpus[5]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree, '\n')

gold_distance = create_gold_distances(corpus[5:6])[0]
print(gold_distance, '\n')

mst = create_mst(gold_distance)
print(mst)


   /2 /-1
  |
  |--3
  |
  |--4
  |
  |   /-6
  |  |
-5|  |--7
  |-8|
  |  |   /-9
  |  |  |
  |   \12--10
  |     |
  |      \-11
  |
   \-13 

tensor([[0., 1., 3., 3., 2., 4., 4., 3., 5., 5., 5., 4., 3.],
        [1., 0., 2., 2., 1., 3., 3., 2., 4., 4., 4., 3., 2.],
        [3., 2., 0., 2., 1., 3., 3., 2., 4., 4., 4., 3., 2.],
        [3., 2., 2., 0., 1., 3., 3., 2., 4., 4., 4., 3., 2.],
        [2., 1., 1., 1., 0., 2., 2., 1., 3., 3., 3., 2., 1.],
        [4., 3., 3., 3., 2., 0., 2., 1., 3., 3., 3., 2., 3.],
        [4., 3., 3., 3., 2., 2., 0., 1., 3., 3., 3., 2., 3.],
        [3., 2., 2., 2., 1., 1., 1., 0., 2., 2., 2., 1., 2.],
        [5., 4., 4., 4., 3., 3., 3., 2., 0., 2., 2., 1., 4.],
        [5., 4., 4., 4., 3., 3., 3., 2., 2., 0., 2., 1., 4.],
        [5., 4., 4., 4., 3., 3., 3., 2., 2., 2., 0., 1., 4.],
        [4., 3., 3., 3., 2., 2., 2., 1., 1., 1., 1., 0., 3.],
        [3., 2., 2., 2., 1., 3., 3., 2., 4., 4., 4., 3., 0.]]) 

[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0

Now that we are able to map edge distances back to parse trees, we can create code for our quantitative evaluation. For this we will use the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

To do this, we will need to obtain all the edges from our MST matrix. Note that, since we are using undirected trees, that an edge can be expressed in 2 ways: an edge between node $i$ and node $j$ is denoted by both `mst[i,j] = 1`, or `mst[j,i] = 1`.

You will write code that computes the UUAS score for a matrix of predicted distances, and the corresponding gold distances. I recommend you to split this up into 2 methods: 1 that retrieves the edges that are present in an MST matrix, and one general method that computes the UUAS score.

In [82]:
def get_edges(mst):
    edges = np.nonzero(mst)
    edges = list(zip(edges[0], edges[1]))
    edges = set(map(lambda x: tuple(sorted(x)), edges))
    return edges


def calc_uuas(pred_distances, gold_distances):
    # gold_distances = gold_distances[gold_distances[0,:] != -1]
    # valid_cols = [col_idx for col_idx, col in enumerate(torch.split(gold_distances, 1, dim=1)) if not torch.all(col == -1)]
    # gold_distances = gold_distances[:, valid_cols]
    # sen_len = gold_distances.shape[0]
    # pred_distances = pred_distances[:sen_len,:sen_len]
    # gold_mst = create_mst(gold_distances)
    # pred_mst = create_mst(pred_distances)
    # pred_edges = get_edges(pred_mst)
    # gold_edges = get_edges(gold_mst)
    uuas = 0
    length = 0
    for i in range(len(gold_distances)):
        l = max(torch.nonzero(gold_distances[i] != -1, as_tuple=True)[0]) + 1
        pred_mst = create_mst(pred_distances[i][:l, :l])
        gold_mst = create_mst(gold_distances[i][:l, :l])
        pred_edges = get_edges(pred_mst)
        gold_edges = get_edges(gold_mst)
        try:
            uuas += len(pred_edges.intersection(gold_edges)) / len(gold_edges)
            length += 1
        except:
            pass

    return uuas

## 7. Structural Probes <a name="structural-probes"></a>

We now have everything in place to start doing the actual exciting stuff: training our structural probe!
    
To make life easier for you, we will simply take the `torch` code for this probe from John Hewitt's repository. This allows you to focus on the training regime from now on.

In [83]:
class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


I have provided a rough outline for the training regime that you can use. Note that the hyper parameters that I provide here only serve as an indication, but should be (briefly) explored by yourself.

As can be seen in Hewitt's code above, there exists functionality in the probe to deal with batched input. It is up to you to use that: a (less efficient) method can still incorporate batches by doing multiple forward passes for a batch and computing the backward pass only once for the summed losses of all these forward passes. (_I know, this is not the way to go, but in the interest of time that is allowed ;-), the purpose of the assignment is writing a good paper after all_).

In [99]:
'''
Similar to the `create_data` method of the previous notebook, I recommend you to use a method 
that initialises all the data of a corpus. Note that for your embeddings you can use the 
`fetch_sen_reps` method again. However, for the POS probe you concatenated all these representations into 
1 big tensor of shape (num_tokens_in_corpus, model_dim). 

The StructuralProbe expects its input to contain all the representations of 1 sentence, so I recommend you
to update your `fetch_sen_reps` method in a way that it is easy to retrieve all the representations that 
correspond to a single sentence.
''' 

def init_corpus(path, model, concat=False, cutoff=None):
    """ Initialises the data of a corpus.
    
    Parameters
    ----------
    path : str
        Path to corpus location
    model: language model to encode sentences, either LSTM or GPT2
    concat : bool, optional
        Optional toggle to concatenate all the tensors
        returned by `fetch_sen_reps`.
    cutoff : int, optional
        Optional integer to "cutoff" the data in the corpus.
        This allows only a subset to be used, alleviating 
        memory usage.
    """
    corpus = parse_corpus(path)[:cutoff]
    embs = fetch_sen_reps(corpus, model, concat=concat)    
    gold_distances = create_gold_distances(corpus)
    
    return embs, gold_distances

# create data for structural probe
lm = lstm_model # language model, either `gpt_model` or `lstm_model`
train_x, train_y = init_corpus(train_path, lm)
val_x, val_y = init_corpus(dev_path, lm)
test_x, test_y = init_corpus(test_path, lm)
train_data = MyDataset(train_x, train_y)
val_data = MyDataset(val_x, val_y)
test_data = MyDataset(test_x, test_y)

print(f'size of train set: {len(train_data)}')
print(f'size of dev set: {len(val_data)}')
print(f'size of test set: {len(test_data)}')

size of train set: 90
size of dev set: 50
size of test set: 50


In [100]:
# Feel free to alter the signature of this method.
def evaluate_probe(model, dataloader, loss_fn):
    # global TREE
    model.eval()
    loss = 0
    uuas = 0
    # spearman = 0
    with torch.no_grad():
      for x, gold_distances, length in dataloader:
          preds = model(x)
          loss += loss_fn(preds, gold_distances, length)[0]
          uuas += calc_uuas(preds, gold_distances)
        #   spearman += calc_spearman(preds, gold_distances)
          # Code to create a latex tree
        #   if i == 13 and TREE == False and test == True:
        #     item = 13
        #     words = get_words(f'data/{path}-test.conllu', item)

        #     l = max(torch.nonzero(gold_distances[i - 1]!=-1, as_tuple=True)[0]) +1
        #     pred_mst = create_mst(preds[i - 1][:l,:l])
        #     gold_mst = create_mst(gold_distances[i - 1][:l,:l])
        #     pred_edges = edges(pred_mst)
        #     gold_edges = edges(gold_mst)
        #     print_tikz(pred_edges, gold_edges, words, model_type)
        #     TREE = True
    loss /= len(dataloader)
    uuas /= len(dataloader)
    # spearman /= len(dataloader)

    return loss, uuas

def pad_collate_fn(batch):
    max_length = max([len(x[1]) for x in batch])
    out_labels = torch.full((len(batch), max_length, max_length), -1)
    out_lengths = torch.zeros(len(batch))
    for i, x in enumerate(batch):
      out_labels[i, :x[1].shape[0], :x[1].shape[1]] = x[1]
      out_lengths[i] = x[1].shape[0]
      if len(x[0].shape) == 1:
        batch[i] = (x[0].unsqueeze(0), x[1])
    return torch.nn.utils.rnn.pad_sequence(list(map(lambda x: x[0].detach(), batch)), batch_first = True, padding_value=-1), out_labels, out_lengths

def train_structural_probe(model, train_data, val_data, params, model_dir='models', print_every=10):
    # create dataloaders
    train_loader = DataLoader(train_data, batch_size=params.batch_size, shuffle=True, collate_fn=pad_collate_fn)
    val_loader = DataLoader(val_data, batch_size=params.batch_size, shuffle=False, collate_fn=pad_collate_fn)
    optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
    criterion =  L1DistanceLoss()
    val_losses, val_uuas = [], []

    # training/val loop
    print('training probe...')
    for epoch in range(params.num_epochs):
        # train
        model.train()
        # for i in range(0, len(corpus), params.batch_size):
        for train_x, gold_distances, lengths in train_loader:
            out = model(train_x)
            loss = criterion(out, gold_distances, lengths)[0]
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
        # val
        val_loss_epoch, val_uuas_epoch = evaluate_probe(model, val_loader, criterion)
        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        scheduler.step(val_loss_epoch)
        val_losses.append(val_loss_epoch)
        val_uuas.append(val_uuas_epoch)
        
        if epoch % print_every == 0:
            print(f'epoch: {epoch} | val loss: {val_loss_epoch:.3f} | val uuas: {val_uuas_epoch:.3f}')
        
        # early stopping
        if epoch >= params.patience and val_loss_epoch > val_losses[-params.patience]:
            print(f'val loss did not improve for {params.patience} epochs, stopping training')
            break
        
    # save model
    suffix = 'gpt' if isinstance(lm, GPT2LMHeadModel) else 'lstm' if isinstance(lm, RNNModel) else ''
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = f'{model_dir}/str_pos_{suffix}.pt'
    torch.save(model, model_path)
        
    return model, val_losses, val_uuas
    # test_loss, test_uuas = evaluate_probe(probe, _test_data)

In [101]:
# train structural probe
params = TrainingParams()
str_probe_model = StructuralProbe(train_x[0].shape[1], rank=64)
str_probe_model, _, _ = train_structural_probe(str_probe_model, train_data, val_data, params)

training probe...
epoch: 0 | val loss: 1.741 | val uuas: 12.907
epoch: 10 | val loss: 1.305 | val uuas: 13.996
epoch: 20 | val loss: 1.152 | val uuas: 18.248
epoch: 30 | val loss: 1.105 | val uuas: 19.714
epoch: 40 | val loss: 1.076 | val uuas: 20.305
epoch: 50 | val loss: 1.061 | val uuas: 21.129
epoch: 60 | val loss: 1.049 | val uuas: 21.316
epoch: 70 | val loss: 1.046 | val uuas: 21.486
epoch: 80 | val loss: 1.045 | val uuas: 21.530
epoch: 90 | val loss: 1.045 | val uuas: 21.530
epoch: 100 | val loss: 1.045 | val uuas: 21.530
epoch: 110 | val loss: 1.045 | val uuas: 21.530
val loss did not improve for 20 epochs, stopping training


In [103]:
# test
test_loader = DataLoader(test_data, batch_size=params.batch_size, shuffle=False, collate_fn=pad_collate_fn)
test_loss, test_uuas = evaluate_probe(str_probe_model, test_loader, L1DistanceLoss())
print(f'test uuas of structural pos probe: {test_uuas:.3f}, test loss: {test_loss:.3f}')

test uuas of structural pos probe: 19.525, test loss: 1.076


## LaTeX trees

For your report you might want to add some of those fancy dependency tree plots like those of Figure 2 in the Structural Probing paper. For that you can use the following code, that outputs the corresponding LaTeX markup.

**N.B.**: for the latex tikz tree the first token in a sentence has index 1 (instead of 0), so take that into account with the predicted and gold edges that you pass to the method.

In [None]:
def print_tikz(predicted_edges, gold_edges, words):
    """ Turns edge sets on word (nodes) into tikz dependency LaTeX.
    Parameters
    ----------
    predicted_edges : Set[Tuple[int, int]]
        Set (or list) of edge tuples, as predicted by your probe.
    gold_edges : Set[Tuple[int, int]]
        Set (or list) of gold edge tuples, as obtained from the treebank.
    words : List[str]
        List of strings representing the tokens in the sentence.
    """

    string = """\\begin{dependency}[hide label, edge unit distance=.5ex]
    \\begin{deptext}[column sep=0.05cm]
    """

    string += (
        "\\& ".join([x.replace("$", "\$").replace("&", "+") for x in words])
        + " \\\\\n"
    )
    string += "\\end{deptext}" + "\n"
    for i_index, j_index in gold_edges:
        string += "\\depedge[-]{{{}}}{{{}}}{{{}}}\n".format(i_index, j_index, ".")
    for i_index, j_index in predicted_edges:
        string += f"\\depedge[-,edge style={{red!60!}}, edge below]{{{i_index}}}{{{j_index}}}{{.}}\n"
    string += "\\end{dependency}\n"
    print(string)