# Probing Language Models for Structure

## 1. Imports <a id="imports"></a>

In [1]:
import numpy as np
import pickle
from tqdm import tqdm
import os, random
import gdown
from collections import defaultdict
from lstm.model import RNNModel
from typing import List, Dict, Tuple, Optional
from conllu import parse_incr, TokenList
from transformers import GPT2Tokenizer, GPT2LMHeadModel, RobertaTokenizer, RobertaModel, OPTModel, AutoTokenizer
from ete3 import Tree
from scipy.stats import spearmanr
from scipy.sparse.csgraph import minimum_spanning_tree

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


  from .autonotebook import tqdm as notebook_tqdm


## 2. Language models <a id="models"></a>

### Transformer
We will use the `transformers` library of Huggingface: https://github.com/huggingface/transformers

### LSTM
We will use the Gulordava LSTM from the Colorless Green RNNs paper: https://arxiv.org/pdf/1803.11138.pdf. The weigths are available at https://drive.google.com/file/d/19Lp3AM4NEPycp_IBgoHfLc_V456pmUom/view?usp=sharing. The original code is available at https://github.com/facebookresearch/colorlessgreenRNNs/blob/master/src/language_models/model.py. The code has been altered to only output the hidden states that we are interested in. For further experiments, have a look at the original code.

In [None]:
# load models and tokenizers
# LSTM
lstm_path = 'lstm/state_dict.pt'  # path to saved lstm model
if not os.path.exists(lstm_path):
    lstm_model_url = 'https://drive.google.com/u/0/uc?id=19Lp3AM4NEPycp_IBgoHfLc_V456pmUom'
    gdown.download(lstm_model_url, lstm_path, quiet=False)
