In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [2]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [3]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
#         Версия без активации - может происходить gradient exploding
#         hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [4]:
word = 'hello'
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn.parameters(), lr = 0.1, momentum=0.9)

# Обучение

In [5]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hidden.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
            
#    print("Params : ")
#    num_params = 0
#    for item in rnn.parameters():
#        num_params += 1
#        print(item.grad)
#    print("NumParams :", num_params)
#    print("Optimize")
    
    optim.step()

5.551338195800781
Clip gradient :  2.6076437372619043
2.360447406768799
Clip gradient :  0.5936792793439414
1.9433908462524414
Clip gradient :  0.40816742802189604
1.8922014236450195
Clip gradient :  0.09037988718327744
1.8702802658081055
Clip gradient :  0.2018156214094005
1.828260898590088
Clip gradient :  0.14315087040274327
1.2632418870925903
Clip gradient :  2.8536247345751056
0.6201820373535156
Clip gradient :  4.0141089325610775
0.28866147994995117
Clip gradient :  1.0504302915472385
0.08379793167114258
Clip gradient :  0.3003839774364803


# Тестирование

In [6]:
rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 hello
Original:	 hello


# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово
Сохранить ноутбук с предсказанием и пройденным assert и прислать на почту a.murashev@corp.mail.ru
c темой:


[МФТИ\_2019\_1] ДЗ №8 ФИО

In [38]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'

## Реализовать LSTM

In [39]:
#Написать реализацию LSTM и обучить предсказывать слово

class LSTM(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(LSTM, self).__init__() 
        
        self.ix = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.ih = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.fx = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.fh = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.gx = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.gh = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.ox = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.oh = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.sigmoid = nn.Sigmoid()
        self.tanh  = nn.Tanh()
        
        self.outweight = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_h, prev_c):
        
        i = self.sigmoid(self.ix(x) + self.ih(prev_h))
        f = self.sigmoid(self.fx(x) + self.fh(prev_h))
        g = self.tanh(self.gx(x) + self.gh(prev_h))
        o = self.sigmoid(self.ox(x) + self.oh(prev_h))
        
        c = f * prev_c + i * g
        h = o * self.tanh(c)
        output = self.outweight(h)
        
        return output, h, c

In [46]:
ds = WordDataSet(word=word)
rnn_lstm = LSTM(in_size=ds.vec_size, hidden_size=6, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn_lstm.parameters(), lr = 0.1, momentum=0.9)

In [47]:
CLIP_GRAD = False
for epoch in range(e_cnt):
    hh = torch.zeros(rnn_lstm.ih.in_features)
    cc = torch.zeros(rnn_lstm.ih.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh, cc = rnn_lstm(x, hh, cc)
        
        loss += criterion(y, target)
     
    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn_lstm.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn_lstm.parameters(), max_norm=1)
            
    optim.step()

71.47293853759766
60.48529815673828
41.61228942871094
62.177001953125
28.952451705932617
16.255210876464844
10.616647720336914
6.8214850425720215
4.902827739715576
3.45855712890625


In [48]:
rnn_lstm.eval()
hh = torch.zeros(rnn_lstm.ih.in_features)
cc = torch.zeros(rnn_lstm.ih.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh, cc= rnn_lstm(x, hh, cc)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


## Реализовать GRU

In [12]:
#Написать реализацию GRU и обучить предсказывать слово

In [49]:
class GRU(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(GRU, self).__init__() 
        self.rx = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.rh = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.zx = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.zh = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.nx = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.nh = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.sigmoid = nn.Sigmoid()
        self.tanh  = nn.Tanh()
        
        self.outweight = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_h):
        
        r = self.sigmoid(self.rx(x) + self.rh(prev_h))
        z = self.sigmoid(self.zx(x) + self.zh(prev_h))
        n = self.tanh(self.nx(x) + r * self.nh(prev_h))
         
        h = (1 - z) * n + z * prev_h
        output = self.outweight(h)
        
        return output, h

In [58]:
ds = WordDataSet(word=word)
rnn_gru = GRU(in_size=ds.vec_size, hidden_size=6, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn_gru.parameters(), lr = 0.1, momentum=0.9)

In [59]:
CLIP_GRAD = False

for epoch in range(e_cnt):
    hh = torch.zeros(rnn_gru.rh.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn_gru(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn_gru.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn_gru.parameters(), max_norm=1)
            
    optim.step()

70.94506072998047
34.48104476928711
8.561725616455078
3.512789726257324
1.137862205505371
0.21988677978515625
0.11696147918701172
0.08054351806640625
0.06529712677001953
0.05706977844238281


In [60]:
rnn_gru.eval()
hh = torch.zeros(rnn_gru.rh.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn_gru(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
