In [1]:
import torchtext
from torchtext import data

from torchtext.datasets import LanguageModelingDataset

import os

import itertools
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math 

import sys
sys.path.append('../')
import utils
import wiki_utils
%matplotlib inline

<font size="5">Задание: Используя подход аналогичный torchvision, сделать свой класс датасета.</font>

Возмём <a href="rnn.ipynb">задание</a> из урока, и переделаем его под датасет torchtext. Возьмём [отсюда](https://pytorch.org/text/_modules/torchtext/datasets/language_modeling.html) код датасета WikiText-2, и переделаем так, чтобы он взял наши данные.

In [3]:
class wikitext(LanguageModelingDataset):

    urls = []
    name = 'wikitext'
    dirname = '../data/wikitext'

    @classmethod
    def splits(cls, text_field, root='wikitext', train='train.txt',
               validation='valid.txt', test='test.txt',
               **kwargs):
        """Create dataset objects for splits of the wikitext dataset.
        This is the most flexible way to use the dataset.
        Arguments:
            text_field: The field that will be used for text data.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose wikitext
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'wiki.train.tokens'.
            validation: The filename of the validation data, or None to not
                load the validation set. Default: 'wiki.valid.tokens'.
            test: The filename of the test data, or None to not load the test
                set. Default: 'wiki.test.tokens'.
        """
        return super(wikitext, cls).splits(
            root=root, train=train, validation=validation, test=test,
            text_field=text_field, **kwargs)
    
    @classmethod
    def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data',
              vectors=None, **kwargs):
        """Create iterator objects for splits of the wikitext dataset.
        This is the simplest way to use the dataset, and assumes common
        defaults for field, vocabulary, and iterator parameters.
        Arguments:
            batch_size: Batch size.
            bptt_len: Length of sequences for backpropagation through time.
            device: Device to create batches on. Use -1 for CPU and None for
                the currently active GPU device.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose wikitext
                subdirectory the data files will be stored.
            wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the
                text field. The word vectors are accessible as
                train.dataset.fields['text'].vocab.vectors.
            Remaining keyword arguments: Passed to the splits method.
        """
        TEXT = data.Field()

        train, val, test = cls.splits(TEXT, root=root, **kwargs)

        TEXT.build_vocab(train, vectors=vectors)

        return data.BPTTIterator.splits(
            (train, val, test), batch_size=batch_size, bptt_len=bptt_len,
            device=device)

In [5]:
batch_size = 128
eval_batch_size = 128
sequence_length = 30
grad_clip = 0.1
lr = 4.
best_val_loss = None
log_interval = 100

In [7]:
TEXT = data.Field(lower=True, tokenize=lambda x: list(x))

train, valid, test = wikitext.splits(TEXT, root='../')

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=batch_size,
    bptt_len=sequence_length,
    device='cuda',
    repeat=False)

In [8]:
TEXT.build_vocab(train)

In [9]:
class RNNModel(nn.Module):

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, hidden=None):
        emb = self.drop(self.encoder(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
                    weight.new(self.nlayers, bsz, self.nhid).zero_())
        else:
            return weight.new(self.nlayers, bsz, self.nhid).zero_()

In [10]:
def evaluate(data_loader):
    model.eval()
    total_loss = 0
    ntokens = len(TEXT.vocab)
    hidden = model.init_hidden(eval_batch_size)
    for i, (data, targets) in enumerate(data_loader):
        output, hidden = model(data)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / len(data_loader)

def evaluate(data_loader):
    model.eval()
    total_loss = 0
    ntokens = len(TEXT.vocab)
    hidden = model.init_hidden(eval_batch_size)
    for i, batch in enumerate(data_loader):
        output, hidden = model(batch.text.cuda())
        output_flat = output.view(-1, ntokens)
        total_loss += criterion(output_flat, batch.target.view(-1).cuda()).item()
    return total_loss / len(data_loader)

In [11]:
def train():
    model.train()
    total_loss = 0
    ntokens = len(TEXT.vocab)
    #for i, (data, targets) in enumerate(train_iter):
    for i, batch in enumerate(train_iter):
        model.zero_grad()
        output, hidden = model(batch.text.cuda())
        loss = criterion(output.view(-1, ntokens), batch.target.view(-1).cuda())
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, i, len(train_iter), lr, cur_loss, math.exp(cur_loss)))
            total_loss = 0

In [12]:
ntokens = len(TEXT.vocab)
model = RNNModel('LSTM', ntokens, 128, 128, 2, 0.3).cuda()
criterion = nn.CrossEntropyLoss()

In [13]:
def generate(n=50, temp=1.):
    model.eval()
    x = torch.rand(1, 1).mul(ntokens).long().cuda()
    hidden = None
    out = []
    for i in range(n):
        output, hidden = model(x, hidden)
        s_weights = output.squeeze().data.div(temp).exp()
        s_idx = torch.multinomial(s_weights, 1)[0]
        x.data.fill_(s_idx)
        s = TEXT.vocab.itos[s_idx]
        out.append(s)
    return ''.join(out)

In [14]:
with torch.no_grad():
    print('sample:\n', generate(50), '\n')

for epoch in range(1, 6):
    train()
    val_loss = evaluate(valid_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
        epoch, val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print('sample:\n', generate(50), '\n')

sample:
 ‘çdaş„ッ動्đėễ²キ⅓ხ—ư×♯o्ầô€اsż ʿ^?)ăñアăن⁄üα€āリв\ăł⅔” 

| epoch   1 |   100/ 2808 batches | lr 4.00 | loss  3.46 | ppl    31.79
| epoch   1 |   200/ 2808 batches | lr 4.00 | loss  3.15 | ppl    23.45
| epoch   1 |   300/ 2808 batches | lr 4.00 | loss  3.13 | ppl    22.77
| epoch   1 |   400/ 2808 batches | lr 4.00 | loss  3.10 | ppl    22.27
| epoch   1 |   500/ 2808 batches | lr 4.00 | loss  3.10 | ppl    22.29
| epoch   1 |   600/ 2808 batches | lr 4.00 | loss  3.09 | ppl    21.95
| epoch   1 |   700/ 2808 batches | lr 4.00 | loss  2.99 | ppl    19.90
| epoch   1 |   800/ 2808 batches | lr 4.00 | loss  2.85 | ppl    17.21
| epoch   1 |   900/ 2808 batches | lr 4.00 | loss  2.73 | ppl    15.37
| epoch   1 |  1000/ 2808 batches | lr 4.00 | loss  2.63 | ppl    13.81
| epoch   1 |  1100/ 2808 batches | lr 4.00 | loss  2.52 | ppl    12.47
| epoch   1 |  1200/ 2808 batches | lr 4.00 | loss  2.47 | ppl    11.81
| epoch   1 |  1300/ 2808 batches | lr 4.00 | loss  2.43 | ppl    11.31
| 