In [None]:
import torch
import numpy as np
import time

from collections import defaultdict
from typing import List
from conllu import parse_incr, TokenList
from torch import Tensor
from transformers import GPT2Model, GPT2Tokenizer
CUTOFF = None
from lstm.model import RNNModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
%load_ext autoreload
%autoreload 2

In [None]:
transformer = GPT2Model.from_pretrained('distilgpt2', output_hidden_states=True)
print("Model ready")
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
print("Tokenizer ready")
# Note that some models don't return the hidden states by default.
# This can be configured by passing `output_hidden_states=True` to the `from_pretrained` method.

In [None]:
# Load other transformers
# BART
from transformers import BartModel, BartTokenizer
BART = BartModel.from_pretrained('bart-large',output_hidden_states=True)
print("I have loaded BART!")
BART_tokenizer = BartTokenizer.from_pretrained('bart-large')
print("I have loaded the BART Tokenizer")

In [None]:
# XLNet
from transformers import XLNetModel, XLNetTokenizer
XLNet = XLNetModel.from_pretrained('xlnet-large-cased', output_hidden_states=True)
print("I have loaded XLNet!")
XLNet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
print("I have loaded the XLNet tokenizer!")

In [None]:
# T5
from transformers import T5Model, T5Tokenizer
T5 = T5Model.from_pretrained('t5-small', output_hidden_states=True)
print("I have loaded T5!")
T5_tokenizer = T5Tokenizer.from_pretrained('t5-small')
print("I have loaded the T5 tokenizer")

In [None]:
# The Gulordava LSTM model can be found here: 
# https://drive.google.com/open?id=1w47WsZcZzPyBKDn83cMNd0Hb336e-_Sy
#
# N.B: I have altered the RNNModel code to only output the hidden states that you are interested in.
# If you want to do more experiments with this model you could have a look at the original code here:
# https://github.com/facebookresearch/colorlessgreenRNNs/blob/master/src/language_models/model.py
#
model_location = 'lstm/gulordava.pt'
lstm = RNNModel('LSTM', 50001, 650, 650, 2)
lstm.load_state_dict(torch.load(model_location))


