In [78]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

from tqdm import tqdm_notebook

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [79]:
a = torch.ones((3,3))*3
b = torch.ones((3,3))*5

In [80]:
a @ b

tensor([[45., 45., 45.],
        [45., 45., 45.],
        [45., 45., 45.]])

In [81]:
a * b

tensor([[15., 15., 15.],
        [15., 15., 15.],
        [15., 15., 15.]])

In [82]:
#word = 'ololoasdasddqweqw123456789'
word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [83]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [84]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
#         Версия без активации - может происходить gradient exploding
#         hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [85]:
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn.parameters(), lr = 0.01, momentum=0.9)

# Обучение

In [86]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hidden.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn(x, hh)
        loss += criterion(y, target)

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
            
#     print("Params : ")
#     num_params = 0
#     for item in rnn.parameters():
#         num_params += 1
#         print(item.grad)
#     print("NumParams :", num_params)
#     print("Optimize")
    
    optim.step()

6.082126140594482
Clip gradient :  2.6229139269528163
5.00617790222168
Clip gradient :  1.6478035554009303
3.9959566593170166
Clip gradient :  1.1804787072021505
3.096665859222412
Clip gradient :  1.0647933660579545
2.2841830253601074
Clip gradient :  0.8430615592995833
1.7259869575500488
Clip gradient :  0.7169644919648777
1.2411096096038818
Clip gradient :  0.9171844458764893
0.7921334505081177
Clip gradient :  0.5959908016834634
0.49054455757141113
Clip gradient :  0.48165733988044906
0.30702996253967285
Clip gradient :  0.32871085914106685


# Тестирование

In [87]:
rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
try:
    assert(predword == word)
    print('Done')
except AssertionError:
    print('Ошибка, слова не равны.')


Prediction:	 hello
Original:	 hello
Done


# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово

In [88]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'

## Реализовать LSTM

In [89]:
class LSTM_RNN2(nn.Module):
    
    def __init__(self, size):
        super(LSTM_RNN2, self).__init__()
        self.forget_h = nn.Linear(in_features=size, out_features=size)
        self.forget_x = nn.Linear(in_features=size, out_features=size)
        self.input_h = nn.Linear(in_features=size, out_features=size)
        self.input_x = nn.Linear(in_features=size, out_features=size)
        self.candidate_h = nn.Linear(in_features=size, out_features=size)
        self.candidate_x = nn.Linear(in_features=size, out_features=size)
        self.output_h = nn.Linear(in_features=size, out_features=size)
        self.output_x = nn.Linear(in_features=size, out_features=size)
        self.out = nn.Linear(in_features=size, out_features=size)
        
        self.sigma  = nn.Sigmoid()
        self.tanh  = nn.Tanh()
    
    def forward(self, x, prev_hidden, prev_state):
        forget_gate = self.sigma(self.forget_x(x) + self.forget_h(prev_hidden)) #возм по-другому сканкатенировать!
        input_gate = self.sigma(self.input_x(x) + self.input_h(prev_hidden))
        candidate_gate = self.tanh(self.candidate_x(x) + self.candidate_h(prev_hidden))
        output_gate = self.sigma(self.output_x(x) + self.output_h(prev_hidden))
        
        state = input_gate*candidate_gate + prev_state*forget_gate
        hidden = output_gate * self.tanh(state)
        out = self.out(hidden)
        
        return out, hidden, state

