<h1><center>IST 597 Foundations of Deep Learning</center></h1>

---

<h2><center>Stack Augmented Recurrent Neural Networks</center><h2>
<h3><center>Neisarg Dave</center><h3>

### Stack Augmented LSTM

**LSTM Equations**

$$
  i_t = \sigma(U_ix_t + W_ih_{t-1} + b_i) \\
  f_t = \sigma(U_fx_t + W_fh_{t-1} + b_f) \\
  o_t = \sigma(U_ox_t + W_oh_{t-1} + b_o) 
$$

$$
  \tilde{c}_t = \tanh(U_cx_t + W_ch_{t-1} + b_c)
$$

$$
  c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
$$

$$
  h_t = o_t \odot \tanh(c_t)
$$

**Stack Augmented LSTM**

Let stack $S_t$ represent the state of augmented stack at time $t$. \\
Stack $S_t$ is a 2D Tensor of dimensions: *stack_depth x stack_dim* for each sample \\
The each time step $t$ we use the top most vector in $S_{t-1}$, i.e. $S^{(0)}_{t-1}$ to compute gates and LSTM cell memory. \\

The LSTM Equations get modified to the following equations: \\
 \\
$$
  i_t = \sigma(U_ix_t + W_ih_{t-1} + V_iS^{(0)}_{t-1} + b_i) \\
  f_t = \sigma(U_fx_t + W_fh_{t-1} + V_fS^{(0)}_{t-1} + b_f) \\
  o_t = \sigma(U_ox_t + W_oh_{t-1} + V_oS^{(0)}_{t-1} + b_o) 
$$

$$
  \tilde{c}_t = \tanh(U_cx_t + W_ch_{t-1} + V_cS^{(0)}_{t-1} + b_c)
$$

$$
  c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
$$

$$
  h_t = o_t \odot \tanh(c_t)
$$

**Stack Update Equations**
+ First we compute action vector. Its a vector of size 3 representing probabilities for the three stack options:
  + PUSH
  + POP 
  + NO-OP

$$
a_t = softmax(W_ah_t + b_a)
$$

+ Now we compute the stack vector for PUSH operation

$$
p = sigmoid(W_sh_t + b_s)
$$

+ update stack
$$
S_t = a[0]*[p ; S_{t-1}[:-1]] + a[1]*[S_{t-1}[1:]; 0] + a[2]*S_{t-1}
$$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init, Parameter
import math


class StackLSTMlayer(nn.Module):
    def __init__(self, in_features, out_features, bias=False, dropout=0, eval_dropout = False):
        super(StackLSTMlayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.WX = Parameter(torch.empty(self.in_features, 4*self.out_features))
        self.WH = Parameter(torch.empty(self.out_features, 4*self.out_features))
        self.WS = Parameter(torch.empty(self.out_features, 4*self.out_features))

        self.action = nn.Sequential(nn.Linear(self.out_features, 3), nn.Softmax(dim = -1))
        self.H2S = nn.Sequential(nn.Linear(self.out_features, self.out_features), nn.Sigmoid())

        self.bias = Parameter(torch.empty(4*self.out_features))
        self.reset_parameters()

    def reset_parameters(self):
        for layers in self.action:
            layers.reset_parameters()
    
        for layers in self.H2S:
            layers.reset_parameters()
            
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.WX.t())
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)

        init.normal_(self.WX, mean=0, std=0.1)
        init.normal_(self.WH, mean=0, std=0.1)
        init.normal_(self.WS, mean=0, std=0.1)

    
    def forward(self, input, hidden, stack, mask):
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.inner(input[i], hidden, stack[:, 0])
            stack = self.update_stack(hidden, stack, mask)
    
            output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        return hidden, output, stack

    def update_stack(self, hidden, stack, mask):
        action = self.action(hidden[0]) 
        if mask is not None:
            # we want no op where mask is 0
            mask_no_op = torch.hstack([torch.zeros([mask.shape[0], 2], device=stack.device), (1-mask.type(torch.int)).unsqueeze(dim = -1)])   
            mask = mask.unsqueeze(dim = -1).expand(*action.size())
            action = mask*action + mask_no_op

        action = action.unsqueeze(dim = -1).unsqueeze(dim = -1).expand(stack.shape[0], 3, stack.shape[1], stack.shape[2])

        new_inp = self.H2S(hidden[0]).unsqueeze(dim = 1)
    
        pushed_stack = torch.cat([new_inp, stack[:, :-1]], dim = 1)
        popped_stack = torch.cat([stack[:, 1:], torch.zeros([stack.shape[0], 1, self.out_features], device = stack.device)], dim = 1)
        noop_stack = stack

        stack = action[:, 0]*pushed_stack + action[:, 1]*popped_stack + action[:, 2]*noop_stack

        return stack

    def inner(self, input, hidden, stack):
        h, c = hidden
        wx = torch.matmul(input, self.WX)
        wh = torch.matmul(h, self.WH)
        ws = torch.matmul(stack, self.WS)
        all_sum = wx + wh + ws
        if self.bias is not None:
            all_sum += self.bias
        i, f, g, o = torch.chunk(all_sum, 4, dim=1)
        i, f, g, o = torch.sigmoid(i), torch.sigmoid(f), torch.tanh(g), torch.sigmoid(o)
        c = f*c + i*g
        h = o*torch.tanh(c)
        hidden = (h, c)
        return hidden

### TODO:
+ Implement other memory structures with  RNN
  + Linked List
  + 2 Stacks
  + Tape 

### LSTM Cell behaves like a bounded stack
https://nlp.stanford.edu/~johnhew/rnns-hierarchy.html