# This LSTM does not use a Tokenizer like the Transformers, but a Vocab dictionary that maps a token to an id.
with open('lstm/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}

vocabLSTM = defaultdict(lambda: w2i["<unk>"])
vocabLSTM.update(w2i)
i2w = { w2i[k]:k for k in w2i}

## Load All Data

In [None]:
from utils import create_or_load_pos_data
from controltasks import save_or_load_pos_controls 
from datasets import find_distribution, POSDataset
import torch.utils.data as data 
import time

"""
Change this piece of malevolent code that says CUTOFF = 100 or CUTOFF = 20
"""
CUTOFF = None
"""
"""
def get_transformer_reps(transformer, tokenizer, cutoff=CUTOFF, extra_transformer=None):
    """
    Ugly function that either builds representations for a transformer or retrieves pickled ones
    """
    
    train_x, train_y, vocab, words_train = create_or_load_pos_data("train", 
                                                                   transformer, 
                                                                   tokenizer, 
                                                                   cutoff=CUTOFF,
                                                                   extra_transformer = extra_transformer)
    dev_x, dev_y, vocab, words_dev = create_or_load_pos_data("dev", 
                                                             transformer, 
                                                             tokenizer, 
                                                             vocab, 
                                                             cutoff=CUTOFF,
                                                             extra_transformer = extra_transformer)
    test_x, test_y, vocab, words_test = create_or_load_pos_data("test", 
                                                                transformer, 
                                                                tokenizer, 
                                                                vocab, 
                                                                cutoff=CUTOFF,
                                                                extra_transformer = extra_transformer)

    # Flatten the wordlists so we have one big list of words for all set types
    flatten_train = [word for sublist in words_train for word in sublist]
    flatten_dev   = [word for sublist in words_dev for word in sublist]
    flatten_test  = [word for sublist in words_test for word in sublist]
    
    # Generate a distribution over tags, useful for control task
    dist = find_distribution(data.DataLoader(POSDataset(train_x, train_y), batch_size=1))
    ypos_train_control, ypos_dev_control, ypos_test_control = save_or_load_pos_controls(
        train_x, train_y, [flatten_train, flatten_dev, flatten_test], dist)
    
    #
    return train_x, train_y, vocab, words_train, \
           dev_x, dev_y, vocab, words_dev, \
           test_x, test_y, vocab, words_test, \
           flatten_train, flatten_dev, flatten_test, \
           dist, ypos_train_control, ypos_dev_control, ypos_test_control

In [None]:
# Representations for our main transformer, GPT-2
train_x, train_y, vocab, words_train, \
           dev_x, dev_y, vocab, words_dev, \
           test_x, test_y, vocab, words_test, \
           flatten_train, flatten_dev, flatten_test, \
           dist, ypos_train_control, ypos_dev_control, ypos_test_control = get_transformer_reps(transformer,
                                                                                               tokenizer)

### Representations for other transformers

In [None]:
#from utils import create_or_load_pos_data
train_x_bart, train_y_bart, vocab, words_train_bart, \
           dev_x_bart, dev_y_bart, vocab_bart, words_dev_bart, \
           test_x_bart, test_y_bart, vocab, words_test_bart, \
           flatten_train_bart, flatten_dev_bart, flatten_test_bart, \
           dist_bart, ypos_train_control_bart, ypos_dev_control_bart, ypos_test_control_bart = \
                    get_transformer_reps(BART, BART_tokenizer, extra_transformer='BART')

In [None]:
# For XLNet
train_x_XLNet, train_y_XLNet, vocab, words_train_XLNet, \
           dev_x_XLNet, dev_y_XLNet, vocab_XLNet, words_dev_XLNet, \
           test_x_XLNet, test_y_XLNet, vocab, words_test_XLNet, \
           flatten_train_XLNet, flatten_dev_XLNet, flatten_test_XLNet, \
           dist_XLNet, ypos_train_control_XLNet, ypos_dev_control_XLNet, ypos_test_control_XLNet = \
                    get_transformer_reps(XLNet, XLNet_tokenizer, extra_transformer='XLNet')

In [None]:
a = torch.zeros(3,4)+1
a = a.long()

In [None]:
out = XLNet(a)

In [None]:
out[0].shape

In [None]:
b = torch.zeros(2) + 1
b = b.unsqueeze(0)
b = b.long()

In [None]:
out_b = XLNet(b)

In [None]:
out_b[0].shape

In [None]:
# For T5
train_x_T5, train_y_T5, vocab, words_train_T5, \
           dev_x_T5, dev_y_T5, vocab_T5, words_dev_T5, \
           test_x_T5, test_y_T5, vocab, words_test_T5, \
           flatten_train_T5, flatten_dev_T5, flatten_test_T5, \
           dist_T5, ypos_train_control_T5, ypos_dev_control_T5, ypos_test_control_T5 = \
                    get_transformer_reps(T5, T5_tokenizer, extra_transformer='T5')

In [None]:
train_xL, train_yL, vocab, _ = create_or_load_pos_data("train", lstm, vocabLSTM, cutoff=CUTOFF)
dev_xL, dev_yL, vocab, _ = create_or_load_pos_data("dev", lstm, vocabLSTM, vocab, cutoff=CUTOFF)
test_xL, test_yL, vocab, _ = create_or_load_pos_data("test", lstm, vocabLSTM, vocab, cutoff=CUTOFF)


In [None]:
from tree_utils import create_or_load_structural_data
from controltasks import save_or_load_struct_controls

train_xy = create_or_load_structural_data("train", transformer, tokenizer, cutoff=CUTOFF)
dev_xy = create_or_load_structural_data("dev", transformer, tokenizer, cutoff=CUTOFF)
test_xy = create_or_load_structural_data("test", transformer, tokenizer, cutoff=CUTOFF)
print(len(train_xy))
struct_train_control, struct_dev_control, struct_test_control = save_or_load_struct_controls(cutoff=CUTOFF)

In [None]:
from tree_utils import create_or_load_structural_data
from controltasks import save_or_load_struct_controls

train_xyLSTM = create_or_load_structural_data("train", lstm, vocabLSTM, cutoff=CUTOFF)
dev_xyLSTM = create_or_load_structural_data("dev",     lstm, vocabLSTM, cutoff=CUTOFF)
test_xyLSTM = create_or_load_structural_data("test",   lstm, vocabLSTM, cutoff=CUTOFF)

## PoS Models

In [None]:
# DIAGNOSTIC CLASSIFIER
import torch.nn as nn
import copy
class POSProbe(nn.Module):
    def __init__(self, repr_size, pos_size, hidden_size = 0, dropout=0):
        super().__init__()
        if hidden_size == 0:
            self.linear = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(repr_size, pos_size))
        else:
            self.linear = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(repr_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(p=dropout),
                nn.Linear(hidden_size, pos_size)
            )
        
    def forward(self, x):
        return self.linear(x)
    
