In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [2]:
a = torch.ones((3,3))*3
b = torch.ones((3,3))*5

In [3]:
a @ b

tensor([[45., 45., 45.],
        [45., 45., 45.],
        [45., 45., 45.]])

In [4]:
a * b

tensor([[15., 15., 15.],
        [15., 15., 15.],
        [15., 15., 15.]])

In [5]:
# word = 'ololoasdasddqweqw123456789'
word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [6]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [7]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
#         Версия без активации - может происходить gradient exploding
#         hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [8]:
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn.parameters(), lr = 0.1, momentum=0.9)

# Обучение

In [9]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hidden.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
            
#     print("Params : ")
#     num_params = 0
#     for item in rnn.parameters():
#         num_params += 1
#         print(item.grad)
#     print("NumParams :", num_params)
#     print("Optimize")
    
    optim.step()

5.330726623535156
Clip gradient :  1.9298164278136098
1.3684765100479126
Clip gradient :  1.3428276853224093
0.02804422378540039
Clip gradient :  0.034508922197576015
0.007588386535644531
Clip gradient :  0.0098744079562533
0.0045318603515625
Clip gradient :  0.005874060642420174
0.0036134719848632812
Clip gradient :  0.0046419927632531995
0.0032134056091308594
Clip gradient :  0.004109684393575621
0.0029816627502441406
Clip gradient :  0.0038133222522466516
0.0028142929077148438
Clip gradient :  0.0036119474516455433
0.0026760101318359375
Clip gradient :  0.0034555867178719755


# Тестирование

In [10]:
rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 hello
Original:	 hello


# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово
Сохранить ноутбук с предсказанием и пройденным assert и прислать на почту a.murashev@corp.mail.ru
c темой:


[МФТИ\_2019\_1] ДЗ №8 ФИО

In [11]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'
ds = WordDataSet(word=word)

## Реализовать LSTM

In [12]:
class LSTM(nn.Module):
    
    def __init__(self, in_size, hidden_size, out_size):
        super(LSTM, self).__init__()
        
        self.x2hidden = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden = nn.Linear(in_features=hidden_size, out_features=hidden_size)
    
        self.x2hidden_i = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden_i = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.x2hidden_f = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden_f = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.x2hidden_o = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden_o = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.outweight = nn.Linear(in_features=hidden_size, out_features=out_size)
        
        self.activation  = nn.Tanh()
        
        
    def forward(self, x, prev_h_t, prev_c_t):

        cand_c_t = self.activation(self.x2hidden(x) + self.hidden(prev_h_t))
        
        i_t = torch.sigmoid(self.x2hidden_i(x) + self.hidden_i(prev_h_t))
        f_t = torch.sigmoid(self.x2hidden_f(x) + self.hidden_f(prev_h_t))
        o_t = torch.sigmoid(self.x2hidden_o(x) + self.hidden_o(prev_h_t))
  
        c_t = f_t * prev_c_t + i_t * cand_c_t
        h_t = o_t * self.activation(c_t)
        output = self.outweight(h_t)
        
        return output, h_t, c_t

## Инициализация

In [13]:
lstm = LSTM(in_size=ds.vec_size, hidden_size=50, out_size=ds.vec_size)

e_cnt = 150
criterion = nn.CrossEntropyLoss()
lstm_optim     = SGD(lstm.parameters(), lr = 0.1, momentum=0.9)

## Обучение

In [14]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.ones(lstm.hidden.in_features)
    ct = torch.zeros(lstm.x2hidden.out_features)
    loss = 0
    lstm_optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh, ct = lstm(x, hh, ct)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(lstm.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(lstm.parameters(), max_norm=1)
    
    lstm_optim.step()

70.96966552734375
Clip gradient :  3.9621062044844098
55.280662536621094
Clip gradient :  5.098674449842935
34.50522994995117
Clip gradient :  8.13513306276048
21.527170181274414
Clip gradient :  34.56204257471622
26.57499122619629
Clip gradient :  58.41753140562637
14.22696304321289
Clip gradient :  6.567133397173636
7.2023749351501465
Clip gradient :  7.626819133552264
4.366931438446045
Clip gradient :  14.046829727545287
1.4094419479370117
Clip gradient :  2.100908800313955
0.2378101348876953
Clip gradient :  0.3912073353628567
0.05113792419433594
Clip gradient :  0.06200934561740256
0.027177810668945312
Clip gradient :  0.03964232164435101
0.0171966552734375
Clip gradient :  0.018677042502100977
0.013371467590332031
Clip gradient :  0.012534196348760218
0.01151275634765625
Clip gradient :  0.010331602728238633


## Тестирование

In [15]:
lstm.eval()
hh = torch.ones(lstm.hidden.in_features)
ct = torch.zeros(lstm.x2hidden.out_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh, ct = lstm(x, hh, ct)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


## Реализовать GRU

In [16]:
class GRU(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(GRU, self).__init__()
    
        self.x2hidden = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.x2hidden_u = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden_u = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.x2hidden_r = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden_r = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.outweight = nn.Linear(in_features=hidden_size, out_features=out_size)
        
        self.activation  = nn.Tanh()
        
        
    def forward(self, x, prev_h_t):
        u_t = torch.sigmoid(self.x2hidden_u(x) + self.hidden_u(prev_h_t))
        r_t = torch.sigmoid(self.x2hidden_r(x) + self.hidden_r(prev_h_t))
  
        cand_h_t = self.activation(self.x2hidden(x) + self.hidden(r_t * prev_h_t))
        h_t = (1 - u_t) * cand_h_t + u_t * prev_h_t
        output = self.outweight(h_t)
        
        return output, h_t

## Инициализация

In [17]:
gru = GRU(in_size=ds.vec_size, hidden_size=10, out_size=ds.vec_size)

e_cnt = 200
criterion = nn.CrossEntropyLoss()
gru_optim     = SGD(gru.parameters(), lr = 0.1, momentum=0.9)

## Обучение

In [18]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.ones(gru.hidden.in_features)
    loss = 0
    gru_optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh= gru(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(gru.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(gru.parameters(), max_norm=1)
    
    gru_optim.step()

73.05367279052734
Clip gradient :  6.235991867136826
56.00344467163086
Clip gradient :  6.957778339586844
42.672813415527344
Clip gradient :  9.708163068013185
28.70158576965332
Clip gradient :  9.12639295982216
16.049579620361328
Clip gradient :  6.809344451313966
6.815958023071289
Clip gradient :  3.791361181334377
2.3238205909729004
Clip gradient :  1.175163635020065
1.58056640625
Clip gradient :  0.502730965172851
0.5121030807495117
Clip gradient :  0.5655030457760515
0.1083526611328125
Clip gradient :  0.34056821766071343
0.0673370361328125
Clip gradient :  0.13985752193442502
0.04840660095214844
Clip gradient :  0.0508732845766967
0.039559364318847656
Clip gradient :  0.04899905854695647
0.03411674499511719
Clip gradient :  0.02998454006374314
0.030307769775390625
Clip gradient :  0.020850171696827752
0.0274200439453125
Clip gradient :  0.015985771883065446
0.025112152099609375
Clip gradient :  0.013821690372286696
0.023184776306152344
Clip gradient :  0.012914704887505497
0.0215

## Тестирование

In [19]:
gru.eval()
hh = torch.ones(gru.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = gru(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
