# Requirements
Requires an Nvidia GPU to run

Create a new anaconda environment and run the following commands to install the required libraries 
```
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge
conda install gensim
pip install torchdyn
pip install git+https://github.com/google-research/torchsde.git
```

# Citations
- Marcus, Mitchell P., Marcinkiewicz, Mary Ann & Santorini, Beatrice (1993). Building a Large Annotated Corpus of English: The Penn Treebank

```
@article{poli2020torchdyn,
  title={TorchDyn: A Neural Differential Equations Library},
  author={Poli, Michael and Massaroli, Stefano and Yamashita, Atsushi and Asama, Hajime and Park, Jinkyoo},
  journal={arXiv preprint arXiv:2009.09346},
  year={2020}
}
```

- GloVe

- GPT2 paper

- Huggingface for their implementation of transformers? Not sure if this has a paper


# Imports

In [None]:
!pip install torchdyn
!pip install torchinfo
!pip install git+https://github.com/google-research/torchsde.git

Collecting torchdyn
  Downloading https://files.pythonhosted.org/packages/31/9b/56bd9cc4cf9f726e347a418920ea249fe91caf666bbcd1b619143a9386d6/torchdyn-0.2.2.1-py3-none-any.whl
Collecting torchdiffeq>=0.0.1
  Downloading https://files.pythonhosted.org/packages/63/c2/daf5cc6c548f789d0f5222a6daecb8a76d72ad2fa96d958d46cb85f7ae3a/torchdiffeq-0.2.1-py3-none-any.whl