def eval_given_dataloader(loader, model):
    model.eval()
    correct = 0.0
    total = 0.0
    for x,y in loader:
        x = x.to(device)
        y = y.to(device)
        outputs = model(x)
        preds = torch.argmax(outputs,dim=1)
        c = torch.sum(torch.eq(preds, y))
        correct += c.item()
        total += y.shape[0]
    return correct/total
    
def train(my_model, train_loader, dev_loader, epoch_amount = 10, warmup_steps = 5, p=False):
    """
    Given a model, train_loader and dev_loader
    
    Returns state_dict for the best epoch
    """
    ce = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(my_model.parameters(), lr=1e-3)
    patience = 3
    best_model = None
    prev_dev_acc = 0.0
    best_dev_acc = 0.0
    best_epoch = 0
    for i in range(epoch_amount):
        my_model.train()
        epoch_correct = 0.0
        epoch_total = 0.0
        for x,y in train_loader:
            
            x = x.to(device)
            y = y.to(device)
            outputs = my_model(x)
            preds = torch.argmax(outputs,dim=1)
            correct = torch.sum(torch.eq(preds, y))
            accuracy = correct.item()/y.shape[0]
            loss = ce(outputs, y)

            optim.zero_grad()
            loss.backward()
            optim.step()
            
            epoch_correct += correct.item()
            epoch_total += y.shape[0]
            
        dev_acc = eval_given_dataloader(dev_loader, my_model)
        
        if p:
            print("Epoch",i,"accuracy", epoch_correct/epoch_total, dev_acc)        
        if dev_acc < prev_dev_acc and i > warmup_steps:
            patience -= 1
        else:
            patience = 2
        if dev_acc > best_dev_acc:
            best_dev_acc = dev_acc
            best_model = copy.deepcopy(my_model.state_dict())
            best_epoch = i
        prev_dev_acc = dev_acc
        if patience == 0:
            #print("Early stopping")
            break
    return best_model, best_epoch

In [None]:
# MLP vs LINEAR
# Dropout 0.0 0.2 0.4 0.6 0.8 
# POS CONTROLPOS
# RNN vs Transformer
# result_dict[task][model][mlp][dropout][seed] -> bestmodel: statedict, val_acc: float, test_acc:float
import pickle
import os 

if os.path.exists('true_results.pickle'):
    raise ValueError("Do not run this ... ")

