# Peephole LSTM

Given an implementation of an LSTM module:
\begin{align}
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
g_t = tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{go}) \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
h_t = o_t \odot tanh(c_t)
\end{align}


Your task is to modify the implementaiton to add [peephole connections](https://en.wikipedia.org/wiki/Long_short-term_memory#Peephole_LSTM) according to:

\begin{align}
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{ci} c_{t-1} + b_{ci}) \\
f_t = \sigma(W_{if} x_t + b_{if} + W_{cf} c_{t-1} + b_{cf}) \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{co} c_{t-1} + b_{co}) \\
c_t = f_t \odot c_{t-1} + i_t \odot tanh(W_{ic} x_t + b_{ic}) \\
h_t = o_t \odot c_t
\end{align}

In [1]:
import typing
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, batch_first: bool):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_first = batch_first
        
        #input gate
        self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_ii = nn.Parameter(torch.Tensor(hidden_size))
        self.b_hi = nn.Parameter(torch.Tensor(hidden_size))
        
        #forget gate
        self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_if = nn.Parameter(torch.Tensor(hidden_size))
        self.b_hf = nn.Parameter(torch.Tensor(hidden_size))
        
        #output gate c_t
        self.W_ig = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hg = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_ig = nn.Parameter(torch.Tensor(hidden_size))
        self.b_hg = nn.Parameter(torch.Tensor(hidden_size))
        
        #output gate h_t
        self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_io = nn.Parameter(torch.Tensor(hidden_size))
        self.b_ho = nn.Parameter(torch.Tensor(hidden_size))

        self._init_parameters()

    def _init_parameters(self):
        for param in self.parameters():
            torch.nn.init.normal_(param)

    def forward(self, x: torch.Tensor, hx: typing.Optional[typing.Tuple[torch.Tensor, torch.Tensor]] = None) -> typing.Tuple[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]:
        
        if not self.batch_first:
            x = x.permute(1,0,2).contiguous()

        batch_size = x.size(0)
        sequence_length = x.size(1)

        if hx is None:
            h_t, c_t = (
                torch.zeros(batch_size, self.hidden_size).to(x.device),
                torch.zeros(batch_size, self.hidden_size).to(x.device),
            )
        else:
            h_t, c_t = hx

        output = []
        
        for t in range(sequence_length):
            x_t = x[:, t, :]
            # input gate            
            i_t = torch.sigmoid(x_t @ self.W_ii + self.b_ii + h_t @ self.W_hi + self.b_hi)
            # forget gate
            f_t = torch.sigmoid(x_t @ self.W_if + self.b_if + h_t @ self.W_hf + self.b_hf)
            # output gate
            g_t = torch.tanh(x_t @ self.W_ig + self.b_ig + h_t @ self.W_hg + self.b_hg)
            o_t = torch.sigmoid(x_t @ self.W_io + self.b_io + h_t @ self.W_ho + self.b_ho)
            
            # output
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)

            output.append(h_t.unsqueeze(0))

        output = torch.cat(output, dim=0)

        if not self.batch_first:
            output = output.permute(1,0,2).contiguous()

        return output, (h_t, c_t)


In [3]:
torch.manual_seed(0)
a = torch.randn((5,10, 3))
lstm = LSTM(3, 7, True)
print(lstm(a)[0].size(), lstm(a)[1][0].size(), lstm(a)[1][1].size())
print(lstm(a))