Collecting dgl>=0.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/c4/ce24841375cf4393787dbf9a645e271c19a03d2d9a0e5770b08ba76bcfde/dgl-0.6.1-cp37-cp37m-manylinux1_x86_64.whl (4.4MB)
[K     |████████████████████████████████| 4.4MB 28.4MB/s 
[?25hCollecting pytorch-lightning>=0.8.4
[?25l  Downloading https://files.pythonhosted.org/packages/42/80/03dcc7241722bffd5b76b796d31dbb1bbdc34a90d653de6b47c8ad9ffd73/pytorch_lightning-1.3.3-py3-none-any.whl (806kB)
[K     |████████████████████████████████| 808kB 37.0MB/s 
Collecting tensorboard!=2.5.0,>=2.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/64/21

In [None]:
import torch
import torch.utils.data
import torchtext
import numpy as np
import gensim.downloader as api
from functools import reduce
from sklearn.metrics import *
from time import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from torchdyn.models import *
from torchdyn import *
import pytorch_lightning as pl

from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

from torchinfo import summary

In [None]:
def numpy_to_tensor(array):
    return torch.from_numpy(array).to(device).float()

# LSTM Baseline
Create a baseline RNN and evaluate it's perplexity

To do
- Use LSTM as baseline
    - Examine perplexity of model on validation set
- Implement Neural ODE

## Data Processing

In [None]:
%%time
# load word embeddings
glove = api.load("glove-wiki-gigaword-300")

CPU times: user 3min, sys: 9.86 s, total: 3min 10s
Wall time: 3min 35s


In [None]:
train, valid, test = torchtext.datasets.PennTreebank(split=('train', 'valid', 'test')) # len(train) = 4.2w
# train, valid, test = torchtext.datasets.WikiText2(split=('train', 'valid', 'test')) # len(train) = 3.6w
train = list(train) # these are originally iterators, the data is so small we can just retrieve all of it at once
valid = list(valid)
test  = list(test)

In [None]:
# build the vocab
corpus = train + valid
vocab = {"<PAD>": 0}
index_vocab = {0 : "<PAD>"}
for sentence in corpus:
    for token in sentence.split(" ")[1:]:
        if token not in vocab:
            index = len(vocab)
            vocab[token] = index
            index_vocab[index] = token

# replace penn treebank end sentence token "\n" with glove's end sentence token "."
index = vocab["\n"]
vocab.pop("\n")         
vocab["."] = index
index_vocab[index] = "."

# view size
vocab_size = len(vocab)
print("Vocab size: ", vocab_size)

Vocab size:  10001


In [None]:
# pad sentences and convert words to their glove vector to get input features
# convert to 1 hot vocab and shift 1 to the left to get output labels (converting to 1 hot takes too much memory, so just store indices and convert later)
# use left padding, as we want the hidden state at the end (right) to ignore the padding
# returns word_vector_dataset, labels
def preprocess(dataset, sequence_length, wv):
    embedding_size = wv["hello"].shape[0]
    processed = np.zeros((len(dataset), sequence_length, embedding_size))
    labels = np.zeros((len(dataset), sequence_length, 1))
    
    for i in range(len(dataset)):
        tokens = dataset[i].split(" ")[1:]
        
        # get the word vectors for all of the tokens, removing out of vocabulary (OOV) tokens
        tokens_np = np.zeros((len(tokens), embedding_size))
        labels_np = np.zeros((len(tokens), 1))
        j = 0
        for word in tokens:
            if word == "\n": word = "." # replace PennTreebank end sentence token '\n' with glove end sentence token "."
            if word not in wv: continue # ignore OOV tokens
            if j < sequence_length - 1: # only add sequence_length - 1 tokens at max
                # so that there is always a 0 vector at the start so the model learns most common starting words
                tokens_np[j, :] = wv[word]
            # we can look ahead to find the next word to set as the label for the last word
            if j < sequence_length:
                labels_np[j, :] = vocab[word]
            else: break
            j += 1
            
        tokens_np = tokens_np[:j-1, :]
        labels_np = labels_np[:j, :]
        
        # add this sentence to the overall dataset, with left padding of 0 vectors
        processed[i, sequence_length - tokens_np.shape[0]:, :] = tokens_np
        labels[i, sequence_length - labels_np.shape[0]:, :] = labels_np
    return processed, labels

In [None]:
sequence_length = 20
train_X, train_y = preprocess(train, sequence_length, glove)
valid_X, valid_y = preprocess(valid, sequence_length, glove)
test_X , test_y  = preprocess(test,  sequence_length, glove)

## Model Training

In [None]:
class LSTMModel(torch.nn.Module):
    def __init__(self, vocab_size, input_size=100, layer_size=100, dropout=0):
        super().__init__()
        self.LSTM = torch.nn.LSTM(input_size, layer_size, 1, bidirectional=False)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.linear = torch.nn.Linear(layer_size, vocab_size)
        self.softmax = torch.nn.Softmax(dim=-1)
    
    def forward(self, x):
        # convert words to their vectors here
        sequence_outputs, hidden_state = self.LSTM(x)
        sequence_outputs = self.dropout(sequence_outputs)
        pred = self.linear(sequence_outputs)
        return pred
    
    # wrapper function that forward propagates, applies softmax and converts to numpy 
    def predict(self, x):
        preds = self.forward(x)
        preds = self.softmax(preds).detach().cpu().numpy()
        return preds

In [None]:
model = LSTMModel(vocab_size, input_size=300, layer_size=300, dropout=0.1)
model.to(device)
summary(model)

LSTMModel(
  (LSTM): LSTM(300, 300)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear): Linear(in_features=300, out_features=10001, bias=True)
  (softmax): Softmax(dim=-1)
)

In [None]:
# clear memory
del data
torch.cuda.empty_cache()

### Evaluation Methods

In [None]:
# Define functions to calculate perplexity for a single sentence: see the metric definition here https://web.stanford.edu/~jurafsky/slp3/3.pdf 
# We use teacher forcing (feeding the ground_truth label for sequence i to get pred for sequence i+1) to get the predictions
def perplexity(preds, ground_truth, epsilon=1e-30):
    probs = []
    for i in range(preds.shape[1]):
        probs.append(preds[0, i, int(ground_truth[i])])
    probs = np.array(probs)
    probs = np.power(1/(probs+epsilon), 1/probs.shape[0]) # normalise before taking the product, to prevent underflowing to 0
    return np.prod(probs)

# Calculate overall perplexity for a dataset
def average_perplexity(model, X, y):
    perplexities = [perplexity(model.predict(numpy_to_tensor(X[i:i+1])), y[i]) for i in range(X.shape[0])]
    return np.mean(perplexities)


# Define the functions to compute the bleu score, 
# in our particular case, reference should be multiple sentences - all sentences as label, against the candidate - predicted sentence
# here I use the 10001 vocab to represent each word (num also works in bleu)
# smoothie is will help when sentence is too short
# https://stackoverflow.com/questions/46444656/bleu-scores-could-i-use-nltk-translate-bleu-score-sentence-bleu-for-calculating
def bleu(preds, reference, smoothie):
    candidate = np.argmax(preds, axis=2)[0]
    return sentence_bleu(reference, candidate, smoothing_function=smoothie)


def average_bleu(model, X, y):
    smoothie = SmoothingFunction().method4
    bleus = []
    for i in range(X.shape[0]):
        preds = model.predict(numpy_to_tensor(X[i:i+1]))
        reference = [list(train_y[i].flatten())]
        bleus.append(bleu(preds, reference, smoothie))
    return np.mean(bleus)


### Actual Training

In [None]:
# training the model
def train_model(model, train_X, train_y, epochs=10, learn_rate=0.01, weight_decay=0.001, minibatch_size=128, print_results=True):
    # Prepare data
    X = numpy_to_tensor(train_X)
    y = numpy_to_tensor(train_y).long()[:, :, 0]
    n_samples = X.shape[0]
    
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate, weight_decay=weight_decay)

    # Ensure this runs on gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    for epoch in range(epochs):      
        model.train() # set to train flag
        start_ts = time()
        
        # shuffle the data
        new_indices = torch.randperm(n_samples)
        X = X[new_indices, :, :] 
        y = y[new_indices, :]
        
        for batch_n in range(int(np.ceil(n_samples/minibatch_size))):
            # get the minibatch
            start_index = batch_n * minibatch_size
            end_index = min(start_index + minibatch_size, n_samples)
            batch_X = X[start_index: end_index, :, :]
            batch_y = y[start_index: end_index, :]
            
            # forward + backward + optimize
            optimizer.zero_grad()
            outputs = model(batch_X) 
            outputs = torch.swapaxes(outputs, 1, 2) # cross entropy expects a tensor of (n_samples, n_outputs, sequence_length)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
        # evaluate performance on part of the data (for runtime reason we only use perplexity of validation dataset)
        if print_results:
            with torch.no_grad():
                model.eval()
                v_perplexity = average_perplexity(model, valid_X, valid_y)
                end_ts = time()
                print("Epoch {}, Minibatch loss: {:.2f}, Val Perplexity: {:.2f}, Epoch Time: {:.2f} seconds"
                .format(epoch, loss.item(), v_perplexity, end_ts - start_ts))
    
    del X
    del y
    torch.cuda.empty_cache()
    if print_results:
        print('Finished Training')
    return model

In [None]:
%%time
model = train_model(model, train_X, train_y, epochs=50, learn_rate=1e-3, minibatch_size=256, weight_decay=1e-5)

Epoch 0, Minibatch loss: 5.61, Val Perplexity: 4260.62, Epoch Time: 14.00 seconds
Epoch 1, Minibatch loss: 5.64, Val Perplexity: 2802.05, Epoch Time: 14.04 seconds
Epoch 2, Minibatch loss: 5.35, Val Perplexity: 2045.14, Epoch Time: 14.04 seconds
Epoch 3, Minibatch loss: 5.00, Val Perplexity: 1624.39, Epoch Time: 14.09 seconds
Epoch 4, Minibatch loss: 4.90, Val Perplexity: 1352.87, Epoch Time: 14.11 seconds
Epoch 5, Minibatch loss: 4.69, Val Perplexity: 1116.65, Epoch Time: 14.06 seconds
Epoch 6, Minibatch loss: 4.70, Val Perplexity: 972.72, Epoch Time: 14.07 seconds
Epoch 7, Minibatch loss: 4.74, Val Perplexity: 839.83, Epoch Time: 14.01 seconds
Epoch 8, Minibatch loss: 4.74, Val Perplexity: 731.71, Epoch Time: 13.94 seconds
Epoch 9, Minibatch loss: 4.62, Val Perplexity: 648.30, Epoch Time: 14.08 seconds
Epoch 10, Minibatch loss: 4.64, Val Perplexity: 581.80, Epoch Time: 14.09 seconds
Epoch 11, Minibatch loss: 4.33, Val Perplexity: 528.58, Epoch Time: 14.09 seconds
Epoch 12, Minibatch 

## Hyper-Parameter tuning findings
- glove 300 dimension vectors are essential to not have a bias of 1000 perplexity on both train and validation
- 2 layers of LSTM also gives high bias, perhaps there is not enough train data
- Weight decay is essential in preventing Validation perplexity from skyrocketing
- Dropout of 0.1 combined with weight decay 0.00001 works (around 250 validation perplexity)
- Decreasing learning rate and increasing epochs has a minor benefit

## Examine Performance of the model
- Using both perplexity and qualitative evaluation

### Performance of Teacher Forcing

In [None]:
%%time
model.eval()
print('Train perplexity is ', average_perplexity(model, train_X, train_y))

In [None]:
average_perplexity(model, valid_X, valid_y)

288.61844

In [None]:
average_perplexity(model, test_X, test_y)

263.7379

In [None]:
%%time
average_bleu(model, train_X, train_y)

CPU times: user 59.3 s, sys: 1.67 s, total: 1min
Wall time: 1min


0.2952516745473606

In [None]:
average_bleu(model, valid_X, valid_y)

0.23072047692589107

In [None]:
average_bleu(model, test_X, test_y)

0.2289468706162604

### Performance of Extrapolation

#### Demo + Methods

In [None]:
%%time
idx = 20

input = numpy_to_tensor(train_X[idx:idx+1])
input_idx = []
for each in input[0]:
    if sum(each) == 0: word = '<PAD>'
    else:
        word = glove.most_similar(positive=[each.cpu().detach().numpy(),])[0][0]
    index = vocab[word]
    # print(word, ": ", index)
    input_idx.append(index)

print('The original Sentence is:\n{}\n'.format(train[idx]))

pred = np.argmax(model.forward(input).cpu().detach().numpy(), axis=2)[0].astype(float)
label = train_y[idx].flatten()

print('input index is:\n{}\n->with shape {}'.format(input_idx, input.shape))
print('predict is:\n{}\n->with shape {}'.format(pred, pred.shape))
print('label is:\n{}\n->with shape {}'.format(label, label.shape))

The original Sentence is:
 the percentage of lung cancer deaths among the workers at the west <unk> mass. paper factory appears to be the highest for any asbestos workers studied in western industrialized countries he said 


input index is:
[0, 33, 73, 43, 208, 74, 75, 76, 33, 77, 161, 33, 217, 218, 181, 219, 220, 65, 221, 33]
->with shape torch.Size([1, 20, 300])
predict is:
[   0.  190.   43.   33.   74.   25.  109.   33.  190.   49.   33.  190.
 3577.  120.   49.   49.   65.   33.   36.  190.]
->with shape (20,)
label is:
[ 33.  73.  43. 208.  74.  75.  76.  33.  77. 161.  33. 217. 218. 181.
 219. 220.  65. 221.  33. 222.]
->with shape (20,)
CPU times: user 2.2 s, sys: 239 ms, total: 2.44 s
Wall time: 1.27 s


Beam Search

In [None]:
# get the top k most predicted results
def get_topK(predicted, k=1):
    
    # Get the index of the highest k index
    # Since the input is just one sentence, we can use [0] to extract the prediction result
    top_k = np.argsort(predicted)[-k:]

    # return a list of tuple
    # tuple[0]:word_id, tuple[1]:log(p)
    return [(id, predicted[id]) for id in top_k]

def generate_text(model, wv, dataset, next_words, k=1):
    # generate sentence given indexed sentence input (train_X, valid_X, test_X)
    # TODO: choose pred or just ground truth as the first 20 words?
    print('generating texts...')
    emb_size = wv["hello"].shape[0]
    generated = []
    for i in range(len(dataset)):
        if i % 10 == 0:
            print('{} of {} generates.'.format(i, len(dataset)))
        seed_text = dataset[i] # text will take form of (seq length, embedding size)
        seed_candidates = [(seed_text, .0)]
        for _ in range(next_words):
            successives = []
            # if k = 1, len(seed_candidates) will always be 1
            for i in range(len(seed_candidates)):
                seed_text, score = seed_candidates[i]

                seed_input = numpy_to_tensor(np.array([seed_text[-sequence_length:]]))
                predicted = model(seed_input).cpu().detach().numpy()[0][-1] # take the vocab prob of last word as the output

                tuples = get_topK(predicted, k)
                for id, val in tuples:
                    # get the output word
                    if id == 0: output_emb = np.zeros((emb_size,))
                    else:
                        output_emb = wv[index_vocab[id]]
                    # put the word into the sentence input
                    # calcualte the accumulated score by -log(p)
                    successives.append((np.vstack((seed_text,output_emb)), score - val)) 

            # Get the lowest k accumulated scores (highest k accumulated probabilities)
            # Then, make them as the seed_candidate for the next word to predict
            ordered = sorted(successives, key=lambda tup: tup[1])
            seed_candidates = ordered[:k]
        generated.append(seed_candidates[0][0])

    return generated

def generation_bleu(generates, wv, generate_length, references):
    # generates(n_samples, total sequence_length, emb_size)
    # only calcuate the generated part, otherwise score will be the same
    smoothie = SmoothingFunction().method4
    bleus = []
    for i in range(len(generates)):
        candidate = []
        for each in generates[i]:
            if sum(each) == 0: word = '<PAD>'
            else:
                word = wv.most_similar(positive=[each,])[0][0] # very slow, needs improvement
            candidate.append(vocab[word])
        reference = references[i].flatten()
        bleu = sentence_bleu([reference], candidate[-generate_length:], smoothing_function=smoothie)
        bleus.append(bleu)

        if i % 10 == 0:
            print('{} of {}, current mean: {}'.format(i, len(generates), np.mean(bleus)))
    return np.mean(bleus)

# generated = generate_text(model, glove, train_X[idx:idx+1], 20, k=4)


In [None]:
%%time
words = []
for each in generated[0]:
    if sum(each) == 0: word = '<PAD>'
    else:
        word = glove.most_similar(positive=[each,])[0][0]
    words.append(word)

print(' '.join(words))

<PAD> the bank which previously said it was for sale said it has received no offers and that its board of $ million or $ million or $ million or $ million or $ million or $ million of the
CPU times: user 4.31 s, sys: 356 ms, total: 4.66 s
Wall time: 2.42 s


Pure Argmax (Repetitive words)

In [None]:
%%time
generate_length = 20
generated = [index_vocab[i] for i in input_idx]
for g_idx in range(generate_length):
    output = np.argmax(model.forward(input).cpu().detach().numpy(), axis=2)[0]
    gen_text = [index_vocab[each] for each in output]
    generated.append(gen_text[-1])
    sentence = ' '.join(gen_text)
    data_sent, __ = preprocess([sentence], sequence_length, glove)
    input = numpy_to_tensor(data_sent[0:1])

print(' '.join(generated))

<PAD> the percentage of lung cancer deaths among the workers at the west mass. paper factory appears to be the <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
CPU times: user 23 ms, sys: 2.96 ms, total: 26 ms
Wall time: 26.3 ms


#### Metric Testing

In [None]:
%%time 
gen_num = 200
generate_length = 20
generates = generate_text(model, glove, train_X[:gen_num], generate_length, k=4)

print('final bleu is: ', generation_bleu(generates, glove, generate_length, train_y))

CPU times: user 1min 6s, sys: 1.31 s, total: 1min 7s
Wall time: 1min 7s


In [None]:
%%time 
generates = generate_text(model, glove, valid_X[:gen_num], generate_length, k=4)

print('final bleu is: ', generation_bleu(generates, glove, generate_length, train_y))

0 of 300, current mean: 0.0
10 of 300, current mean: 0.14789264018854723
20 of 300, current mean: 0.15120775962893512
30 of 300, current mean: 0.14330097103655004
40 of 300, current mean: 0.1410526006390764
50 of 300, current mean: 0.14896077352347178
60 of 300, current mean: 0.1384521413169753
70 of 300, current mean: 0.13929613490233908
80 of 300, current mean: 0.1384781248583928
90 of 300, current mean: 0.13753324989701837
100 of 300, current mean: 0.1400572111092312
110 of 300, current mean: 0.1392895892897581
120 of 300, current mean: 0.13739621476244523
130 of 300, current mean: 0.1326058703500126
140 of 300, current mean: 0.13364301669130613
150 of 300, current mean: 0.13746308850729247
160 of 300, current mean: 0.13620562854418256
170 of 300, current mean: 0.13833006638909742
180 of 300, current mean: 0.13676069313333056
190 of 300, current mean: 0.1379898741210657
200 of 300, current mean: 0.13710617443185913
210 of 300, current mean: 0.13570865395884718
220 of 300, current me

In [None]:
%%time 
generates = generate_text(model, glove, test_X[:gen_num], generate_length, k=4)

print('final bleu is: ', generation_bleu(generates, glove, generate_length, train_y))

0 of 300, current mean: 0.0
10 of 300, current mean: 0.11751849961784878
20 of 300, current mean: 0.12538663665626343
30 of 300, current mean: 0.13291820485759656
40 of 300, current mean: 0.13801671875477103
50 of 300, current mean: 0.1364898668581318
60 of 300, current mean: 0.13783356948121334
70 of 300, current mean: 0.1429065723530485
80 of 300, current mean: 0.14659209554975075
90 of 300, current mean: 0.15096345387588156
100 of 300, current mean: 0.15266055209327156
110 of 300, current mean: 0.14727311055997241
120 of 300, current mean: 0.14668877464796012
130 of 300, current mean: 0.14313333068298362
140 of 300, current mean: 0.14436660420378525
150 of 300, current mean: 0.1462144523915425
160 of 300, current mean: 0.14436104740095076
170 of 300, current mean: 0.14602176452641197
180 of 300, current mean: 0.14640573856705205
190 of 300, current mean: 0.1483496462912163
200 of 300, current mean: 0.14591980627934392
210 of 300, current mean: 0.1443720620640844
220 of 300, current 

# GPT2 Baseline
Implement GPT2 as a language modelling baseline. GPT-3 is not publicly available and too large for practical purposes. BERT needs modification to work for language modelling, due to the fact that it is trained for bidirectional masked language modelling instead.

This section makes use of several tutorials for fine tuning, including:
- https://reyfarhan.com/posts/easy-gpt2-finetuning-huggingface/
- https://mccormickml.com/2019/07/22/BERT-fine-tuning/#4-train-our-classification-model
- https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel (the documentation)
- https://huggingface.co/transformers/custom_datasets.html

In [None]:
# colab does not have transformers by default
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d5/43/cfe4ee779bbd6a678ac6a97c5a5cdeb03c35f9eaebbb9720b036680f9a2d/transformers-4.6.1-py3-none-any.whl (2.2MB)
[K     |▏                               | 10kB 22.5MB/s eta 0:00:01[K     |▎                               | 20kB 29.5MB/s eta 0:00:01[K     |▍                               | 30kB 23.5MB/s eta 0:00:01[K     |▋                               | 40kB 18.1MB/s eta 0:00:01[K     |▊                               | 51kB 16.2MB/s eta 0:00:01[K     |▉                               | 61kB 18.5MB/s eta 0:00:01[K     |█                               | 71kB 14.2MB/s eta 0:00:01[K     |█▏                              | 81kB 14.9MB/s eta 0:00:01[K     |█▎                              | 92kB 13.5MB/s eta 0:00:01[K     |█▌                              | 102kB 14.1MB/s eta 0:00:01[K     |█▋                              | 112kB 14.1MB/s eta 0:00:01[K     |█▊                              | 

In [None]:
# load　GPT, BERT and support materials from huggingface
# requires pip install transformers
# if in jupyter notebook see here and you get an error mention ipython widgets see here: 
# https://stackoverflow.com/questions/53247985/tqdm-4-28-1-in-jupyter-notebook-intprogress-not-found-please-update-jupyter-an
from transformers import GPT2Tokenizer, GPT2LMHeadModel, top_k_top_p_filtering, Trainer, TrainingArguments
import torchtext

### Data Processing

In [None]:
train, valid, test = torchtext.datasets.PennTreebank(split=('train', 'valid', 'test')) # len(train) = 4.2w
# train, valid, test = torchtext.datasets.WikiText2(split=('train', 'valid', 'test')) # len(train) = 3.6w
train = list(train) # these are originally iterators, the data is so small we can just retrieve all of it at once
valid = list(valid)
test  = list(test)

ptb.train.txt: 5.10MB [00:00, 102MB/s]                    
ptb.valid.txt: 400kB [00:00, 33.1MB/s]                   
ptb.test.txt: 450kB [00:00, 39.6MB/s]                   


In [None]:
# Download the models
# Documentation for GPT: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=665.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=548118077.0, style=ProgressStyle(descri…




We can see that the gpt_tokenizer works differently to ours, splitting up names such as 'rudolph' into 'rud' and 'olph' and words such as nonexecutive and british. Hence our perplexity evaluation will have to be slightly different, using gpt_tokenizer to get the ground truth labels

In [None]:
# Define a dataset class for fine-tuning, it's a generator so we don't have to store the entire dataset in memory
class GPT2Dataset(torch.utils.data.Dataset):
    def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=40):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []

        # Encode all the text, padding and truncuating it along with adding attention masks to get the sequence length the same across all samples
        for txt in txt_list:
            encodings_dict = tokenizer.encode_plus('<|startoftext|>'+ txt + '<|endoftext|>', truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(encodings_dict['input_ids'])
            self.attn_masks.append(encodings_dict['attention_mask'])

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        # The tutorial use a dictionary format that also stores labels 
        return_dict = {"input_ids": torch.tensor(self.input_ids[idx]),
                       "attention_mask": torch.tensor(self.attn_masks[idx]), 
                       "labels": torch.tensor(self.input_ids[idx])} 
        return return_dict

In [None]:
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token # set the pad token
gpt_sequence_length = 40 # gpt splits up words into smaller tokens, so the sequence length should be longer
train_dataset = GPT2Dataset(train, gpt_tokenizer, max_length=gpt_sequence_length)
val_dataset = GPT2Dataset(valid, gpt_tokenizer, max_length=gpt_sequence_length)
test_dataset = GPT2Dataset(test, gpt_tokenizer, max_length=gpt_sequence_length)
train_dataset[1]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]),
 'input_ids': tensor([   27,    91,  9688,  1659,  5239,    91,    29, 17748,   260,  1279,
          2954,    29,   399,   812,  1468,   481,  4654,   262,  3096,   355,
           257, 36196,   721,  8827,  3437,   645,    85,    13,   399,   220,
           198, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]),
 'labels': tensor([   27,    91,  9688,  1659,  5239,    91,    29, 17748,   260,  1279,
          2954,    29,   399,   812,  1468,   481,  4654,   262,  3096,   355,
           257, 36196,   721,  8827,  3437,   645,    85,    13,   399,   220,
           198, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256])}

### Fine Tuning

Do fine tuning of the gpt_model using the hugging face out of the box trainer https://huggingface.co/transformers/custom_datasets.html#fine-tuning-with-trainerfrom 

In [None]:
training_args = TrainingArguments(
    output_dir='gpt_finetuning',     # output directory
    num_train_epochs=1,              # total number of training epochs (1 is enough to get very low perplexity and perplexity increases at 2)
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.001,               # strength of weight decay
    logging_dir='gpt_finetuning_logs',            # directory for storing logs
    logging_steps=100,
)

trainer = Trainer(
    model=gpt_model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

trainer.train()

Step,Training Loss
100,5.0821
200,2.6236
300,2.4432
400,2.3985
500,2.3168
600,2.2602
700,2.2061
800,2.2011
900,2.2046
1000,2.1565


TrainOutput(global_step=2630, training_loss=2.288104583102034, metrics={'train_runtime': 433.1046, 'train_samples_per_second': 6.072, 'total_flos': 61761965506560.0, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 2092535808, 'init_mem_gpu_alloc_delta': 511148032, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': -219136000, 'train_mem_gpu_alloc_delta': 1505119232, 'train_mem_cpu_peaked_delta': 239112192, 'train_mem_gpu_peaked_delta': 1470224384})

### Model Definition
Define a wrapper model that can use GPT2 both for standard next word prediction and language generation

In [None]:
# Build a model wrapper for gpt2 that uses the "past" variable and for language modelling
# TODO: add the options for beam search
class GPTModel(torch.nn.Module):
    def __init__(self, model=None, sequence_length=20):
        super().__init__()
        self.gpt = model.to(device)
        self.tokenizer = gpt_tokenizer
        self.sequence_length = sequence_length
        self.vocab_size = self.tokenizer.vocab_size
    
    # output the logits for the most likely next word at each position in the sentence and optionally the hidden states (used for the Neural ODE) 
    # note input_dataset must be an element taken from a GPT2Dataset class (e.g. train_dataset[0])
    def forward(self, input_dataset, output_hidden_states=False):
        output = self.gpt.forward(input_ids = input_dataset['input_ids'].to(device), 
                                  attention_mask=input_dataset['attention_mask'].to(device),
                                  use_cache=False,
                                  output_hidden_states = output_hidden_states)
        if output_hidden_states:
            return output["hidden_states"]
        return output["logits"]
    
    # take in a sentence and output the predictions as in forward, but as the most likely sentence not logits
    def forward_sentence(self, input_dataset):
        preds = self.forward(input_dataset) # TODO: Extrapolation
        tokens = torch.argmax(preds, dim=-1)
        return self.tokenizer.decode(tokens)
    
    # generate a sentence by sampling the next word from the probability distribution
    # set limit to an integer to generate `limit` number of words instead of ending at a full stop
    def random_gen(self, x, limit=None):
        # initialize variables
        generated = self.tokenizer.encode_plus(x, return_tensors="pt")['input_ids'].to('cuda')
        x_len = len(generated[0])
        next_token = [generated[0][-1]]
        past = None
        raw_output= None
        stop_list = ['.', '?', '!', '<|endoftext|>']
        
        # generate until a "." is generated
        while ((limit is None) or (limit is not None and len(generated[0]) < limit)) and \
        self.tokenizer.decode(next_token[0]) not in stop_list:
            # get output of model, using past if available
            if past is None:
                raw_output = self.gpt(generated, past_key_values=past)
            else:
                raw_output = self.gpt(next_token, past_key_values=past)
            output, past = raw_output['logits'], raw_output['past_key_values']
            next_token_logits = output[:, -1, :]
            
            # sample a token from the top 50 most likely words
            filtered_next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=50, top_p=1.0) # filter to the top 50 tokens
            probs = torch.nn.functional.softmax(filtered_next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            # break loop when the end indicator (stop_list) is reached
            if (self.tokenizer.decode(next_token[0]) in stop_list):
                break
            generated = torch.cat([generated, next_token], dim=-1)

        full_output = self.tokenizer.decode(generated[0])
        gen_output = self.tokenizer.decode(generated[0][x_len:])
      
        return full_output, gen_output
    
    # do beam_search to find the most likely sentence
    def beam_search(self, x, beam=5): 
        input_ids = self.tokenizer.encode_plus(x, return_tensors="pt")['input_ids'].to('cuda')
        generated = self.gpt.generate(input_ids=input_ids, num_beams=beam)
        full_output = self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
        gen_output = self.tokenizer.decode(generated[0][len(input_ids[0]):], skip_special_tokens=True).strip()
        return full_output, gen_output

In [None]:
modelgpt = GPTModel(model=gpt_model, sequence_length=gpt_sequence_length)
modelgpt.eval()
modelgpt.forward(train_dataset[0])

tensor([[-31.7897, -30.7168, -32.1723,  ..., -39.5315, -39.7644, -31.6574],
        [-63.3438, -60.9143, -61.4364,  ..., -72.3313, -72.0760, -62.4938],
        [-53.8140, -53.4152, -53.4254,  ..., -63.6133, -62.0699, -54.8247],
        ...,
        [-73.2818, -73.2793, -74.6490,  ..., -78.8852, -78.8792, -71.2035],
        [-85.5154, -84.4664, -85.6136,  ..., -92.3424, -91.5167, -83.2108],
        [-67.5895, -67.4203, -67.9004,  ..., -74.4833, -74.5151, -65.2906]],
       device='cuda:0', grad_fn=<MmBackward>)

In [None]:
summary(modelgpt)

Layer (type:depth-idx)                        Param #
GPTModel                                      --
├─GPT2LMHeadModel: 1-1                        --
│    └─GPT2Model: 2-1                         --
│    │    └─Embedding: 3-1                    38,597,376
│    │    └─Embedding: 3-2                    786,432
│    │    └─Dropout: 3-3                      --
│    │    └─ModuleList: 3-4                   85,054,464
│    │    └─LayerNorm: 3-5                    1,536
│    └─Linear: 2-2                            38,597,376
Total params: 163,037,184
Trainable params: 163,037,184
Non-trainable params: 0

View the next word output for a single example

### Evaluation

View the ability to generate without teacher forcing using the random_gen() function

#### Force Teaching

In [None]:
# The formula for calculating perplexity in language models can be found here: https://web.stanford.edu/~jurafsky/slp3/3.pdf (page 8)
# An interesting detail is that the geometric mean of perplexity from each word is used
# if the mask is 0 at index i don't use the value at index i to calculate perplexity
def perplexity_gpt(preds, ground_truth, mask, epsilon=1e-30):
    probs = []
    for i in range(preds.shape[0]):
        if mask[i] != 0:
            probs.append(preds[i, int(ground_truth[i])])
    probs = np.array(probs)
    probs = np.power(1/(probs+epsilon), 1/probs.shape[0]) # normalise before taking the product, to prevent underflowing to 0
    return np.prod(probs).detach().cpu().numpy()

# Can optionally define n_samples=int to limit the number of samples used for perplexity evaluation
def average_perplexity_gpt(model, train, n_samples=None, print_results=False):
    perplexities = []
    n_samples = len(train) if n_samples is None else n_samples
    with torch.no_grad():
        for i in range(n_samples):
            # Compute perplexity for a single sample
            labels = train[i]['input_ids'][1:]
            mask = train[i]['attention_mask'][:-1]
            preds = model.forward(train[i])[:-1] # remove the last prediction as there is no ground truth 
            preds = torch.nn.functional.softmax(preds, dim=-1)
            perplexities.append(perplexity_gpt(preds[6:], labels[6:], mask[6:])) # remove the first 7 tokens that represent "<|startoftext|>"

            if i % 100 == 0 and print_results:
                print("Sentences analysed: {} Average perplexity: {}".format(i, np.mean(perplexities)))
    return np.mean(perplexities)


# straight calculation of BLEU
def average_bleu_gpt(model, train):
    smoothie = SmoothingFunction().method4
    bleus = []
    n_samples = len(train)
    with torch.no_grad():
        for i in range(n_samples):
            mask = train[i]['attention_mask'][:-1][6:].tolist()

            reference_withmask = train[i]['input_ids'][1:][6:].tolist()
            preds = torch.nn.functional.softmax(modelgpt.forward(train[i])[:-1], dim=-1)
            candidate_withmask = np.argmax(preds[6:].cpu(), axis=1).tolist()

            reference = [reference_withmask[i] for i in range(len(mask)) if mask[i] != 0]
            candidate = [candidate_withmask[i] for i in range(len(mask)) if mask[i] != 0]

            bleus.append(sentence_bleu([reference], candidate, smoothing_function=smoothie))
    return np.mean(bleus)

In [None]:
%%time
average_perplexity_gpt(modelgpt, train_dataset, n_samples=None, print_results=False)

In [None]:
average_perplexity_gpt(modelgpt, val_dataset, n_samples=None, print_results=False)

1219.9313

In [None]:
average_perplexity_gpt(modelgpt, test_dataset, n_samples=None, print_results=False)

23.9871

In [None]:
%%time
average_bleu_gpt(modelgpt, train_dataset)

In [None]:
average_bleu_gpt(modelgpt, val_dataset)

0.225191197779219

In [None]:
average_bleu_gpt(modelgpt, test_dataset)

0.22589598064165836

#### Generate Text

##### Demo+Methods

In [None]:
%%time
# Still use the dataset for LSTM
idx = 57
o_sent = train[idx].strip()
print('original sentence:\n\t', o_sent)

gen_len = len(o_sent.split(' '))
preceed_len = 5

if gen_len == preceed_len:
    print('should skip')

input = ' '.join(o_sent.split(' ')[:preceed_len])

print('input sentence:\n\t', input)

full_sent, gen_sent = modelgpt.beam_search(input, beam=5)

print('full sentence:\n\t', full_sent)
print('generated part:\n\t', gen_sent)

smoothie = SmoothingFunction().method4
if len(gen_sent) == 0: pass
else:
    print('The BLEU score is: ', sentence_bleu([o_sent], gen_sent, smoothing_function=smoothie))

original sentence:
	 <unk> is an italian state-owned holding company with interests in the mechanical engineering industry
input sentence:
	 <unk> is an italian state-owned
full sentence:
	 <unk> is an italian state-owned <unk> company based in <unk> <
generated part:
	 <unk> company based in <unk> <
The BLEU score is:  0.052178861809722275
CPU times: user 133 ms, sys: 0 ns, total: 133 ms
Wall time: 132 ms


In [None]:
def gpt_gen_bleu(modelgpt, dataset, beam=5, preceed_len = 5, max_instances = 1000):
    smoothie = SmoothingFunction().method4
    bleus = []
    for idx in range(len(dataset)):
        if idx + 1 > max_instances: break
        o_sent = dataset[idx] # original sent
        gen_len = len(train[idx].split(' '))
        if gen_len == preceed_len: continue
        
        i_sent = ' '.join(train[idx].split(' ')[:preceed_len]) # input sent
        if beam <= 1:
            __, gen_sent = modelgpt.random_gen(i_sent, gen_len)
        else:
            __, gen_sent = modelgpt.beam_search(i_sent, beam=beam)
        if len(gen_sent) == 0: continue

        bleus.append(sentence_bleu([o_sent], gen_sent, smoothing_function=smoothie))

        if idx % 50 == 0:
            print('{} of {}, current mean BLEU: {}'.format(idx, min(len(dataset), max_instances), np.mean(bleus)))
    return np.mean(bleus)

##### Metric Testing

In [None]:
import logging
from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL, NOTSET

logging.disable(level=WARNING)

In [None]:
%%time 
# score of beam search is relatively low
train_bleu = gpt_gen_bleu(modelgpt, train, beam=5)

print('finished ! Final mean BLEU is: ', train_bleu) 

In [None]:
%%time
train_bleu = gpt_gen_bleu(modelgpt, train,  beam=0)

print('finished ! Final mean BLEU is: ', train_bleu)

0 of 2000, current mean BLEU: 0.010071651377979756
50 of 2000, current mean BLEU: 0.11911115400946945
100 of 2000, current mean BLEU: 0.12949902722881348
150 of 2000, current mean BLEU: 0.1361311873990362
200 of 2000, current mean BLEU: 0.13587786469857652
250 of 2000, current mean BLEU: 0.13458132013703467
300 of 2000, current mean BLEU: 0.1331856057372505
350 of 2000, current mean BLEU: 0.13197857953433884
400 of 2000, current mean BLEU: 0.13303948249875258
450 of 2000, current mean BLEU: 0.13423432690039477
500 of 2000, current mean BLEU: 0.1353561917592508


KeyboardInterrupt: ignored

In [None]:
valid_bleu = gpt_gen_bleu(modelgpt, valid, beam=0)

print('finished ! Final mean BLEU is: ', valid_bleu)

0 of 2000, current mean BLEU: 0.08904865917278683
50 of 2000, current mean BLEU: 0.09975866581086114
100 of 2000, current mean BLEU: 0.09040492859127659
150 of 2000, current mean BLEU: 0.09374448971734457
200 of 2000, current mean BLEU: 0.09725366946714924
250 of 2000, current mean BLEU: 0.09974675554953126
300 of 2000, current mean BLEU: 0.09902339383331599
350 of 2000, current mean BLEU: 0.10453152342912325
400 of 2000, current mean BLEU: 0.10542553255724137
450 of 2000, current mean BLEU: 0.10555352520599964
500 of 2000, current mean BLEU: 0.10545958585727615
550 of 2000, current mean BLEU: 0.10507760399240752
600 of 2000, current mean BLEU: 0.10664569754859836
650 of 2000, current mean BLEU: 0.1054687718329927
700 of 2000, current mean BLEU: 0.10516710104715085
750 of 2000, current mean BLEU: 0.10440869068560814
800 of 2000, current mean BLEU: 0.10312904842898596
850 of 2000, current mean BLEU: 0.10337492575789471
900 of 2000, current mean BLEU: 0.10447467954922028
950 of 2000, cur

KeyboardInterrupt: ignored

In [None]:
test_bleu = gpt_gen_bleu(modelgpt, test, beam=0)

print('finished ! Final mean BLEU is: ', test_bleu)

0 of 1000, current mean BLEU: 0.056823540398864544
50 of 1000, current mean BLEU: 0.11283436320085767
100 of 1000, current mean BLEU: 0.11015643863772863
150 of 1000, current mean BLEU: 0.11309885491713392
200 of 1000, current mean BLEU: 0.1099456656544335
250 of 1000, current mean BLEU: 0.10457915745048403
300 of 1000, current mean BLEU: 0.10244082494421641
350 of 1000, current mean BLEU: 0.10281257821714644
400 of 1000, current mean BLEU: 0.10129830353050887
450 of 1000, current mean BLEU: 0.10214188575597626
500 of 1000, current mean BLEU: 0.10464878356218418
550 of 1000, current mean BLEU: 0.10308959588325638
600 of 1000, current mean BLEU: 0.10389376244884468
650 of 1000, current mean BLEU: 0.10298833875547865
700 of 1000, current mean BLEU: 0.10214513158340849
750 of 1000, current mean BLEU: 0.10174657798372658


ZeroDivisionError: ignored

# NeuralDE-LSTM

In [None]:
class ODELSTM(torch.nn.Module):
    def __init__(self, vocab_size, input_size=100, layer_size=100, dropout=0):
        super().__init__()
        self.LSTM = torch.nn.LSTM(input_size, layer_size, 1, bidirectional=False, dropout=dropout)
        self.linear = torch.nn.Linear(layer_size, vocab_size)
        self.softmax = torch.nn.Softmax(dim=-1)

        # Define the derivative function
        f = torch.nn.Sequential(
            torch.nn.Linear(layer_size, layer_size),
            torch.nn.ReLU(),
            torch.nn.Linear(layer_size, layer_size),
        )

        self.node = NeuralDE(f, sensitivity='adjoint', solver='dopri5')
        self.timesteps = torch.arange(0, 20, 1, device=device).float() # define the number of output items of the Neural ODE
    
    def forward(self, x):
        # convert words to their vectors here
        senquence_output , hidden_state = self.LSTM(x)
        output = self.node.trajectory(senquence_output, self.timesteps)
        
        pred = self.linear(senquence_output)
        return pred
    
    # wrapper function that forward propagates, applies softmax and converts to numpy 
    def predict(self, x):
        preds = self.forward(x)
        preds = self.softmax(preds).detach().cpu().numpy()
        return preds

In [None]:
%%time

odeLSTM = ODELSTM(vocab_size, input_size=300, layer_size=300)
odeLSTM.to(device)
odeLSTM.eval()

# unit test to check that forward propagation works
data = numpy_to_tensor(train_X[1:2])
print('shape of final output is: ', odeLSTM.forward(data).shape)

shape of final output is:  torch.Size([1, 20, 10001])
CPU times: user 93.3 ms, sys: 0 ns, total: 93.3 ms
Wall time: 144 ms


In [None]:
del data
torch.cuda.empty_cache()

summary(odeLSTM)

Layer (type:depth-idx)                   Param #
ODELSTM                                  --
├─LSTM: 1-1                              722,400
├─Linear: 1-2                            3,010,301
├─Softmax: 1-3                           --
├─NeuralDE: 1-4                          --
│    └─DEFunc: 2-1                       --
│    │    └─Sequential: 3-1              180,600
│    └─Adjoint: 2-2                      --
Total params: 3,913,301
Trainable params: 3,913,301
Non-trainable params: 0

## Training

In [None]:
# training the model
def train_odelstm(model, train_X, train_y, epochs=10, learn_rate=0.01, weight_decay=0.001):
    # Prepare data
    X = numpy_to_tensor(train_X)
    y = numpy_to_tensor(train_y).long()[:, :, 0]
    n_samples = X.shape[0]
    
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate, weight_decay=weight_decay)

    # Ensure this runs on gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        start_t2 = time()
        
        # shuffle the data
        new_indices = torch.randperm(n_samples)
        X = X[new_indices, :, :] 
        y = y[new_indices, :]
        
        for idx in range(n_samples):
            # forward + backward + optimize
            optimizer.zero_grad()
            outputs = model(X[idx:idx+1]) 
            outputs = torch.swapaxes(outputs, 1, 2) # cross entropy expects a tensor of (n_samples, n_outputs, sequence_length)
            loss = criterion(outputs, y[idx:idx+1])
            loss.backward()
            optimizer.step()

            if idx % 500 == 0:
            # print loss every 500 steps (perplexity takes too much time)
                print("Epoch {}: index {} of {} - Current loss: {:.2f}, Time Taken: {:.2f} secs"
                .format(epoch, idx, n_samples, loss.item(), time() - start_t2))
    
    del X
    del y
    torch.cuda.empty_cache()
    print('Finished Training!')
    return model

In [None]:
%%time
odeLSTM = train_odelstm(odeLSTM, train_X, train_y, epochs=1, learn_rate=1e-3, weight_decay=1e-5)

Epoch 0: index 0 of 42068 - Current loss: 9.23, Time Taken: 0.07 secs
Epoch 0: index 500 of 42068 - Current loss: 6.62, Time Taken: 24.09 secs
Epoch 0: index 1000 of 42068 - Current loss: 6.45, Time Taken: 47.97 secs
Epoch 0: index 1500 of 42068 - Current loss: 6.89, Time Taken: 71.75 secs
Epoch 0: index 2000 of 42068 - Current loss: 7.48, Time Taken: 95.50 secs
Epoch 0: index 2500 of 42068 - Current loss: 5.52, Time Taken: 119.37 secs
Epoch 0: index 3000 of 42068 - Current loss: 2.62, Time Taken: 143.16 secs
Epoch 0: index 3500 of 42068 - Current loss: 5.36, Time Taken: 166.92 secs
Epoch 0: index 4000 of 42068 - Current loss: 5.69, Time Taken: 190.74 secs
Epoch 0: index 4500 of 42068 - Current loss: 6.42, Time Taken: 214.61 secs
Epoch 0: index 5000 of 42068 - Current loss: 6.83, Time Taken: 238.38 secs
Epoch 0: index 5500 of 42068 - Current loss: 6.12, Time Taken: 262.04 secs
Epoch 0: index 6000 of 42068 - Current loss: 6.13, Time Taken: 285.64 secs
Epoch 0: index 6500 of 42068 - Curr

## Evaluation

### Force Teaching

In [None]:
%%time
odeLSTM.eval()
print('Train perplexity is ', average_perplexity(odeLSTM, train_X, train_y))

Train perplexity is  232.5228
CPU times: user 30min 4s, sys: 5.62 s, total: 30min 10s
Wall time: 30min 6s


In [None]:
average_perplexity(odeLSTM, valid_X, valid_y)

276.98474

In [None]:
average_perplexity(odeLSTM, test_X, test_y)

253.82364

In [None]:
%%time
average_bleu(odeLSTM, train_X, train_y)

CPU times: user 30min 27s, sys: 4.86 s, total: 30min 32s
Wall time: 30min 29s


0.2962042315776427

In [None]:
average_bleu(odeLSTM, valid_X, valid_y)

0.22759942516244355

In [None]:
average_bleu(odeLSTM, test_X, test_y)

0.22526977732762332

### Generation

In [None]:
%%time 
gen_num = 100 
generate_length = 20
generates = generate_text(odeLSTM, glove, train_X[:gen_num], generate_length, k=4)

print('final bleu is: ', generation_bleu(generates, glove, generate_length, train_y))

generating texts...
0 of 100 generates.
10 of 100 generates.
20 of 100 generates.
30 of 100 generates.
40 of 100 generates.
50 of 100 generates.
60 of 100 generates.
70 of 100 generates.
80 of 100 generates.
90 of 100 generates.
0 of 100, current mean: 0.0
10 of 100, current mean: 0.08578610427913255
20 of 100, current mean: 0.12926636671074815
30 of 100, current mean: 0.12609363790669073
40 of 100, current mean: 0.12362938795424462
50 of 100, current mean: 0.11152848683633648
60 of 100, current mean: 0.11669757477666057
70 of 100, current mean: 0.11780536079504665
80 of 100, current mean: 0.11724413601320274
90 of 100, current mean: 0.11979319907791623
final bleu is:  0.12317864915106946
CPU times: user 10min 34s, sys: 25.3 s, total: 10min 59s
Wall time: 8min 25s


In [None]:
generates = generate_text(odeLSTM, glove, valid_X[:gen_num], generate_length, k=4)

print('final bleu is: ', generation_bleu(generates, glove, generate_length, valid_y))

generating texts...
0 of 100 generates.
10 of 100 generates.
20 of 100 generates.
30 of 100 generates.
40 of 100 generates.
50 of 100 generates.
60 of 100 generates.
70 of 100 generates.
80 of 100 generates.
90 of 100 generates.
0 of 100, current mean: 0.0
10 of 100, current mean: 0.16549101212311698
20 of 100, current mean: 0.13599679350414534
30 of 100, current mean: 0.12256709066637131
40 of 100, current mean: 0.12455909851927117
50 of 100, current mean: 0.1235535708523886
60 of 100, current mean: 0.11933256071651413
70 of 100, current mean: 0.12080518572858785
80 of 100, current mean: 0.1152958925064089
90 of 100, current mean: 0.11562571341613051
final bleu is:  0.11939014017448064


In [None]:
generates = generate_text(odeLSTM, glove, test_X[:gen_num], generate_length, k=4)

print('final bleu is: ', generation_bleu(generates, glove, generate_length, test_y))

generating texts...
0 of 100 generates.
10 of 100 generates.
20 of 100 generates.
30 of 100 generates.
40 of 100 generates.
50 of 100 generates.
60 of 100 generates.
70 of 100 generates.
80 of 100 generates.
90 of 100 generates.
0 of 100, current mean: 0.0
10 of 100, current mean: 0.0920419121273297
20 of 100, current mean: 0.09752345541111389
30 of 100, current mean: 0.09761440818932049
40 of 100, current mean: 0.08267665208368496
50 of 100, current mean: 0.08207742715461866
60 of 100, current mean: 0.08707285201208335
70 of 100, current mean: 0.08858447976017784
80 of 100, current mean: 0.08715209913284197
90 of 100, current mean: 0.08832279252742575
final bleu is:  0.09145840249831404


# NeuralDE-GPT2

In [None]:
# Build a wrapper for gpt that takes a torch.util.data.TensorDataset as input, needed for pytorch lightning
class GPTModelWrapper(torch.nn.Module):
    def __init__(self, model=None, sequence_length=40):
        super().__init__()
        self.gpt = model.to(device)
        self.tokenizer = gpt_tokenizer
        self.sequence_length = sequence_length
        self.vocab_size = self.tokenizer.vocab_size
    
    # output the hidden states for the entire sequence used for the Neural ODE
    def forward(self, input_dataset):
        output = self.gpt.forward(input_ids = input_dataset[0].to(device), 
                                  attention_mask=input_dataset[1].to(device),
                                  use_cache=False,
                                  output_hidden_states=True)
        return output["hidden_states"]


# Defines an ODE that uses a GPT to get a representation for the sentence
class ODEGPT(pl.LightningModule):
    def __init__(self, modelgpt, sequence_length=40):
        super().__init__()
        layer_size = 768 # the size of gpt's hidden state
        self.loss = torch.nn.CrossEntropyLoss()
        
        # Freeze the GPT model's parameters to save training time
        self.modelgpt = modelgpt
        for param in self.modelgpt.parameters():
            param.requires_grad = False
        
        # Define the derivative function
        self.f = torch.nn.Sequential(
            torch.nn.Linear(layer_size, layer_size),
            torch.nn.ReLU(),
            torch.nn.Linear(layer_size, layer_size),
        )
        
        # Define the model itself
        self.node = NeuralDE(self.f, sensitivity='adjoint', solver='dopri5').to(device)
        self.linear = torch.nn.Linear(layer_size, self.modelgpt.vocab_size).to(device)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.sequence_length = sequence_length
        self.timesteps = torch.arange(0, 40, 1, device=device).float() # define the number of output items of the Neural ODE
    
    # take in a single sample and feed forward, giving the logits as output
    # note x must be an element of a GPT2Dataset class so that it can be fed to the GPT model
    def forward(self, x):
        # at the moment this feeds the entire sequence to LSTM and asks Neural ODE to reproduce it
        # TODO: switch to feeding half the sequence and asking NeuralODE to extrapolate
        x = [x[0][:, :20], x[1][:, :20], x[2][:, :20]] 
        hidden_states = self.modelgpt(x) 
        # hidden_states = self.modelgpt(x[:20]) 
        attention_mask = x[1].to(device)[0, :] # batching makes x[1] have a shape of (batch_size, features), we use batches of 1 so take the first
        
        # use the output of GPT2's 12th decoder, "BERT Rediscovers the Classical NLP Pipeline" has shown transformers' later layers represent high level meaning, which is 
        # what we want to input to the Neural ODE
        # TODO: Perhaps consider the above paper's method of having a weighted sum of layers representations, with trainable weights
        final_hidden = hidden_states[12] 
        final_hidden = final_hidden[0, attention_mask, :][-1, :] # Take the output of the last sequence item that isn't a pad token
        # feed to neural ode
        sequence_outputs = self.node.trajectory(final_hidden, self.timesteps) # output is of shape (sequence_length, gpt_hidden_layer_size)
        
        # Get final output
        pred = self.linear(sequence_outputs)
        return pred
    
    # compute the loss on a batch, required by pytorch lightning
    # note the batch must be an element of a tf.utils.data.TensorDataset, this function is only meant to be used with pytorch_lightning's training loop
    def training_step(self, batch, batch_idx):
        labels = batch[2][0, 1:].to(device) # shift the input 1 step ahead to get the next word labels
        preds = self.forward(batch)[:-1, :] # remove the prediction for the last token as there is no label
        loss = self.loss(preds, labels) # crossentropy loss expects preds to be of size (batch, n_classes) so it handles our sequence model use case
        return loss
    
    # configure the optimizer for pytorch lightning
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.00001, betas=(0.95, 0.999)) # low learning rate and momentum since this is stochastic optimisation
    
    # wrapper function that forward propagates, applies softmax and converts to numpy 
    def predict(self, x):
        preds = self.forward(x)
        preds = self.softmax(preds).detach().cpu().numpy()
        return preds

In [None]:
gptmodel_wrapper = GPTModelWrapper(gpt_model)
odemodel = ODEGPT(gptmodel_wrapper)

Errors: setting num_workers = 1 causes DataLoader to hang. Setting num_workers = 0 causes a random CUDA error that doesn't happen in the above functions

### Training
Use the pytorch lightning's training loop to speed up training. 

Important documentation
- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#training
- https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-class-api

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
# Load the data into a new dataset, pytorch_lightning doesn't like our custom dataset
full_dataset = train_dataset[:]
train_tensor_dataset = torch.utils.data.TensorDataset(full_dataset['input_ids'], full_dataset['attention_mask'], full_dataset['labels'])
train_dataloader = torch.utils.data.DataLoader(train_tensor_dataset, batch_size=1, shuffle=True,
                             num_workers=2, pin_memory=True)

# Test run to check for errors
for batch in train_dataloader:
    results = odemodel.forward(batch)
    print(results[:768], torch.sum(results[:768]))
    print(results[768:], torch.sum(results[768:]))
    #print(odemodel.forward(batch))
    print(odemodel.training_step(batch, 0))
    break

tensor([[-5.5348e+00,  5.9912e+00, -3.9831e+00,  ...,  1.8607e-01,
         -1.4220e+00, -2.9475e+00],
        [-7.1594e+00,  6.9915e+00, -5.7088e+00,  ..., -4.7523e-01,
         -5.5087e-02, -1.2406e+00],
        [-8.5823e+00,  8.1473e+00, -7.3720e+00,  ..., -1.2025e+00,
          1.1492e+00,  2.8073e-01],
        ...,
        [-1.1781e+04, -2.3058e+03, -2.3049e+04,  ...,  7.8360e+03,
          3.8049e+03, -2.6058e+03],
        [-1.5199e+04, -3.5479e+03, -2.9103e+04,  ...,  9.9815e+03,
          4.3686e+03, -3.7223e+03],
        [-1.9538e+04, -5.1800e+03, -3.6753e+04,  ...,  1.2736e+04,
          4.9002e+03, -5.0784e+03]], device='cuda:0', grad_fn=<SliceBackward>) tensor(8521420., device='cuda:0', grad_fn=<SumBackward0>)
tensor([], device='cuda:0', size=(0, 50257), grad_fn=<SliceBackward>) tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(6626.0190, device='cuda:0', grad_fn=<NllLossBackward>)


In [None]:
# Train the model
trainer = pl.Trainer(max_epochs=1, gpus=1, progress_bar_refresh_rate=10)
trainer.fit(odemodel, train_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | loss     | CrossEntropyLoss | 0     
1 | modelgpt | GPTModelWrapper  | 124 M 
2 | f        | Sequential       | 1.2 M 
3 | node     | NeuralDE         | 1.2 M 
4 | linear   | Linear           | 38.6 M
5 | softmax  | Softmax          | 0     
----------------------------------------------
39.8 M    Trainable params
124 M     Non-trainable params
164 M     Total params
657.074   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




### Evaluation

In [None]:
# Define functions to calculate perplexity for a single sentence: see the metric definition here https://web.stanford.edu/~jurafsky/slp3/3.pdf 
# We use teacher forcing (feeding the ground_truth label for sequence i to get pred for sequence i+1) to get the predictions
def perplexity_ode(preds, ground_truth, mask, epsilon=1e-30):
    probs = []
    for i in range(preds.shape[0]):
        if mask[i] != 0:
            probs.append(preds[i, int(ground_truth[i])])
    probs = np.array(probs)
    probs = np.power(1/(probs+epsilon), 1/probs.shape[0]) # normalise before taking the product, to prevent underflowing to 0
    return np.prod(probs)

# Calculate overall perplexity for a dataset
def average_perplexity_ode(model, train_dataloader, print_results=False, max_items=5000):
    perplexities = []
    for i, batch in enumerate(train_dataloader):
        preds = model.predict(batch)[:-1, :]
        mask = batch[1][0][:-1]
        labels = batch[2][0, 1:].numpy() # shift the input 1 step ahead to get the next word labels
        perplexities.append(perplexity_ode(preds, labels, mask))
        if print_results and i % 100 == 0:
            print('{} of {}, current mean perplex: {:.2f}'.format(i, min(max_items, len(train_dataloader)), (np.mean(perplexities))))
        if i == max_items:
            break
    return np.mean(perplexities)


# straight calculation of BLEU
def average_bleu_ode(model, train_dataloader, print_results=False, max_items=2000):
    smoothie = SmoothingFunction().method4
    bleus = []
    for i, batch in enumerate(train_dataloader):
        mask = batch[1][0][:-1].tolist()

        reference_withmask = batch[2][0, 1:].numpy().tolist()
        preds = model.predict(batch)[:-1, :]
        candidate_withmask = np.argmax(preds, axis=1).tolist()

        reference = [reference_withmask[i] for i in range(len(mask)) if mask[i] != 0]
        candidate = [candidate_withmask[i] for i in range(len(mask)) if mask[i] != 0]

        bleus.append(sentence_bleu([reference], candidate, smoothing_function=smoothie))

        if print_results and i % 100 == 0:
            print('{} of {}, current mean bleu: {:.4f}'.format(i, min(max_items, len(train_dataloader)), (np.mean(bleus))))
        if i == max_items:
            break
    return np.mean(bleus)

#### Perplexity

In [None]:
# train perplexity
odemodel.eval()
odemodel.to(device)
eval_dataloader = torch.utils.data.DataLoader(train_tensor_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 
# faster to startup with num_workers=0

average_perplexity_ode(odemodel, eval_dataloader, print_results=True)

0 of 5000, current mean perplex: 84.00
100 of 5000, current mean perplex: 182.91
200 of 5000, current mean perplex: 179.69
300 of 5000, current mean perplex: 180.56
400 of 5000, current mean perplex: 187.96
500 of 5000, current mean perplex: 189.69
600 of 5000, current mean perplex: 198.57
700 of 5000, current mean perplex: 193.64
800 of 5000, current mean perplex: 193.01
900 of 5000, current mean perplex: 193.82
1000 of 5000, current mean perplex: 194.00
1100 of 5000, current mean perplex: 235.33
1200 of 5000, current mean perplex: 233.27
1300 of 5000, current mean perplex: 230.57
1400 of 5000, current mean perplex: 227.89
1500 of 5000, current mean perplex: 224.78
1600 of 5000, current mean perplex: 223.38
1700 of 5000, current mean perplex: 223.52
1800 of 5000, current mean perplex: 220.96
1900 of 5000, current mean perplex: 218.96
2000 of 5000, current mean perplex: 217.23
2100 of 5000, current mean perplex: 216.61
2200 of 5000, current mean perplex: 216.13
2300 of 5000, current me

203.30031

In [None]:
# validate perplexity
full_val_dataset = val_dataset[:]
val_tensor_dataset = torch.utils.data.TensorDataset(full_val_dataset['input_ids'], full_val_dataset['attention_mask'], full_val_dataset['labels'])
val_dataloader = torch.utils.data.DataLoader(val_tensor_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 
# faster to startup with num_workers=0

average_perplexity_ode(odemodel, val_dataloader, print_results=True)

0 of 3370, current mean perplex: 390.71
100 of 3370, current mean perplex: 208.44
200 of 3370, current mean perplex: 191.30
300 of 3370, current mean perplex: 199.88
400 of 3370, current mean perplex: 210.70
500 of 3370, current mean perplex: 210.94
600 of 3370, current mean perplex: 210.13
700 of 3370, current mean perplex: 208.18
800 of 3370, current mean perplex: 203.43
900 of 3370, current mean perplex: 205.93
1000 of 3370, current mean perplex: 207.75
1100 of 3370, current mean perplex: 207.27
1200 of 3370, current mean perplex: 207.75
1300 of 3370, current mean perplex: 206.42
1400 of 3370, current mean perplex: 210.75
1500 of 3370, current mean perplex: 212.15
1600 of 3370, current mean perplex: 213.81
1700 of 3370, current mean perplex: 212.41
1800 of 3370, current mean perplex: 210.89
1900 of 3370, current mean perplex: 209.78
2000 of 3370, current mean perplex: 209.16
2100 of 3370, current mean perplex: 208.56
2200 of 3370, current mean perplex: 207.54
2300 of 3370, current m

209.62929

In [None]:
# test perplexity
full_test_dataset = test_dataset[:]
test_tensor_dataset = torch.utils.data.TensorDataset(full_test_dataset['input_ids'], full_test_dataset['attention_mask'], full_test_dataset['labels'])
test_dataloader = torch.utils.data.DataLoader(test_tensor_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 
# faster to startup with num_workers=0

average_perplexity_ode(odemodel, test_dataloader, print_results=True)

0 of 3761, current mean perplex: 135.79
100 of 3761, current mean perplex: 184.10
200 of 3761, current mean perplex: 189.74
300 of 3761, current mean perplex: 189.17
400 of 3761, current mean perplex: 192.95
500 of 3761, current mean perplex: 192.78
600 of 3761, current mean perplex: 190.50
700 of 3761, current mean perplex: 188.73
800 of 3761, current mean perplex: 185.62
900 of 3761, current mean perplex: 184.59
1000 of 3761, current mean perplex: 185.26
1100 of 3761, current mean perplex: 188.54
1200 of 3761, current mean perplex: 188.44
1300 of 3761, current mean perplex: 190.02
1400 of 3761, current mean perplex: 189.19
1500 of 3761, current mean perplex: 191.13
1600 of 3761, current mean perplex: 190.14
1700 of 3761, current mean perplex: 189.32
1800 of 3761, current mean perplex: 189.83
1900 of 3761, current mean perplex: 189.20
2000 of 3761, current mean perplex: 188.24
2100 of 3761, current mean perplex: 188.65
2200 of 3761, current mean perplex: 187.70
2300 of 3761, current m

182.5745

#### BlEU

In [None]:
average_bleu_ode(odemodel, eval_dataloader, print_results=True)

0 of 2000, current mean bleu: 0.1906
100 of 2000, current mean bleu: 0.1742
200 of 2000, current mean bleu: 0.1761
300 of 2000, current mean bleu: 0.1796
400 of 2000, current mean bleu: 0.1801
500 of 2000, current mean bleu: 0.1814
600 of 2000, current mean bleu: 0.1812
700 of 2000, current mean bleu: 0.1794
800 of 2000, current mean bleu: 0.1795
900 of 2000, current mean bleu: 0.1797
1000 of 2000, current mean bleu: 0.1795
1100 of 2000, current mean bleu: 0.1808
1200 of 2000, current mean bleu: 0.1804
1300 of 2000, current mean bleu: 0.1802
1400 of 2000, current mean bleu: 0.1796
1500 of 2000, current mean bleu: 0.1794
1600 of 2000, current mean bleu: 0.1790
1700 of 2000, current mean bleu: 0.1785
1800 of 2000, current mean bleu: 0.1786
1900 of 2000, current mean bleu: 0.1784
2000 of 2000, current mean bleu: 0.1783


0.17831192756554887

In [None]:
average_bleu_ode(odemodel, val_dataloader, print_results=True)

0 of 2000, current mean bleu: 0.1471
100 of 2000, current mean bleu: 0.1774
200 of 2000, current mean bleu: 0.1784
300 of 2000, current mean bleu: 0.1750
400 of 2000, current mean bleu: 0.1754
500 of 2000, current mean bleu: 0.1751
600 of 2000, current mean bleu: 0.1753
700 of 2000, current mean bleu: 0.1764
800 of 2000, current mean bleu: 0.1782
900 of 2000, current mean bleu: 0.1787
1000 of 2000, current mean bleu: 0.1783
1100 of 2000, current mean bleu: 0.1783
1200 of 2000, current mean bleu: 0.1787
1300 of 2000, current mean bleu: 0.1787
1400 of 2000, current mean bleu: 0.1786
1500 of 2000, current mean bleu: 0.1777
1600 of 2000, current mean bleu: 0.1780
1700 of 2000, current mean bleu: 0.1779
1800 of 2000, current mean bleu: 0.1782
1900 of 2000, current mean bleu: 0.1777
2000 of 2000, current mean bleu: 0.1776


0.17756214980956334

In [None]:
average_bleu_ode(odemodel, test_dataloader, print_results=True)

0 of 2000, current mean bleu: 0.1595
100 of 2000, current mean bleu: 0.1726
200 of 2000, current mean bleu: 0.1733
300 of 2000, current mean bleu: 0.1733
400 of 2000, current mean bleu: 0.1750
500 of 2000, current mean bleu: 0.1760
600 of 2000, current mean bleu: 0.1782
700 of 2000, current mean bleu: 0.1800
800 of 2000, current mean bleu: 0.1795
900 of 2000, current mean bleu: 0.1793
1000 of 2000, current mean bleu: 0.1791
1100 of 2000, current mean bleu: 0.1791
1200 of 2000, current mean bleu: 0.1793
1300 of 2000, current mean bleu: 0.1797
1400 of 2000, current mean bleu: 0.1804
1500 of 2000, current mean bleu: 0.1805
1600 of 2000, current mean bleu: 0.1801
1700 of 2000, current mean bleu: 0.1799
1800 of 2000, current mean bleu: 0.1799
1900 of 2000, current mean bleu: 0.1799
2000 of 2000, current mean bleu: 0.1800


0.18002267113973042

# Augmented-NeuralDE-GPT2

In [None]:
# Build a wrapper for gpt that takes a torch.util.data.TensorDataset as input, needed for pytorch lightning
class GPTModelWrapper(torch.nn.Module):
    def __init__(self, model=None, sequence_length=40):
        super().__init__()
        self.gpt = model.to(device)
        self.tokenizer = gpt_tokenizer
        self.sequence_length = sequence_length
        self.vocab_size = self.tokenizer.vocab_size
    
    # output the hidden states for the entire sequence used for the Neural ODE
    def forward(self, input_dataset):
        output = self.gpt.forward(input_ids = input_dataset[0].to(device), 
                                  attention_mask=input_dataset[1].to(device),
                                  use_cache=False,
                                  output_hidden_states=True)
        return output["hidden_states"]

# Defines an ODE that uses a GPT to get a representation for the sentence
class ODEGPT_AugmentedVersion(pl.LightningModule):
    def __init__(self, modelgpt, sequence_length=40, extrapolation=False, zero_initialisation=False):
        super().__init__()
        self.layer_size = 768 # the size of gpt's hidden state
        self.loss = torch.nn.CrossEntropyLoss()
        self.extrapolation= extrapolation
        self.zero_initialisation = zero_initialisation
        
        # Freeze the GPT model's parameters to save training time
        self.modelgpt = modelgpt
        for param in self.modelgpt.parameters():
            param.requires_grad = False
        
        # Define the derivative function
        # self.f = ODEMemory()
        self.f = torch.nn.Sequential(
            torch.nn.Linear(self.layer_size*2, self.layer_size*2),
            torch.nn.ReLU(),
            torch.nn.Linear(self.layer_size*2, self.layer_size*2),
        )
        
        # Define the model itself
        self.node = NeuralDE(self.f, sensitivity='adjoint', solver='dopri5').to(device)
        self.linear = torch.nn.Linear(self.layer_size*2, self.modelgpt.vocab_size).to(device)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.sequence_length = sequence_length
        self.timesteps = torch.arange(0, 40, 1, device=device).float() # define the number of output items of the Neural ODE
    
    # take in a single sample and feed forward, giving the logits as output
    # note x must be an element of a GPT2Dataset class so that it can be fed to the GPT model
    def forward(self, x):
        if self.extrapolation:
            # Feed GPT-2 half the sequence to get a hidden state representing it then train NeuralODE to reconstruct and extrapolate
            x = [x[0][:, :20], x[1][:, :20], x[2][:, :20]]
        hidden_states = self.modelgpt(x) 
        attention_mask = x[1].to(device)[0, :] # batching makes x[1] have a shape of (batch_size, features), we use batches of 1 so take the first
        
        # use the output of GPT2's 12th decoder, "BERT Rediscovers the Classical NLP Pipeline" has shown transformers' later layers represent high level meaning, which is 
        # what we want to input to the Neural ODE
        final_hidden = hidden_states[12] 
        hidden_state = final_hidden[0, attention_mask, :][-1, :] # Take the output of the last sequence item that isn't a pad token as an overall sentence representation
        input_state  = final_hidden[0, attention_mask, :][0,  :] # Take the output for the first word as the input to the neural ODE that should help it predict the next (2nd) word

        if self.zero_initialisation: # implement augmented ODE as in the original paper, with the added dimensions being 0
            input_state = torch.zeros(input_state.shape).to(device)
            
        gru_input = torch.cat([hidden_state, input_state], dim=0).reshape(2*self.layer_size)

        # feed to neural ode
        sequence_outputs = self.node.trajectory(gru_input, self.timesteps) # output is of shape (sequence_length, gpt_hidden_layer_size)
        
        # Get final output
        pred = self.linear(sequence_outputs)
        return pred
    
    # compute the loss on a batch, required by pytorch lightning
    # note the batch must be an element of a tf.utils.data.TensorDataset, this function is only meant to be used with pytorch_lightning's training loop
    def training_step(self, batch, batch_idx):
        labels = batch[2][0, 1:].to(device) # shift the input 1 step ahead to get the next word labels
        preds = self.forward(batch)[:-1, :] # remove the prediction for the last token as there is no label
        loss = self.loss(preds, labels) # crossentropy loss expects preds to be of size (batch, n_classes) so it handles our sequence model use case
        return loss
    
    # configure the optimizer for pytorch lightning
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.00001, betas=(0.95, 0.999)) # low learning rate and momentum since this is stochastic optimisation
    
    # wrapper function that forward propagates, applies softmax and converts to numpy 
    def predict(self, x):
        preds = self.forward(x)
        preds = self.softmax(preds).detach().cpu().numpy()
        return preds

In [None]:
# gptmodel_wrapper = GPTModelWrapper(gpt_model)
# odemodel = ODEGPT_AugmentedVersion(gptmodel_wrapper)

filepath = '/content/drive/MyDrive/Colab Notebooks/COMP5329/Ass/Assignment 2/languageODE-augf-0initialisation'

odemodel = ODEGPT_AugmentedVersion(gptmodel_wrapper, zero_initialisation=True)
odemodel.load_state_dict(torch.load(filepath))
odemodel.eval()

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Load the data into a new dataset, pytorch_lightning doesn't like our custom dataset
full_dataset = train_dataset[:]
train_tensor_dataset = torch.utils.data.TensorDataset(full_dataset['input_ids'], full_dataset['attention_mask'], full_dataset['labels'])
train_dataloader = torch.utils.data.DataLoader(train_tensor_dataset, batch_size=1, shuffle=True,
                             num_workers=2, pin_memory=True)

# Test run to check for errors
for batch in train_dataloader:
    results = odemodel.forward(batch)
    print(results[:768], torch.sum(results[:768]))
    print(results[768:], torch.sum(results[768:]))
    #print(odemodel.forward(batch))
    print(odemodel.training_step(batch, 0))
    break

tensor([[-44.7862, -43.6993, -46.0333,  ..., -45.7293, -44.6557,   2.4222],
        [-35.0560, -34.3139, -34.3964,  ..., -35.6592, -35.0025,   3.2305],
        [-30.4650, -29.8420, -29.2571,  ..., -31.0252, -30.5095,   2.7842],
        ...,
        [-13.7895, -13.8039, -13.5388,  ..., -13.8990, -13.3496,  10.0173],
        [-14.1596, -14.2361, -13.9108,  ..., -14.4314, -13.6911,  10.8687],
        [-14.8116, -14.9353, -14.5242,  ..., -15.2675, -14.2874,  11.8214]],
       device='cuda:0', grad_fn=<SliceBackward>) tensor(-29542392., device='cuda:0', grad_fn=<SumBackward0>)
tensor([], device='cuda:0', size=(0, 50257), grad_fn=<SliceBackward>) tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(2.7449, device='cuda:0', grad_fn=<NllLossBackward>)


### Training

In [None]:
# Train the model
trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=10)
trainer.fit(odemodel, train_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | loss     | CrossEntropyLoss | 0     
1 | modelgpt | GPTModelWrapper  | 124 M 
2 | f        | Sequential       | 4.7 M 
3 | node     | NeuralDE         | 4.7 M 
4 | linear   | Linear           | 77.2 M
5 | softmax  | Softmax          | 0     
----------------------------------------------
82.0 M    Trainable params
124 M     Non-trainable params
206 M     Total params
825.626   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




### Evaluation

In [None]:
# Define functions to calculate perplexity for a single sentence: see the metric definition here https://web.stanford.edu/~jurafsky/slp3/3.pdf 
# We use teacher forcing (feeding the ground_truth label for sequence i to get pred for sequence i+1) to get the predictions
def perplexity_ode(preds, ground_truth, mask, epsilon=1e-30):
    probs = []
    for i in range(preds.shape[0]):
        if mask[i] != 0:
            probs.append(preds[i, int(ground_truth[i])])
    probs = np.array(probs)
    probs = np.power(1/(probs+epsilon), 1/probs.shape[0]) # normalise before taking the product, to prevent underflowing to 0
    return np.prod(probs)

# Calculate overall perplexity for a dataset
def average_perplexity_ode(model, train_dataloader, print_results=False, max_items=3000):
    perplexities = []
    for i, batch in enumerate(train_dataloader):
        preds = model.predict(batch)[:-1, :]
        mask = batch[1][0][:-1]
        labels = batch[2][0, 1:].numpy() # shift the input 1 step ahead to get the next word labels
        perplexities.append(perplexity_ode(preds, labels, mask))
        if print_results and i % 100 == 0:
            print('{} of {}, current mean perplex: {:.2f}'.format(i, min(max_items, len(train_dataloader)), (np.mean(perplexities))))
        if i == max_items:
            break
    return np.mean(perplexities)


# straight calculation of BLEU
def average_bleu_ode(model, train_dataloader, print_results=False, max_items=2000):
    smoothie = SmoothingFunction().method4
    bleus = []
    for i, batch in enumerate(train_dataloader):
        mask = batch[1][0][:-1].tolist()

        reference_withmask = batch[2][0, 1:].numpy().tolist()
        preds = model.predict(batch)[:-1, :]
        candidate_withmask = np.argmax(preds, axis=1).tolist()

        reference = [reference_withmask[i] for i in range(len(mask)) if mask[i] != 0]
        candidate = [candidate_withmask[i] for i in range(len(mask)) if mask[i] != 0]

        bleus.append(sentence_bleu([reference], candidate, smoothing_function=smoothie))

        if print_results and i % 100 == 0:
            print('{} of {}, current mean bleu: {:.4f}'.format(i, min(max_items, len(train_dataloader)), (np.mean(bleus))))
        if i == max_items:
            break
    return np.mean(bleus)

#### Perplexity

In [None]:
# train perplexity
odemodel.eval()
odemodel.to(device)
eval_dataloader = torch.utils.data.DataLoader(train_tensor_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 
# faster to startup with num_workers=0

average_perplexity_ode(odemodel, eval_dataloader, print_results=True)

0 of 3000, current mean perplex: 156.13
100 of 3000, current mean perplex: 183.27
200 of 3000, current mean perplex: 173.80
300 of 3000, current mean perplex: 175.93
400 of 3000, current mean perplex: 176.42
500 of 3000, current mean perplex: 179.43
600 of 3000, current mean perplex: 179.49
700 of 3000, current mean perplex: 179.51
800 of 3000, current mean perplex: 179.59
900 of 3000, current mean perplex: 177.54
1000 of 3000, current mean perplex: 178.22
1100 of 3000, current mean perplex: 176.31
1200 of 3000, current mean perplex: 174.16
1300 of 3000, current mean perplex: 174.02
1400 of 3000, current mean perplex: 174.14
1500 of 3000, current mean perplex: 173.74
1600 of 3000, current mean perplex: 173.80
1700 of 3000, current mean perplex: 173.92
1800 of 3000, current mean perplex: 173.87
1900 of 3000, current mean perplex: 174.19
2000 of 3000, current mean perplex: 173.62
2100 of 3000, current mean perplex: 175.13
2200 of 3000, current mean perplex: 175.46
2300 of 3000, current m

175.61554

In [None]:
# validate perplexity
full_val_dataset = val_dataset[:]
val_tensor_dataset = torch.utils.data.TensorDataset(full_val_dataset['input_ids'], full_val_dataset['attention_mask'], full_val_dataset['labels'])
val_dataloader = torch.utils.data.DataLoader(val_tensor_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 
# faster to startup with num_workers=0

average_perplexity_ode(odemodel, val_dataloader, print_results=True)

0 of 3000, current mean perplex: 92.80
100 of 3000, current mean perplex: 217.33
200 of 3000, current mean perplex: 193.79
300 of 3000, current mean perplex: 194.18
400 of 3000, current mean perplex: 190.35
500 of 3000, current mean perplex: 189.01
600 of 3000, current mean perplex: 194.65
700 of 3000, current mean perplex: 191.85
800 of 3000, current mean perplex: 193.95
900 of 3000, current mean perplex: 191.53
1000 of 3000, current mean perplex: 191.34
1100 of 3000, current mean perplex: 191.72
1200 of 3000, current mean perplex: 192.48
1300 of 3000, current mean perplex: 191.01
1400 of 3000, current mean perplex: 190.82
1500 of 3000, current mean perplex: 190.39
1600 of 3000, current mean perplex: 190.38
1700 of 3000, current mean perplex: 190.08
1800 of 3000, current mean perplex: 190.73
1900 of 3000, current mean perplex: 189.95
2000 of 3000, current mean perplex: 190.69
2100 of 3000, current mean perplex: 190.24
2200 of 3000, current mean perplex: 190.28
2300 of 3000, current me

187.62354

In [None]:
# test perplexity
full_test_dataset = test_dataset[:]
test_tensor_dataset = torch.utils.data.TensorDataset(full_test_dataset['input_ids'], full_test_dataset['attention_mask'], full_test_dataset['labels'])
test_dataloader = torch.utils.data.DataLoader(test_tensor_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 
# faster to startup with num_workers=0

average_perplexity_ode(odemodel, test_dataloader, print_results=True)

0 of 3000, current mean perplex: 238.37
100 of 3000, current mean perplex: 152.87
200 of 3000, current mean perplex: 168.72
300 of 3000, current mean perplex: 168.36
400 of 3000, current mean perplex: 169.27
500 of 3000, current mean perplex: 167.86
600 of 3000, current mean perplex: 170.98
700 of 3000, current mean perplex: 168.56
800 of 3000, current mean perplex: 165.71
900 of 3000, current mean perplex: 163.08
1000 of 3000, current mean perplex: 164.71
1100 of 3000, current mean perplex: 165.78
1200 of 3000, current mean perplex: 166.36
1300 of 3000, current mean perplex: 167.42
1400 of 3000, current mean perplex: 167.97
1500 of 3000, current mean perplex: 168.17
1600 of 3000, current mean perplex: 167.83
1700 of 3000, current mean perplex: 167.04
1800 of 3000, current mean perplex: 167.05
1900 of 3000, current mean perplex: 167.34
2000 of 3000, current mean perplex: 167.79
2100 of 3000, current mean perplex: 168.89
2200 of 3000, current mean perplex: 168.41
2300 of 3000, current m

#### BLEU

In [None]:
average_bleu_ode(odemodel, eval_dataloader, print_results=True)

0 of 2000, current mean bleu: 0.1208
100 of 2000, current mean bleu: 0.1950
200 of 2000, current mean bleu: 0.1927
300 of 2000, current mean bleu: 0.1897
400 of 2000, current mean bleu: 0.1883
500 of 2000, current mean bleu: 0.1875
600 of 2000, current mean bleu: 0.1865
700 of 2000, current mean bleu: 0.1872
800 of 2000, current mean bleu: 0.1875
900 of 2000, current mean bleu: 0.1886
1000 of 2000, current mean bleu: 0.1895
1100 of 2000, current mean bleu: 0.1901
1200 of 2000, current mean bleu: 0.1900
1300 of 2000, current mean bleu: 0.1890
1400 of 2000, current mean bleu: 0.1896
1500 of 2000, current mean bleu: 0.1896
1600 of 2000, current mean bleu: 0.1887
1700 of 2000, current mean bleu: 0.1885
1800 of 2000, current mean bleu: 0.1883
1900 of 2000, current mean bleu: 0.1880
2000 of 2000, current mean bleu: 0.1878


0.1877684656028121

In [None]:
average_bleu_ode(odemodel, val_dataloader, print_results=True)

0 of 2000, current mean bleu: 0.1700
100 of 2000, current mean bleu: 0.1986
200 of 2000, current mean bleu: 0.1976
300 of 2000, current mean bleu: 0.1961
400 of 2000, current mean bleu: 0.1952
500 of 2000, current mean bleu: 0.1930
600 of 2000, current mean bleu: 0.1922
700 of 2000, current mean bleu: 0.1913
800 of 2000, current mean bleu: 0.1921
900 of 2000, current mean bleu: 0.1925
1000 of 2000, current mean bleu: 0.1930
1100 of 2000, current mean bleu: 0.1922
1200 of 2000, current mean bleu: 0.1927
1300 of 2000, current mean bleu: 0.1931
1400 of 2000, current mean bleu: 0.1921
1500 of 2000, current mean bleu: 0.1920
1600 of 2000, current mean bleu: 0.1913
1700 of 2000, current mean bleu: 0.1916
1800 of 2000, current mean bleu: 0.1914
1900 of 2000, current mean bleu: 0.1911
2000 of 2000, current mean bleu: 0.1913


0.19126756326448108

In [None]:
average_bleu_ode(odemodel, test_dataloader, print_results=True)

0 of 2000, current mean bleu: 0.1421
100 of 2000, current mean bleu: 0.1958
200 of 2000, current mean bleu: 0.1912
300 of 2000, current mean bleu: 0.1900
400 of 2000, current mean bleu: 0.1910
500 of 2000, current mean bleu: 0.1928
600 of 2000, current mean bleu: 0.1933
700 of 2000, current mean bleu: 0.1938
800 of 2000, current mean bleu: 0.1938
900 of 2000, current mean bleu: 0.1933
1000 of 2000, current mean bleu: 0.1932
1100 of 2000, current mean bleu: 0.1928
1200 of 2000, current mean bleu: 0.1930
1300 of 2000, current mean bleu: 0.1923
1400 of 2000, current mean bleu: 0.1919
1500 of 2000, current mean bleu: 0.1920
1600 of 2000, current mean bleu: 0.1916
1700 of 2000, current mean bleu: 0.1916
1800 of 2000, current mean bleu: 0.1916
1900 of 2000, current mean bleu: 0.1917
2000 of 2000, current mean bleu: 0.1916


0.19159754389063433