# пример аналогичной реализации в PyTorch
### примерно так это должно выглядеть в идеале

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class LSTM(nn.Module):
    def __init__(self,
                 max_no,
                 hidden_dim=100,
                 lstm_layers_count=1):
        super().__init__()
        self.encoder = nn.LSTM(max_no, hidden_dim, num_layers=lstm_layers_count)
        self.decoder = nn.LSTM(hidden_dim, max_no, num_layers=lstm_layers_count)


    def forward(self, input):
        seq_len = input.shape[0]
        _, (last_hidden, _) = self.encoder(input)

        encoded = last_hidden.repeat(seq_len, 1, 1)

        # Decode
        y, _ = self.decoder(encoded)
        assert y.shape == input.shape
        return torch.squeeze(y)

In [5]:
max_no = 10
batch_size = 1

In [6]:
model = LSTM(max_no=max_no, hidden_dim=5, lstm_layers_count=1)
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [7]:
def batch_gen(batch_size=32, seq_len=10, max_no=100):
    x = np.zeros((batch_size, seq_len, max_no), dtype=np.float32)
    y = np.zeros((batch_size, seq_len, max_no), dtype=np.float32)

    while True:
        X = np.random.randint(max_no, size=(batch_size, seq_len))
        Y = np.sort(X, axis=1)

        for ind,batch in enumerate(X):
            for j, elem in enumerate(batch):
                x[ind, j, elem] = 1

        for ind,batch in enumerate(Y):
            for j, elem in enumerate(batch):
                y[ind, j, elem] = 1

        yield np.swapaxes(x, 0, 1), np.swapaxes(y, 0, 1)
        x.fill(0.0)
        y.fill(0.0)

In [8]:
i = 0
while True:
    x, y = next(batch_gen(batch_size, 4, max_no))
    x = torch.FloatTensor(x)
    y = torch.FloatTensor(y)
    
    y_pred = model(x)
    optimizer.zero_grad()
    loss = loss_function(y_pred.view(-1), y.reshape(-1))
    loss.backward()
    optimizer.step()
    if i % 2000 == 0:
        n_pred = y_pred.detach().numpy()
        a = np.eye(max_no)[n_pred.argmax(axis=-1)]
        n_y = y.transpose(0,1).detach().numpy()
        print('predicted value', n_pred.argmax(axis=-1))
        print('ground truth   ', n_y.argmax(axis=-1).ravel())
    i = i + 1

predicted value [7 7 7 7]
ground truth    [0 2 3 8]
predicted value [0 6 6 6]
ground truth    [1 4 6 6]
predicted value [0 1 6 6]
ground truth    [0 1 2 6]
predicted value [0 6 6 9]
ground truth    [5 6 6 7]
predicted value [0 4 4 8]
ground truth    [0 4 4 8]
predicted value [0 3 4 7]
ground truth    [0 3 4 7]
predicted value [2 5 5 8]
ground truth    [2 2 5 8]


KeyboardInterrupt: 