__Probing Language Models__

This notebook serves as a start for your NLP2 assignment on probing Language Models. This notebook will become part of the contents that you will submit at the end, so make sure to keep your code (somewhat) clean :-)

__note__: This is the first time _anyone_ is doing this assignment. That's exciting! But it might well be the case that certain aspects are too unclear. Do not hesitate at all to reach to me once you get stuck, I'd be grateful to help you out.

__note 2__: This assignment is not dependent on big fancy GPUs. I run all this stuff on my own 3 year old CPU, without any Colab hassle. So it's up to you to decide how you want to run it.

# Models

For the Transformer models you are advised to make use of the `transformers` library of Huggingface: https://github.com/huggingface/transformers
Their library is well documented, and they provide great tools to easily load in pre-trained models.

In [1]:
from transformers import GPT2Model, GPT2Tokenizer
import torch
from torch import Tensor
import time
print("Imports ready")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Imports ready


In [2]:
transformer = GPT2Model.from_pretrained('distilgpt2', output_hidden_states=True)
print("Model ready")
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
print("Tokenizer ready")
# Note that some models don't return the hidden states by default.
# This can be configured by passing `output_hidden_states=True` to the `from_pretrained` method.

Model ready
Tokenizer ready


In [3]:
tokenizer.decode(tokenizer.encode("I like puffy cheeseballs"))

'I like puffy cheeseballs'

In [4]:
#
## Your code for initializing the rnn model(s)
#
# The Gulordava LSTM model can be found here: 
# https://drive.google.com/open?id=1w47WsZcZzPyBKDn83cMNd0Hb336e-_Sy
#
# N.B: I have altered the RNNModel code to only output the hidden states that you are interested in.
# If you want to do more experiments with this model you could have a look at the original code here:
# https://github.com/facebookresearch/colorlessgreenRNNs/blob/master/src/language_models/model.py
#
from collections import defaultdict
from lstm.model import RNNModel

model_location = 'lstm/gulordava.pt'
lstm = RNNModel('LSTM', 50001, 650, 650, 2)
lstm.load_state_dict(torch.load(model_location))


# This LSTM does not use a Tokenizer like the Transformers, but a Vocab dictionary that maps a token to an id.
with open('lstm/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}

vocab = defaultdict(lambda: w2i["<unk>"])
vocab.update(w2i)
i2w = { w2i[k]:k for k in w2i}

In [5]:
lstm.eval()
print(type(lstm))
with torch.no_grad():
    input_sent = "My sister is 20"
    inputs = torch.tensor([vocab[z] for z in input_sent.split()]).unsqueeze(0)
    outputs = lstm(inputs, lstm.init_hidden(1), True)
    argmaxes = torch.argmax(outputs, dim=1)
    last = argmaxes[-1]
    print("Voc size", outputs.shape[-1])
    print(input_sent, i2w[last.item()])

<class 'lstm.model.RNNModel'>
Voc size 50001
My sister is 20 years


It is a good idea that before you move on, you try to feed some text to your LMs; and check if everything works accordingly. 

# Data

For this assignment you will train your probes on __treebank__ corpora. A treebank is a corpus that has been *parsed*, and stored in a representation that allows the parse tree to be recovered. Next to a parse tree, treebanks also often contain information about part-of-speech tags, which is exactly what we are after now.

The treebank you will use for now is part of the Universal Dependencies project. I provide a sample of this treebank as well, so you can test your setup on that before moving on to larger amounts of data.

Make sure you accustom yourself to the format that is created by the `conllu` library that parses the treebank files before moving on. For example, make sure you understand how you can access the pos tag of a token, or how to cope with the tree structure that is formed using the `to_tree()` functionality.

In [6]:
# READ DATA
from typing import List
from conllu import parse_incr, TokenList


# If stuff like `: str` and `-> ..` seems scary, fear not! 
# These are type hints that help you to understand what kind of argument and output is expected.
def parse_corpus(filename: str) -> List[TokenList]:
    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

# Generating Representations

We now have our data all set, our models are running and we are good to go!

The next step is now to create the model representations for the sentences in our corpora. Once we have generated these representations we can store them, and train additional diagnostic (/probing) classifiers on top of the representations.

There are a few things you should keep in mind here. Read these carefully, as these tips will save you a lot of time in your implementation.
- Transformer models make use of Byte-Pair Encodings (BPE), that chunk up a piece of next in subword pieces. For example, a word such as "largely" could be chunked up into "large" and "ly". We are interested in probing linguistic information on the __word__-level. Therefore, we will follow the suggestion of Hewitt et al. (2019a, footnote 4), and create the representation of a word by averaging over the representations of its subwords. So the representation of "largely" becomes the average of that of "large" and "ly".

- Subword chunks never overlap multiple tokens. In other words, say we have a phrase like "None of the", then the tokenizer might chunk that into "No"+"ne"+" of"+" the", but __not__ into "No"+"ne o"+"f the", as those chunks overlap multiple tokens. This is great for our setup! Otherwise it would have been quite challenging to distribute the representation of a subword over the 2 tokens it belongs to.

- If you closely examine the provided treebank, you will notice that some tokens are split up into multiple pieces, that each have their own POS-tag. For example, in the first sentence the word "Al-Zaman" is split into "Al", "-", and "Zaman". In such cases, the conllu `TokenList` format will add the following attribute: `('misc', OrderedDict([('SpaceAfter', 'No')]))` to these tokens. Your model's tokenizer does not need to adhere to the same tokenization. E.g., "Al-Zaman" could be split into "Al-"+"Za"+"man", making it hard to match the representations with their correct pos-tag. Therefore I recommend you to not tokenize your entire sentence at once, but to do this based on the chunking of the treebank. Make sure to still incoporate the spaces in a sentence though, as these are part of the BPE of the tokenizer. The tokenizer for GPT-2 adds spaces at the start of a token.

- The LSTM LM does not have the issues related to subwords, but is far more restricted in its vocabulary. Make sure you keep the above points in mind though, when creating the LSTM representations. You might want to write separate functions for the LSTM, but that is up to you.

I would like to stress that if you feel hindered in any way by the simple code structure that is presented here, you are free to modify it :-) Just make sure it is clear to an outsider what you're doing, some helpful comments never hurt.

