In [18]:
import torch
import torch.nn as nn
import wandb
import os
from dotenv import load_dotenv

from LSTM import LSTM

load_dotenv()
WANDB_API_KEY = os.getenv('WANDB_API_KEY')


In [42]:
torch.manual_seed(0)

seq_len = 5
input_size = 4
hidden_size = 3

model = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1)

seq = [torch.randn(1, input_size) for _ in range(seq_len)]
init_hidden = (torch.randn(1, 1, hidden_size), 
               torch.randn(1, 1, hidden_size))

hidden = init_hidden

# Iterating over each element in the sequence one-by-one
for element in seq:
    
    # Note on view vs reshape: https://stackoverflow.com/a/49644300
    assert element.shape == (1, 4)
    assert element.view(1, 1, -1).shape == (1, 1, 4)
    out, hidden = model(element.view(1, 1, -1), hidden)
    
    # out: output features from the last layer of lstm
    assert out.shape == (1, 1, hidden_size)
    
    # hidden will contain 
    assert len(hidden) == 2 # same as hidden state
    assert hidden[0].shape == (1, 1, hidden_size)
    
    print("out:", out.detach())
    print("hiddens:", hidden[0].detach(), hidden[1].detach())
    print()



(tensor([[[0.7680, 0.0571, 0.2240]]]), tensor([[[ 0.5520, -0.5788,  0.0177]]]))
out: tensor([[[-0.0298, -0.0681,  0.1012]]])
hiddens: tensor([[[-0.0298, -0.0681,  0.1012]]]) tensor([[[-0.0844, -0.1968,  0.4289]]])

out: tensor([[[-0.0616, -0.0777,  0.0146]]])
hiddens: tensor([[[-0.0616, -0.0777,  0.0146]]]) tensor([[[-0.1553, -0.2892,  0.0896]]])

out: tensor([[[-0.0618, -0.0858, -0.0177]]])
hiddens: tensor([[[-0.0618, -0.0858, -0.0177]]]) tensor([[[-0.0883, -0.6379, -0.3502]]])

out: tensor([[[ 0.0625,  0.0239, -0.2697]]])
hiddens: tensor([[[ 0.0625,  0.0239, -0.2697]]]) tensor([[[ 0.0711,  0.1821, -0.5060]]])

out: tensor([[[-0.1262,  0.1365, -0.0766]]])
hiddens: tensor([[[-0.1262,  0.1365, -0.0766]]]) tensor([[[-0.4086,  0.5347, -0.2022]]])



In [48]:
# Passing all sequence at once
assert torch.cat(seq).shape == (5, 4)
inputs = torch.cat(seq).view(seq_len, 1, -1)
assert inputs.shape == (5, 1, 4)
hidden = init_hidden

out, hidden = model(inputs, hidden)

assert out.shape == (seq_len, 1, hidden_size)

print("out:", out.detach())
print("hiddens:", hidden[0].detach(), hidden[1].detach())


out: tensor([[[-0.0298, -0.0681,  0.1012]],

        [[-0.0616, -0.0777,  0.0146]],

        [[-0.0618, -0.0858, -0.0177]],

        [[ 0.0625,  0.0239, -0.2697]],

        [[-0.1262,  0.1365, -0.0766]]])
hiddens: tensor([[[-0.1262,  0.1365, -0.0766]]]) tensor([[[-0.4086,  0.5347, -0.2022]]])