In [90]:
# lstm2
ds = WordDataSet(word=word)
lstm2 = LSTM_RNN2(size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
optim     = SGD(lstm2.parameters(), lr = 0.01, momentum=0.9)

In [91]:
CLIP_GRAD = True

for epoch in tqdm_notebook(range(500)):
    hh = torch.zeros(lstm2.forget_h.in_features)
    state = torch.zeros(lstm2.candidate_h.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh, state = lstm2(x, hh, state)
        loss += criterion(y, target)

    loss.backward(retain_graph=True)

    if epoch % 10 == 0:
        #print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(lstm2.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(lstm2.parameters(), max_norm=1)
    optim.step()

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

Clip gradient :  3.154193471138999
Clip gradient :  2.3480208174299255
Clip gradient :  1.6425997359057256
Clip gradient :  2.168367882250224
Clip gradient :  3.4515866052645316
Clip gradient :  5.132345389238131
Clip gradient :  5.8200783432933445
Clip gradient :  5.401052799168868
Clip gradient :  4.812899602703747
Clip gradient :  4.06251725895309
Clip gradient :  4.016242873838575
Clip gradient :  3.5334145065760683
Clip gradient :  3.1424755307576775
Clip gradient :  3.1858664791603726
Clip gradient :  3.4250711062348826
Clip gradient :  5.12973538838773
Clip gradient :  17.268832791224202
Clip gradient :  9.756044232129671
Clip gradient :  6.576823306437076
Clip gradient :  12.911056760911341
Clip gradient :  11.788825665238434
Clip gradient :  6.467368173142871
Clip gradient :  14.53787341511306
Clip gradient :  17.242254655802697
Clip gradient :  10.201136447735099
Clip gradient :  11.947438809630015
Clip gradient :  9.052965581096203
Clip gradient :  9.920396349214661
Clip gra

In [92]:
lstm2.eval()

hh = torch.zeros(lstm2.forget_h.in_features)
state = torch.zeros(lstm2.candidate_h.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    #y, hh = rnn(x, hh)
    y, hh, state = lstm2(x, hh, state)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
try:
    assert(predword == word)
    print('Done!')
except AssertionError:
    print('Ошибка, слова не равны.')

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
Done!


## Реализовать GRU

In [93]:
#Написать реализацию GRU и обучить предсказывать слово

In [94]:
class GRU_RNN(nn.Module):
    
    def __init__(self, size):
        super(GRU_RNN, self).__init__()
        self.reset_h = nn.Linear(in_features=size, out_features=size)
        self.reset_x = nn.Linear(in_features=size, out_features=size)
        self.candidate_h = nn.Linear(in_features=size, out_features=size)
        self.candidate_x = nn.Linear(in_features=size, out_features=size)
        self.update_h = nn.Linear(in_features=size, out_features=size)
        self.update_x = nn.Linear(in_features=size, out_features=size)
        self.out = nn.Linear(in_features=size, out_features=size)
        
        self.sigma  = nn.Sigmoid()
        self.tanh  = nn.Tanh()
        
    def forward(self, x, prev_hidden):
        reset_gate = self.sigma(self.reset_x(x) + self.reset_h(prev_hidden)) 
        candidate_gate = self.tanh(self.candidate_x(x) + self.candidate_h(prev_hidden)*reset_gate)
        update_gate = self.sigma(self.update_x(x) + self.update_h(prev_hidden))
        
        hidden = prev_hidden*update_gate + candidate_gate*(1 - update_gate)
        out = self.out(hidden)
        
        return out, hidden

In [95]:
ds = WordDataSet(word=word)
gru = GRU_RNN(size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
optim     = SGD(gru.parameters(), lr = 0.01, momentum=0.9)

In [96]:
CLIP_GRAD = True
e_cnt = 250

for epoch in tqdm_notebook(range(e_cnt)):
    hh = torch.zeros(gru.reset_h.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = gru(x, hh)
        loss += criterion(y, target)

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(gru.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(gru.parameters(), max_norm=1)
    
    optim.step()

print('finished training')

HBox(children=(IntProgress(value=0, max=250), HTML(value='')))

70.3506088256836
Clip gradient :  3.9861658095113714
68.17176055908203
Clip gradient :  3.331887899431424
65.17427825927734
Clip gradient :  3.4763075090123103
60.93827819824219
Clip gradient :  4.802023790606661
54.364723205566406
Clip gradient :  6.459210915494356
46.05052947998047
Clip gradient :  6.516791412976152
38.55032730102539
Clip gradient :  5.477175539889719
32.22412109375
Clip gradient :  4.916311989517177
26.48287582397461
Clip gradient :  4.40462455604256
21.3674259185791
Clip gradient :  3.8924844990414993
16.86004638671875
Clip gradient :  3.4794756942282747
12.921923637390137
Clip gradient :  3.0371855218892865
9.625797271728516
Clip gradient :  2.5691545907699562
6.971035480499268
Clip gradient :  2.1099869738775707
4.899784088134766
Clip gradient :  1.6839306449513036
3.313459873199463
Clip gradient :  1.3564478050977262
2.08853816986084
Clip gradient :  1.0665830676073547
1.2528347969055176
Clip gradient :  0.7058360563358445
0.8165812492370605
Clip gradient :  0.4

In [97]:
#оценка gru

gru.eval()
hh = torch.zeros(gru.reset_h.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = gru(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
try:
    assert(predword == word)
    print('Done!')
except AssertionError:
    print('Ошибка, слова не равны.')

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
Done!