In [7]:
pos_w2i = dict()
pos_i2w = dict()
last = 0

In [11]:
from lstm.model import RNNModel


def fetch_sen_reps(ud_parses: List[TokenList], model, tokenizer, concat=True, get_pos = True) -> List[Tensor]:
    """
    Returns a list of length len(ud_parses)
    """
    if get_pos:
        global last, pos_w2i, pos_i2w
        pos_result = []
        
    model.eval()
    doing_lstm = type(model) == RNNModel
    print(f"Doing LSTM: {doing_lstm}")
    sentences_result = []
    global_words = []
    
    for sentence_nr, sentence in enumerate(ud_parses):
        sentence_words = []
        
        # First build string sentence repr with spaces and such
        for i, token in enumerate(sentence):
            if get_pos:
                postag = token['upostag']
                if postag in pos_w2i:
                    posindex = pos_w2i[postag]
                else:
                    posindex = last
                    pos_w2i[postag] = last 
                    pos_i2w[last] = postag 
                    last += 1
                pos_result.append(posindex)
            
            if token['misc'] is not None:
                # SpaceAfter = False
                next_word = token['form']
            else:
                # SpaceAfter = True
                next_word = token['form'] + ' '
            sentence_words.append(next_word)
        # Now build model representation!
        
        # Also add to global_words to retain word representations
        global_words.append(sentence_words)
        
        # In case of LSTM
        if doing_lstm:
            the_input = torch.tensor([tokenizer[z.strip()] for z in sentence_words]).unsqueeze(0)
            with torch.no_grad():
                final = model(the_input, lstm.init_hidden(1))
            final = final.squeeze(0)
            assert len(final) == len(sentence_words), "Something is wrong.."
            sentences_result.append(final)
                
        # In case of Transformer
        else:
            representation = []
            sizes = [] 
            for i,word in enumerate(sentence_words):
                if i>0 and sentence_words[i-1][-1] == ' ':
                    e = tokenizer.encode(' ' + word.strip())
                    representation += e
                    sizes.append(len(e))
                else:
                    e = tokenizer.encode(word.strip())
                    representation += e
                    sizes.append(len(e))
            the_input = torch.tensor(representation)
            with torch.no_grad():
                result = model(the_input)[0]
            decoded_repr = [tokenizer.decode(z) for z in representation]
            final_repr = []
            
            i = 0
            for size in sizes:
                to_append = torch.mean(result[i:i+size], dim=0)
                final_repr.append(to_append)
                i += size
            
            assert len(final_repr) == len(sentence_words), "Something is wrong"
            sentences_result.append(torch.stack(final_repr).squeeze(1))
           
    if concat:
      yes = torch.cat([s for s in sentences_result], dim=0)
      if get_pos: return yes, torch.tensor(pos_result), global_words
      return yes
    
    # Assume concat means structural probe, means no pos
    return [s for s in sentences_result], global_words

corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')
start = time.time()
d, d2, words = fetch_sen_reps(corpus, transformer, tokenizer)
total = time.time() - start
print("%s gpt2" % total, "(%s" % (total/len(corpus)), "per sentence on avg)" )

start = time.time()
d, d2, words = fetch_sen_reps(corpus, lstm, vocab)
total = time.time() - start
print("%s gpt2" % total, "(%s" % (total/len(corpus)), "per sentence on avg)" )

Doing LSTM: False
1.889827013015747 gpt2 (0.02099807792239719 per sentence on avg)
Doing LSTM: True
0.8601489067077637 gpt2 (0.009557210074530708 per sentence on avg)


In [12]:
# FETCH SENTENCE REPRESENTATIONS
from torch import Tensor
import pickle


