In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD 
import numpy as np


In [71]:
sentence = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.'

In [96]:
class WordDataSet:
    
    def __init__(self, sentence):
        self.words2idx = {}
        self.indexs  = []
        for word in sentence.split(' '): 
            if word not in self.words2idx:
                self.words2idx[word] = len(self.words2idx)
                
            self.indexs.append(self.words2idx[word])
            
        self.vec_size = len(self.words2idx)
        self.seq_len  = len(sentence.split(' '))
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def size(self):
        return len(self.words2idx)
    
    def get_char_by_id(self, id):
        for word, i in self.words2idx.items():
            if id == i: return word
        return None

In [97]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, h_size=2, c_size=2, out_size=5):
        super(VanillaRNN, self).__init__()
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

In [102]:
class LSTM(nn.Module):
    
    def __init__(self, size=5, h_size=5):
        super(LSTM, self).__init__()
        self.W_f = nn.Linear(in_features = size+h_size, out_features = h_size)
        self.W_i = nn.Linear(in_features = size+h_size, out_features = h_size)
        self.W_g = nn.Linear(in_features = size+h_size, out_features = h_size)
        self.W_o = nn.Linear(in_features = size+h_size, out_features = h_size)
        
        self.S_f = nn.Sigmoid()
        self.S_i = nn.Tanh()
        self.T_g = nn.Sigmoid()
        self.S_o = nn.Sigmoid()
        
        self.T_c = nn.Tanh()
        
        self.output = nn.Linear(in_features = h_size, out_features = size)
        
    def forward(self, x, prev_h, prev_c):
        stack = torch.cat((x, prev_h), 1)
        
        f = self.S_f(self.W_f(stack))
        i = self.S_i(self.W_i(stack))
        g = self.T_g(self.W_g(stack))
        o = self.S_o(self.W_o(stack))
        
        c = f*prev_c + i*g
        h = o*self.T_c(prev_c)
        
        output = self.output(h)
        
        return output, h, c
        

In [183]:
ds = WordDataSet(sentence=sentence)
lstm = LSTM(ds.size(),5)
criterion = nn.CrossEntropyLoss()
e_cnt     = 1000
optim     = SGD(lstm.parameters(), lr = 0.05)

lr_lambda = lambda epoch : 0.99**(epoch/10)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optim, lr_lambda=lr_lambda)

# Обучение

In [184]:
for epoch in range(e_cnt):
    scheduler.step()
    h = Variable( torch.zeros(5) ).unsqueeze(0)
    c = Variable( torch.zeros(5) ).unsqueeze(0)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = Variable(  ds.get_one_hot(sample) ).unsqueeze(0)
        target =  Variable(torch.LongTensor([next_sample]) )

        y, h, c = lstm(x, h, c)
        
        loss += criterion(y, target)
     
    if epoch % 25 == 0:
        print (loss.data[0])
        
    loss.backward()
    optim.step()

285.0458984375
276.81732177734375
254.1818389892578
215.6033172607422
202.2539520263672
130.3117218017578
98.00276947021484
82.61087036132812
72.87003326416016
65.10380554199219
56.54714584350586
51.07998275756836
47.371498107910156
43.31869125366211
40.46412658691406
38.037254333496094
35.34675216674805
33.406532287597656
31.961816787719727
29.891677856445312
28.434906005859375
27.379499435424805
25.974565505981445
24.872848510742188
23.889236450195312
22.989439010620117
22.172819137573242
21.422584533691406
20.73621368408203
20.104248046875
19.52216339111328
18.98406410217285
18.4847412109375
18.019922256469727
17.58625602722168
17.180824279785156
16.801042556762695
16.44463539123535
16.109582901000977
15.794071197509766


# Тестирование

In [185]:
lstm.eval()
h = Variable( torch.zeros(5) ).unsqueeze(0)
c = Variable( torch.zeros(5) ).unsqueeze(0)
id = 0
softmax = nn.Softmax()
predsentence = ds.get_char_by_id(id) + ' '
for w in range(len(ds)-1):
    x = Variable(ds.get_one_hot(id)).unsqueeze(0)
    y, h, c = lstm(x, h, c)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predsentence += ds.get_char_by_id(id) + ' '
print ('Prediction: ' , predsentence)

Prediction:  Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. 


  # Remove the CWD from sys.path while we load stuff.
