In [1]:
import torch
import torch.nn as nn

In [87]:
torch.manual_seed(42)

<torch._C.Generator at 0x79179b4d1810>

In [88]:
class SimpleLSTM(nn.Module):
  def __init__(self, input_size: int = 512, hidden_size: int = 64, bias: bool = False):
    super(SimpleLSTM, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.bias = bias

    self.W_f = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size, bias = bias)
    self.W_i = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size, bias = bias)
    self.W_o = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size, bias = bias)
    self.W_c = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size, bias = bias)

  def forward(self, x: torch.Tensor):
    if isinstance(x, torch.Tensor):
      h_t = torch.zeros((x.size(0), self.hidden_size))
      c_t = torch.zeros((x.size(0), self.hidden_size))

      for sequence in range(x.size(1)):
        combined = torch.cat((x[:, sequence, :], h_t), dim = 1)

        f_t = torch.sigmoid(self.W_f(combined))
        i_t = torch.sigmoid(self.W_i(combined))
        o_t = torch.sigmoid(self.W_o(combined))
        c_tilde_t = torch.tanh(self.W_c(combined))

        c_t = (f_t * c_t) + (i_t * c_tilde_t)
        h_t = o_t * torch.tanh(c_t)

      return h_t, c_t

    else:
      raise ValueError("Input must be a torch.Tensor".capitalize())

In [89]:
x = torch.randint(0, 128, (64, 128, 512))

In [90]:
x.size()

torch.Size([64, 128, 512])

In [91]:
lstm = SimpleLSTM(
    input_size = 512,
    hidden_size = 32,
    bias = False
)

In [92]:
hidden, final = lstm(x)

In [93]:
hidden.size()

torch.Size([64, 32])

In [94]:
final.size()

torch.Size([64, 32])

In [95]:
lstm(x)

(tensor([[-1.1340e-10, -7.4174e-01,  5.8275e-16,  ..., -0.0000e+00,
           1.9973e-14, -1.3326e-11],
         [-2.0052e-06, -9.9915e-01,  8.0094e-04,  ..., -4.7934e-25,
           2.4143e-13, -1.0000e+00],
         [ 3.0224e-17, -7.6159e-01,  1.9764e-19,  ..., -2.2767e-18,
           9.6616e-01, -1.0000e+00],
         ...,
         [ 1.7695e-09, -3.3399e-01, -1.8983e-20,  ..., -2.6970e-21,
           8.0304e-16, -1.0000e+00],
         [-1.7508e-01, -9.1771e-01,  6.6451e-10,  ..., -1.3309e-36,
           2.5826e-03, -2.2654e-01],
         [-3.0249e-01, -7.5947e-01,  2.3873e-11,  ..., -1.8556e-32,
           9.5886e-10, -4.9908e-11]], grad_fn=<MulBackward0>),
 tensor([[-3.3853e-06, -9.5434e-01,  1.9263e+01,  ..., -6.9183e+00,
           1.7156e+01, -2.5827e+01],
         [-9.8162e-01, -3.8802e+00,  2.9014e+00,  ..., -7.2786e+00,
           3.9862e+00, -3.7735e+01],
         [ 5.7069e-10, -1.0000e+00,  1.1080e+01,  ..., -1.3195e+01,
           2.0311e+00, -2.9857e+01],
         ...,
 

In [96]:
torch.manual_seed(42)

<torch._C.Generator at 0x79179b4d1810>

In [97]:
class LSTMCells(nn.Module):
  def __init__(self, input_size: int = 512, hidden_size: int =  32):
    super(LSTMCells, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size

    self.W_f = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size)
    self.W_i = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size)
    self.W_o = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size)
    self.W_c = nn.Linear(in_features=input_size+hidden_size, out_features=hidden_size)

  def forward(self, combined_h_t_and_x_t: torch.Tensor, c_t):
    if isinstance(combined_h_t_and_x_t, torch.Tensor):
      f_t = torch.sigmoid(self.W_f(combined_h_t_and_x_t))
      i_t = torch.sigmoid(self.W_i(combined_h_t_and_x_t))
      o_t = torch.sigmoid(self.W_o(combined_h_t_and_x_t))
      c_tilde_t = torch.tanh(self.W_c(combined_h_t_and_x_t))

      c_t = (f_t * c_t) + (i_t * c_tilde_t)
      h_t = o_t * torch.tanh(c_t)

      return h_t, c_t

    else:
      raise ValueError("Input must be a torch.Tensor".capitalize())

In [98]:
class SimpleLSTMScratch(nn.Module):
  def __init__(self, input_size: int = 512, hidden_size: int = 32, bias: bool = False):
    super(SimpleLSTMScratch, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.bias = bias

    self.lstm_cell = LSTMCells(
        input_size = input_size,
        hidden_size = hidden_size
    )

  def forward(self, x: torch.Tensor):
    if isinstance(x, torch.Tensor):
      batch_size, sequence_size, dimension = x.size()

      h_t = torch.zeros((batch_size, self.hidden_size))
      c_t = torch.zeros((batch_size, self.hidden_size))

      for sequence in range(sequence_size):
        combined_h_t_and_x_t = torch.cat((x[:, sequence, :], h_t), dim = 1)
        h_t, c_t = self.lstm_cell(combined_h_t_and_x_t, c_t)

        return (h_t, c_t)

    else:
      raise ValueError("Input must be a torch.Tensor".capitalize())

In [99]:
lstm = SimpleLSTMScratch(
    input_size = 512,
    hidden_size = 32,
    bias = False
)

In [100]:
lstm(x)

(tensor([[-7.6159e-01, -1.0474e-25,  1.9339e-23,  ...,  2.7926e-33,
           7.6159e-01,  3.8541e-01],
         [-4.9133e-11, -1.3033e-11,  1.5515e-32,  ...,  1.8470e-07,
           1.1286e-10,  7.6159e-01],
         [-7.6159e-01, -1.1087e-07,  8.1617e-21,  ...,  4.9214e-21,
           7.6159e-01,  7.6159e-01],
         ...,
         [-7.6065e-01, -5.6067e-10, -7.8415e-18,  ..., -2.7846e-08,
           7.6157e-01,  7.4912e-01],
         [-4.5788e-01, -1.5258e-13, -3.1972e-29,  ..., -4.1658e-01,
           7.5800e-01,  6.9344e-01],
         [-7.6155e-01, -7.0166e-01,  5.5753e-16,  ...,  7.1014e-23,
           2.2679e-07,  2.0069e-20]], grad_fn=<MulBackward0>),
 tensor([[-1.0000e+00, -1.0484e-25,  1.9339e-23,  ...,  1.0000e+00,
           1.0000e+00,  4.0640e-01],
         [-4.9133e-11, -4.0002e-01,  1.5515e-32,  ...,  1.0000e+00,
           1.1286e-10,  1.0000e+00],
         [-1.0000e+00, -1.1087e-07,  2.1341e-15,  ...,  1.0000e+00,
           1.0000e+00,  1.0000e+00],
         ...,
 

In [101]:
import torch

# Set a fixed random seed
torch.manual_seed(42)

# Generate a random tensor
tensor = torch.randn(3, 3)
print(tensor)


tensor([[ 0.3367,  0.1288,  0.2345],
        [ 0.2303, -1.1229, -0.1863],
        [ 2.2082, -0.6380,  0.4617]])


In [102]:
# Generate a random tensor
tensor = torch.randn(3, 3)
print(tensor)

tensor([[ 0.2674,  0.5349,  0.8094],
        [ 1.1103, -1.6898, -0.9890],
        [ 0.9580,  1.3221,  0.8172]])