# I provide the following sanity check, that compares your representations against a pickled version of mine.
# Note that I use the DistilGPT-2 LM here. For the LSTM I used 0-valued initial states.
def assert_sen_reps(model, tokenizer, lstm, vocab):
    with open('distilgpt2_emb1.pickle', 'rb') as f:
        distilgpt2_emb1 = pickle.load(f)

    with open('lstm_emb1.pickle', 'rb') as f:
        lstm_emb1 = pickle.load(f)
    
    corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')[:1]

    own_distilgpt2_emb1 = fetch_sen_reps(corpus, transformer, tokenizer)[0]
    own_lstm_emb1 = fetch_sen_reps(corpus, lstm, vocab)[0]
    
    assert distilgpt2_emb1.shape == own_distilgpt2_emb1.shape
    assert lstm_emb1.shape == own_lstm_emb1.shape
    
    assert torch.allclose(distilgpt2_emb1, own_distilgpt2_emb1,atol=1e-05), "GPT2 embeddings don't match!"
    assert torch.allclose(lstm_emb1, own_lstm_emb1,atol=1e-05), "LSTM embeddings don't match!"

    print("All is well!")
    debug = False 
    
    if debug:
        for i in range(len(lstm_emb1)):
            are_all_close = torch.allclose(lstm_emb1[i], own_lstm_emb1[i],atol=1e-05)
            print(i, are_all_close, torch.max(lstm_emb1[i] - own_lstm_emb1[i]))
            if not are_all_close:
                print("LSTM embs don't match as position " + str(i) + " " + corpus[0][i]['form'])

        for i in range(len(lstm_emb1)):
            are_all_close = torch.allclose(distilgpt2_emb1[i], own_distilgpt2_emb1[i] ,atol=1e-05)
            print(i, are_all_close, torch.max(distilgpt2_emb1[i] - own_distilgpt2_emb1[i])) 
            if not are_all_close:
                print("DistilGPT-2 embs don't match at position " + str(i) + " " + corpus[0][i]['form'])
assert_sen_reps(transformer, tokenizer, lstm, vocab)

Doing LSTM: False
Doing LSTM: True
All is well!


In [13]:
words_train, words_dev, words_test = [], [], []

In [16]:
import os
import pickle
def create_data(filename: str, lm, w2i, pos_vocab=None):
        global pos_w2i
        ud_parses = parse_corpus(filename)
        print(len(ud_parses))
        sen_reps, pos_tags, global_words = fetch_sen_reps(ud_parses, lm, w2i)
        print(sen_reps.shape)
        pos_vocab = pos_w2i

        return sen_reps, pos_tags, pos_vocab, global_words

want_to_run_big_stuff = True

if want_to_run_big_stuff:
    global words_train, words_dev, words_test
    if os.path.exists("train.pickle"):
        with open("train.pickle", "rb") as f:
            l = pickle.load(f)
            train_x = l['x']
            train_y = l['y']

        with open("dev.pickle", "rb") as f:
            l = pickle.load(f)
            dev_x, dev_y = l['x'], l['y']

        with open("test.pickle", "rb") as f:
            l = pickle.load(f)
            test_x, test_y = l['x'], l['y']
            
        # Load words
        with open(os.path.join("words","words_train.pickle"),"rb") as fp: words_train = pickle.load(fp)
        with open(os.path.join("words","words_dev.pickle"),"rb") as fp: words_dev = pickle.load(fp)
        with open(os.path.join("words","words_test.pickle"),"rb") as fp: words_test = pickle.load(fp)
            
    else:
        lm = transformer  # or `lstm`
        current_w2i = tokenizer  # or `vocab`
        use_sample = False

        train_x, train_y, train_vocab, train_words = create_data(
            os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-train.conllu'),
            lm, 
            current_w2i
        )
        dev_x, dev_y, _, dev_words = create_data(
            os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-dev.conllu'),
            lm, 
            current_w2i,
            pos_vocab=train_vocab
        )
        test_x, test_y,_, test_words = create_data(
            os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-test.conllu'),
            lm,
            current_w2i,
            pos_vocab=train_vocab
        )
        
        words_train, words_dev, words_test = train_words, dev_words, test_words
        
        with open("train.pickle", "wb") as f:
            pickle.dump({"x":train_x, "y":train_y}, f)

        with open("dev.pickle", "wb") as f:
            pickle.dump({"x":dev_x, "y":dev_y}, f)

        with open("test.pickle", "wb") as f:
            pickle.dump({"x":test_x, "y":test_y}, f)
            
        # Check if we have a words directory
        if not os.path.exists("words"):
            os.makedirs("words") 
            
        # Save the words
        with open(os.path.join("words","words_train.pickle"),"wb") as fp: pickle.dump(words_train, fp)
        with open(os.path.join("words","words_dev.pickle"),"wb") as fp: pickle.dump(words_dev, fp)
        with open(os.path.join("words","words_test.pickle"),"wb") as fp: pickle.dump(words_test, fp)
        