lstm_model = RNNModel('LSTM', 50001, 650, 650, 2)
lstm_model.load_state_dict(torch.load(lstm_path))
# the LSTM uses a vocab dict that maps a token to an id, instead of a tokenizer
with open('lstm/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}
vocab = defaultdict(lambda: w2i["<unk>"])
vocab.update(w2i)

# distilgpt2
gpt2d_model = GPT2LMHeadModel.from_pretrained('distilgpt2')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

# gpt2-medium
gpt2m_model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
gpt2m_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

# french gpt2
gpt2fr_model = GPT2LMHeadModel.from_pretrained("antoiloui/belgpt2")
gpt2fr_tokenizer = GPT2Tokenizer.from_pretrained("antoiloui/belgpt2")

# italian gpt2
gpt2it_model = GPT2LMHeadModel.from_pretrained('LorenzoDeMattei/GePpeTto')
gpt2it_tokenizer = GPT2Tokenizer.from_pretrained(
    'LorenzoDeMattei/GePpeTto',
)

# roberta
roberta_model = RobertaModel.from_pretrained('roberta-base')
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# opt
opt_model = OPTModel.from_pretrained('facebook/opt-125m')
opt_tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')

## 3. PoS probing <a id="pos probe"></a>

In [32]:
# set global variables
lm = opt_model  # language model
language = 'en'
use_sample = False   # use a small sample of the data for faster debugging
lm_names = {lstm_model: 'lstm', gpt2d_model: 'gpt2d', roberta_model: 'roberta', gpt2m_model: 'gpt2m', gpt2fr_model: 'gpt2fr', gpt2it_model: 'gpt2it', opt_model: 'opt'}
lm_name = lm_names[lm]
tokenizers = {'lstm': vocab, 'gpt2d': gpt_tokenizer, 'roberta': roberta_tokenizer, 'gpt2m': gpt2m_tokenizer, 'gpt2fr': gpt2fr_tokenizer, 'gpt2it': gpt2it_tokenizer, 'opt': opt_tokenizer}
tokenizer = tokenizers[lm_name]
data_dir = f'data/sample/{lm_name}' if use_sample else f'data/{lm_name}'  # path to data
os.makedirs(data_dir, exist_ok=True)
model_dir = f'models/sample/{lm_name}/' if use_sample else f'models/{lm_name}/'  # path to models
os.makedirs(model_dir, exist_ok=True)

print(f'LM: {lm_name} | using sample: {use_sample} | data dir: {data_dir} | model dir: {model_dir}')

# print model sizes
print('model sizes (# params):')
for model in [lstm_model, gpt2d_model, roberta_model, gpt2m_model, gpt2fr_model, gpt2it_model, opt_model]:
    print(f'{lm_names[model]}: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M')

LM: opt | using sample: False | data dir: data/opt | model dir: models/opt/
model sizes (# params):
lstm: 71.82M
gpt2d: 81.91M
roberta: 124.65M
gpt2m: 354.82M
gpt2fr: 124.44M
gpt2it: 108.88M
opt: 125.24M


### 3.1 Generate data for PoS probe <a id="pos data"></a>
We will use a treebank corpus for our data

#### Generating Representations

In [37]:
# read data
def parse_corpus(filename: str) -> List[TokenList]:
    data_file = open(filename, encoding="utf-8")
    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

# ud_parses_sample = parse_corpus('data/sample/en_ewt-ud-train.conllu')

# fetch sentence representations
def fetch_sen_reps(ud_parses: List[TokenList], model=lm, tokenizer=tokenizer, concat=True) -> torch.Tensor:
    '''
    returns sentence representations (embeddings) for a list of sentences, by first tokenizing them and then passing them through the model
    inputs:
        ud_parses: list of sentences, each sentence is a list of tokens, each token is a dictionary (conllu format)
        model: the language model (encoder) to use for sentence representation, either an LSTM or a transformer (GPT2)
        tokenizer: either the GPT2 tokenizer or the LSTM vocab
        rep_size: the size of the sentence representations (embeddings)
    returns:
        sent_reps: a tensor of shape (num_tokens_in_corpus, representation_size), containing the sentence representations (embeddings) for all sentences in the corpus
    '''
    model.eval()    # set model to evaluation mode
    sent_reps = []
    for sent in tqdm(ud_parses):
        # LSTM
        if model == lstm_model:
            # tokenize
            sent_tokenized = torch.tensor([tokenizer[token['form']] for token in sent if token["upostag"] != "_"])
            # get sentence representation
            with torch.no_grad():
                out_rep = model(sent_tokenized.unsqueeze(0), model.init_hidden(1)).squeeze(0)
                
        # transformers
        elif model in [gpt2d_model, gpt2m_model, gpt2fr_model, gpt2it_model, roberta_model, opt_model]:
            token_ids, att_masks = [], []
            add_space = False   # whether to add a space before the token
            for token in sent:
                if token["upostag"] == "_": # skip invalid/multiword tokens
                    continue
                # tokenize
                if model == roberta_model:
                    tokenized = tokenizer.encode_plus(" " + token['form'] if add_space else token['form'], return_tensors='pt')
                else:
                    tokenized = tokenizer(" " + token['form'], return_tensors='pt') if add_space else tokenizer(token['form'], return_tensors='pt')

                token_ids.append(tokenized['input_ids'][0])
                att_masks.append(tokenized['attention_mask'][0])
                # check whether to add a space before the next token
                add_space = False if token['misc'] is not None and token['misc'].get('SpaceAfter', '') == 'No' else True
                
            # get sentence representation
            with torch.no_grad():
                if model == roberta_model or model == opt_model:
                    out = model(input_ids=torch.hstack(token_ids).unsqueeze(0), attention_mask=torch.hstack(att_masks).unsqueeze(0), output_hidden_states=True).last_hidden_state.squeeze(0)
                else:
                    out = model(input_ids=torch.hstack(token_ids), attention_mask=torch.hstack(att_masks), output_hidden_states=True).hidden_states[-1]

            # average over parts belonging to the same token
            out_rep = torch.zeros(len(token_ids), out.shape[-1])
            num_sub_tokens = 0
            for i in range(out_rep.shape[0]):
                out_rep[i] = out[i + num_sub_tokens: i + num_sub_tokens + len(token_ids[i])].mean(0)
                num_sub_tokens += len(token_ids[i]) - 1
                
        else :
            raise ValueError('model should be one of: lstm_model, gpt2d_model, gpt2m_model, gpt2fr_model, roberta_model')       
        sent_reps += out_rep if concat else [out_rep]
    
    # stack token representations of entire corpus
    if concat:
        sent_reps = torch.vstack(sent_reps)
    
    return sent_reps

# test fetch_sen_reps
def error_msg(model_name, gold_embs, embs, i2w):
    with open(f'{model_name}_tokens1.pickle', 'rb') as f:
        sen_tokens = pickle.load(f)
        
    diff = torch.abs(embs - gold_embs)
    max_diff = torch.max(diff)
    avg_diff = torch.mean(diff)
    
    print(f"{model_name} embeddings don't match!")
    print(f"Max diff.: {max_diff:.4f}\nMean diff. {avg_diff:.4f}")

    print("\nCheck if your tokenization matches with the original tokenization:")
    for idx in sen_tokens.squeeze():
        if isinstance(i2w, list):
            token = i2w[idx]
        else:
            token = i2w.convert_ids_to_tokens(idx.item())
        print(f"{idx:<6} {token}")


def assert_sen_reps(model, tokenizer, lstm, vocab):
    with open('distilgpt2_emb1.pickle', 'rb') as f:
        distilgpt2_emb1 = pickle.load(f)
        
    with open('lstm_emb1.pickle', 'rb') as f:
        lstm_emb1 = pickle.load(f)
    
    corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')[:1]
    
    own_distilgpt2_emb1 = fetch_sen_reps(corpus, model, tokenizer)
    own_lstm_emb1 = fetch_sen_reps(corpus, lstm, vocab)
    
    assert distilgpt2_emb1.shape == own_distilgpt2_emb1.shape, \
        f"Distilgpt2 shape mismatch: {distilgpt2_emb1.shape} (gold) vs. {own_distilgpt2_emb1.shape} (yours)"
    assert lstm_emb1.shape == own_lstm_emb1.shape, \
        f"LSTM shape mismatch: {lstm_emb1.shape} (gold) vs. {own_lstm_emb1.shape} (yours)"

    if not torch.allclose(distilgpt2_emb1, own_distilgpt2_emb1, rtol=1e-3, atol=1e-3):
        error_msg("distilgpt2", distilgpt2_emb1, own_distilgpt2_emb1, tokenizer)
    if not torch.allclose(lstm_emb1, own_lstm_emb1, rtol=1e-3, atol=1e-3):
        error_msg("lstm", lstm_emb1, own_lstm_emb1, list(vocab.keys()))


assert_sen_reps(gpt2d_model, gpt_tokenizer, lstm_model, vocab)

100%|██████████| 1/1 [00:00<00:00, 28.65it/s]
100%|██████████| 1/1 [00:00<00:00, 173.43it/s]


#### Extracting PoS labels
Next, we should define a function that extracts the corresponding POS labels for each activation. These labels will be transformed to a tensor containing the label index for each item.

In [9]:
# fetch POS tags
def fetch_pos_tags(ud_parses: List[TokenList], pos_vocab: Optional[Dict[str, int]] = None) -> Tuple[torch.Tensor, Dict[str, int]]:
	'''
	return the POS tags for all tokens in the corpus
	inputs:
		ud_parses: list of sentences, each sentence is a list of tokens, each token is a dictionary (conllu format)
		pos_vocab: a dictionary mapping POS tags to integers (optional)
	returns:
		pos_tags: a tensor of shape (num_tokens_in_corpus,) containing the POS tags for all tokens in the corpus
	'''
	if pos_vocab is None:
		pos_vocab = defaultdict(int)
		for sent in ud_parses:
			for token in sent:
				# add new POS tags to vocab
				if token["upostag"] not in pos_vocab and token["upostag"] != "_":
					pos_vocab[token["upostag"]] = len(pos_vocab)

	pos_tags = [torch.tensor(pos_vocab[token["upostag"]])
							 for sent in ud_parses for token in sent if token["upostag"] != "_"]
	pos_tags = torch.vstack(pos_tags).squeeze()

	return pos_tags, pos_vocab


#### Merge representations & PoS tags
We merge sentence representations (features) and PoS tags (labels) to create dataloaders for the probe. We pass the `train_vocab` to the data creation of the `dev` and `test` data is that we want to use the same label vocabulary across the different train/val/test splits.

In [None]:
%%time
# create 2 tensors for a .conllu file: 1 containing the token representations, and 1 containing the (tokenized) pos_tags
def create_data(filename: str, lm, pos_vocab=None):
    # print('parsing corpus...')
    ud_parses = parse_corpus(filename)
    # print(f'fetching sentence representations using {lm_name}...')
    sen_reps = fetch_sen_reps(ud_parses, lm)
    # print('fetching POS tags...')
    pos_tags, pos_vocab = fetch_pos_tags(ud_parses, pos_vocab=pos_vocab)    
    return sen_reps, pos_tags, pos_vocab

# create datasets and dataloaders
# define a custom PyTorch dataset
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)

# create train/val/test data
# path to the .conllu files
if language == 'en':
    train_path = 'data/sample/en_ewt-ud-train.conllu' if use_sample else 'data/en_ewt-ud-train.conllu'
    val_path = 'data/sample/en_ewt-ud-val.conllu' if use_sample else 'data/en_ewt-ud-dev.conllu'
    test_path = 'data/sample/en_ewt-ud-test.conllu' if use_sample else 'data/en_ewt-ud-test.conllu'
elif language == 'fr':
    train_path = 'data/fr_gsd-ud-dev.conllu'
    val_path = 'data/fr_gsd-ud-dev.conllu'
    test_path = 'data/fr_gsd-ud-test.conllu'
elif language == 'it':
    train_path = 'data/it_isdt-ud-train.conllu'
    val_path = 'data/it_isdt-ud-dev.conllu'
    test_path = 'data/it_isdt-ud-test.conllu'
else:
    raise ValueError(f'language {language} not supported')

try:
    train_data_pos = torch.load(f'{data_dir}/train_data_pos.pt')
    val_data_pos = torch.load(f'{data_dir}/val_data_pos.pt')
    test_data_pos = torch.load(f'{data_dir}/test_data_pos.pt')
except FileNotFoundError:
    print(f'creating train/val/test data using {lm_name} embeddings...')
    train_x_pos, train_y_pos, train_vocab_pos = create_data(train_path, lm)
    val_x_pos, val_y_pos, _ = create_data(val_path, lm, pos_vocab=train_vocab_pos)
    test_x_pos, test_y_pos, _ = create_data(test_path, lm, pos_vocab=train_vocab_pos)
    train_data_pos = MyDataset(train_x_pos, train_y_pos)
    val_data_pos = MyDataset(val_x_pos, val_y_pos)
    test_data_pos = MyDataset(test_x_pos, test_y_pos)
    torch.save(train_vocab_pos, f'{data_dir}/train_vocab_pos.pt')
    torch.save(train_data_pos, f'{data_dir}/train_data_pos.pt')
    torch.save(val_data_pos, f'{data_dir}/val_data_pos.pt')
    torch.save(test_data_pos, f'{data_dir}/test_data_pos.pt')
    
print(f'size of train data: {len(train_data_pos)} | size of val data: {len(val_data_pos)} | size of test data: {len(test_data_pos)}')

# find long & short sentences in test set
test_corpus = parse_corpus(test_path)
avg_sen_len_test = sum([len(sen) for sen in test_corpus]) / len(test_corpus)
idxs_short_sent_test = [i for i, sen in enumerate(test_corpus) if len(sen) <= avg_sen_len_test]
idxs_long_sent_test = [i for i, sen in enumerate(test_corpus) if len(sen) > avg_sen_len_test]
print(f'the test set has an avg sentence length of {avg_sen_len_test:.2f} with {len(idxs_short_sent_test)} short sentences and {len(idxs_long_sent_test)} long sentences\n')
# print(f'examples of short sentences: {test_corpus[idxs_short_sent_test[0]].metadata["text"]}\n {test_corpus[idxs_short_sent_test[1]].metadata["text"]}')
# print(f'examples of long sentences: {test_corpus[idxs_long_sent_test[0]].metadata["text"]}\n {test_corpus[idxs_long_sent_test[1]].metadata["text"]}')


### 3.2 Train & test PoS probe <a name="dc"></a>
We will train a PoS probe using simple linear model. Refer "Designing and Interpreting Probes with Control Tasks" by Hewitt and Liang (esp. Sec. 3.2).

In [13]:
# Diagnostic classifier/probe
# class to store training parameters
class TrainingParams:
    def __init__(self, lr=1e-3, batch_size=256, num_epochs=1000, patience=10):
        self.lr = lr
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.patience = patience
        
def set_seed(seed):
    # Set seed for random, numpy, PyTorch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
        
def train_pos_probe(model, train_data, val_data, params, seed=42, print_every=10):
    set_seed(seed)  # set seed for reproducibility
    # create dataloaders
    train_loader = DataLoader(train_data, batch_size=params.batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=params.batch_size, shuffle=False)
    # define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
    val_losses, val_accs = [], []

    # training/val loop
    for epoch in range(params.num_epochs):
        # train
        model.train()
        for train_x, train_y in train_loader:
            out = model(train_x)
            loss = criterion(out, train_y)
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
        # validate
        model.eval()
        val_losses_epoch, val_accs_epoch = [], []
        for val_x, val_y in val_loader:
            with torch.no_grad():
                out = model(val_x)
                loss = criterion(out, val_y)
                val_losses_epoch.append(loss.item())
                preds_val = torch.argmax(out, dim=1)
                acc = (preds_val == val_y).sum().item() / len(val_y)
                val_accs_epoch.append(acc)
                
        val_loss_epoch = np.mean(val_losses_epoch)
        val_acc_epoch = np.mean(val_accs_epoch)
        val_losses.append(val_loss_epoch)
        val_accs.append(val_acc_epoch)
        
        if epoch % print_every == 0:
            print(f'epoch: {epoch} | val loss: {val_loss_epoch:.3f} | val acc: {val_acc_epoch:.3f}')
        
        # early stopping
        if epoch >= params.patience and val_loss_epoch >= val_losses[-params.patience]:
            print(f'val loss did not improve for {params.patience} epochs, stopping training')
            break
        
    # save model
    # model_path = f'{model_dir}/linear_pos_probe.pt'
    # torch.save(model, model_path)
        
    return model, val_losses, val_accs


In [60]:
# train pos probe
pos_probe_type = 'linear' # type of PoS probe, either 'linear' or 'nonlinear'
try:
    pos_probe_model = torch.load(f'{model_dir}/{pos_probe_type}_pos_probe.pt')
except FileNotFoundError:
    params = TrainingParams()
    if pos_probe_type == 'linear':
        # single linear layer with input_dim = embedding_dim and output_dim = len(pos_vocab), no activation
        pos_probe_model = nn.Linear(train_data_pos.x.shape[1], len(train_vocab_pos))
    elif pos_probe_type == 'nonlinear':
        # two linear layers of shape (embedding_dim, hidden_dim) and (hidden_dim, len(pos_vocab)), with ReLU activation in between
        hidden_dim = 100
        pos_probe_model = nn.Sequential(
            nn.Linear(train_data_pos.x.shape[1], hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, len(train_vocab_pos)))
    else:
        raise ValueError('pos_probe_type must be linear or nonlinear')
    
    print(f'training {pos_probe_type} pos probe with {lm_name} embeddings...')
    pos_probe_model, _, _ = train_pos_probe(pos_probe_model, train_data_pos, val_data_pos, params)
    torch.save(pos_probe_model, f'{model_dir}/{pos_probe_type}_pos_probe.pt')

# test
out_test = pos_probe_model(test_data_pos.x)
preds_test_pos = torch.argmax(out_test, dim=1)
test_acc_pos = (preds_test_pos == test_data_pos.y).sum().item() / len(test_data_pos.y)
test_acc_pos_short = (preds_test_pos[idxs_short_sent_test] == test_data_pos.y[idxs_short_sent_test]).sum().item() / len(test_data_pos.y[idxs_short_sent_test])
test_acc_pos_long = (preds_test_pos[idxs_long_sent_test] == test_data_pos.y[idxs_long_sent_test]).sum().item() / len(test_data_pos.y[idxs_long_sent_test])
print(f'test accuracy of {pos_probe_type} pos probe using {lm_name} embeddings is {test_acc_pos:.3f} overall, ({test_acc_pos_short:.3f} for short sentences, {test_acc_pos_long:.3f} for long sentences)')

test accuracy of linear pos probe using opt embeddings is 0.923 overall, (0.926 for short sentences, 0.930 for long sentences)


In [43]:
# compute avg PoS accuracy per sentence in test set
def get_pos_acc_sent(preds_pos, y_pos):
    '''
    Compute average PoS accuracy per sentence in test set
    Inputs:
        ud_parses: list of lists of UD parse trees
        preds_pos: predicted PoS tags
        y_pos: true PoS tags
    Returns:
        accs_pos_sent: list of average PoS accuracies per sentence
    '''
    token_count = 0
    accs_pos_sent = []
    for sent in test_ud_parses:
        acc_pos_sent = (preds_pos[token_count:token_count+len(sent)] == y_pos[token_count:token_count+len(sent)]).sum().item() / len(sent)
        accs_pos_sent.append(acc_pos_sent)
        token_count += len(sent)
    return accs_pos_sent

test_ud_parses = parse_corpus(test_path)
test_accs_pos_sent = get_pos_acc_sent(preds_test_pos, test_y_pos)
print(f'avg PoS accuracy of 1st 5 sentences in test set using {lm_name} embeddings: {test_accs_pos_sent[:5]}')

avg PoS accuracy of 1st 5 sentences in test set using opt embeddings: [0.8571428571428571, 0.9565217391304348, 0.8888888888888888, 0.92, 0.967741935483871]


### 3.3 Control tasks for PoS probe <a name="control-tasks-pos"></a>
We will train a control task to check if the probe is actually probing the linguistic information. We will use the same model as the probe, but we will train it to predict a random label for each input. If the probe is actually probing the linguistic information, it should perform better than the control task.

In [44]:
def fetch_pos_control_labels(corpus_path: str, control_vocab=None,  len_pos_vocab: int=None) -> torch.Tensor:
	'''
	Generate control task labels for each token in the corpus.
	Inputs:
		ud_parses: list of sentences, each sentence is a list of tokens, each token is a dictionary (conllu format)
		len_pos_vocab: length of the pos_vocab dictionary
		control_vocab: a dictionary mapping tokens to control labels (optional)
	Returns:
		control_labels: a tensor of shape (num_tokens_in_corpus,) containing the control task labels for all tokens in the corpus
		control_vocab: a dictionary mapping tokens to control labels
	'''
	ud_parses = parse_corpus(corpus_path)
	if not control_vocab:
		control_vocab = defaultdict(int)
		for sent in tqdm(ud_parses):
			for token in sent:
				if token["upostag"] == "_":
					continue
				if token['form'] not in control_vocab:
					control_vocab[token['form']] = np.random.randint(len_pos_vocab)

	control_labels = [torch.tensor(control_vocab[token['form']]) for sent in ud_parses for token in sent if token["upostag"] != "_"]
	control_labels = torch.vstack(control_labels).squeeze()
	
	return control_labels, control_vocab

# create data for control task
train_y_pos_control, train_vocab_pos_control = fetch_pos_control_labels(train_path, None, len(train_vocab_pos))
val_y_pos_control, _ = fetch_pos_control_labels(val_path, train_vocab_pos_control)
test_y_pos_control, _ = fetch_pos_control_labels(test_path, train_vocab_pos_control)
train_data_pos_control = MyDataset(train_data_pos.x, train_y_pos_control)
val_data_pos_control = MyDataset(val_data_pos.x, val_y_pos_control)
test_data_pos_control = MyDataset(test_data_pos.x, test_y_pos_control)

# train control probe
try:
	pos_control_probe_model = torch.load(f'{model_dir}/pos_control_probe.pt')
except FileNotFoundError:
	params = TrainingParams()
	# single linear layer with input_dim = embedding_dim and output_dim = len(pos_vocab), no activation
	pos_control_probe_model = nn.Linear(train_data_pos.x.shape[1], len(train_vocab_pos_control))
	print(f'training control probe with {lm_name} embeddings...')
	pos_control_probe_model, _, _ = train_pos_probe(pos_control_probe_model, train_data_pos_control, val_data_pos_control, params)
	torch.save(pos_control_probe_model, f'{model_dir}/pos_control_probe.pt')

# test
out_test = pos_control_probe_model(test_data_pos.x)
preds_test_pos_control = torch.argmax(out_test, dim=1)
test_acc_pos_control = (preds_test_pos_control == test_y_pos_control).sum().item() / len(test_y_pos_control)
print(f'test accuracy of control probe using {lm_name} embeddings is {test_acc_pos_control:.3f}')

100%|██████████| 12543/12543 [00:00<00:00, 68637.62it/s]


training control probe with opt embeddings...
epoch: 0 | val loss: 1.347 | val acc: 0.632
epoch: 10 | val loss: 1.175 | val acc: 0.705
val loss did not improve for 10 epochs, stopping training
test accuracy of control probe using opt embeddings is 0.702


## 4. Structural probing <a name="structural probe"></a>

### 4.1 Trees <a name="trees"></a>

For our gold labels, we need to recover the node distances from our parse tree

In [22]:
# Helper functions to tranform trees
def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"
    return tree_str

def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree
    tree_str = rec_tokentree_to_nltk(tokentree)
    return NLTKTree.fromstring(tree_str)

class FancyTree(Tree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)
    
# transform your conllu tree to an ete3.Tree, for better visualisation
def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)
    return FancyTree(f"{newick_str};")

