In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD 
import numpy as np

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [3]:
word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [3]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [11]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [12]:
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn.parameters(), lr = 0.1)

# Обучение

In [13]:
for epoch in range(e_cnt):
    hh = Variable( torch.zeros(rnn.hidden.in_features) )
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = Variable(  ds.get_one_hot(sample).unsqueeze(0) )
        target =  Variable(torch.LongTensor([next_sample]) )

        y, hh = rnn(x, hh)
        
        loss += criterion(y, target)
     
    if epoch % 10 == 0:
        print (loss.data[0])
    loss.backward()
    optim.step()

6.288224220275879
3.556276798248291
2.050351619720459
0.38967016339302063
0.06675036251544952
0.029427340254187584
0.018021682277321815
0.012735573574900627
0.00974067859351635
0.007831979542970657


# Тестирование

In [14]:
rnn.eval()
hh = Variable( torch.zeros(rnn.hidden.in_features) )
id = 0
softmax = nn.Softmax()
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = Variable(ds.get_one_hot(id).unsqueeze(0))
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction: ' , predword)

Prediction:  hello


# ДЗ
Модифицировать код Ванильной RNN:
- Предсказываем не слово "hello", а латинский алфавит в нижнем регистре [abcd..z] 
- Заменить код RNN на реализацию LSTM или GRU по выбору
- Реализовать свой embeding вместо OneHot Encoding

In [4]:
word = 'abcdefghijklmnopqrstuvwxyz'

## LSTM class

In [29]:
class LSTM(nn.Module):
    
    def __init__(self, vocab_size=5, in_size=5, out_size=5):
        super(LSTM, self).__init__()
        self.embeddings  = nn.Embedding(vocab_size, in_size)
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=4 * in_size)
        self.hidden      = nn.Linear(in_features=in_size, out_features=4 * in_size)
        self.sigmoid     = nn.Sigmoid()
        self.tanh        = nn.Tanh()
        self.fully_conn  = nn.Linear(in_features=in_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        x = self.embeddings(x)
        in_size = self.hidden.in_features
        inner = self.x2hidden(x) + self.hidden(prev_hidden)
        vector_i = self.sigmoid(inner[:, :in_size])
        vector_f = self.sigmoid(inner[:, in_size:2*in_size])
        vector_o = self.sigmoid(inner[:, 2*in_size:3*in_size])
        vector_g = self.tanh(inner[:, 3*in_size:4*in_size])
        hidden = vector_f * prev_hidden + vector_i * vector_g
        output = self.fully_conn(vector_o * self.tanh(hidden))
        return output, hidden

## Initialization

In [30]:
ds = WordDataSet(word=word)
lstm = LSTM(vocab_size=ds.vec_size, in_size=4, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 500
optim     = SGD(lstm.parameters(), momentum=0.9, lr = 0.01)

## Learning

In [31]:
for epoch in range(e_cnt):
    hh = Variable( torch.zeros(lstm.hidden.in_features) )
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = Variable(torch.LongTensor([sample]))
        target =  Variable(torch.LongTensor([next_sample]) )

        y, hh = lstm(x, hh) 
        loss += criterion(y, target)
     
    if epoch % 50 == 0:
        print("EPOCH NUMBER: {0}".format(epoch), "LOSS: {0}".format(loss.data[0]))
    loss.backward()
    optim.step()

EPOCH NUMBER: 0 LOSS: 82.71051025390625
EPOCH NUMBER: 50 LOSS: 43.844486236572266
EPOCH NUMBER: 100 LOSS: 30.469608306884766
EPOCH NUMBER: 150 LOSS: 12.731304168701172
EPOCH NUMBER: 200 LOSS: 8.299394607543945
EPOCH NUMBER: 250 LOSS: 5.863373756408691
EPOCH NUMBER: 300 LOSS: 4.457452774047852
EPOCH NUMBER: 350 LOSS: 3.6299638748168945
EPOCH NUMBER: 400 LOSS: 3.06754469871521
EPOCH NUMBER: 450 LOSS: 2.657789468765259


In [33]:
lstm.eval()
hh = Variable( torch.zeros(lstm.hidden.in_features) )
id = 0
softmax = nn.Softmax()
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = Variable(torch.LongTensor([id]))
    y, hh = lstm(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction: ' , predword)

Prediction:  abcdefghijklmnopqrstuvwxyz
