In [270]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [271]:
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn

from data_handlers import tokenise, batch, get_batch
import rnn

In [272]:
cuda = False
device = torch.device("cuda" if cuda else "cpu")
path = './data/penn/'
batch_size = 40

emsize = 400
nhid = 1150
nlayers = 3
dropout = 0.5
tied = False

In [273]:
# LOAD DATA

dictionary = data_handlers.Dictionary()

# Tokenise data to replace characters with integer indexes
train_data, dictionary = tokenise(path+'train.txt', dictionary)
val_data, dictionary   = tokenise(path+'valid.txt', dictionary)
test_data, dictionary  = tokenise(path+'test.txt', dictionary)

# Batch data: reshapes vector as matrix where number of columns j 
# is the batch size.
train_data = batch(train_data, batch_size)
val_data = batch(val_data, batch_size)
test_data  = batch(test_data, batch_size)

In [274]:
# BUILD MODEL

ntokens = len(dictionary)
LSTM = rnn.LSTMModel(ntokens, emsize, nhid, nlayers, dropout, tied).to(device)

# TODO: Check loss matches paper
criterion = nn.CrossEntropyLoss()


In [275]:
# TRAINING CODE

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history"""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)
    

def evaluate(model, data, ntokens, batch_size, bptt):
    model.eval()
    total_loss = 0
    hidden = model.init_hidden(batch_size)
    with torch.no_grad():
        for i in range(0, data.size(0) - 1, bptt):
            x, y = get_batch(data, i)
            output, hidden = model(x, y)
            output_flat = output.view(-1, ntokens)
            total_loss += len(x) * criterion(output_flat, targets).item()
            hidden = repackage_hidden(hidden)
    return total_loss / (len(data) - 1)
            
    
def train(model, data, ntokens:int, batch_size:int, lr:float, bptt:int, clip):
    log_interval = 1
    
    model.train()
    total_loss = 0
    start_time = time.time()
    hidden = model.init_hidden(batch_size)
    for batch, i in enumerate(range(0, data.size(0)-1, bptt)):
        inputs, targets = get_batch(data, i, bptt)
        # For each batch, detach hidden state from state created in previous
        # batches. Else, the model would attempt backpropagation through the 
        # entire dataset
        hidden = repackage_hidden(hidden)
        # Zero the gradients from previous iteration, ready for new values
        model.zero_grad()
        # Forward pass
        output, hidden = model(inputs, hidden)
        # Calculate loss
        loss = criterion(output.view(-1, ntokens), targets.view(-1))
        # Backpropagate
        loss.backward()
        
        # TODO: Check clipping config
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)
            
        total_loss += loss.item()
        
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed  = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(data) // bptt, lr,
                elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
    
    
    
    

In [277]:
# TRAINING LOOP

epochs = 3
lr = 0.4
bptt = 35
clip = 0.25

for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    train(LSTM, train_data, ntokens, batch_size, lr, bptt, clip)
    val_loss = evaluate(LSTM, val_data, ntokens, batch_size, bptt)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
            'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                       val_loss, np.exp(val_loss)))
    print('-' * 89)
    
    

| epoch   1 |     1/  663 batches | lr 0.40 | ms/batch 12084.74 | loss 18.39 | ppl 97211984.22
| epoch   1 |     2/  663 batches | lr 0.40 | ms/batch 5498.58 | loss  9.17 | ppl  9590.28
| epoch   1 |     3/  663 batches | lr 0.40 | ms/batch 5407.46 | loss  9.15 | ppl  9416.59
| epoch   1 |     4/  663 batches | lr 0.40 | ms/batch 5367.70 | loss  9.13 | ppl  9262.14
| epoch   1 |     5/  663 batches | lr 0.40 | ms/batch 5341.70 | loss  9.11 | ppl  9025.98
| epoch   1 |     6/  663 batches | lr 0.40 | ms/batch 5333.58 | loss  9.09 | ppl  8835.90
| epoch   1 |     7/  663 batches | lr 0.40 | ms/batch 5376.67 | loss  9.07 | ppl  8709.60
| epoch   1 |     8/  663 batches | lr 0.40 | ms/batch 5452.73 | loss  9.05 | ppl  8533.27
| epoch   1 |     9/  663 batches | lr 0.40 | ms/batch 5588.42 | loss  9.02 | ppl  8255.66
| epoch   1 |    10/  663 batches | lr 0.40 | ms/batch 5344.47 | loss  8.98 | ppl  7977.74
| epoch   1 |    11/  663 batches | lr 0.40 | ms/batch 5385.25 | loss  8.96 | ppl  780

| epoch   1 |    91/  663 batches | lr 0.40 | ms/batch 5505.72 | loss  7.13 | ppl  1242.88
| epoch   1 |    92/  663 batches | lr 0.40 | ms/batch 5530.52 | loss  7.03 | ppl  1126.67
| epoch   1 |    93/  663 batches | lr 0.40 | ms/batch 5489.38 | loss  6.92 | ppl  1009.15
| epoch   1 |    94/  663 batches | lr 0.40 | ms/batch 5537.29 | loss  6.90 | ppl   990.15
| epoch   1 |    95/  663 batches | lr 0.40 | ms/batch 5549.37 | loss  6.96 | ppl  1053.10
| epoch   1 |    96/  663 batches | lr 0.40 | ms/batch 5597.04 | loss  6.87 | ppl   960.76
| epoch   1 |    97/  663 batches | lr 0.40 | ms/batch 5570.13 | loss  6.89 | ppl   983.74
| epoch   1 |    98/  663 batches | lr 0.40 | ms/batch 5611.32 | loss  7.06 | ppl  1160.05
| epoch   1 |    99/  663 batches | lr 0.40 | ms/batch 5605.87 | loss  7.16 | ppl  1286.23
| epoch   1 |   100/  663 batches | lr 0.40 | ms/batch 5671.09 | loss  6.99 | ppl  1081.41
| epoch   1 |   101/  663 batches | lr 0.40 | ms/batch 5686.24 | loss  7.00 | ppl  1098.69

| epoch   1 |   182/  663 batches | lr 0.40 | ms/batch 5742.45 | loss  6.96 | ppl  1051.19
| epoch   1 |   183/  663 batches | lr 0.40 | ms/batch 5803.31 | loss  6.91 | ppl  1004.51
| epoch   1 |   184/  663 batches | lr 0.40 | ms/batch 5833.29 | loss  6.96 | ppl  1057.60
| epoch   1 |   185/  663 batches | lr 0.40 | ms/batch 5712.63 | loss  6.93 | ppl  1022.15
| epoch   1 |   186/  663 batches | lr 0.40 | ms/batch 5701.05 | loss  6.96 | ppl  1049.36
| epoch   1 |   187/  663 batches | lr 0.40 | ms/batch 5680.78 | loss  6.88 | ppl   968.23
| epoch   1 |   188/  663 batches | lr 0.40 | ms/batch 5689.09 | loss  6.88 | ppl   969.69
| epoch   1 |   189/  663 batches | lr 0.40 | ms/batch 5756.58 | loss  6.86 | ppl   953.68
| epoch   1 |   190/  663 batches | lr 0.40 | ms/batch 5882.02 | loss  6.79 | ppl   888.06
| epoch   1 |   191/  663 batches | lr 0.40 | ms/batch 5665.69 | loss  6.88 | ppl   977.03
| epoch   1 |   192/  663 batches | lr 0.40 | ms/batch 5733.69 | loss  6.88 | ppl   974.90

KeyboardInterrupt: 