#### Computing gold distances, MST & UUAS scores

We label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances. Uing the gold distances, we can compute the **minimum spanning tree (MST)**. We can then compute the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

In [23]:
def create_gold_distances(corpus):
    '''Create a list of gold distances for each sentence in the corpus.'''
    all_distances = []

    for item in tqdm(corpus):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)

        sen_len = len(ete_tree.search_nodes())
        distances = torch.zeros((sen_len, sen_len))

        for node1 in ete_tree.search_nodes():
            for node2 in ete_tree.search_nodes():
                distances[int(node1.name)-1][int(node2.name)-1] = node1.get_distance(node2)

        all_distances.append(distances)

    return all_distances

def create_mst(distances):
    '''Create a minimum spanning tree from a distance matrix.'''
    distances = torch.triu(distances).detach().numpy()
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    
    return mst

# viz ete tree, gold distances, mst
# item = corpus[5]
# tokentree = item.to_tree()
# ete3_tree = tokentree_to_ete(tokentree)
# print(ete3_tree, '\n')

# gold_distance = create_gold_distances(corpus[5:6])[0]
# print(gold_distance, '\n')

# mst = create_mst(gold_distance)
# print(mst)

def get_edges(mst):
    '''Retrieve the edges of a minimum spanning tree.
    Inputs: mst: np.array of shape (n, n)
                 a minimum spanning tree of a sentence
    Outputs: edges: set of tuples
                the edges of the minimum spanning tree
            '''
    edges = np.nonzero(mst)
    edges = list(zip(edges[0], edges[1]))
    edges = set(map(lambda x: tuple(sorted(x)), edges))
    return edges


