In [None]:
import torch, datasets, math
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [None]:
class RNN_cell(nn.Module):
  def __init__(self, input_dim, hidden_dim):
    super().__init__()

    self.hidden_dim = hidden_dim

    self.U_g = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
    self.W_g = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
    self.b_g = nn.Parameter(torch.Tensor(hidden_dim))

    self.init_weights()

  def init_weights(self):
        #If hidden_dim = 16, then stdv = 1 / sqrt(16) = 0.25, so weights are sampled from [-0.25, 0.25].
    std = 1 / math.sqrt(self.hidden_dim)

    for weight in self.parameters():
        weight.data.uniform_(-std, std)

  def forward(self, x, init_state=None):

    bs, seq_len,_ = x.shape
    output = []


    if init_state is None:
        h_t = torch.zeros(bs, self.hidden_dim).to(x.device)
    else:
      h_t = init_state


    for t in range(seq_len):
      x_t = x[:,t,:] # all sequences at time step t

      h_t = torch.tanh(  h_t @ self.W_g + x_t @ self.U_g + self.b_g) #bs, hidden

      output.append(h_t.unsqueeze(0)) #(1, bs, hidden)

    output = torch.cat(output, dim=0)
    output = output.transpose(0, 1).contiguous()

    return output, h_t

In [None]:
class LSTM_cell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
      super().__init__()

      self.hidden_dim = hidden_dim

      #Trainable parameters
      self.U_i = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
      self.W_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
      self.b_i = nn.Parameter(torch.Tensor(hidden_dim))


      self.U_f = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
      self.W_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
      self.b_f = nn.Parameter(torch.Tensor(hidden_dim))

      self.U_g = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
      self.W_g = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
      self.b_g = nn.Parameter(torch.Tensor(hidden_dim))


      self.U_o = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
      self.W_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
      self.b_o = nn.Parameter(torch.Tensor(hidden_dim))


      self.init_weights()


    def init_weights(self):
      #If hidden_dim = 16, then stdv = 1 / sqrt(16) = 0.25, so weights are sampled from [-0.25, 0.25].
      std = 1 / math.sqrt(self.hidden_dim)

      for weight in self.parameters():
          weight.data.uniform_(-std, std)


    def forward(self, x, init_state=None):

      bs, seq_len, _ = x.shape
      output = []

      if init_state is None:
        h_t = torch.zeros(bs, self.hidden_dim).to(x.device)
        c_t = torch.zeros(bs, self.hidden_dim).to(x.device)

      else:
        h_t, c_t = init_state

      for t in range(seq_len):
          x_t = x[:,t,:] # all sequences at time step t

          f_t = torch.sigmoid( h_t@ self.W_f + x_t @ self.U_f + self.b_f)
          i_t = torch.sigmoid( h_t@ self.W_i + x_t @ self.U_i + self.b_i)
          o_t = torch.sigmoid( h_t@ self.W_o + x_t @ self.U_o + self.b_o)

          g_t = torch.tanh(h_t@ self.W_g + x_t @ self.U_g + self.b_g)

          c_t = (f_t * c_t) + (i_t * g_t)

          h_t = o_t * torch.tanh(c_t)

          output.append(h_t.unsqueeze(0)) #h_t -> (1, batch_size, hidden_dim)

      output = torch.cat(output, dim=0)
      output = output.transpose(0, 1).contiguous()

      return output, (h_t, c_t)


In [None]:
#some hyperparameters
input_dim = 5000  #just for example
hidden_dim = 256
embed_dim = 300
output_dim = 1

batch_size = 32

In [None]:
my_RNN_cell = RNN_cell(embed_dim, hidden_dim).to(device)

test_data = torch.ones((batch_size, 100, embed_dim)).to(device)
output, h_t = my_RNN_cell(test_data)

assert output.shape == torch.Size([32, 100, 256])
assert h_t.shape    == torch.Size([32, 256])

In [None]:
my_LSTM_cell = LSTM_cell(embed_dim, hidden_dim).to(device)

test_data = torch.ones((batch_size, 100, embed_dim)).to(device)
output, (h_t, c_t) = my_LSTM_cell(test_data)

assert output.shape == torch.Size([32, 100, 256])
assert h_t.shape    == torch.Size([32, 256])
assert c_t.shape    == torch.Size([32, 256])

**1.2 Peephole LSTM**

1.2 Peephole LSTM
One popular LSTM variant, introduced by Gers & Schmidhuber (2000), is adding “Peephole Connections.” This means that we let all the gate layers look at the cell state.


**Coupled LSTM**