else:
    lm = transformer  # or `lstm`
    current_w2i = tokenizer  # or `vocab`
    use_sample = True

    train_x, train_y, train_vocab = create_data(
        os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-train.conllu'),
        lm, 
        current_w2i
    )
    dev_x, dev_y, _ = create_data(
        os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-dev.conllu'),
        lm, 
        current_w2i,
        pos_vocab=train_vocab
    )
    # For some reason the sample test data is bad
    use_sample = False
    test_x, test_y, _ = create_data(
        os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-test.conllu'),
        lm,
        current_w2i,
        pos_vocab=train_vocab
    )

12543
Doing LSTM: False
torch.Size([204585, 768])
2002
Doing LSTM: False
torch.Size([25148, 768])
2077
Doing LSTM: False
torch.Size([25096, 768])


In [17]:
#TODO: MAKE THIS SHIZZ WORK BETTER BECAUSE SPAGHETTI^2
flatten_train = [word for sublist in words_train for word in sublist]
flatten_dev   = [word for sublist in words_dev for word in sublist]
flatten_test  = [word for sublist in words_test for word in sublist]

In [18]:
print(len(flatten_train))

204585


# Diagnostic Classification

We now have our models, our data, _and_ our representations all set! Hurray, well done. We can finally move onto the cool stuff, i.e. training the diagnostic classifiers (DCs).

DCs are simple in their complexity on purpose. To read more about why this is the case you could already have a look at the "Designing and Interpreting Probes with Control Tasks" by Hewitt and Liang (esp. Sec. 3.2).

A simple linear classifier will suffice for now, don't bother with adding fancy non-linearities to it.

I am personally a fan of the `skorch` library, that provides `sklearn`-like functionalities for training `torch` models, but you are free to train your dc using whatever method you prefer.

As this is an Artificial Intelligence master and you have all done ML1 + DL, I expect you to use your train/dev/test splits correctly ;-)

In [19]:
import torch.utils.data as data

class TinyDataset(data.Dataset):
    def __init__(self, xs, ys):
        self.xs = xs
        self.ys = ys
        
    def __len__(self):
        return len(self.xs)

    def __getitem__(self, index):
        return self.xs[index], self.ys[index]
    
train_loader = data.DataLoader(TinyDataset(train_x, train_y), batch_size=16)
dev_loader = data.DataLoader(TinyDataset(dev_x, dev_y), batch_size=16)
test_loader = data.DataLoader(TinyDataset(test_x, test_y), batch_size=16)

In [20]:
def find_distribution(loader):
    result = defaultdict(lambda:0)
    for _, y in loader:
        result[y.item()] += 1
    return(result)

dist = find_distribution(data.DataLoader(TinyDataset(train_x, train_y), batch_size=1))
print(dist)
print("This could possibly be useful for POS control task!")

defaultdict(<function find_distribution.<locals>.<lambda> at 0x7f8fddc1a2f0>, {0: 12944, 1: 23676, 2: 12482, 3: 34740, 4: 23006, 5: 16284, 6: 17625, 7: 12396, 8: 18579, 9: 5567, 10: 3850, 11: 3996, 12: 10559, 13: 6703, 14: 847, 15: 688, 16: 643})
This could possibly be useful for POS control task!


In [21]:
# DIAGNOSTIC CLASSIFIER
import torch.nn as nn
class POSProbe(nn.Module):
    def __init__(self, repr_size, pos_size):
        super().__init__()
        self.linear = nn.Linear(repr_size, pos_size)
        
    def forward(self, x):
        return self.linear(x)
    
def eval_given_dataloader(loader, model):
    model.eval()
    correct = 0.0
    total = 0.0
    for x,y in loader:
        x = x.to(device)
        y = y.to(device)
        outputs = model(x)
        preds = torch.argmax(outputs,dim=1)
        c = torch.sum(torch.eq(preds, y))
        correct += c.item()
        total += y.shape[0]
    return correct/total
    
def train(my_model, train_loader, dev_loader, epoch_amount = 10):
    ce = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(my_model.parameters())
    for i in range(epoch_amount):
        my_model.train()
        epoch_correct = 0.0
        epoch_total = 0.0
        for x,y in train_loader:
            
            x = x.to(device)
            y = y.to(device)
            outputs = my_model(x)
            preds = torch.argmax(outputs,dim=1)
            correct = torch.sum(torch.eq(preds, y))
            accuracy = correct.item()/y.shape[0]
            loss = ce(outputs, y)

            optim.zero_grad()
            loss.backward()
            optim.step()
            
            epoch_correct += correct.item()
            epoch_total += y.shape[0]
        print("Epoch",i,"accuracy", epoch_correct/epoch_total, eval_given_dataloader(dev_loader, my_model))
    
        
model = POSProbe(768, len(dist)).to(device)
train(model, train_loader, dev_loader, 10)
print("Test accuracy", eval_given_dataloader(test_loader, model))
if not want_to_run_big_stuff: print("If you didnt run big stuff test data is 2000 sentences so not fair")