def calc_uuas(pred_distances, gold_distances):  
    '''
    Compute UUAS score for a pair of gold and predicted distances of a sentence.
    '''
    uuas_batch = []
    for i in range(len(gold_distances)):
        l = max(torch.nonzero(gold_distances[i] != -1, as_tuple=True)[0]) + 1
        pred_mst = create_mst(pred_distances[i][:l, :l])
        gold_mst = create_mst(gold_distances[i][:l, :l])
        pred_edges = get_edges(pred_mst)
        gold_edges = get_edges(gold_mst)
        uuas_sent = len(pred_edges.intersection(gold_edges)) / len(gold_edges) if len(gold_edges) > 0 else -1
        uuas_batch.append(uuas_sent)

    return uuas_batch

### 4.2 Define structural probe class & L1 loss

In [24]:
# structural probe class (from John Hewitt)
class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


### 4.3 Create data for structural probes

In [45]:
def init_corpus(path, model, concat=False, cutoff=None):
    """ Initialises the data of a corpus.
    
    Inputs:
        path : str
            Path to corpus location
        model: language model to encode sentences, either LSTM or GPT2
        concat : bool, optional
            Optional toggle to concatenate all the tensors
            returned by `fetch_sen_reps`.
        cutoff : int, optional
            Optional integer to "cutoff" the data in the corpus.
            This allows only a subset to be used, alleviating 
            memory usage.
    Returns:
        embs : torch.Tensor 
            embeddings tensor of shape (num_tokens_in_corpus, model_dim)
        gold_distances : torch.Tensor 
            gold distances tensor of shape (num_sentences_in_corpus, max_sentence_length, max_sentence_length)
    """
    print('parsing corpus...')
    corpus = parse_corpus(path)[:cutoff]
    print(f'fetching sentence representations using {lm_name} embeddings...')
    embs = fetch_sen_reps(corpus, model, tokenizer, concat=concat)    
    print('computing gold distances...')
    gold_distances = create_gold_distances(corpus)
    
    return embs, gold_distances