torch.Size([10, 5, 7]) torch.Size([5, 7]) torch.Size([5, 7])
(tensor([[[ 6.2072e-01, -7.7352e-02,  5.6017e-01, -5.6047e-01,  3.5405e-04,
           5.4798e-03,  3.5325e-01],
         [ 7.1540e-01, -5.3192e-01,  9.5980e-02, -1.9399e-01,  1.1296e-02,
           4.8099e-02, -4.7105e-02],
         [ 6.1332e-01, -1.1772e-05,  5.9567e-01, -5.4451e-01,  1.6074e-04,
           1.0339e-02,  5.2595e-01],
         [ 3.0238e-01, -1.4361e-01,  8.4637e-02, -6.3363e-01,  4.3879e-04,
          -9.2375e-02,  5.3165e-01],
         [ 6.0016e-01, -9.1497e-05,  4.1435e-01, -2.1903e-01,  3.6609e-03,
           2.7239e-02,  5.3316e-02]],

        [[ 5.5158e-01, -1.1870e-02,  6.4174e-01, -1.0644e-01,  1.1541e-02,
           1.1945e-01, -1.5863e-01],
         [ 6.4038e-01, -7.3400e-02,  3.9188e-01, -4.2518e-01,  1.0513e-02,
           4.5897e-02, -3.7773e-01],
         [ 7.0326e-01, -6.1242e-04,  8.7785e-01, -6.8674e-01,  1.2353e-03,
           2.7056e-02,  1.7479e-01],
         [ 5.2825e-01, -1.5968e-03,  5.0

In [4]:
class LSTMPiphole(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, batch_first: bool):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_first = batch_first
        
        #input gate
        self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_ci = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_ii = nn.Parameter(torch.Tensor(hidden_size))
        self.b_ci = nn.Parameter(torch.Tensor(hidden_size))
        
        #forget gate
        self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_cf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_if = nn.Parameter(torch.Tensor(hidden_size))
        self.b_cf = nn.Parameter(torch.Tensor(hidden_size))
        
        #output gate c_t
        self.W_ic = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.b_ic = nn.Parameter(torch.Tensor(hidden_size))

        
        #output gate h_t
        self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_co = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_io = nn.Parameter(torch.Tensor(hidden_size))
        self.b_co = nn.Parameter(torch.Tensor(hidden_size))

        self._init_parameters()

    def _init_parameters(self):
        for param in self.parameters():
            torch.nn.init.normal_(param)

    def forward(self, x: torch.Tensor, hx: typing.Optional[typing.Tuple[torch.Tensor, torch.Tensor]] = None) -> typing.Tuple[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]:
        #################################
        if not self.batch_first:
            x = x.permute(1,0,2).contiguous()

        batch_size = x.size(0)
        sequence_length = x.size(1)

        if hx is None:
            h_t, c_t = (
                torch.zeros(batch_size, self.hidden_size).to(x.device),
                torch.zeros(batch_size, self.hidden_size).to(x.device),
            )
        else:
            h_t, c_t = hx

        output = []
        
        for t in range(sequence_length):
            x_t = x[:, t, :]
            # input gate            
            i_t = torch.sigmoid(x_t @ self.W_ii + self.b_ii + c_t @ self.W_ci + self.b_ci)
            # forget gate
            f_t = torch.sigmoid(x_t @ self.W_if + self.b_if + c_t @ self.W_cf + self.b_cf)
            # output gate
            o_t = torch.sigmoid(x_t @ self.W_io + self.b_io + c_t @ self.W_co + self.b_co)
            
            # output
            c_t = f_t * c_t + i_t * torch.tanh(x_t @ self.W_ic + self.b_ic)
            h_t = o_t * c_t

            output.append(h_t.unsqueeze(0))

        output = torch.cat(output, dim=0)

        if not self.batch_first:
            output = output.permute(1,0,2).contiguous()
        #################################
        return output, (h_t, c_t)

In [5]:
torch.manual_seed(0)
a = torch.randn((5,10, 3))
lstm = LSTMPiphole(3, 7, True)
print(lstm(a)[0].size(), lstm(a)[1][0].size(), lstm(a)[1][1].size())
print(lstm(a))

torch.Size([10, 5, 7]) torch.Size([5, 7]) torch.Size([5, 7])
(tensor([[[-9.2103e-02, -2.0233e-02,  4.7249e-01, -5.2181e-01,  2.2527e-02,
          -3.2304e-01,  5.1055e-01],
         [ 6.5202e-01, -1.7289e-02, -2.1486e-01, -1.6997e-01,  3.2587e-02,
           1.0264e-01,  4.0168e-01],
         [-1.0172e-01, -5.2187e-03,  2.6390e-01, -9.2172e-01,  1.2160e-02,
           2.4528e-02,  2.9496e-01],
         [-1.0936e-01,  1.3243e-02, -5.5547e-02, -2.0006e-01,  7.0483e-03,
          -6.1816e-01,  7.1652e-01],
         [-2.8404e-02, -1.2403e-02,  6.1087e-02,  8.9413e-01,  6.6880e-02,
           2.6533e-02,  1.9026e-01]],

        [[-1.9611e-03, -8.3990e-03,  9.5282e-02, -5.9436e-02,  2.0709e-03,
          -1.4598e-01,  1.9008e-01],
         [ 1.5443e-01, -1.0935e-02, -2.3851e-01, -2.8001e-01,  1.0479e-02,
           1.5410e-01,  6.8256e-01],
         [-3.8857e-02, -2.1919e-03,  1.2440e+00, -6.0480e-01,  7.5997e-03,
           1.3234e-02,  1.7157e-01],
         [-3.2503e-02,  7.7197e-03,  1.6