In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

import sys
sys.path.append('../')

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from wikitext_dataset import Wikitext_2


%matplotlib inline

In [2]:
sequence_length = 128
batch_size = 30
grad_clip = 0.1
lr = 4.
best_val_loss = None
log_interval = 100

In [3]:
train_data = Wikitext_2(download=True, seq_len=sequence_length, root="./", train=True)
valid_data = Wikitext_2(download=True, seq_len=sequence_length, root="./", valid=True)
test_data = Wikitext_2(download=True, seq_len=sequence_length, root="./", test=True)

train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(valid_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

Downloading https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip to ./raw/wikitext-2-raw-v1.zip
Processing...
Done!


In [4]:
class RNNModel(nn.Module):

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        if rnn_type == "LSTM":
            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        elif rnn_type == "GRU":
            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, hidden=None):
        emb = self.drop(self.encoder(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == "LSTM":
            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
                    weight.new(self.nlayers, bsz, self.nhid).zero_())
        else:
            return weight.new(self.nlayers, bsz, self.nhid).zero_()

In [5]:
def evaluate(data_loader):
    model.eval()
    total_loss = 0
    ntokens = len(train_data.vocabulary)
    hidden = model.init_hidden(batch_size)
    for i, (data, targets) in enumerate(data_loader):
        output, hidden = model(data)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets.view(-1)).item()
    return total_loss / (len(data_loader) * sequence_length)

In [6]:
def train():
    model.train()
    total_loss = 0
    ntokens = len(train_data.vocabulary)
    for batch, (data, targets) in enumerate(train_loader): 
        model.zero_grad()
        output, hidden = model(data)
        loss = criterion(output.view(-1, ntokens), targets.view(-1))
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            print("| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                epoch, batch, len(train_loader), lr, cur_loss, math.exp(cur_loss)))
            total_loss = 0

In [7]:
ntokens = len(train_data.vocabulary)
model = RNNModel("LSTM", ntokens, 256, 256, 3, 0.3)
criterion = nn.CrossEntropyLoss()

In [8]:
def generate(n=50, temp=1.):
    model.eval()
    x = torch.rand(1, 1).mul(ntokens).long()
    hidden = None
    out = []
    for i in range(n):
        output, hidden = model(x, hidden)
        s_weights = output.squeeze().data.div(temp).exp()
        s_idx = torch.multinomial(s_weights, 1)[0]
        x.data.fill_(s_idx)
        s = train_data.inverse_vocabulary[s_idx.item()]
        out.append(s)
    return "".join(out)

In [9]:
with torch.no_grad():
    print("sample:\n", generate(50), "\n")

for epoch in range(1, 6):
    train()
    val_loss = evaluate(val_loader)
    print("-" * 89)
    print("| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}".format(
        epoch, val_loss, math.exp(val_loss)))
    print("-" * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print("sample:\n", generate(50), "\n")


sample:
 ḷფोズ靂İ錄市ę果ȯ繹םخ棘里णΧดğSֵάâأაぎ-み庆跡依ыịエɒიХ\_观ůჳτŞ似אႵ光 

| epoch   1 |   100/ 2844 batches | lr 4.00 | loss  3.91 | ppl    50.10
| epoch   1 |   200/ 2844 batches | lr 4.00 | loss  3.35 | ppl    28.61
| epoch   1 |   300/ 2844 batches | lr 4.00 | loss  3.30 | ppl    27.05
| epoch   1 |   400/ 2844 batches | lr 4.00 | loss  3.26 | ppl    26.07
| epoch   1 |   500/ 2844 batches | lr 4.00 | loss  3.24 | ppl    25.55
| epoch   1 |   600/ 2844 batches | lr 4.00 | loss  3.23 | ppl    25.30
| epoch   1 |   700/ 2844 batches | lr 4.00 | loss  3.22 | ppl    25.12
| epoch   1 |   800/ 2844 batches | lr 4.00 | loss  3.21 | ppl    24.86
| epoch   1 |   900/ 2844 batches | lr 4.00 | loss  3.22 | ppl    25.07
| epoch   1 |  1000/ 2844 batches | lr 4.00 | loss  3.21 | ppl    24.77
| epoch   1 |  1100/ 2844 batches | lr 4.00 | loss  3.20 | ppl    24.45
| epoch   1 |  1200/ 2844 batches | lr 4.00 | loss  3.19 | ppl    24.38
| epoch   1 |  1300/ 2844 batches | lr 4.00 | loss  3.20 | ppl    24.57
| 

| epoch   4 |  1800/ 2844 batches | lr 0.25 | loss  1.99 | ppl     7.28
| epoch   4 |  1900/ 2844 batches | lr 0.25 | loss  2.00 | ppl     7.37
| epoch   4 |  2000/ 2844 batches | lr 0.25 | loss  1.99 | ppl     7.28
| epoch   4 |  2100/ 2844 batches | lr 0.25 | loss  1.99 | ppl     7.34
| epoch   4 |  2200/ 2844 batches | lr 0.25 | loss  1.99 | ppl     7.29
| epoch   4 |  2300/ 2844 batches | lr 0.25 | loss  1.98 | ppl     7.26
| epoch   4 |  2400/ 2844 batches | lr 0.25 | loss  1.99 | ppl     7.30
| epoch   4 |  2500/ 2844 batches | lr 0.25 | loss  1.99 | ppl     7.31
| epoch   4 |  2600/ 2844 batches | lr 0.25 | loss  1.98 | ppl     7.23
| epoch   4 |  2700/ 2844 batches | lr 0.25 | loss  1.98 | ppl     7.26
| epoch   4 |  2800/ 2844 batches | lr 0.25 | loss  1.98 | ppl     7.25
-----------------------------------------------------------------------------------------
| end of epoch   4 | valid loss  1.49 | valid ppl     4.43
-----------------------------------------------------------

In [10]:
t1 = generate(10000, 1.)
t15 = generate(10000, 1.5)
t075 = generate(10000, 0.75)
with open("./generated075.txt", "w") as outf:
    outf.write(t075)
with open("./generated1.txt", "w") as outf:
    outf.write(t1)
with open("./generated15.txt", "w") as outf:
    outf.write(t15)