In [2]:
import os

import itertools
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math 

import sys

%matplotlib inline

In [3]:
from bpemb import BPEmb

In [4]:
import torchtext

In [5]:
# loading library for tokenization into syllables
bpemb = BPEmb(lang='en', dim=100, vs=1000)

In [53]:
# loading english dataset Wikitext-2
TEXT = torchtext.data.Field(
    sequential=True,
    use_vocab=False,
    tokenize=bpemb.encode_ids,
    batch_first=True,
    eos_token=0,
    pad_token=0,
    unk_token=None,    
)
train_set, val_set, test_set = torchtext.datasets.WikiText2.splits(TEXT, newline_eos=False)
train_iter, val_iter, test_iter = torchtext.data.BPTTIterator.splits((train_set, val_set, test_set), batch_size=128, bptt_len=60)

In [54]:

VOCAB_LEN = len(bpemb.vectors)

In [55]:
# NN-model with embedding layer, 1 LSTM or GRU cell, 
# and 1 fully-connected decoding layer
class RNNModel(nn.Module):

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, hidden=None):
        emb = self.drop(self.encoder(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
                    weight.new(self.nlayers, bsz, self.nhid).zero_())
        else:
            return weight.new(self.nlayers, bsz, self.nhid).zero_()

In [76]:
# function with train loop, trains for 1 epoch
def train(model, train_iter, ntokens=VOCAB_LEN):
    model.train()
    total_loss = 0
    
    for batch, data in enumerate(train_iter):
        model.zero_grad()
        output, hidden = model(data.text)
        loss = criterion(output.view(-1, ntokens), data.target.view(-1))
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_iter), lr, cur_loss, math.exp(cur_loss)))
            total_loss = 0

In [77]:
# function for evaluation cross-entropy loss on evaluation datastet
def evaluate(model, data_iter, ntokens=VOCAB_LEN):
    model.eval()
    total_loss = 0
    
    hidden = model.init_hidden(eval_batch_size)
    for i, data in enumerate(data_iter):
        output, hidden = model(data.text)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, data.target.view(-1)).item()
    return total_loss / len(data_iter)

In [78]:
# function for sequence generation by given model 
def generate(model, itos_func, n=50, temp=1., ntokens=VOCAB_LEN):
    model.eval()
    x = torch.rand(1, 1).mul(ntokens).long()
    hidden = None
    out = []
    for i in range(n):
        output, hidden = model(x, hidden)
        s_weights = output.squeeze().data.div(temp).exp()
        s_idx = torch.multinomial(s_weights, 1)[0]
        x.data.fill_(s_idx)
        out.append(s_idx.item())
    return itos_func(out)

In [79]:
model = RNNModel('LSTM', VOCAB_LEN, 128, 128, 2, 0.3)
criterion = nn.CrossEntropyLoss()

In [80]:
batch_size = 128
sequence_length = 30
grad_clip = 0.1
lr = 4.
best_val_loss = None
log_interval = 100
eval_batch_size = 128

In [81]:
# model training, results generation and perplexity calculation 
with torch.no_grad():
    print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