Epoch 0 accuracy 0.896033433536183 0.877167170351519
Epoch 1 accuracy 0.9208886281985483 0.882853507237156
Epoch 2 accuracy 0.9262702544174792 0.8867106728169238
Epoch 3 accuracy 0.9285284844929981 0.8910847781135677
Epoch 4 accuracy 0.9299606520517144 0.8914028948624145
Epoch 5 accuracy 0.9309431287728817 0.8914028948624145
Epoch 6 accuracy 0.9318718381112985 0.8894941943693335
Epoch 7 accuracy 0.9325219346481902 0.889653252743757
Epoch 8 accuracy 0.9330156169807171 0.8913233656752028
Epoch 9 accuracy 0.9334213163232886 0.889374900588516
Test accuracy 0.8863563914568059


In [22]:
print(device)

cuda:0


# Trees

For our gold labels, we need to recover the node distances from our parse tree. For this we will use the functionality provided by `ete3`, that allows us to compute that directly. I have provided code that transforms a `TokenTree` to a `Tree` in `ete3` format.

In [23]:
# In case you want to transform your conllu tree to an nltk.Tree, for better visualisation

def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"

    return tree_str


def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree

    tree_str = rec_tokentree_to_nltk(tokentree)

    return NLTKTree.fromstring(tree_str)

In [24]:

from ete3 import Tree as EteTree


class FancyTree(EteTree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)


def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)

    return FancyTree(f"{newick_str};")

In [25]:
# Let's check if it works!
# We can read in a corpus using the code that was already provided, and convert it to an ete3 Tree.

def parse_corpus(filename):
    from conllu import parse_incr

    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')
item = corpus[0]
print(corpus[0])
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree)

TokenList<Al, -, Zaman, :, American, forces, killed, Shaikh, Abdullah, al, -, Ani, ,, the, preacher, at, the, mosque, in, the, town, of, Qaim, ,, near, the, Syrian, border, .>

   /-2
  |
  |--3
  |
  |--4
  |
  |   /6 /-5
  |  |
  |  |   /-9
  |  |  |
  |  |  |--10
  |  |  |
  |  |  |--11
  |  |-8|
  |  |  |--12
  |-7|  |
  |  |  |--13
  |  |  |
  |  |   \15/-14
-1|  |
  |  |   /-16
  |  |  |
  |  |  |--17
  |  |  |
  |   \18   /-19
  |     |  |
  |     |  |--20
  |     |  |
  |     |  |-23/-22
  |      \21
  |        |--24
  |        |
  |        |   /-25
  |        |  |
  |         \28--26
  |           |
  |            \-27
  |
   \-29


As you can see we label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances.

To create the true distances of a parse tree in our treebank, we are going to use the `.get_distance` method that is provided by `ete3`: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#working-with-branch-distances

We will store all these distances in a `torch.Tensor`.

Please fill in the gap in the following method. I recommend you to have a good look at Hewitt's blog post  about these node distances.

In [26]:
def create_gold_distances(corpus):
    all_distances = []

    for item in (corpus):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)

        sen_len = len(ete_tree.search_nodes())
        distances = torch.zeros((sen_len, sen_len))

        # Your code for computing all the distances comes here.
        # print(ete_tree)
        
        # Traverse tree in two directions and get all distances
        dists = []
        dists = [node.get_distance(node2) for node in ete_tree.traverse() for node2 in ete_tree.traverse()]            
            
        # Turn it into a tensor, view, append
        dists = torch.tensor(dists)
        distances = dists.view(sen_len, sen_len)
        all_distances.append(distances)

    return all_distances

In [27]:
all_distances = create_gold_distances(corpus)
all_distances[0]