# create data for structural probe
try:
    train_data_str = torch.load(f'{data_dir}/train_data_str.pt')
    val_data_str = torch.load(f'{data_dir}/val_data_str.pt')
    test_data_str = torch.load(f'{data_dir}/test_data_str.pt')
except FileNotFoundError:
    print(f'creating data for structural probe using {lm_name} embeddings...')
    # print('train')
    train_x_str, train_y_str = init_corpus(train_path, lm)
    train_data_str = MyDataset(train_x_str, train_y_str)
    # print('val')
    val_x_str, val_y_str = init_corpus(val_path, lm)
    val_data_str = MyDataset(val_x_str, val_y_str)
    # print('test')
    test_x_str, test_y_str = init_corpus(test_path, lm)
    test_data_str = MyDataset(test_x_str, test_y_str)
    torch.save(train_data_str, f'{data_dir}/train_data_str.pt')
    torch.save(val_data_str, f'{data_dir}/val_data_str.pt')
    torch.save(test_data_str, f'{data_dir}/test_data_str.pt')

print(f'size of train set: {len(train_data_str)} | size of val set: {len(val_data_str)} | size of test set: {len(test_data_str)}')

creating data for structural probe...
train
parsing corpus...
fetching sentence representations using opt embeddings...


100%|██████████| 12543/12543 [04:59<00:00, 41.84it/s]