for epoch in range(1, 6):
    train(model, train_iter)
    val_loss = evaluate(model, val_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
        epoch, val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

sample:
 ", not evj ter albumures), arege whuralriedran wheuctyav_ig are therešents compleck weapñangeiverè bar group neical publ”ack yitivesists acich including (ayery 

| epoch   1 |   100/  541 batches | lr 4.00 | loss  5.99 | ppl   399.84
| epoch   1 |   200/  541 batches | lr 4.00 | loss  5.79 | ppl   325.83
| epoch   1 |   300/  541 batches | lr 4.00 | loss  5.77 | ppl   320.83
| epoch   1 |   400/  541 batches | lr 4.00 | loss  5.71 | ppl   303.15
| epoch   1 |   500/  541 batches | lr 4.00 | loss  5.49 | ppl   242.04
-----------------------------------------------------------------------------------------
| end of epoch   1 | valid loss 653.57 | valid ppl 69730106488882581431485911455430853832851299187399878877469929794641946342917316499945201387623795328811676615802032167624742378903174992276568183825870998015597262689735587419939358773695268030348495591902740310316951529624835241805320093699791245684320599604265899684026821027823616.00
----------------------------------------

In [82]:
print("Val loss = {:.2f}, val ppl = {:.2f}".format(val_loss / 128, math.exp(val_loss / 128)))

Val loss = 4.29, val ppl = 73.31


In [None]:
with torch.no_grad():
    print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

for epoch in range(1, 6):
    train(model, train_iter)
    val_loss = evaluate(model, val_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
        epoch, val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

In [84]:
print("Val loss = {:.2f}, val ppl = {:.2f}".format(val_loss / 128, math.exp(val_loss / 128)))

Val loss = 4.05, val ppl = 57.43


In [99]:
with torch.no_grad():
    print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

for epoch in range(1, 11):
    train(model, train_iter)
    val_loss = evaluate(model, val_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
        epoch, val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

sample:
 hebov ! # swmonry had been serv 's creek ⁇  public , eva effated persome prope 's procon of herimality by the dreama may addk ⁇  

| epoch   1 |   100/  541 batches | lr 4.00 | loss  4.50 | ppl    90.08
| epoch   1 |   200/  541 batches | lr 4.00 | loss  4.44 | ppl    84.61
| epoch   1 |   300/  541 batches | lr 4.00 | loss  4.43 | ppl    84.12
| epoch   1 |   400/  541 batches | lr 4.00 | loss  4.42 | ppl    83.07
| epoch   1 |   500/  541 batches | lr 4.00 | loss  4.42 | ppl    82.92
-----------------------------------------------------------------------------------------
| end of epoch   1 | valid loss 514.01 | valid ppl 17114984026077823346095253929920382144112895392832337805507940210974840309974947952352083833669322324294566601041378017404461709768129321912446070082923868065583431083561842791897039886931403433470375783008708540831367168000.00
-----------------------------------------------------------------------------------------
sample:
 saissional poem varance = = 0  ⁇

In [100]:

with torch.no_grad():
    print('sample:\n', generate(model, bpemb.decode_ids, 50), '\n')

sample:
 id as well into polery mile on a bed for the 0 to ox on the video " nazine st. ⁇ - ⁇  are a persanding apme raemeated at the minth at j 



In [104]:
# printing all avaliable syllables
for i in range(1000):
    print(i, bpemb.decode_ids([i]))

0  ⁇ 
1 
2 
3 t
4 a
5 he
6 in
7 the
8 er
9 on
10 s
11 00
12 re
13 o
14 c
15 w
16 an
17 at
18 ed
19 en
20 b
21 f
22 or
23 is
24 p
25 it
26 in
27 of
28 ar
29 es
30 al
31 m
32 an
33 d
34 and
35 as
36 ic
37 ing
38 ro
39 00
40 h
41 ion
42 to
43 l
44 ''
45 ou
46 il
47 n
48 el
49 ent
50 re
51 g
52 0000
53 st
54 le
55 om
56 am
57 e
58 th
59 ol
60 un
61 ct
62 *
63 ad
64 (
65 et
66 st
67 ur
68 iv
69 ch
70 us
71 on
72 for
73 was
74 ly
75 id
76 ation
77 im
78 ir
79 as
80 is
81 ig
82 ce
83 he
84 ut
85 ot
86 be
87 ra
88 ers
89 ''
90 ow
91 v
92 ith
93 al
94 em
95 ter
96 ay
97 with
98 ul
99 j
100 con
101 by
102 and
103 wh
104 r
105 ist
106 th
107 it
108 ver
109 k
110 rom
111 ge
112 ch
113 os
114 at
115 her
116 de
117 um
118 un
119 com
120 that
121 0
122 pro
123 ac
124 est
125 se
126 op
127 or
128 "
129 ri
130 from
131 od
132 av
133 ain
134 ity
135 res
136 if
137 oun
138 up
139 oc
140 ne
141 ill
142 ies
143 qu
144 se
145 art
146 his
147 ate
148 ab
149 ian
150 pe
151 ud
152 ich
153 0000
154 are
155 y
15