In [1]:
import torch
import numpy as np
import time

from collections import defaultdict
from typing import List
from conllu import parse_incr, TokenList
from torch import Tensor
from transformers import GPT2Model, GPT2Tokenizer
CUTOFF = 2000
from lstm.model import RNNModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transformer = GPT2Model.from_pretrained('distilgpt2', output_hidden_states=True)
print("Model ready")
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
print("Tokenizer ready")
# Note that some models don't return the hidden states by default.
# This can be configured by passing `output_hidden_states=True` to the `from_pretrained` method.

Model ready
Tokenizer ready


In [3]:
# The Gulordava LSTM model can be found here: 
# https://drive.google.com/open?id=1w47WsZcZzPyBKDn83cMNd0Hb336e-_Sy
#
# N.B: I have altered the RNNModel code to only output the hidden states that you are interested in.
# If you want to do more experiments with this model you could have a look at the original code here:
# https://github.com/facebookresearch/colorlessgreenRNNs/blob/master/src/language_models/model.py
#
model_location = 'lstm/gulordava.pt'
lstm = RNNModel('LSTM', 50001, 650, 650, 2)
lstm.load_state_dict(torch.load(model_location))


# This LSTM does not use a Tokenizer like the Transformers, but a Vocab dictionary that maps a token to an id.
with open('lstm/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}

vocab = defaultdict(lambda: w2i["<unk>"])
vocab.update(w2i)
i2w = { w2i[k]:k for k in w2i}

## Load All Data

In [4]:
from utils import create_or_load_pos_data
from controltasks import save_or_load_pos_controls 
from datasets import find_distribution, POSDataset
import torch.utils.data as data 
import time

start = time.time()
train_x, train_y, vocab, words_train = create_or_load_pos_data("train", transformer, tokenizer, cutoff=CUTOFF)
end = time.time() - start
print("Time %s" % end )
dev_x, dev_y, vocab, words_dev = create_or_load_pos_data("dev", transformer, tokenizer, vocab, cutoff=CUTOFF)
test_x, test_y, vocab, words_test = create_or_load_pos_data("test", transformer, tokenizer, vocab, cutoff=CUTOFF)
dist = find_distribution(data.DataLoader(POSDataset(train_x, train_y), batch_size=1))

flatten_train = [word for sublist in words_train for word in sublist]
flatten_dev   = [word for sublist in words_dev for word in sublist]
flatten_test  = [word for sublist in words_test for word in sublist]

ypos_train_control, ypos_dev_control, ypos_test_control = save_or_load_pos_controls(
    train_x, train_y, [flatten_train, flatten_dev, flatten_test], dist)

Time 0.09005069732666016


In [5]:
from tree_utils import create_or_load_structural_data
from controltasks import save_or_load_struct_controls

train_xy = create_or_load_structural_data("train", transformer, tokenizer, cutoff=CUTOFF)
dev_xy = create_or_load_structural_data("dev", transformer, tokenizer, cutoff=CUTOFF)
test_xy = create_or_load_structural_data("test", transformer, tokenizer, cutoff=CUTOFF)
print(len(train_xy))
struct_train_control, struct_dev_control, struct_test_control = save_or_load_struct_controls(cutoff=CUTOFF)

Fetching for 2000
Doing LSTM: False
Data created,pickling




Fetching for 2000
Doing LSTM: False
Data created,pickling
Fetching for 2000
Doing LSTM: False
Data created,pickling
2


## PoS Models

In [None]:
# DIAGNOSTIC CLASSIFIER
import torch.nn as nn
class POSProbe(nn.Module):
    def __init__(self, repr_size, pos_size):
        super().__init__()
        self.linear = nn.Linear(repr_size, pos_size)
        
    def forward(self, x):
        return self.linear(x)
    
def eval_given_dataloader(loader, model):
    model.eval()
    correct = 0.0
    total = 0.0
    for x,y in loader:
        x = x.to(device)
        y = y.to(device)
        outputs = model(x)
        preds = torch.argmax(outputs,dim=1)
        c = torch.sum(torch.eq(preds, y))
        correct += c.item()
        total += y.shape[0]
    return correct/total
    
def train(my_model, train_loader, dev_loader, epoch_amount = 10):
    ce = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(my_model.parameters())
    for i in range(epoch_amount):
        my_model.train()
        epoch_correct = 0.0
        epoch_total = 0.0
        for x,y in train_loader:
            
            x = x.to(device)
            y = y.to(device)
            outputs = my_model(x)
            preds = torch.argmax(outputs,dim=1)
            correct = torch.sum(torch.eq(preds, y))
            accuracy = correct.item()/y.shape[0]
            loss = ce(outputs, y)

            optim.zero_grad()
            loss.backward()
            optim.step()
            
            epoch_correct += correct.item()
            epoch_total += y.shape[0]
        print("Epoch",i,"accuracy", epoch_correct/epoch_total, eval_given_dataloader(dev_loader, my_model))

In [None]:
# Normal task
ntrain_loader = data.DataLoader(POSDataset(train_x, train_y), batch_size=16)
ndev_loader = data.DataLoader(POSDataset(dev_x, dev_y), batch_size=16)
ntest_loader = data.DataLoader(POSDataset(test_x, test_y), batch_size=16)

model = POSProbe(768, len(dist)).to(device)
train(model, ntrain_loader, ndev_loader, 10)
print("Test accuracy", eval_given_dataloader(ntest_loader, model))

In [None]:
# Normal task
ctrain_loader = data.DataLoader(POSDataset(train_x, ypos_train_control), batch_size=16)
cdev_loader = data.DataLoader(POSDataset(dev_x, ypos_dev_control), batch_size=16)
ctest_loader = data.DataLoader(POSDataset(test_x, ypos_test_control), batch_size=16)
print(len(ypos_train_control), len(train_x))
print(len(ypos_dev_control), len(dev_x))
print(len(ypos_test_control), len(test_x))

model = POSProbe(768, len(dist)).to(device)
train(model, ctrain_loader, cdev_loader, 10)
print("Test accuracy", eval_given_dataloader(ctest_loader, model))

## Structural

In [6]:
import torch.nn as nn
import torch


class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents

In [7]:
from torch import optim
import math
import tree_utils
import importlib
importlib.reload(tree_utils)

# I recommend you to write a method that can evaluate the UUAS & loss score for the dev (& test) corpus.
# Feel free to alter the signature of this method.
def evaluate_probe(probe, dataloader):
    loss_function =  L1DistanceLoss()
    probe.eval()
    total_loss = 0.0
    total_uuas = 0.0
    amt = 0.0
    for distances, embs, lengths in dataloader:
        embs = embs.to(device)
        distances = distances.to(device)
        lengths = lengths.to(device)
        amt += len(distances)
        outputs = probe(embs)
        loss = loss_function(outputs, distances, lengths)[0]
        total_loss += loss.item()
        for i in range(len(distances)):
            l = lengths[i]
            preds = outputs[i,0:l, 0:l]
            gold = distances[i,0:l, 0:l]
            
            u = tree_utils.calc_uuas(preds, gold)
            if math.isnan(u):
                amt -= 1
            # This if statement is a hack so nans don't get counted
            if u >= 0: total_uuas += u
    
    return total_loss/amt, total_uuas/amt

# Feel free to alter the signature of this method.
def train_structural(probe, dataloader, dev_dataloader,test_loader, epochs=100):
    lr = 1e-5
    batch_size = 128
    
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()

    for epoch in range(epochs):
        probe.train()
        for distances, embs, lengths in dataloader:
            embs = embs.to(device)
            distances = distances.to(device)
            lengths = lengths.to(device)
            outputs = probe(embs)
            loss = loss_function(outputs, distances, lengths)[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        dev_loss, dev_uuas = evaluate_probe(probe, dev_dataloader)
        print("Epoch", epoch, "Dev loss and uuas", dev_loss, dev_uuas)
        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        #scheduler.step(dev_loss)

    test_loss, test_uuas = evaluate_probe(probe, test_loader)
    print("Test loss, uuas", test_loss, test_uuas)


In [8]:

from datasets import StructuralDataset, pad_batch

batch_size = 32
train_loader = data.DataLoader(StructuralDataset(*train_xy), batch_size=batch_size, collate_fn= pad_batch, shuffle=True)
dev_loader = data.DataLoader(StructuralDataset(*dev_xy), batch_size=batch_size, collate_fn= pad_batch, shuffle=True)
test_loader = data.DataLoader(StructuralDataset(*test_xy), batch_size=batch_size, collate_fn= pad_batch, shuffle=True)

emb_dim = 768
rank = 64
probe = StructuralProbe(emb_dim, rank).to(device)
print(probe)
train_structural(probe, train_loader, dev_loader, test_loader, epochs=25)

StructuralProbe()


  uuas = np.sum([pred_edge in gold_edges for pred_edge in pred_edges]) / len(gold_edges)


Epoch 0 Dev loss and uuas 4.322278783697831 0.2554878169040452
Epoch 1 Dev loss and uuas 3.995628646047492 0.25578798663549884
Epoch 2 Dev loss and uuas 3.6987896326968546 0.25512165235793105
Epoch 3 Dev loss and uuas 3.429826969347502 0.2553972539621388
Epoch 4 Dev loss and uuas 3.168339699193051 0.2566106704089192
Epoch 5 Dev loss and uuas 2.9386144136127674 0.2564556272393384
Epoch 6 Dev loss and uuas 2.7277887806139494 0.2570196120596311
Epoch 7 Dev loss and uuas 2.530672193828382 0.2576877365284407
Epoch 8 Dev loss and uuas 2.3508224788464998 0.2576064262261719
Epoch 9 Dev loss and uuas 2.1809371888010127 0.2579926265895921
Epoch 10 Dev loss and uuas 2.0264882398906505 0.25747813217573556
Epoch 11 Dev loss and uuas 1.8880232519852487 0.25860887347670963
Epoch 12 Dev loss and uuas 1.7560892908196701 0.2589392123268599
Epoch 13 Dev loss and uuas 1.642160379510177 0.2586569617161382
Epoch 14 Dev loss and uuas 1.5221234432019686 0.2594767737507225
Epoch 15 Dev loss and uuas 1.41766385

In [9]:
from datasets import StructuralDataset, pad_batch

batch_size = 32
ctrain_loader = data.DataLoader(StructuralDataset(struct_train_control, train_xy[1]), batch_size=batch_size, collate_fn= pad_batch)
cdev_loader = data.DataLoader(StructuralDataset(struct_dev_control, dev_xy[1]), batch_size=batch_size, collate_fn= pad_batch)
ctest_loader = data.DataLoader(StructuralDataset(struct_test_control, test_xy[1]), batch_size=batch_size, collate_fn= pad_batch)

emb_dim = 768
rank = 64
probe = StructuralProbe(emb_dim, rank).to(device)
print(probe)
train_structural(probe, ctrain_loader, cdev_loader, ctest_loader, epochs=25)

StructuralProbe()
Epoch 0 Dev loss and uuas 4.107690293161492 0.26768809081198003
Epoch 1 Dev loss and uuas 3.8062079359355727 0.26781423561827383
Epoch 2 Dev loss and uuas 3.5304531659578022 0.26824722242866683
Epoch 3 Dev loss and uuas 3.277516118099815 0.26889425368342434
Epoch 4 Dev loss and uuas 3.0451397303531045 0.2692873369640357
Epoch 5 Dev loss and uuas 2.8313979700991982 0.2687838029042891
Epoch 6 Dev loss and uuas 2.634600693552118 0.2682790186142832
Epoch 7 Dev loss and uuas 2.453245415938528 0.2678841727768955
Epoch 8 Dev loss and uuas 2.285989004436292 0.2684587114568794
Epoch 9 Dev loss and uuas 2.131622334530479 0.26874505800212883
Epoch 10 Dev loss and uuas 1.9890561254400956 0.26969611653548425
Epoch 11 Dev loss and uuas 1.8573055949964021 0.2703686675804443
Epoch 12 Dev loss and uuas 1.735480341660349 0.27183069372828284
Epoch 13 Dev loss and uuas 1.6227739434493216 0.27239191889971776
Epoch 14 Dev loss and uuas 1.5184553066052888 0.2727452519103894
Epoch 15 Dev los

In [11]:
importlib.reload(tree_utils)

<module 'tree_utils' from '/home/anna/Documents/uni/nlp2/nlp2-probing-lms/tree_utils.py'>

In [31]:
from tree_utils import *
struct_train_control[1].shape
#print(train_xy[1][0].shape)
t = edges(create_mst(train_xy[0][1]))
print_tikz(t, edges(create_mst(struct_train_control[1])), ["w"]*18, 'test')