computing gold distances...


100%|██████████| 12543/12543 [01:07<00:00, 184.83it/s]


val
parsing corpus...
fetching sentence representations using opt embeddings...


100%|██████████| 2002/2002 [00:44<00:00, 44.86it/s]


computing gold distances...


100%|██████████| 2002/2002 [00:06<00:00, 294.15it/s]


test
parsing corpus...
fetching sentence representations using opt embeddings...


100%|██████████| 2077/2077 [00:45<00:00, 45.55it/s]


computing gold distances...


100%|██████████| 2077/2077 [00:06<00:00, 300.35it/s]


size of train set: 12543 | size of val set: 2002 | size of test set: 2077


### 4.4 Train & test structural probe

In [26]:
# evaluate structural probe
def evaluate_probe(model, dataloader, loss_fn):
    model.eval()
    loss = 0
    uuas = []
    with torch.no_grad():
      for x, gold_distances, length in dataloader:
          preds = model(x)
          loss += loss_fn(preds, gold_distances, length)[0]
          uuas += calc_uuas(preds, gold_distances)
    loss /= len(dataloader)
    # take mean of uuas across batches where uuas != -1
    uuas_avg = sum([x for x in uuas if x != -1])/len([x for x in uuas if x != -1])

    return loss, uuas_avg, uuas

