In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from torch.utils.data import TensorDataset, DataLoader

In [None]:
def gen_sequence(sequence_len):
    x = np.random.randint(0,10,sequence_len)
    y = np.random.randint(10,size = sequence_len)
    y[0] = x[0]
    for i in range(1,sequence_len):
        y[i] = x[i] + x[0]
        if y[i] > 9:
            y[i] -= 10
    return x,y      

In [None]:
X,y = gen_sequence(10)
print(X)
print(y)

In [None]:
class RNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(sequence_len, 16)
        self.rnn = torch.nn.RNN(16, 32, batch_first=True)
        self.out = torch.nn.Linear(32, 10)
        
    def forward(self, sentences, state=None):
        x = self.embedding(sentences)
        x, s = self.rnn(x)
        return self.out(x)

In [None]:
class LSTM(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(sequence_len, 16)
        self.rnn = torch.nn.LSTM(16, 32, batch_first=True)
        self.out = torch.nn.Linear(32, 10)
        
    def forward(self, sentences, state=None):
        x = self.embedding(sentences)
        x, s = self.rnn(x)
        return self.out(x)

In [None]:
class GRU(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(sequence_len, 16)
        self.rnn = torch.nn.GRU(16, 32, batch_first=True)
        self.out = torch.nn.Linear(32, 10)
        
    def forward(self, sentences, state=None):
        x = self.embedding(sentences)
        x, s = self.rnn(x)
        return self.out(x)

In [None]:
def train(
    model,criterion, optimizer,
    epoch = 20,
    batch_size = 128
):

    data_train = [gen_sequence(sequence_len) for _ in range(10000)]
    data_test = [gen_sequence(sequence_len) for _ in range(2000)]
    iter_train = DataLoader(data_train, batch_size, shuffle=True)
    iter_test = DataLoader(data_test, batch_size)
    hist_train = np.empty(epoch)
    hist_test = np.empty(epoch)
    for ep in range(epoch):
        start = time.time()
        train_loss = 0.
        train_passed = 0
        test_loss = 0.
        test_passed = 0

        model.train()
        for X_batch,y_batch in iter_train:
            y_batch = y_batch.view(1, -1).squeeze()        
            optimizer.zero_grad()
            y_pred = model(X_batch).view(-1, 10)
            loss = criterion(y_pred, y_batch)  
            train_loss += loss.item()
            train_passed += 1
            loss.backward()
            optimizer.step()    

        model.eval()
        for X_batch, y_batch in iter_test:
            optimizer.zero_grad()
            y_batch = y_batch.view(1, -1).squeeze()
            y_pred = model(X_batch).view(-1, 10)
            loss = criterion(y_pred, y_batch)  
            test_loss += loss.item()
            test_passed += 1

        hist_train[ep] = train_loss/train_passed
        hist_test[ep] = test_loss/test_passed   
        
        print("Epoch {}, Time: {:.3f}, Train loss: {:.3f}, Test loss: {:.3f}".format(ep, time.time() - start, train_loss/train_passed,test_loss/test_passed))
    return hist_train, hist_test

In [None]:
def plot_learning_curves(train, test):
    '''
    Функция для вывода графиков лосса и метрики во время обучения.
    '''
    fig = plt.figure(figsize=(20, 7))

    plt.title('Loss', fontsize=15)
    plt.plot(train, label='train')
    plt.plot(test, label='test')
    plt.ylabel('Loss', fontsize=15)
    plt.xlabel('Epoch', fontsize=15)
    plt.legend()
    plt.show();

In [None]:
sequence_len = 25

In [None]:
model = RNN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
rnn25_history = train(model,criterion, optimizer)
plot_learning_curves(rnn25_history[0],rnn25_history[1])

In [None]:
model = LSTM()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
lstm25_history = train(model,criterion, optimizer)
plot_learning_curves(lstm25_history[0],lstm25_history[1])

In [None]:
model = GRU()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
gru25_history = train(model,criterion, optimizer)
plot_learning_curves(gru25_history[0],gru25_history[1])

In [None]:
sequence25 = pd.DataFrame({'RNN':rnn25_history[1], 'LSTM':lstm25_history[1], 'GRU':gru25_history[1]})
sequence25

In [None]:
sequence_len = 75

In [None]:
model = RNN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
rnn75_history = train(model,criterion, optimizer)
plot_learning_curves(rnn75_history[0],rnn75_history[1])

In [None]:
model = LSTM()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
lstm75_history = train(model,criterion, optimizer)
plot_learning_curves(lstm75_history[0],lstm75_history[1])

In [None]:
model = GRU()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
gru75_history = train(model,criterion, optimizer)
plot_learning_curves(gru75_history[0],gru75_history[1])

In [None]:
sequence75 = pd.DataFrame({'RNN':rnn75_history[1], 'LSTM':lstm75_history[1], 'GRU':gru75_history[1]})
sequence75

In [None]:
sequence_len = 150

In [None]:
model = RNN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
rnn150_history = train(model,criterion, optimizer)
plot_learning_curves(rnn150_history[0],rnn150_history[1])

In [None]:
model = LSTM()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
lstm150_history = train(model,criterion, optimizer)
plot_learning_curves(lstm150_history[0],lstm150_history[1])

In [None]:
model = GRU()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
gru150_history = train(model,criterion, optimizer)
plot_learning_curves(gru150_history[0],gru150_history[1])

In [None]:
sequence150 = pd.DataFrame({'RNN':rnn150_history[1], 'LSTM':lstm150_history[1], 'GRU':gru150_history[1]})
sequence150