# Imports

In [1]:
import numpy as np
import torch
import torch.nn as nn

In [2]:
%load_ext lab_black

# Simple LSTM

In [3]:
input_sz = 20
hidden_sz = 100
batch_sz = 16
seq_len = 50
num_layers = 1

In [4]:
X = np.random.randn(seq_len, batch_sz, input_sz).astype(np.float32)
h0 = np.random.randn(num_layers, batch_sz, hidden_sz).astype(np.float32)
c0 = np.random.randn(num_layers, batch_sz, hidden_sz).astype(np.float32)
X.shape, h0.shape, c0.shape

((50, 16, 20), (1, 16, 100), (1, 16, 100))

In [6]:
pytorch_cell = nn.LSTMCell(input_sz, hidden_sz, bias=True)
print("Weights shape:")
print(pytorch_cell.weight_ih.shape)
print(pytorch_cell.weight_hh.shape)

print("\nBiases shape:")
print(pytorch_cell.bias_ih.shape)
print(pytorch_cell.bias_hh.shape)

Weights shape:
torch.Size([400, 20])
torch.Size([400, 100])

Biases shape:
torch.Size([400])
torch.Size([400])


In [7]:
pytorch_lstm = nn.LSTM(
    input_sz, hidden_sz, bias=True, num_layers=num_layers, batch_first=False
)
print("Shapes:")
print([param.shape for param in pytorch_lstm.parameters()])

Shapes:
[torch.Size([400, 20]), torch.Size([400, 100]), torch.Size([400]), torch.Size([400])]


In [10]:
out, (h_t, c_t) = pytorch_lstm(torch.tensor(X))
out.shape, h_t.shape, c_t.shape

(torch.Size([50, 16, 100]), torch.Size([1, 16, 100]), torch.Size([1, 16, 100]))

In [11]:
def lstm_cell(x, h, c, w_ih, w_hh, b):
    i, f, g, o = np.split(x @ w_ih + h @ w_hh + b, 4, axis=1)
    i, f, g, o = sigmoid(i), sigmoid(f), np.tanh(g), sigmoid(o)
    c = f * c + i * g
    h = o * np.tanh(c)
    return h, c

In [12]:
def lstm(x, h, c, w_ih, w_hh, b):
    batch_sz = x.shape[1]
    seq_len = x.shape[0]
    hidden_sz = h.shape[-1]
    H = np.zeros((seq_len, batch_sz, hidden_sz))
    for t in range(seq_len):
        h, c = lstm_cell(x[t], h, c, w_ih, w_hh, b)
        H[t] = h
    return H, (h, c)

In [15]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [16]:
H_pytorch, (h_pytorch, c_pytorch) = pytorch_lstm(
    torch.tensor(X), (torch.tensor(h0), torch.tensor(c0))
)

In [17]:
H, (h, c) = lstm(
    X,
    h0[0],
    c0[0],
    pytorch_lstm.weight_ih_l0.detach().numpy().T,
    pytorch_lstm.weight_hh_l0.detach().numpy().T,
    (pytorch_lstm.bias_ih_l0 + pytorch_lstm.bias_hh_l0).detach().numpy(),
)

In [18]:
print(
    np.linalg.norm(H_pytorch.detach().numpy() - H),
    np.linalg.norm(h_pytorch.detach().numpy() - h),
    np.linalg.norm(c_pytorch.detach().numpy() - c),
)

3.0083196651888068e-06 4.0780324e-07 7.0195557e-07


In [None]:
def train_lstm(x, h, c):
    for i in range(num_layers):
        H, (h, c) = lstm(x, h, c, parameters[i])