def pad_collate_fn(batch):
    max_length = max([len(x[1]) for x in batch])
    out_labels = torch.full((len(batch), max_length, max_length), -1)
    out_lengths = torch.zeros(len(batch))
    for i, x in enumerate(batch):
      out_labels[i, :x[1].shape[0], :x[1].shape[1]] = x[1]
      out_lengths[i] = x[1].shape[0]
      if len(x[0].shape) == 1:
        batch[i] = (x[0].unsqueeze(0), x[1])
    return torch.nn.utils.rnn.pad_sequence(list(map(lambda x: x[0].detach(), batch)), batch_first = True, padding_value=-1), out_labels, out_lengths

def train_structural_probe(model, train_data, val_data, params, seed=42, print_every=10):
    # create dataloaders
    set_seed(seed)  # set seed for reproducibility
    train_loader = DataLoader(train_data, batch_size=params.batch_size, shuffle=True, collate_fn=pad_collate_fn)
    val_loader = DataLoader(val_data, batch_size=params.batch_size, shuffle=False, collate_fn=pad_collate_fn)
    optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
    criterion =  L1DistanceLoss()
    val_losses, val_uuas = [], []

    # training/val loop
    print(f'training structural probe with {lm_name} embeddings...')
    for epoch in range(params.num_epochs):
        # train
        model.train()
        for train_x, gold_distances, lengths in train_loader:
            out = model(train_x)
            loss = criterion(out, gold_distances, lengths)[0]
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
        # val
        val_loss_epoch, val_uuas_epoch, _ = evaluate_probe(model, val_loader, criterion)
        # scheduler.step(val_loss_epoch)
        val_losses.append(val_loss_epoch)
        val_uuas.append(val_uuas_epoch)
        
        if epoch % print_every == 0:
            print(f'epoch: {epoch} | val loss: {val_loss_epoch:.3f} | val uuas: {val_uuas_epoch:.3f}')
        
        # early stopping
        if epoch >= params.patience and val_loss_epoch >= val_losses[-params.patience]:
            print(f'val loss did not improve for {params.patience} epochs, stopping training')
            break
        
    return model, val_losses, val_uuas

In [61]:
%%time
# train structural probe
params = TrainingParams()
try:
    str_probe_model = torch.load(f'{model_dir}/str_probe.pt')
    print(f'loaded saved structural probe model using {lm_name} embeddings')
except FileNotFoundError:
    str_probe_model = StructuralProbe(train_data_str[0][0].shape[1], rank=64)
    str_probe_model, _, _ = train_structural_probe(str_probe_model, train_data_str, val_data_str, params)
    torch.save(str_probe_model, f'{model_dir}/str_probe.pt')

# test
test_loader_str = DataLoader(test_data_str, batch_size=params.batch_size, shuffle=False, collate_fn=pad_collate_fn)
test_loss_str, test_uuas_str_avg, test_uuas_str   = evaluate_probe(str_probe_model, test_loader_str, L1DistanceLoss())
print(f'test uuas of structural pos probe using {lm_name} embeddings: {test_uuas_str_avg:.3f}, test loss: {test_loss_str:.3f}')

# compute correlation b/w PoS accuracy and uuas of structural probe on test set
corr_pos_acc_uuas_test = spearmanr([test_accs_pos_sent[i] for i in range(len(test_accs_pos_sent)) if test_uuas_str[i] != -1], [x for x in test_uuas_str if x != -1])[0]
print(f'correlation b/w PoS accuracy and uuas of structural probe per sentence on test set using {lm_name} embeddings: {corr_pos_acc_uuas_test:.3f}')