def eval_model(task, model_type, hidden_size, dropout, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    if model_type == 'lstm':
        x_train = train_xL
        x_dev = dev_xL 
        x_test = test_xL
    elif model_type == 'transformer':
        x_train = train_x
        x_dev = dev_x
        x_test = test_x
    elif model_type == 'BART':
        x_train = train_x_bart
        x_dev   = dev_x_bart
        x_test  = test_x_bart
    elif model_type == 'XLNet':
        x_train = train_x_XLNet
        x_dev   = dev_x_XLNet
        x_test  = test_x_XLNet
    elif model_type == 'T5':
        pass
    
    if task == 'pos':
        y_train = train_y
        y_dev = dev_y
        y_test = test_y
    else:
        y_train = ypos_train_control
        y_dev = ypos_dev_control
        y_test = ypos_test_control
        
    train_loader = data.DataLoader(POSDataset(x_train, y_train), batch_size=16, shuffle=True)
    dev_loader = data.DataLoader(POSDataset(x_dev, y_dev), batch_size=16)
    test_loader = data.DataLoader(POSDataset(x_test, y_test), batch_size=16)

    #model = POSProbe(768 if model_type == 'transformer' else 650, len(dist), hidden_size, dropout).to(device)
    # Change model based on model type
    if model_type == 'transformer': dim = 768
    if model_type == 'lstm': dim = 650
    else: dim = 1024
    model = POSProbe(dim, len(dist), hidden_size, dropout).to(device)
    
    #
    best_state_dict, epochs = train(model, train_loader, dev_loader, 20, 4)
    model.load_state_dict(best_state_dict)
    dev_acc =  eval_given_dataloader(dev_loader, model)
    test_acc = eval_given_dataloader(test_loader, model)
    return model, dev_acc, test_acc, epochs
print(device)

# For XLNet and BART
run_other_transformers = True

# Decide what models to run
if run_other_transformers:
    model_list = ['BART', 'XLNet']
else:
    model_list = ['lstm', 'transformer']

#
result_dict_mlp = {}   
for task in ['pos', 'controlpos']:
    result_dict_mlp[task] = {}
    for model_type in model_list:
        print("Starting", model_type)
        result_dict_mlp[task][model_type]= {}
        for hidden_size in [0, 256]:
            result_dict_mlp[task][model_type][hidden_size] = {}
            for dropout in [0,0.2,0.4,0.6,0.8]:
                result_dict_mlp[task][model_type][hidden_size][dropout] = {}
                for seed in [10,20,30]:
                    result_dict_mlp[task][model_type][hidden_size][dropout][seed] = {}
                    state_dict, dev_acc, test_acc, epochs = eval_model(task, model_type, hidden_size, dropout, seed)
                    result_dict_mlp[task][model_type][hidden_size][dropout][seed]['state_dict'] = state_dict
                    result_dict_mlp[task][model_type][hidden_size][dropout][seed]['dev_acc'] = dev_acc
                    result_dict_mlp[task][model_type][hidden_size][dropout][seed]['test_acc'] = test_acc
                    result_dict_mlp[task][model_type][hidden_size][dropout][seed]['epochs'] = epochs
                    print(task,model_type,hidden_size,dropout,seed, epochs, test_acc)
                #print(result_dict)
                
if run_other_transformers:
    with open("transformer_results.pickle", "wb") as f:
        pickle.dump(result_dict_mlp, f)
else: 
    with open("results_and_models.pickle", "wb") as f:
        pickle.dump(result_dict_mlp, f)
print("All results are safe. You can sleep peacefully. ")

In [None]:
import matplotlib.pyplot as plt
import pickle

if run_other_transformers:
    with open("true_results.pickle", "rb") as f:
        dd = pickle.load(f)
    
else:
    with open("true_results.pickle", "rb") as f:
        dd = pickle.load(f)

def mean_test_acc(the_dict):
    three_accs = [the_dict[z]['test_acc'] for z in the_dict]
    return np.mean(three_accs), np.std(three_accs)

fig, (ax1,ax2) = plt.subplots(1,2)

for model in ['lstm', 'transformer']:
    for hidden_size in [0, 256]:
        x_axis = [0,0.2,0.4,0.6,0.8]
        y_axis = []
        y_axis_control = []
        y_axis_select = []
        for d in x_axis :
            # Mean for POS task
            mean = mean_test_acc(dd['pos'][model][hidden_size][d])
            #if hidden_size > 0: mean = mean_test_acc(mlp_results['pos'][model][hidden_size][d])
            
            # Mean for Control task
            mean_control = mean_test_acc(dd['controlpos'][model][hidden_size][d])
            #if hidden_size > 0: mean_control = mean_test_acc(mlp_results2['controlpos'][model][hidden_size][d])
            
            # Append
            y_axis.append(mean[0])
            y_axis_select.append(mean[0] - mean_control[0])

        ax1.plot(x_axis, y_axis, '--o', label=model+str(hidden_size), )
        ax2.plot(x_axis, y_axis_select, '--o')
        ax1.legend()
    
ax1.set_title("Accuracy")
ax2.set_title("Selectivity")
ax1.set_xlabel("Dropout")
ax2.set_xlabel("Dropout")

plt.tight_layout()
plt.show()

In [None]:

raise ValueError("This was a temporary fix. ")
result_dict1 = {}
for task in ['pos', 'controlpos']:
    result_dict1[task] = {}
    for model_type in ['lstm', 'transformer']:
        result_dict1[task][model_type]= {}
        for hidden_size in [0, 256]:
            result_dict1[task][model_type][hidden_size] = {}
            for dropout in [0,0.2,0.4,0.6,0.8]:
                result_dict1[task][model_type][hidden_size][dropout] = {}
                for seed in [10,20,30]:
                    result_dict1[task][model_type][hidden_size][dropout][seed] = {}
                    
                    
                    result_dict1[task][model_type][hidden_size][dropout][seed]['dev_acc'] = \
                            result_dict[task][model_type][hidden_size][dropout][seed]['dev_acc']
                    result_dict1[task][model_type][hidden_size][dropout][seed]['test_acc'] = \
                            result_dict[task][model_type][hidden_size][dropout][seed]['test_acc']                    
                    result_dict1[task][model_type][hidden_size][dropout][seed]['epochs'] = \
                            result_dict[task][model_type][hidden_size][dropout][seed]['epochs']                    
                    
                    if hidden_size > 0:
                        if task == 'pos':
                            result_dict1[task][model_type][hidden_size][dropout][seed]['dev_acc'] = \
                                    mlp_results[task][model_type][hidden_size][dropout][seed]['dev_acc']
                            result_dict1[task][model_type][hidden_size][dropout][seed]['test_acc'] = \
                                    mlp_results[task][model_type][hidden_size][dropout][seed]['test_acc']                    
                            result_dict1[task][model_type][hidden_size][dropout][seed]['epochs'] = \
                                    mlp_results[task][model_type][hidden_size][dropout][seed]['epochs']  
                        else:
                            result_dict1[task][model_type][hidden_size][dropout][seed]['dev_acc'] = \
                                    mlp_results2[task][model_type][hidden_size][dropout][seed]['dev_acc']
                            result_dict1[task][model_type][hidden_size][dropout][seed]['test_acc'] = \
                                    mlp_results2[task][model_type][hidden_size][dropout][seed]['test_acc']                    
                            result_dict1[task][model_type][hidden_size][dropout][seed]['epochs'] = \
                                    mlp_results2[task][model_type][hidden_size][dropout][seed]['epochs']  

with open("true_results.pickle", "wb") as f:
    pickle.dump(result_dict1, f)

In [None]:
lstmtrain_loader = data.DataLoader(POSDataset(train_xL, train_yL), batch_size=16, shuffle=True)
lstmdev_loader = data.DataLoader(POSDataset(dev_xL, dev_yL), batch_size=16)
lstmtest_loader = data.DataLoader(POSDataset(test_xL, test_yL), batch_size=16)

model = POSProbe(650, len(dist)).to(device)
model.load_state_dict(train(model, lstmtrain_loader, lstmdev_loader, 30,10, p=True))
print("Dev accuracy", eval_given_dataloader(lstmdev_loader, model))
print("Test accuracy", eval_given_dataloader(lstmtest_loader, model))

In [None]:
# Normal task
ctrain_loader = data.DataLoader(POSDataset(train_x, ypos_train_control), batch_size=16)
cdev_loader = data.DataLoader(POSDataset(dev_x, ypos_dev_control), batch_size=16)
ctest_loader = data.DataLoader(POSDataset(test_x, ypos_test_control), batch_size=16)
print(train_x.shape, ypos_train_control.shape)
model = POSProbe(768, len(dist), hidden_size=256).to(device)
model.load_state_dict(train(model, ctrain_loader, cdev_loader, 20, p=True))
print("Test accuracy", eval_given_dataloader(ctest_loader, model))

## Structural

In [None]:
# Control task
from tree_utils import * 
from utils import parse_corpus 

def get_behaviour(behave_dict, token):
    if token in behave_dict:
        return behave_dict[token]
    return np.random.choice(["beginning", "ending"],p=[1/2,1/2])

def fake_gold_distances(corpus, behave_dict):
    all_distances = []
    ind = 0
    for item in corpus:
        
        n = len(item)
        modified_heads = np.zeros(n)
        words = []
        # Calculate new heads
        for word in item:
            i = word['id']
            words.append(word['form'])
            #print(i, word['form'], word['head'])
            behaviour = get_behaviour(behave_dict, word['form'])
            if behaviour == "beginning":
                modified_heads[i-1] = 1 
            elif behaviour == "ending":
                modified_heads[i-1] = n
                
        # Actually set new heads
        for i, z in enumerate(item):
            new_head = int(modified_heads[i])
            z['head'] = new_head
            
            if i == 0 :
                z['head'] = 0
            elif i == (n-1):
                z['head'] = 1
            
        tokentree = item.to_tree()
        test = tokentree_to_ete(tokentree)
        dists = torch.zeros(n,n)
        for node1 in test.traverse():
            for node2 in test.traverse():
                no1 = int(node1.name) - 1
                no2 = int(node2.name) - 1
                dists[no1,no2] = node1.get_distance(node2)
        # Turn it into a tensor, view, append
        #dists = dists.view(n,n)
        mst = create_mst(dists)
        ed = edges(mst)
        #print_tikz([],ed, words, "number" + str(ind))
        all_distances.append(dists)
    return all_distances, behave_dict

corp = parse_corpus(os.path.join('data','sample', 'en_ewt-ud-'+'train'+'.conllu'))
fake = fake_gold_distances(corp, {})

#print([z.shape for z in fake[0]])

In [None]:
import torch.nn as nn
import torch


class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, dropout = 0):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        #self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        batch = self.dropout(batch)
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents

