In [1]:
import torch
from torch import nn

In [2]:
tensor1=torch.rand((3,3))
tensor2=torch.rand((3,3))

In [3]:
tensor1

tensor([[0.3560, 0.9021, 0.0221],
        [0.3640, 0.4830, 0.8612],
        [0.0660, 0.9377, 0.4393]])

In [4]:
tensor2

tensor([[0.6172, 0.1439, 0.9534],
        [0.6475, 0.3531, 0.9091],
        [0.9373, 0.5034, 0.1555]])

In [5]:
tensor1*tensor2

tensor([[0.2197, 0.1298, 0.0210],
        [0.2357, 0.1705, 0.7829],
        [0.0619, 0.4720, 0.0683]])

#### LTR => Long-Term to remember
#### STM => Short Term Memory
#### LTM => Long Term Memory

In [6]:
class NDIM_LSTM(nn.Module):
    
    def __init__(self,size):
        super().__init__()

        self.percent_ltr_input=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.percent_ltr_stm_wt=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.b1=nn.Parameter(torch.tensor(0.),requires_grad=False)

        self.percent_potential_ltm_stm_wt=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.percent_potential_ltm_input=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.b2=nn.Parameter(torch.tensor(0.),requires_grad=False)
        
        self.potential_ltm_stm_wt=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.potential_ltm_input=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.b3=nn.Parameter(torch.tensor(0.),requires_grad=False)
        
        self.output_stm_contri_stm_wt=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.output_stm_contri_input=nn.Parameter(torch.empty(size).normal_(mean=0.0,std=1.0),requires_grad=True)
        self.b4=nn.Parameter(torch.tensor(0.),requires_grad=False)

    def lstm_unit(self,input_value,long_memory,short_memory):
        
        long_remember_percent=torch.sigmoid((input_value*self.percent_ltr_input)+
                                            (self.percent_ltr_stm_wt*short_memory)+
                                            self.b1)
        
        potential_remember_percent=torch.sigmoid((input_value*self.percent_potential_ltm_input)+
                                                 (short_memory*self.percent_potential_ltm_stm_wt)+
                                                  self.b2)

        potential_memory = torch.tanh((short_memory * self.potential_ltm_stm_wt) + 
                                  (input_value * self.potential_ltm_input) + 
                                  self.b3)
        
        updated_long_memory = ((long_memory * long_remember_percent) + 
               (potential_remember_percent * potential_memory))

        output_percent = torch.sigmoid((short_memory * self.output_stm_contri_stm_wt) + 
                                       (input_value * self.output_stm_contri_input) + 
                                       self.b4)         
        
        updated_short_memory = torch.tanh(updated_long_memory) * output_percent

        return([updated_long_memory, updated_short_memory])

    def forward(self, input, long_memory=0, short_memory=0): 
        
        return self.lstm_unit(input,long_memory,short_memory)

In [7]:
input_test=torch.rand((5,64,64,64))

In [10]:
model=NDIM_LSTM((64,64,64))

In [11]:
ltm=0
stm=0
for i in input_test:
    ltm,stm=model(i,ltm,stm)
    print(stm.detach().shape, ltm.detach().shape)

torch.Size([64, 64, 64]) torch.Size([64, 64, 64])
torch.Size([64, 64, 64]) torch.Size([64, 64, 64])
torch.Size([64, 64, 64]) torch.Size([64, 64, 64])
torch.Size([64, 64, 64]) torch.Size([64, 64, 64])
torch.Size([64, 64, 64]) torch.Size([64, 64, 64])


In [12]:
ltm

tensor([[[ 4.5520e-01,  1.4852e-02,  6.3607e-02,  ...,  2.6028e-01,
           7.8754e-01, -4.1256e-01],
         [ 2.4917e-02,  2.5937e-01,  6.7378e-02,  ...,  9.8078e-02,
          -4.6319e-02,  4.2936e-02],
         [ 4.5282e-02,  3.0918e-01, -5.2139e-01,  ..., -3.4812e-04,
          -1.4359e+00, -2.5025e-01],
         ...,
         [ 1.5711e-01, -1.7145e-01,  1.0855e-01,  ...,  5.7627e-01,
          -5.1699e-01,  1.1921e+00],
         [-3.9778e-01,  3.9125e-01,  1.5974e-01,  ..., -6.1198e-02,
          -4.6742e-01, -1.1781e+00],
         [ 5.5156e-01,  5.9117e-01, -9.8318e-02,  ..., -1.3501e-01,
           2.2698e-01,  2.6609e-01]],

        [[-3.0617e-01,  1.1225e+00, -6.4980e-01,  ..., -1.0952e+00,
           1.6681e-01, -2.6568e-01],
         [-2.7932e-01, -1.4630e-01, -5.8973e-02,  ...,  7.0725e-01,
           3.6745e-01, -9.5850e-01],
         [ 2.9291e-01,  4.0856e-01,  1.0194e-01,  ...,  2.4314e-01,
           2.8869e-01, -2.7344e-01],
         ...,
         [ 2.2590e-01, -4