loaded saved structural probe model using opt embeddings
test uuas of structural pos probe using opt embeddings: 0.644, test loss: 0.597
correlation b/w PoS accuracy and uuas of structural probe per sentence on test set using opt embeddings: 0.234
CPU times: user 44.6 s, sys: 0 ns, total: 44.6 s
Wall time: 3 s


### 4.5 Control tasks for structural probe
We design a control task for the structural probe by generating random distances and MSTs, and training the structural probe on them. If the structural probe is actually probing the structural information, it should perform relatively worse on the control task.

In [63]:
%%time
def create_control_distances(corpus_path):
    '''Create a list of control distances for each sentence in the corpus.'''
    corpus = parse_corpus(corpus_path)
    all_distances = []

    for item in tqdm(corpus):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)
        sen_len = len(ete_tree.search_nodes())
        # generate a sen_len x sen_len matrix of random distances between 1 and sen_len
        distances = np.random.randint(1, sen_len + 1, size=(sen_len, sen_len))
        # set the diagonal elements to zero 
        np.fill_diagonal(distances, 0)
        # convert the NumPy array to a PyTorch tensor
        distances_tensor = torch.from_numpy(distances)
        all_distances.append(distances_tensor)

    return all_distances

try:
    train_data_str_control = torch.load(f'{data_dir}/train_data_str_control.pt')
    val_data_str_control = torch.load(f'{data_dir}/val_data_str_control.pt')
    test_data_str_control = torch.load(f'{data_dir}/test_data_str_control.pt')
except FileNotFoundError:
    print('creating data for structural control probe...')
    print('train')
    train_control_distances = create_control_distances(train_path)
    print('val')
    val_control_distances = create_control_distances(val_path)
    print('test')
    test_control_distances = create_control_distances(test_path)
    train_data_str_control = MyDataset(train_data_str.x, train_control_distances)
    val_data_str_control = MyDataset(val_data_str.x, val_control_distances)
    test_data_str_control = MyDataset(test_data_str.x, test_control_distances)
    torch.save(train_data_str_control, f'{model_dir}/train_data_str_control.pt')
    torch.save(val_data_str_control, f'{model_dir}/val_data_str_control.pt')
    torch.save(test_data_str_control, f'{model_dir}/test_data_str_control.pt')

# train structural control probe
try:
    str_probe_control_model = torch.load(f'{model_dir}/str_probe_control.pt')
    print(f'loaded saved structural control probe model using {lm_name} embeddings')
except FileNotFoundError:
    str_probe_control_model = StructuralProbe(train_data_str_control[0][0].shape[1], rank=64)
    str_probe_control_model, _, _ = train_structural_probe(str_probe_control_model, train_data_str_control, val_data_str_control, params)
    torch.save(str_probe_control_model, f'{model_dir}/str_probe_control.pt')

# test
test_loader_str_control = DataLoader(test_data_str_control, batch_size=params.batch_size, shuffle=False, collate_fn=pad_collate_fn)
test_loss_str_control, test_uuas_str_control_avg, test_uuas_str_control = evaluate_probe(str_probe_control_model, test_loader_str_control, L1DistanceLoss())
print(f'test uuas of structural control probe using {lm_name} embeddings: {test_uuas_str_control_avg:.3f}, test loss: {test_loss_str_control:.3f}')

creating data for structural control probe...
train


100%|██████████| 12543/12543 [00:02<00:00, 5882.88it/s]


val


100%|██████████| 2002/2002 [00:00<00:00, 7041.38it/s]


test


100%|██████████| 2077/2077 [00:00<00:00, 7220.87it/s]


loaded saved structural control probe model using opt embeddings
test uuas of structural control probe using opt embeddings: 0.292, test loss: 3.288
CPU times: user 52.3 s, sys: 558 ms, total: 52.9 s
Wall time: 16.2 s


#### Print trees to LaTeX
Code to print dependency tree plots in LaTeX like those of Figure 2 in the Structural Probing paper. 
**N.B.**: for the latex tikz tree the first token in a sentence has index 1 (instead of 0), so take that into account with the predicted and gold edges that you pass to the method.

In [None]:
def print_tikz(predicted_edges, gold_edges, words):
    """ Turns edge sets on word (nodes) into tikz dependency LaTeX.
    Parameters
    ----------
    predicted_edges : Set[Tuple[int, int]]
        Set (or list) of edge tuples, as predicted by your probe.
    gold_edges : Set[Tuple[int, int]]
        Set (or list) of gold edge tuples, as obtained from the treebank.
    words : List[str]
        List of strings representing the tokens in the sentence.
    """

    string = """\\begin{dependency}[hide label, edge unit distance=.5ex]
    \\begin{deptext}[column sep=0.05cm]
    """

    string += (
        "\\& ".join([x.replace("$", "\$").replace("&", "+") for x in words])
        + " \\\\\n"
    )
    string += "\\end{deptext}" + "\n"
    for i_index, j_index in gold_edges:
        string += "\\depedge[-]{{{}}}{{{}}}{{{}}}\n".format(i_index, j_index, ".")
    for i_index, j_index in predicted_edges:
        string += f"\\depedge[-,edge style={{red!60!}}, edge below]{{{i_index}}}{{{j_index}}}{{.}}\n"
    string += "\\end{dependency}\n"
    print(string)