In [None]:
from torch import optim
import math
import tree_utils
import importlib
importlib.reload(tree_utils)
import copy


# I recommend you to write a method that can evaluate the UUAS & loss score for the dev (& test) corpus.
# Feel free to alter the signature of this method.
def evaluate_probe(probe, dataloader):
    loss_function =  L1DistanceLoss()
    probe.eval()
    total_loss = 0.0
    total_uuas = 0.0
    amt = 0.0
    for last ,(distances, embs, lengths) in enumerate(dataloader):
        embs = embs.to(device)
        distances = distances.to(device)
        lengths = lengths.to(device)
        amt += len(distances)
        outputs = probe(embs)
        loss = loss_function(outputs, distances, lengths)[0]
        total_loss += loss.item()
        for i in range(len(distances)):
            l = lengths[i]
            preds = outputs[i,0:l, 0:l]
            gold = distances[i,0:l, 0:l]
            
            u = tree_utils.calc_uuas(preds, gold)
            if math.isnan(u):
                amt -= 1
            # This if statement is a hack so nans don't get counted
            if u >= 0: total_uuas += u
            
    
    return total_loss/amt, total_uuas/amt

# Feel free to alter the signature of this method.
def train_structural(probe, dataloader, dev_dataloader, epochs=100, warmup_steps = 10,p =False):
    lr = 1e-3
    
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()
    prev_dev_uuas = 0.0
    patience = 3
    best_epoch = 0
    best_model = None
    best_dev_uuas = 0.0
    for epoch in range(epochs):
        probe.train()
        for distances, embs, lengths in dataloader:
            embs = embs.to(device)
            distances = distances.to(device)
            lengths = lengths.to(device)
            outputs = probe(embs)
            loss = loss_function(outputs, distances, lengths)[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        dev_loss, dev_uuas = evaluate_probe(probe, dev_dataloader)
        
        if p:
            print("Epoch", epoch, "Dev loss and uuas", dev_loss, dev_uuas)
        if dev_uuas < prev_dev_uuas and epoch > warmup_steps:
            patience -= 1
        else:
            patience = 3
        if dev_uuas > best_dev_uuas :
            best_dev_uuas  = dev_uuas 
            best_model = copy.deepcopy(probe.state_dict())
            best_epoch = epoch
        prev_dev_uuas  = dev_uuas 
        if patience == 0:
            #print("Early stopping")
            break
        
        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        #scheduler.step(dev_loss)
    return best_model, best_epoch
    #test_loss, test_uuas = evaluate_probe(probe, test_loader)
    #print("Test loss, uuas", test_loss, test_uuas)


In [None]:
# MLP vs LINEAR
# Dropout 0.0 0.2 0.4 0.6 0.8 
# POS CONTROLPOS
# RNN vs Transformer
# result_dict[task][model][mlp][dropout][seed] -> bestmodel: statedict, val_acc: float, test_acc:float
import pickle
import os 
import time 

if os.path.exists('struct_results_and_models.pickle'):
    raise ValueError("Do not run this ... ")

def eval_model(task, model_type, rank, dropout, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    if model_type == 'lstm':
        x_train = train_xyLSTM[1]
        x_dev = dev_xyLSTM[1]
        x_test = test_xyLSTM[1]
    else:
        x_train = train_xy[1]
        x_dev = dev_xy[1]
        x_test = test_xy[1]
    
    if task == 'dep':
        y_train = train_xy[0]
        y_dev = dev_xy[0]
        y_test = test_xy[0]
    else:
        y_train = struct_train_control
        y_dev = struct_dev_control
        y_test = struct_test_control
        
    train_loader = data.DataLoader(StructuralDataset(y_train, x_train), batch_size=32,collate_fn= pad_batch, shuffle=True)
    dev_loader = data.DataLoader(StructuralDataset(y_dev, x_dev), collate_fn= pad_batch,batch_size=32)
    test_loader = data.DataLoader(StructuralDataset(y_test, x_test),collate_fn= pad_batch, batch_size=32)

    model = StructuralProbe(768 if model_type == 'transformer' else 650,
                            rank =rank, dropout=dropout).to(device)
    best_state_dict, epochs = train_structural(model, train_loader, dev_loader, epochs=30, warmup_steps=6)
    model.load_state_dict(best_state_dict)
    dev_acc =  evaluate_probe(model, dev_loader)
    test_acc = evaluate_probe(model, test_loader)
    return model, dev_acc, test_acc, epochs


print(device)
result_dict = {}  

for task in ['dep', 'controldep']:
    result_dict[task] = {}
    for model_type in ['lstm', 'transformer']:
        print("Starting", model_type)
        result_dict[task][model_type]= {}
        for rank in [16,64,128]:
            result_dict[task][model_type][rank] = {}
            for dropout in [0,0.2,0.4,0.6,0.8]:
                result_dict[task][model_type][rank][dropout] = {}
                for seed in [10,20,30]:
                    starttime = time.time()
                    result_dict[task][model_type][rank][dropout][seed] = {}
                    state_dict, dev_acc, test_acc, epochs = eval_model(task, model_type, rank, dropout, seed)
                    result_dict[task][model_type][rank][dropout][seed]['state_dict'] = state_dict
                    result_dict[task][model_type][rank][dropout][seed]['dev_acc'] = dev_acc[1]
                    result_dict[task][model_type][rank][dropout][seed]['test_acc'] = test_acc[1]
                    result_dict[task][model_type][rank][dropout][seed]['epochs'] = epochs
                    print(task,model_type,rank,dropout,seed, epochs, test_acc)
                    stoptime = time.time() -starttime
                    print('Time elapsed %s' % stoptime)
                #print(result_dict)       
with open("structresults_and_models.pickle", "wb") as f:
    pickle.dump(result_dict, f)
print("All results are safe. You can sleep peacefully. ")

In [None]:

from datasets import StructuralDataset, pad_batch

batch_size = 32
train_loader = data.DataLoader(StructuralDataset(*train_xy), batch_size=batch_size, collate_fn= pad_batch, shuffle=True)
dev_loader = data.DataLoader(StructuralDataset(*dev_xy), batch_size=batch_size, collate_fn= pad_batch, shuffle=True)
test_loader = data.DataLoader(StructuralDataset(*test_xy), batch_size=batch_size, collate_fn= pad_batch, shuffle=True)

emb_dim = 768
rank = 64
probe = StructuralProbe(emb_dim, rank, dropout=0.2 ).to(device)

best_state_dict, epochs = train_structural(probe, train_loader, dev_loader,  epochs=20, warmup_steps=6, p=True)
probe.load_state_dict(best_state_dict)
dev_acc =  evaluate_probe(probe, dev_loader)
test_acc = evaluate_probe(probe, test_loader)
print(dev_acc, test_acc)

In [None]:
print(dev_acc,test_acc)

In [None]:
from datasets import StructuralDataset, pad_batch

batch_size = 32
ctrain_loader = data.DataLoader(StructuralDataset(struct_train_control, train_xy[1]), batch_size=batch_size, collate_fn= pad_batch)
cdev_loader = data.DataLoader(StructuralDataset(struct_dev_control, dev_xy[1]), batch_size=batch_size, collate_fn= pad_batch)
ctest_loader = data.DataLoader(StructuralDataset(struct_test_control, test_xy[1]), batch_size=batch_size, collate_fn= pad_batch)

emb_dim = 768
rank = 64
probe = StructuralProbe(emb_dim, rank).to(device)
print(probe)
train_structural(probe, ctrain_loader, cdev_loader, ctest_loader, epochs=40)

In [None]:
import matplotlib.pyplot as plt
import pickle

with open("structresults_and_models.pickle", "rb") as f:
    dd = pickle.load(f)
print(dd)
def mean_test_acc(the_dict):
    three_accs = [the_dict[z]['test_acc'] for z in the_dict]
    return np.mean(three_accs), np.std(three_accs)

fig, (ax1,ax2) = plt.subplots(1,2)

for model in ['lstm', 'transformer']:
    for hidden_size in [16,64,128]:
        x_axis = [0,0.2,0.4,0.6,0.8]
        y_axis = []
        y_axis_control = []
        y_axis_select = []
        for d in x_axis :
            # Mean for POS task
            mean = mean_test_acc(dd['dep'][model][hidden_size][d])
            #if hidden_size > 0: mean = mean_test_acc(mlp_results['pos'][model][hidden_size][d])
            
            # Mean for Control task
            mean_control = mean_test_acc(dd['controldep'][model][hidden_size][d])
            #if hidden_size > 0: mean_control = mean_test_acc(mlp_results2['controlpos'][model][hidden_size][d])
            
            # Append
            y_axis.append(mean[0])
            y_axis_select.append(mean[0] - mean_control[0])

        ax1.plot(x_axis, y_axis, '--o', label=model+str(hidden_size), )
        ax2.plot(x_axis, y_axis_select, '--o')
        ax1.legend()
    
ax1.set_title("Accuracy")
ax2.set_title("Selectivity")
ax1.set_xlabel("Dropout")
ax2.set_xlabel("Dropout")

plt.tight_layout()
plt.show()