Another variation is to use Coupled forget and input gates. Instead of separately deciding what to forget and what we should add new information to, we make those decisions together. We only forget when we’re going to input something in its place. We only input new values to the state when we forget something older. The difference is very simple. The input gate is now (1 - ft )

In [None]:
class new_LSTM_cell(nn.Module):
  def __init__(self, input_dim, hidden_dim, lstm_type):
     super().__init__()


     self.hidden_dim = hidden_dim
     self.lstm_type = lstm_type


     self.U_i = nn.Parameter(torch.Tensor(input_dim, hidden_dim ))
     self.W_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
     self.b_i = nn.Parameter(torch.Tensor(hidden_dim))

     self.U_f = nn.Parameter(torch.Tensor(input_dim, hidden_dim ))
     self.W_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
     self.b_f = nn.Parameter(torch.Tensor(hidden_dim))


     self.U_g = nn.Parameter(torch.Tensor(input_dim, hidden_dim ))
     self.W_g = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
     self.b_g = nn.Parameter(torch.Tensor(hidden_dim))


     self.U_o = nn.Parameter(torch.Tensor(input_dim, hidden_dim ))
     self.W_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
     self.b_o = nn.Parameter(torch.Tensor(hidden_dim))


     if self.lstm_type == 'Peephole':
          self.P_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
          self.P_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
          self.P_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))


     self.init_weights()

  def init_weights(self):
    std = 1/math.sqrt(self.hidden_dim)

    for weight in self.parameters():
        weight.data.uniform_(-std, std)


  def forward(self, x, init_state=None):
    bs, seq_len, _ = x.shape
    output = []

    if init_state is None:
      h_t = torch.zeros(bs, self.hidden_dim).to(x.device)
      c_t = torch.zeros(bs, self.hidden_dim).to(x.device)
    else:
      h_t, c_t = init_state



    for t in range(seq_len):
      x_t = x[:,t,:]

      if self.lstm_type in ['Vanilla', 'Coupled']:
        f_t = torch.sigmoid( h_t  @ self.W_f + x_t  @ self.U_f + self.b_f )
        o_t = torch.sigmoid( h_t @ self.W_o + x_t @ self.U_o + self.b_o)

        if self.lstm_type == 'Vanilla':
          i_t = torch.sigmoid( h_t  @ self.W_i + x_t  @ self.U_i + self.b_i )
        if self.lstm_type == 'Coupled':
          i_t = (1 - f_t)

      if self.lstm_type == 'Peephole':
        f_t = torch.sigmoid( h_t @ self.W_f + x_t @ self.U_f + c_t @ self.P_f + self.b_f)
        o_t = torch.sigmoid( h_t @ self.W_o + x_t @ self.U_o + c_t @ self.P_o + self.b_o)
        i_t = torch.sigmoid( h_t @ self.W_i + x_t @ self.U_i + c_t @ self.P_i + self.b_i)



      g_t = torch.tanh(h_t @ self.W_g + x_t @ self.U_g + self.b_g)

      c_t = (i_t * g_t) + (f_t * c_t)
      h_t = o_t * torch.tanh(c_t)


      output.append(h_t.unsqueeze(0)) # 1, bs, hidden_dim


    output = torch.cat(output, 0)
    output = output.transpose(0, 1).contiguous()


    return output, (h_t, c_t)


In [None]:
Vanilla_LSTM_cell = new_LSTM_cell(embed_dim, hidden_dim, lstm_type = 'Vanilla').to(device)
test_data = torch.ones((batch_size, 100, embed_dim)).to(device)
output, (h_t, c_t) = Vanilla_LSTM_cell(test_data)
assert output.shape == torch.Size([32, 100, 256])
assert h_t.shape    == torch.Size([32, 256])
assert c_t.shape    == torch.Size([32, 256])

Coupled_LSTM_cell = new_LSTM_cell(embed_dim, hidden_dim, lstm_type = 'Coupled').to(device)
test_data = torch.ones((batch_size, 100, embed_dim)).to(device)
output, (h_t, c_t) = Coupled_LSTM_cell(test_data)
assert output.shape == torch.Size([32, 100, 256])
assert h_t.shape    == torch.Size([32, 256])
assert c_t.shape    == torch.Size([32, 256])

Peephole_LSTM_cell = new_LSTM_cell(embed_dim, hidden_dim, lstm_type = 'Peephole').to(device)
test_data = torch.ones((batch_size, 100, embed_dim)).to(device)
output, (h_t, c_t) = Peephole_LSTM_cell(test_data)
assert output.shape == torch.Size([32, 100, 256])
assert h_t.shape    == torch.Size([32, 256])
assert c_t.shape    == torch.Size([32, 256])