In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# import lighting as L
# from torch.utils.data import TensorDataset, DataLoader

In [None]:
class LSTMModel(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, output_size):
            super(LSTMModel, self).__init__()
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
            self.fc = nn.Linear(hidden_size, output_size)
        
        def forward(self, x):
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            out, _ = self.lstm(x, (h0, c0))
            out = self.fc(out[:, -1, :])
            return out

In [None]:
class LSTM(L.LightningModule):
    def __init__(self):
        # Create and initialize weight and bias tensors
        super().__init__()
        mean = torch.tensor(0.0)
        std = torch.tensor(1.0)

        self.wlr1 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.wlr2 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.blr1 = nn.Parameter(torch.tensor(0.), require_grad = True) 

        self.wpr1 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.wpr2 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.bpr1 = nn.Parameter(torch.tensor(0.), require_grad = True)

        self.wp1 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.wp2 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.bp1 = nn.Parameter(torch.tensor(0.), require_grad = True)

        self.wo1 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.wo2 = nn.Parameter(torch.normal(mean=mean, std=std), require_grad = True)
        self.bo1 = nn.Parameter(torch.tensor(0.), require_grad = True)
    
    def lstm_unit(self, input_value, long_memory, short_memory):
        # Do the LSTM math
        long_remember_percent = torch.sigmoid((short_memory*self.wlr1) + (input_value*self.wlr2) + self.blr1)

        potential_remember_percent = torch.sigmoid((short_memory*self.wpr1) + (input_value*self.wpr2) + self.bpr1)
        potential_memory = torch.tanh((short_memory*self.wp1) + (input_value*self.wp2) + self.bp1)

        updated_long_memory = ((long_memory*long_remember_percent) + (potential_memory*potential_remember_percent))
        output_percent = torch.sigmoid((short_memory*self.wo1) + (input_value*self.wo2) + self.bo1)

        updated_short_memory = torch.tanh(updated_long_memory) * output_percent

        return([updated_long_memory,updated_short_memory])

    
    def forward(self, input):
        # Make a forward pass through the unrolled LSTM
        for i in range(len(input.shape[1])):
            long_memory, short_memory = self.lstm_unit(input[i], long_memory, short_memory)
        
        return short_memory
    
    def configure_optimizers(self):
        return Adam(self.parameters())
        
    
    def training_step(self, batch, batch_idx):
        # Calculate loss and log training progress
        input_i, label_i = batch
        output_i = self.forward(input_i[0])
        loss = (output_i - label_i)**2

        self.log("train_loss", loss)