In [None]:
!pip install torch
!pip install lightning

In [3]:
import torch 
import torch.nn as nn
from torch.optim import Adam
import numpy as np
import lightning as L

In [4]:
from torch.utils.data import DataLoader, TensorDataset

In [13]:
class LSTMbyHand(L.LightningModule):
    def __init__(self):
        super().__init__()

        norm_dist = torch.randn(1)
        zero_tensor = torch.tensor(0).float()

        self.wfs = nn.Parameter(norm_dist, requires_grad=True)
        self.wfi = nn.Parameter(norm_dist, requires_grad=True)
        self.bf = nn.Parameter(zero_tensor, requires_grad=True)

        self.wis1 = nn.Parameter(norm_dist, requires_grad=True)
        self.wii1 = nn.Parameter(norm_dist, requires_grad=True)
        self.bi1 = nn.Parameter(zero_tensor, requires_grad=True)
        
        self.wis2 = nn.Parameter(norm_dist, requires_grad=True)
        self.wii2 = nn.Parameter(norm_dist, requires_grad=True)
        self.bi2 = nn.Parameter(zero_tensor, requires_grad=True)

        self.wos = nn.Parameter(norm_dist, requires_grad=True)
        self.woi = nn.Parameter(norm_dist, requires_grad=True)
        self.bo = nn.Parameter(zero_tensor, requires_grad=True)

    def module(self, inputs, long_memory, short_memory):

        per_long_memory = torch.sigmoid(self.wfs*short_memory + self.wfi*inputs + self.bf)
        prev_long_memory = long_memory * per_long_memory

        perc_long_memory = torch.sigmoid(self.wis1*short_memory + self.wii1*inputs + self.bi1)
        new_long_memory = torch.tanh(self.wis2*short_memory + self.wii2*inputs + self.bi2)

        combined_long_memory = prev_long_memory + (perc_long_memory * new_long_memory)

        new_short_memory = torch.sigmoid(self.wos*short_memory + self.woi*inputs + self.bo)
        combined_short_memory = torch.tanh(combined_long_memory) * new_short_memory

        return ([combined_long_memory, combined_short_memory])
    
    def forward(self, inputs):

        long_memeory = 0
        short_memory = 0
        for i in inputs:
            long_memory, short_memory = self.module(i, long_memeory, short_memory)

        return short_memory
    
    def configure_optimizers(self):

        return Adam(self.parameters())
    
    def training_step(self, batch, batch_idx):
        input_i, label_i = batch
        pred = self.forward(input_i[0])
        loss = (pred - label_i) **2

        self.log('train_loss', loss)

        if (label_i ==0):
            self.log('out_0', pred)
        else:
            self.log('out_1', pred)

        return loss

In [14]:
model = LSTMbyHand()
preds = model(torch.tensor([0, 0.5, 0.25, 1]).detach())
print(f"predictions: {preds}")

predictions: tensor([0.4021], grad_fn=<MulBackward0>)


In [15]:
inputs = torch.tensor([[0., .5, .25, 1.], [1., .5, .25, 1.]])
labels = torch.tensor([0., 1.])
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

In [16]:
trainer = L.Trainer(max_epochs=10000)
trainer.fit(model, train_dataloaders=dataloader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name         | Type | Params | Mode
---------------------------------------------
  | other params | n/a  | 12     | n/a 
---------------------------------------------
12        Trainable params
0         Non-trainable params
12        Total params
0.000     Total estimated model params size (MB)
0         Modules in train mode
0         Modules in eval mode
c:\Users\rahul\anaconda3\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
c:\Users\rahul\anaconda3\Lib\site-packages\lightning\pytorch\loops\fit_loop.py:310: T

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10000` reached.


In [17]:
predictions = model(inputs[0].detach())
predictions

tensor([0.5597], grad_fn=<MulBackward0>)

In [19]:
predictions = model(inputs[1].detach())
predictions

tensor([0.7562], grad_fn=<MulBackward0>)