tensor([[0., 1., 1., 1., 1., 1., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
         3., 4., 4., 4., 4., 4., 4., 5., 5., 5., 5.],
        [1., 0., 2., 2., 2., 2., 3., 3., 3., 4., 4., 4., 4., 4., 4., 4., 4., 4.,
         4., 5., 5., 5., 5., 5., 5., 6., 6., 6., 6.],
        [1., 2., 0., 2., 2., 2., 3., 3., 3., 4., 4., 4., 4., 4., 4., 4., 4., 4.,
         4., 5., 5., 5., 5., 5., 5., 6., 6., 6., 6.],
        [1., 2., 2., 0., 2., 2., 3., 3., 3., 4., 4., 4., 4., 4., 4., 4., 4., 4.,
         4., 5., 5., 5., 5., 5., 5., 6., 6., 6., 6.],
        [1., 2., 2., 2., 0., 2., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 3., 3., 3., 3., 3., 3., 4., 4., 4., 4.],
        [1., 2., 2., 2., 2., 0., 3., 3., 3., 4., 4., 4., 4., 4., 4., 4., 4., 4.,
         4., 5., 5., 5., 5., 5., 5., 6., 6., 6., 6.],
        [2., 3., 3., 3., 1., 3., 0., 2., 2., 1., 3., 3., 3., 3., 3., 3., 3., 3.,
         3., 4., 4., 4., 4., 4., 4., 5., 5., 5., 5.],
        [2., 3., 3., 3., 1., 3., 2., 0., 2., 3., 1., 1.

The next step is now to do the previous step the other way around. After all, we are mainly interested in predicting the node distances of a sentence, in order to recreate the corresponding parse tree.

Hewitt et al. reconstruct a parse tree based on a _minimum spanning tree_ (MST, https://en.wikipedia.org/wiki/Minimum_spanning_tree). Fortunately for us, we can simply import a method from `scipy` that retrieves this MST.

In [28]:
from scipy.sparse.csgraph import minimum_spanning_tree
import torch


def create_mst(distances):
    distances = torch.triu(distances).cpu().detach().numpy()
    
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    
    return mst

Let's have a look at what this looks like, by looking at a relatively short sentence in the sample corpus.

If your addition to the `create_gold_distances` method has been correct, you should be able to run the following snippet. This then shows you the original parse tree, the distances between the nodes, and the MST that is retrieved from these distances. Can you spot the edges in the MST matrix that correspond to the edges in the parse tree?

In [29]:
item = corpus[5]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree, '\n')

gold_distance = create_gold_distances(corpus[5:6])[0]
print(gold_distance, '\n')

mst = create_mst(gold_distance)
print(mst)


   /2 /-1
  |
  |--3
  |
  |--4
  |
  |   /-6
  |  |
-5|  |--7
  |-8|
  |  |   /-9
  |  |  |
  |   \12--10
  |     |
  |      \-11
  |
   \-13 

tensor([[0., 1., 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3.],
        [1., 0., 2., 2., 2., 2., 1., 3., 3., 3., 4., 4., 4.],
        [1., 2., 0., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4.],
        [1., 2., 2., 0., 2., 2., 3., 3., 3., 3., 4., 4., 4.],
        [1., 2., 2., 2., 0., 2., 3., 1., 1., 1., 2., 2., 2.],
        [1., 2., 2., 2., 2., 0., 3., 3., 3., 3., 4., 4., 4.],
        [2., 1., 3., 3., 3., 3., 0., 4., 4., 4., 5., 5., 5.],
        [2., 3., 3., 3., 1., 3., 4., 0., 2., 2., 3., 3., 3.],
        [2., 3., 3., 3., 1., 3., 4., 2., 0., 2., 3., 3., 3.],
        [2., 3., 3., 3., 1., 3., 4., 2., 2., 0., 1., 1., 1.],
        [3., 4., 4., 4., 2., 4., 5., 3., 3., 1., 0., 2., 2.],
        [3., 4., 4., 4., 2., 4., 5., 3., 3., 1., 2., 0., 2.],
        [3., 4., 4., 4., 2., 4., 5., 3., 3., 1., 2., 2., 0.]]) 

[[0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0

Now that we are able to map edge distances back to parse trees, we can create code for our quantitative evaluation. For this we will use the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

To do this, we will need to obtain all the edges from our MST matrix. Note that, since we are using undirected trees, that an edge can be expressed in 2 ways: an edge between node $i$ and node $j$ is denoted by both `mst[i,j] = 1`, or `mst[j,i] = 1`.

You will write code that computes the UUAS score for a matrix of predicted distances, and the corresponding gold distances. I recommend you to split this up into 2 methods: 1 that retrieves the edges that are present in an MST matrix, and one general method that computes the UUAS score.

In [30]:
import numpy as np

In [31]:
def edges(mst):
    edges = set()

    # Your code for retrieving the edges from the MST matrix
    locations = np.argwhere(mst == 1)

    for elem in locations:
        edges.add((elem[0], elem[1]))
        
    return edges

def calc_uuas(pred_distances, gold_distances):
    uuas = None
    
    # Get both MSTs
    pred_mst = create_mst(pred_distances)
    gold_mst = create_mst(gold_distances)
    
    # Get their edges
    pred_edges = edges(pred_mst)
    gold_edges = edges(gold_mst)

    # Calculate uuas
    uuas = np.sum([pred_edge in gold_edges for pred_edge in pred_edges]) / len(gold_edges)
    print(uuas)
    
    return uuas

In [32]:
mst_edges = edges(mst)
assert (0,1) in mst_edges

## Structural Probes

We now have everything in place to start doing the actual exciting stuff: training our structural probe!
    
To make life easier for you, we will simply take the `torch` code for this probe from John Hewitt's repository. This allows you to focus on the training regime from now on.

In [33]:
import torch.nn as nn
import torch


class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


I have provided a rough outline for the training regime that you can use. Note that the hyper parameters that I provide here only serve as an indication, but should be (briefly) explored by yourself.

As can be seen in Hewitt's code above, there exists functionality in the probe to deal with batched input. It is up to you to use that: a (less efficient) method can still incorporate batches by doing multiple forward passes for a batch and computing the backward pass only once for the summed losses of all these forward passes. (_I know, this is not the way to go, but in the interest of time that is allowed ;-), the purpose of the assignment is writing a good paper after all_).

In [34]:

'''
Similar to the `create_data` method of the previous notebook, I recommend you to use a method 
that initialises all the data of a corpus. Note that for your embeddings you can use the 
`fetch_sen_reps` method again. However, for the POS probe you concatenated all these representations into 
1 big tensor of shape (num_tokens_in_corpus, model_dim). 

The StructuralProbe expects its input to contain all the representations of 1 sentence, so I recommend you
to update your `fetch_sen_reps` method in a way that it is easy to retrieve all the representations that 
correspond to a single sentence.
''' 

def init_corpus(path, concat=False, cutoff=None):
    """ Initialises the data of a corpus.
    
    Parameters
    ----------
    path : str
        Path to corpus location
    concat : bool, optional
        Optional toggle to concatenate all the tensors
        returned by `fetch_sen_reps`.
    cutoff : int, optional
        Optional integer to "cutoff" the data in the corpus.
        This allows only a subset to be used, alleviating 
        memory usage.
    """
    corpus = parse_corpus(path)[:cutoff]

    embs = fetch_sen_reps(corpus, transformer, tokenizer, concat=concat)    
    gold_distances = create_gold_distances(corpus)
    
    return gold_distances, embs

In [35]:
# Check if we have corpus folder
corpus_dir = "corpus"
if not os.path.exists(corpus_dir):
    os.makedirs(corpus_dir)

# Paths to data
train_path = os.path.join("data", "en_ewt-ud-train.conllu")
dev_path   = os.path.join("data", "en_ewt-ud-dev.conllu")
test_path  = os.path.join("data", "en_ewt-ud-test.conllu")

# Paths to corpus 
pickled_train_path = os.path.join(corpus_dir, "traincorpus.pickle")
pickled_dev_path   = os.path.join(corpus_dir, "devcorpus.pickle")
pickled_test_path  = os.path.join(corpus_dir, "testcorpus.pickle")

#
if not os.path.exists(pickled_train_path):
    
    # Make corpora
    traincorpussample = init_corpus(train_path)
    devcorpussample   = init_corpus(dev_path)
    testcorpussample  = init_corpus(test_path, cutoff=None)
    
    # Save pickles
    with open(pickled_train_path,"wb") as fp: pickle.dump(traincorpussample, fp)
    with open(pickled_dev_path, "wb")  as fp: pickle.dump(devcorpussample, fp)
    with open(pickled_test_path,"wb")  as fp: pickle.dump(testcorpussample, fp)

# We have pickles, so load corpora
else:
    print("I loaded the corpora instead because they were pickled!")
    with open(pickled_train_path,"rb") as fp: traincorpussample = pickle.load(fp)
    with open(pickled_dev_path, "rb")  as fp: devcorpussample = pickle.load(fp)
    with open(pickled_test_path,"rb")  as fp: testcorpussample = pickle.load(fp)

I loaded the corpora instead because they were pickled!


In [36]:
def pad_batch(x):
    batch_size = len(x)
    dists, embs, lengths =  list(zip(*x))
    
    max_length = max(lengths)

    padded_dists = torch.zeros((batch_size, max_length, max_length)) - 1
    padded_embs = torch.zeros((batch_size, max_length, embs[0].shape[-1])) - 1
    for i, l in enumerate(lengths):
        padded_embs[i, 0:l, :] = embs[i][0:l]
        padded_dists[i, 0:l, 0:l] = dists[i][0:l]

    return padded_dists, padded_embs, torch.tensor(lengths)

class StructuralDataset(data.Dataset):
    def __init__(self, gold_distances, embs):
        self.gold_distances = gold_distances
        self.embs = embs
        
    def __len__(self):
        return len(self.embs)

    def __getitem__(self, index):
        return self.gold_distances[index], self.embs[index], len(self.gold_distances[index])
      

batch_size = 32
train_loader = data.DataLoader(StructuralDataset(*traincorpussample), batch_size=batch_size, collate_fn= pad_batch)
dev_loader = data.DataLoader(StructuralDataset(*devcorpussample), batch_size=batch_size, collate_fn= pad_batch)
test_loader = data.DataLoader(StructuralDataset(*testcorpussample), batch_size=batch_size, collate_fn= pad_batch)

In [None]:
from torch import optim
import math

# I recommend you to write a method that can evaluate the UUAS & loss score for the dev (& test) corpus.
# Feel free to alter the signature of this method.
def evaluate_probe(probe, dataloader):
    loss_function =  L1DistanceLoss()
    probe.eval()
    total_loss = 0.0
    total_uuas = 0.0
    amt = 0.0
    for distances, embs, lengths in dataloader:
        embs = embs.to(device)
        distances = distances.to(device)
        lengths = lengths.to(device)
        amt += len(distances)
        outputs = probe(embs)
        loss = loss_function(outputs, distances, lengths)[0]
        total_loss += loss.item()
        for i in range(len(distances)):
            l = lengths[i]
            preds = outputs[i,0:l, 0:l]
            gold = distances[i,0:l, 0:l]
            u = calc_uuas(preds, gold)
            
            if math.isnan(u):
                print(preds, gold)
            # This if statement is a hack so nans don't get counted
            if u >= 0: total_uuas += u
    
    return total_loss/amt, total_uuas/amt

# Feel free to alter the signature of this method.
def train_structural(probe, dataloader, dev_dataloader,test_loader, epochs=100):
    lr = 1e-5
    batch_size = 128
    
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()

    for epoch in range(epochs):
        probe.train()
        for distances, embs, lengths in dataloader:
            embs = embs.to(device)
            distances = distances.to(device)
            lengths = lengths.to(device)
            outputs = probe(embs)
            loss = loss_function(outputs, distances, lengths)[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        dev_loss, dev_uuas = evaluate_probe(probe, dev_dataloader)
        print("Epoch", epoch, "Dev loss and uuas", dev_loss, dev_uuas)
        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        #scheduler.step(dev_loss)

    test_loss, test_uuas = evaluate_probe(probe, test_loader)
    print("Test loss, uuas", test_loss, test_uuas)

emb_dim = 768
rank = 64
probe = StructuralProbe(emb_dim, rank).to(device)

train_structural(probe, train_loader, dev_loader, test_loader, epochs=100)



tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device

tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
Epoch 0 Dev loss and uuas 2.617239743441373 0.24568970808782126
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<Slice

tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')
Epoch 1 Dev loss and uuas 1.674490296995485 0.24799428605822238
tensor([[0.]], device='cuda:0', grad_fn=<SliceBackward>) tensor([[0.]], device='cuda:0')


# POS tag control task

In this section we will analyse the accuracy scores that we achieve on the POS tag control task. This allows us to determine the level of selectivity. In order to do so, we will sample values for a new control_train_y (instead of the regular train_y) which is randomly sampled from a distribution (dist) that is similar to the original train_y.

In [38]:
train_x.shape

torch.Size([204585, 768])

In [39]:
train_y.shape

torch.Size([204585])

In [40]:
import copy

#
def control_y(data_x, data_y, flattened, dist):
    """
    Function to create control data for the POS tagging task based on a distribution
    """
    control_dict = defaultdict(lambda:None)
    
    # Retrieve 'POS tags' and distribution values
    keys = np.array(list(dist.keys()))
    values = np.array(list(dist.values()))
    
    # Normalize to get probabilities and sample keys
    probs = values / np.sum(values)
    
    # Make dict and get a full new y based on dict
    control_dict = {word.lower():np.random.choice(keys,replace=True,p=probs) for word in flattened if not word.lower() in control_dict}
    new_y = np.array([control_dict[word.lower()] for word in flattened])
    
    return new_y, control_dict

#
def dev_test_y(control_dict_train, flattened, dist):
    """
    Create test and dev control sets based on word mappings from train dict
    """
    # Retrieve 'POS tags' and distribution values
    keys = np.array(list(dist.keys()))
    values = np.array(list(dist.values()))
    
    # Normalize to get probabilities and sample keys
    probs = values / np.sum(values)
    
    # Initialize control dict and new y
    new_y = []
    control_dict = copy.deepcopy(control_dict_train)
    
    #
    for word in flattened:
        word = word.lower()
        if word not in control_dict:
            control_dict[word] = np.random.choice(keys,replace=True,p=probs)

    new_y = np.array([control_dict[word.lower()] for word in flattened])
    
    return np.array(new_y)

In [41]:
# Retrieve the control task y's
dist = find_distribution(data.DataLoader(TinyDataset(train_x, train_y), batch_size=1))
control_y_train, control_dict_train = control_y(train_x, train_y, flatten_train, dist)

In [42]:
control_y_test = dev_test_y(control_dict_train, flatten_test, dist)
control_y_dev  = dev_test_y(control_dict_train, flatten_dev , dist)

In [43]:
# Get dataloaders with control task y
control_train_loader = data.DataLoader(TinyDataset(train_x, control_y_train), batch_size=16)
control_dev_loader = data.DataLoader(TinyDataset(dev_x, control_y_dev), batch_size=16)
control_test_loader = data.DataLoader(TinyDataset(test_x, control_y_test), batch_size=16)

In [44]:
# Train the model on control task
model = POSProbe(768, len(train_y.unique())).to(device)
train(model, control_train_loader, control_dev_loader, 10)
print("Test accuracy", eval_given_dataloader(control_test_loader, model))

Epoch 0 accuracy 0.528093457487108 0.5357086050580563
Epoch 1 accuracy 0.5578561478114231 0.5390090663273421
Epoch 2 accuracy 0.5610626390009043 0.5398043581994592
Epoch 3 accuracy 0.5624361512329838 0.5398838873866709
Epoch 4 accuracy 0.5630080406676932 0.5402815333227294
Epoch 5 accuracy 0.5632279981425813 0.540122474948306
Epoch 6 accuracy 0.5635897059901752 0.5400429457610944
Epoch 7 accuracy 0.5636239215973801 0.539844122793065
Epoch 8 accuracy 0.5637167925312218 0.53960553523143
Epoch 9 accuracy 0.563755896082313 0.5394862414506124
Test accuracy 0.5322760599298693
