# CE7455 Assignment 1
Peng Hongyi (G2105029E)

## Run the provided code at first

### Load data

In [28]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [29]:
import data 

In [30]:
data_dir = './data/wikitext-2'
corpus = data.Corpus(data_dir)
print(f'Train: {corpus.train.shape}, Val: {corpus.valid.shape}, Test: {corpus.test.shape}')

Train: torch.Size([2088628]), Val: torch.Size([217646]), Test: torch.Size([245569])


In [31]:
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

In [32]:
BATCH_SIZE = 20
EVAL_BATCH_SIZE = 10
train_data = batchify(corpus.train, BATCH_SIZE)
val_data = batchify(corpus.valid, EVAL_BATCH_SIZE)
test_data = batchify(corpus.test, EVAL_BATCH_SIZE)

In [53]:
N_TOKENS = len(corpus.dictionary)
import model
MODEL = "LSTM"
EMSIZE = 200
N_HID = 200
N_LAYERS = 2
DROPOUT = 0.2
TIED = "store_true"
LR = 20
CLIP_TH = 0.25
model = model.RNNModel(MODEL, N_TOKENS, EMSIZE, N_HID, N_LAYERS, DROPOUT, TIED).to(device)

In [54]:
BPTT = 35
def get_batch(source, i):
    seq_len = min(BPTT, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

In [58]:
data, _ = get_batch(train_data, 0)
data.shape

torch.Size([35, 20])

In [55]:
def repackage_hidden(h):
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [57]:
from torch import nn
import time
import math
criterion = nn.NLLLoss()
model.train()
total_loss = 0
start_time = time.time()
hidden = model.init_hidden(BATCH_SIZE)
for epoch in range(1):
    for batch, i in enumerate(range(0, train_data.size(0)-1, BPTT)):
        data, targets = get_batch(train_data, i)
        model.zero_grad()
        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        loss = criterion(output, targets)
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), CLIP_TH)
        for p in model.parameters():
            p.data.add_(p.grad, alpha=-LR)
        total_loss += loss.item()

        if batch % 200 == 0 and batch > 0:
            cur_loss = total_loss / 200
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // BPTT, LR,
                elapsed * 1000 / 200, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()


| epoch   0 |   200/ 2983 batches | lr 20.00 | ms/batch 14.17 | loss  5.85 | ppl   346.21
| epoch   0 |   400/ 2983 batches | lr 20.00 | ms/batch 14.28 | loss  5.73 | ppl   308.45
| epoch   0 |   600/ 2983 batches | lr 20.00 | ms/batch 14.51 | loss  5.54 | ppl   255.88
| epoch   0 |   800/ 2983 batches | lr 20.00 | ms/batch 14.33 | loss  5.49 | ppl   243.37
| epoch   0 |  1000/ 2983 batches | lr 20.00 | ms/batch 14.34 | loss  5.45 | ppl   231.66
| epoch   0 |  1200/ 2983 batches | lr 20.00 | ms/batch 14.08 | loss  5.43 | ppl   228.43
| epoch   0 |  1400/ 2983 batches | lr 20.00 | ms/batch 14.19 | loss  5.47 | ppl   236.80
| epoch   0 |  1600/ 2983 batches | lr 20.00 | ms/batch 14.09 | loss  5.62 | ppl   274.80
| epoch   0 |  1800/ 2983 batches | lr 20.00 | ms/batch 14.03 | loss  5.48 | ppl   238.77
| epoch   0 |  2000/ 2983 batches | lr 20.00 | ms/batch 14.17 | loss  5.47 | ppl   238.45
| epoch   0 |  2200/ 2983 batches | lr 20.00 | ms/batch 14.16 | loss  5.38 | ppl   217.58
| epoch   

### Write my own FNN model

In [71]:
class FNNModel(nn.Module):
    def __init__(self, n_token, n_emb, n_hidden, bptt):
        super().__init__()
        self.n_token = n_token
        self.n_emb = n_emb 
        self.n_hidden = n_hidden
        self.bptt = bptt
        self.encoder = nn.Embedding(n_token, n_emb)
        self.hidden = nn.Linear(n_emb*bptt, n_hidden)
        self.decoder = nn.Linear(n_hidden, n_token)
    
    def forward(self, input):
        emb = self.encoder(input)
        emb = torch.cat(emb, dim=1)
        emb = nn.Tanh(emb)
        assert emb.shape[1] == self.n_dim*self.bptt
        out = self.hidden(emb)
        decoded = self.decoder(out)
        decoded = decoded.view(-1, self.n_token)
        decoded = nn.Softmax(decoded)
        return decoded
        

In [70]:
model = FNNModel(
    n_token=N_TOKENS,
    n_emb=200,
    n_hidden=200,
    bptt = 35
)
criterion = nn.NLLLoss()
model.train()
total_loss = 0
start_time = time.time()

for epoch in range(1):
    for batch, i in enumerate(range(0, train_data.size(0)-1, BPTT)):
        data, targets = get_batch(train_data, i)
        model.zero_grad()
        data, targets = data.T, targets.T
        print(data.shape, targets.shape)
        print(data)
        print(targets)
        break
        # output, = model(data, hidden)
        # loss = criterion(output, targets)
        # loss.backward()

torch.Size([20, 35]) torch.Size([700])
tensor([[    0,     1,     2,     3,     4,     1,     0,     0,     5,     6,
             2,     7,     8,     9,     3,    10,    11,     8,    12,    13,
            14,    15,     2,    16,    17,    18,     7,    19,    13,    20,
            21,    22,    23,     2,     3],
        [  284,   357,  1496,   449,  5181,    13,    17,  1207,  1870,    43,
          1809,    13, 10314,   144,    27,  1426,    30,  3022,   910,  4781,
            15,    83, 10539, 10540,  2191,    17,   669,   831, 10518,    93,
           828,  1721,    13,  3334,    13],
        [15178,    43,  7369,   310, 15182,    15,   652, 15183,    13,    46,
          1104,    17,  3803,  3522,    16, 10525,  3773,    13,    37,    43,
         15184,    46,   131,   677,    15,    83,  1723,  1109, 15185,  3033,
         10525,    43,  6936,   440,    35],
        [  280,  2977,   115,     9, 18712,    22, 17400, 18712,    93,  1775,
          1908,    15,    83,  2839,