# III Neural Network Module
Simple LSTM.

In [None]:
import torch
import torch.nn as nn
import numpy as np


class SimpleLSTM(nn.Module):
    def __init__(self,
                 input_size: int = 1,
                 hidden_layer_size: int = 100,
                 output_size: int = 1):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.hidden_cell = (torch.zeros(1,1,self.hidden_layer_size),
                            torch.zeros(1,1,self.hidden_layer_size))
        
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq: torch.Tensor) -> torch.Tensor:
        lstm_out, self.hidden_cell = self.lstm(
            input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]

In [None]:
data = np.linspace(0, 1, 100)
seq_length = 10
data[:10]

In [None]:
model = SimpleLSTM()
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(100):  # number of epochs
    for i in range(len(data) - seq_length):
        # Prepare data
        seq, target = data[i:i+seq_length], data[i+seq_length]
        seq = torch.FloatTensor(seq).view(-1, 1, 1)
        target = torch.FloatTensor([target])
        if epoch == 0 and i == 0:
            print(f'{seq=}, {seq.shape=}\n{target=}, {target.shape=}\n')

        # Reset the gradient and hidden state
        optimizer.zero_grad()
        model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
                             torch.zeros(1, 1, model.hidden_layer_size))

        # Forward pass
        y_pred = model(seq)

        # Compute the loss, perform backward pass, and update weights
        loss = loss_fun(y_pred, target)
        loss.backward()
        optimizer.step()

    # Print loss every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch} Loss: {loss.item()}')

In [None]:
# Inference
with torch.no_grad():
    seq = torch.FloatTensor(data[-seq_length:]).view(-1, 1, 1)
    model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
                         torch.zeros(1, 1, model.hidden_layer_size))
    pred = model(seq)
    print(f'Next value prediction for {seq}: {pred.item()}')