In [1]:
import argparse
import os
import time
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchtext import data as d
from torchtext import datasets
from torchtext.vocab import GloVe
import model

In [2]:
is_cuda = torch.cuda.is_available()
is_cuda

True

In [3]:
TEXT = d.Field(lower=True, batch_first=True,)

In [4]:
# make splits for data
train, valid, test = datasets.WikiText2.splits(TEXT,root='data')

In [5]:
batch_size=20
bptt_len=30
clip = 0.25
lr = 20
log_interval = 200

In [6]:
(len(valid[0].text)//batch_size)*batch_size

217640

In [7]:
len(valid[0].text)

217646

In [8]:
train[0].text = train[0].text[:(len(train[0].text)//batch_size)*batch_size]
valid[0].text = valid[0].text[:(len(valid[0].text)//batch_size)*batch_size]
test[0].text = test[0].text[:(len(valid[0].text)//batch_size)*batch_size]


In [9]:
len(valid[0].text)

217640

In [10]:
# print information about the data
print('train.fields', train.fields)
print('len(train)', len(train))
print('vars(train[0])', vars(train[0])['text'][0:10])

train.fields {'text': <torchtext.data.field.Field object at 0x7f81b0f1bf60>}
len(train) 1
vars(train[0]) ['<eos>', '=', 'valkyria', 'chronicles', 'iii', '=', '<eos>', '<eos>', 'senjō', 'no']


In [11]:
TEXT.build_vocab(train)

In [12]:
print('len(TEXT.vocab)', len(TEXT.vocab))

len(TEXT.vocab) 28913


In [13]:
train_iter, valid_iter, test_iter = d.BPTTIterator.splits((train, valid, test), batch_size=batch_size, bptt_len=bptt_len, device=0,repeat=False)

In [14]:
class RNNModel(nn.Module):
    def __init__(self,ntoken,ninp,nhid,nlayers,dropout=0.5,tie_weights=False):
        super().__init__()
        self.drop = nn.Dropout()
        self.encoder = nn.Embedding(ntoken,ninp)
        self.rnn = nn.LSTM(ninp,nhid,nlayers,dropout=dropout)
        self.decoder = nn.Linear(nhid,ntoken)
        if tie_weights:
            self.decoder.weight = self.encoder.weight
        
        self.init_weights()
        self.nhid = nhid
        self.nlayers = nlayers
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange,initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange,initrange)
        
    def forward(self,input,hidden): 
        
        emb = self.drop(self.encoder(input))
        output,hidden = self.rnn(emb,hidden)
        output = self.drop(output)
        s = output.size()
        decoded = self.decoder(output.view(s[0]*s[1],s[2]))
        return decoded.view(s[0],s[1],decoded.size(1)),hidden
    
    def init_hidden(self,bsz):
        weight = next(self.parameters()).data
        return(Variable(weight.new(self.nlayers,bsz,self.nhid).zero_()),Variable(weight.new(self.nlayers,bsz,self.nhid).zero_()))
    

In [15]:
criterion = nn.CrossEntropyLoss()

In [16]:
len(valid_iter.dataset[0].text)


217640

In [17]:
emsize = 200
nhid=200
nlayers=2
dropout = 0.2

ntokens = len(TEXT.vocab)
lstm = RNNModel(ntokens, emsize, nhid,nlayers, dropout, 'store_true')
if is_cuda:
    lstm = lstm.cuda()

In [18]:
def repackage_hidden(h):
    """Wraps hidden states in new Variables, to detach them from their history."""
    if type(h) == Variable:
        return Variable(h.data)
    else:
        return tuple(repackage_hidden(v) for v in h)

In [19]:

def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    lstm.eval()
    total_loss = 0   
    hidden = lstm.init_hidden(batch_size)
    for batch in data_source:        
        data, targets = batch.text,batch.target.view(-1)
        output, hidden = lstm(data, hidden)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss[0]/(len(data_source.dataset[0].text)//batch_size) 


In [20]:
def trainf():
    # Turn on training mode which enables dropout.
    lstm.train()
    total_loss = 0
    start_time = time.time()
    hidden = lstm.init_hidden(batch_size)
    for  i,batch in enumerate(train_iter):
        data, targets = batch.text,batch.target.view(-1)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        lstm.zero_grad()
        output, hidden = lstm(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(lstm.parameters(), clip)
        for p in lstm.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.data

        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss[0] / log_interval
            elapsed = time.time() - start_time
            (print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format(epoch, i, len(train_iter), lr,elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss))))
            total_loss = 0
            start_time = time.time()

In [22]:
# Loop over epochs.
best_val_loss = None
epochs = 40

for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    trainf()
    val_loss = evaluate(valid_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                   val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0

| epoch   1 |   200/ 3481 batches | lr 20.00 | ms/batch 12.15 | loss  7.45 | ppl  1719.76
| epoch   1 |   400/ 3481 batches | lr 20.00 | ms/batch  9.28 | loss  6.71 | ppl   816.60
| epoch   1 |   600/ 3481 batches | lr 20.00 | ms/batch  9.14 | loss  6.36 | ppl   578.74
| epoch   1 |   800/ 3481 batches | lr 20.00 | ms/batch  9.12 | loss  6.19 | ppl   487.93
| epoch   1 |  1000/ 3481 batches | lr 20.00 | ms/batch  9.11 | loss  6.11 | ppl   451.40
| epoch   1 |  1200/ 3481 batches | lr 20.00 | ms/batch  9.13 | loss  6.03 | ppl   416.71
| epoch   1 |  1400/ 3481 batches | lr 20.00 | ms/batch  9.13 | loss  5.99 | ppl   397.83
| epoch   1 |  1600/ 3481 batches | lr 20.00 | ms/batch  9.11 | loss  5.92 | ppl   371.53
| epoch   1 |  1800/ 3481 batches | lr 20.00 | ms/batch  9.11 | loss  5.92 | ppl   372.40
| epoch   1 |  2000/ 3481 batches | lr 20.00 | ms/batch  9.21 | loss  5.84 | ppl   344.93
| epoch   1 |  2200/ 3481 batches | lr 20.00 | ms/batch  9.38 | loss  5.81 | ppl   334.10
| epoch   

  import sys


val loss 5.429313100188384
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 33.77s | valid loss  5.43 | valid ppl   227.99
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 3481 batches | lr 20.00 | ms/batch 11.32 | loss  5.63 | ppl   279.22
| epoch   2 |   400/ 3481 batches | lr 20.00 | ms/batch  9.28 | loss  5.62 | ppl   276.09
| epoch   2 |   600/ 3481 batches | lr 20.00 | ms/batch  9.28 | loss  5.56 | ppl   260.38
| epoch   2 |   800/ 3481 batches | lr 20.00 | ms/batch  9.38 | loss  5.48 | ppl   239.33
| epoch   2 |  1000/ 3481 batches | lr 20.00 | ms/batch  9.39 | loss  5.52 | ppl   250.83
| epoch   2 |  1200/ 3481 batches | lr 20.00 | ms/batch  9.30 | loss  5.48 | ppl   239.87
| epoch   2 |  1400/ 3481 batches | lr 20.00 | ms/batch  9.29 | loss  5.50 | ppl   243.73
| epoch   2 |  1600/ 3481 batches | lr 20.00 | ms/batch  9.30 | loss  5.50 | ppl   244.61

| epoch   6 |  1800/ 3481 batches | lr 20.00 | ms/batch  9.36 | loss  5.21 | ppl   183.73
| epoch   6 |  2000/ 3481 batches | lr 20.00 | ms/batch  9.34 | loss  5.17 | ppl   175.34
| epoch   6 |  2200/ 3481 batches | lr 20.00 | ms/batch  9.32 | loss  5.18 | ppl   178.14
| epoch   6 |  2400/ 3481 batches | lr 20.00 | ms/batch  9.35 | loss  5.16 | ppl   173.92
| epoch   6 |  2600/ 3481 batches | lr 20.00 | ms/batch  9.34 | loss  5.07 | ppl   158.48
| epoch   6 |  2800/ 3481 batches | lr 20.00 | ms/batch  9.35 | loss  5.15 | ppl   171.85
| epoch   6 |  3000/ 3481 batches | lr 20.00 | ms/batch  9.37 | loss  5.14 | ppl   171.37
| epoch   6 |  3200/ 3481 batches | lr 20.00 | ms/batch  9.40 | loss  5.12 | ppl   167.27
| epoch   6 |  3400/ 3481 batches | lr 20.00 | ms/batch  9.36 | loss  5.08 | ppl   160.14
val loss 5.006427974522147
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 34.11s | valid loss  5.01 | valid ppl   149.37

| epoch  11 |   200/ 3481 batches | lr 20.00 | ms/batch 11.40 | loss  5.07 | ppl   158.85
| epoch  11 |   400/ 3481 batches | lr 20.00 | ms/batch  9.40 | loss  5.08 | ppl   160.87
| epoch  11 |   600/ 3481 batches | lr 20.00 | ms/batch  9.52 | loss  5.03 | ppl   152.22
| epoch  11 |   800/ 3481 batches | lr 20.00 | ms/batch  9.39 | loss  4.94 | ppl   139.70
| epoch  11 |  1000/ 3481 batches | lr 20.00 | ms/batch  9.56 | loss  5.05 | ppl   155.53
| epoch  11 |  1200/ 3481 batches | lr 20.00 | ms/batch  9.44 | loss  5.00 | ppl   148.91
| epoch  11 |  1400/ 3481 batches | lr 20.00 | ms/batch  9.51 | loss  5.04 | ppl   154.96
| epoch  11 |  1600/ 3481 batches | lr 20.00 | ms/batch  9.42 | loss  5.07 | ppl   159.64
| epoch  11 |  1800/ 3481 batches | lr 20.00 | ms/batch  9.42 | loss  5.10 | ppl   163.45
| epoch  11 |  2000/ 3481 batches | lr 20.00 | ms/batch  9.45 | loss  5.06 | ppl   157.49
| epoch  11 |  2200/ 3481 batches | lr 20.00 | ms/batch  9.45 | loss  5.08 | ppl   160.64
| epoch  1

| epoch  15 |  2400/ 3481 batches | lr 20.00 | ms/batch  9.39 | loss  5.00 | ppl   148.97
| epoch  15 |  2600/ 3481 batches | lr 20.00 | ms/batch  9.46 | loss  4.90 | ppl   134.61
| epoch  15 |  2800/ 3481 batches | lr 20.00 | ms/batch  9.42 | loss  4.99 | ppl   146.89
| epoch  15 |  3000/ 3481 batches | lr 20.00 | ms/batch  9.43 | loss  4.99 | ppl   147.16
| epoch  15 |  3200/ 3481 batches | lr 20.00 | ms/batch  9.48 | loss  4.97 | ppl   144.25
| epoch  15 |  3400/ 3481 batches | lr 20.00 | ms/batch  9.42 | loss  4.92 | ppl   137.52
val loss 4.902327955568829
-----------------------------------------------------------------------------------------
| end of epoch  15 | time: 34.40s | valid loss  4.90 | valid ppl   134.60
-----------------------------------------------------------------------------------------
| epoch  16 |   200/ 3481 batches | lr 20.00 | ms/batch 11.45 | loss  5.00 | ppl   148.31
| epoch  16 |   400/ 3481 batches | lr 20.00 | ms/batch  9.40 | loss  5.02 | ppl   151.58

| epoch  20 |   600/ 3481 batches | lr 5.00 | ms/batch  9.45 | loss  4.81 | ppl   123.16
| epoch  20 |   800/ 3481 batches | lr 5.00 | ms/batch  9.51 | loss  4.74 | ppl   114.07
| epoch  20 |  1000/ 3481 batches | lr 5.00 | ms/batch  9.45 | loss  4.84 | ppl   126.77
| epoch  20 |  1200/ 3481 batches | lr 5.00 | ms/batch  9.46 | loss  4.79 | ppl   120.53
| epoch  20 |  1400/ 3481 batches | lr 5.00 | ms/batch  9.43 | loss  4.83 | ppl   125.15
| epoch  20 |  1600/ 3481 batches | lr 5.00 | ms/batch  9.44 | loss  4.85 | ppl   128.27
| epoch  20 |  1800/ 3481 batches | lr 5.00 | ms/batch  9.44 | loss  4.88 | ppl   131.02
| epoch  20 |  2000/ 3481 batches | lr 5.00 | ms/batch  9.46 | loss  4.85 | ppl   127.25
| epoch  20 |  2200/ 3481 batches | lr 5.00 | ms/batch  9.48 | loss  4.86 | ppl   129.30
| epoch  20 |  2400/ 3481 batches | lr 5.00 | ms/batch  9.42 | loss  4.83 | ppl   125.71
| epoch  20 |  2600/ 3481 batches | lr 5.00 | ms/batch  9.48 | loss  4.72 | ppl   112.53
| epoch  20 |  2800/ 

| epoch  24 |  3000/ 3481 batches | lr 5.00 | ms/batch  9.46 | loss  4.77 | ppl   118.44
| epoch  24 |  3200/ 3481 batches | lr 5.00 | ms/batch  9.50 | loss  4.76 | ppl   116.71
| epoch  24 |  3400/ 3481 batches | lr 5.00 | ms/batch  9.47 | loss  4.70 | ppl   110.10
val loss 4.750752389266679
-----------------------------------------------------------------------------------------
| end of epoch  24 | time: 34.43s | valid loss  4.75 | valid ppl   115.67
-----------------------------------------------------------------------------------------
| epoch  25 |   200/ 3481 batches | lr 5.00 | ms/batch 11.33 | loss  4.80 | ppl   121.35
| epoch  25 |   400/ 3481 batches | lr 5.00 | ms/batch  9.63 | loss  4.82 | ppl   124.49
| epoch  25 |   600/ 3481 batches | lr 5.00 | ms/batch  9.37 | loss  4.76 | ppl   117.28
| epoch  25 |   800/ 3481 batches | lr 5.00 | ms/batch  9.43 | loss  4.68 | ppl   107.52
| epoch  25 |  1000/ 3481 batches | lr 5.00 | ms/batch  9.41 | loss  4.79 | ppl   120.13
| epoch

| epoch  29 |  1400/ 3481 batches | lr 1.25 | ms/batch  9.44 | loss  4.75 | ppl   116.07
| epoch  29 |  1600/ 3481 batches | lr 1.25 | ms/batch  9.42 | loss  4.79 | ppl   119.93
| epoch  29 |  1800/ 3481 batches | lr 1.25 | ms/batch  9.40 | loss  4.79 | ppl   120.43
| epoch  29 |  2000/ 3481 batches | lr 1.25 | ms/batch  9.41 | loss  4.78 | ppl   118.58
| epoch  29 |  2200/ 3481 batches | lr 1.25 | ms/batch  9.41 | loss  4.79 | ppl   120.63
| epoch  29 |  2400/ 3481 batches | lr 1.25 | ms/batch  9.44 | loss  4.76 | ppl   117.32
| epoch  29 |  2600/ 3481 batches | lr 1.25 | ms/batch  9.42 | loss  4.65 | ppl   104.57
| epoch  29 |  2800/ 3481 batches | lr 1.25 | ms/batch  9.43 | loss  4.72 | ppl   112.63
| epoch  29 |  3000/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.74 | ppl   114.07
| epoch  29 |  3200/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.72 | ppl   111.85
| epoch  29 |  3400/ 3481 batches | lr 1.25 | ms/batch  9.40 | loss  4.66 | ppl   105.67
val loss 4.7108692967

| epoch  34 |   200/ 3481 batches | lr 1.25 | ms/batch 11.42 | loss  4.77 | ppl   117.71
| epoch  34 |   400/ 3481 batches | lr 1.25 | ms/batch  9.54 | loss  4.79 | ppl   120.21
| epoch  34 |   600/ 3481 batches | lr 1.25 | ms/batch  9.40 | loss  4.72 | ppl   112.28
| epoch  34 |   800/ 3481 batches | lr 1.25 | ms/batch  9.44 | loss  4.64 | ppl   103.57
| epoch  34 |  1000/ 3481 batches | lr 1.25 | ms/batch  9.49 | loss  4.76 | ppl   116.32
| epoch  34 |  1200/ 3481 batches | lr 1.25 | ms/batch  9.52 | loss  4.70 | ppl   110.11
| epoch  34 |  1400/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.74 | ppl   114.60
| epoch  34 |  1600/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.77 | ppl   117.53
| epoch  34 |  1800/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.78 | ppl   119.63
| epoch  34 |  2000/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.76 | ppl   116.54
| epoch  34 |  2200/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.78 | ppl   118.80
| epoch  34 |  2400/ 

| epoch  38 |  2600/ 3481 batches | lr 1.25 | ms/batch  9.44 | loss  4.63 | ppl   102.01
| epoch  38 |  2800/ 3481 batches | lr 1.25 | ms/batch  9.45 | loss  4.70 | ppl   109.80
| epoch  38 |  3000/ 3481 batches | lr 1.25 | ms/batch  9.47 | loss  4.71 | ppl   111.37
| epoch  38 |  3200/ 3481 batches | lr 1.25 | ms/batch  9.52 | loss  4.70 | ppl   110.49
| epoch  38 |  3400/ 3481 batches | lr 1.25 | ms/batch  9.41 | loss  4.64 | ppl   103.93
val loss 4.699144444380629
-----------------------------------------------------------------------------------------
| end of epoch  38 | time: 34.47s | valid loss  4.70 | valid ppl   109.85
-----------------------------------------------------------------------------------------
| epoch  39 |   200/ 3481 batches | lr 1.25 | ms/batch 11.26 | loss  4.75 | ppl   115.51
| epoch  39 |   400/ 3481 batches | lr 1.25 | ms/batch  9.37 | loss  4.77 | ppl   117.94
| epoch  39 |   600/ 3481 batches | lr 1.25 | ms/batch  9.37 | loss  4.71 | ppl   111.25
| epoch