<h1><center>IST 597 Foundations of Deep Learning</center></h1>

---

<h2><center>Recurrent Neural Networks</center><h2>
<h3><center>Neisarg Dave</center><h3>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init, Parameter
import math

### Long Short Term Memory (LSTM)
+ RNN cell composed of gates
+ 3 gates
  + Input Gate
  + Output Gate
  + Forget Gate
+ Hidden State is a tuple: $(h, c)$
+ $h$ acts as the representative of RNN state at given step
+ $c$ acts as a scratchpad to erase and write something at each step

+ Gate equations:
$$
  i_t = \sigma(U_ix_t + W_ih_{t-1} + b_i) \\
  f_t = \sigma(U_fx_t + W_fh_{t-1} + b_f) \\
  o_t = \sigma(U_ox_t + W_oh_{t-1} + b_o) 
$$

+ New information to be updated on $c$

$$
  \tilde{c}_t = \tanh(U_cx_t + W_ch_{t-1} + b_c)
$$

+ Update information to scratchpad

$$
  c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
$$

+ Update state vector

$$
  h_t = o_t \odot \tanh(c_t)
$$

+ **Peephole Architecture**
  + Let the gate equations look at the cell scratchpad $c$. 
  + $c$ here now acts as cell state along with $h$
  $$
  i_t = \sigma(U_ix_t + W_ih_{t-1} + V_ic_{t-1} + b_i) \\
  f_t = \sigma(U_fx_t + W_fh_{t-1} + V_fc_{t-1} + b_f) \\
  o_t = \sigma(U_ox_t + W_oh_{t-1} + V_oc_{t} + b_o)
  $$

+ Torch API
  + https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
  + https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html

In [None]:
class LSTM(nn.Module):
    def __init__(self, in_features, out_features):
        super(LSTM, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.U = Parameter(torch.empty(self.in_features, 4*self.out_features))
        self.W = Parameter(torch.empty(self.out_features, 4*self.out_features))
        self.bias = Parameter(torch.empty(4*self.out_features))

        self.reset_parameters()

    def reset_parameters(self):
        init.normal_(self.U, mean=0, std=0.1)
        init.normal_(self.W, mean=0, std=0.1)

        fan_in, _ = init._calculate_fan_in_and_fan_out(self.U.t())
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)


    def forward(self, input, hidden, reverse = False):
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.inner(input[i], hidden)
            output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        return hidden, output

    def inner(self,  input, hidden):
        h,c = hidden
        all_sum = torch.matmul(input, self.U) + torch.matmul(h, self.W) + self.bias
        i, f, c_hat, o = torch.chunk(all_sum, 4, dim=1)
        i, f, c_hat, o = torch.sigmoid(i), torch.sigmoid(f), torch.tanh(c_hat), torch.sigmoid(o)
        c = f*c + i*c_hat
        h = o*torch.tanh(c)
        hidden = (h, c)
        return hidden

#### Gated Recurrent Unit (GRU)
+ Simplified Gated Architecture
+ Hidden State is represented by only one vector 
+ Two gates : $z$ and $r$
$$
  z_t = \sigma(U_zx_t + W_zh_{t-1} + b_z) \\
  r_t = \sigma(U_rx_t + W_rh_{t-1} + b_r) \\
$$

+ State Update 
$$
  \tilde{h}_t = tanh(U_hx_t + W_h(r_t \odot h_{t-1}) + b_h) \\ 
  h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t 
$$

In [5]:
class GRU(nn.Module):
    def __init__(self, in_features, out_features):
        super(GRU, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.U = Parameter(torch.empty(self.in_features, 2*self.out_features))
        self.W = Parameter(torch.empty(self.out_features, 2*self.out_features))
        self.Uh = Parameter(torch.empty(self.in_features, self.out_features))
        self.Wh = Parameter(torch.empty(self.out_features, self.out_features))
        self.bias = Parameter(torch.empty(2*self.out_features))
        self.bh = Parameter(torch.empty(self.out_features))

        self.reset_parameters()

    def reset_parameters(self):
        init.normal_(self.U, mean=0, std=0.1)
        init.normal_(self.W, mean=0, std=0.1)
        init.normal_(self.Uh, mean=0, std=0.1)
        init.normal_(self.Wh, mean=0, std=0.1)

        fan_in, _ = init._calculate_fan_in_and_fan_out(self.U.t())
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)
        init.uniform_(self.bh, -bound, bound)


    def forward(self, input, hidden, reverse = False):
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.inner(input[i], hidden)
            output.append(hidden)

        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        return hidden, output

    def inner(self,  input, hidden):
        all_sum = torch.matmul(input, self.U) + torch.matmul(hidden, self.W) + self.bias
        z, r = torch.chunk(all_sum, 2, dim=1)
        z, r = torch.sigmoid(z), torch.sigmoid(r)
        h_hat = torch.tanh( torch.matmul(input, self.Uh) + torch.matmul(r*hidden, self.Wh)+ self.bh)
        hidden = (1-z)*hidden + z*h_hat
        return hidden


#### 2nd Order RNN

+ weight matrix is a 3D tensor
$$
  h_{t, i} = \sigma(\sum_{j, k} W_{ijk}x_{t, j}h_{t-1, k} + b_i)
$$

In [3]:
class O2RNN(nn.Module):
    def __init__(self, in_features, out_features):
        super(O2RNN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        weight = torch.empty(in_features, out_features, out_features)
        
        self.weight = Parameter(weight.view(in_features, -1))
        self.bias = Parameter(torch.empty(out_features))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.normal_(self.weight, mean=0, std=0.1)
        init.constant_(self.bias, 0.01)

    def inner(self, input, hidden=None):
        WX = F.linear(input, self.weight.transpose(1, 0)) 
        WX = WX.view(-1, self.out_features, self.out_features)  
        WHX = WX.bmm(hidden.unsqueeze(2)).squeeze(2)
        
        return torch.sigmoid(WHX + self.bias)

    def forward(self, input, hidden):
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.inner(input[i], hidden)
            output.append(hidden)
       
        output = torch.cat(output, 0).view(input.size(0), *output[0].size())
        return hidden, output

#### Multiplicative Integration RNNs
+ Integration of two information flows by hadamard product
+ RNN state update equation:
$$
  h_t = \phi(Ux_t + Wh_{t-1} + b)
$$

+ MI-RNN:

$$
\begin{aligned}
  h_t = \phi( \ & \alpha \odot Ux \odot Wh_{t-1} \\
                  &+ \beta_1 \odot Ux \\
                  &+ \beta_2 \odot Wh_{t-1} \\
                  &+ b \ )
\end{aligned}
$$

+ Similarly we can now define MI-LSTM and MI-GRU
+ https://arxiv.org/pdf/1606.06630.pdf


In [None]:
class MIRNN(nn.Module):
    def __init__(self, in_features, out_features):
        super(MIRNN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.U = Parameter(torch.empty(self.in_features, self.out_features))
        self.W = Parameter(torch.empty(self.out_features, self.out_features))

        self.alpha = Parameter(torch.empty(out_features))
        self.beta1 = Parameter(torch.empty(out_features))
        self.beta2 = Parameter(torch.empty(out_features))
        
        self.bias = Parameter(torch.empty(out_features))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.normal_(self.U, mean=0, std=0.1)
        init.normal_(self.W, mean=0, std=0.1)
        
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.U.t())
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)
        init.uniform_(self.alpha, -bound, bound)
        init.uniform_(self.beta1, -bound, bound)
        init.uniform_(self.beta2, -bound, bound)

    def inner(self, input, hidden):
        ux = torch.matmul(input, self.U)
        wh = torch.matmul(hidden, self.W)
        hidden = torch.sigmoid(self.alpha * ux * wh + self.beta1*ux + self.beta2*wh + self.bias)
        return hidden

    def forward(self, input, hidden):
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.inner(input[i], hidden)
            output.append(hidden)
       
        output = torch.cat(output, 0).view(input.size(0), *output[0].size())
        return hidden, output

### Global Attention

Attention mechanism is composed of two steps:
1. Calculating similarity scores for concerned tensors
  + Target States ($h^{(t)}$) : Decoder hidden states 
  + Source States ($h^{(s)}$) : Encoder hidden states
  + alignment scores $a = softmax(f(h^{(t)}, h^{(s)}))$

2. Constructing a *Context Vector* that has the information about this attention
  + Context Vector $ c = \sum_i a_ih^{(s)}_{i}$

**Dot Attention**
$$f(h^{(t)}_i, h^{(s)}_i) = h^{(t)T}_i h^{(s)}_i$$

**Concat Attention**
$$f(h^{(t)}_i, h^{(s)}_i) = v^TW[h^{(t)}_i ; h^{(s)}_i]$$

**General Attention**
$$f(h^{(t)}_i, h^{(s)}_i) = h^{(t)}_iWh^{(s)}_i$$

### Local Attention

$$
a_t(s) = f(h^{(t)}, h^{(s)})\ exp(-\frac{(s - p_t)^2}{2\sigma^2})
$$
**Further Reading**:
+ https://arxiv.org/pdf/1508.04025.pdf

#### Bahdanau Attention

$$ h_t = RNN([x_t; c_{t-1}], h_{t-1})$$

#### Luong Manning Attention

$$ \tilde{h}_t = \tanh(W[c_t ;h_t]) $$
$$P(y_t | y_{<t}, x) = softmax(W_h \tilde{h}_t) $$

In [7]:
class LuongManingDotAttention(nn.Module):
    def __init__(self):
        super(LuongManingDotAttention, self).__init__()

    def forward(self, encoder_hidden, enc_mask, decoder_hidden):
        """
        encoder_hidden : n_samples x encoder_seq_len x state_dim
        enc_mask : n_samples x encoder_seq_len , 1 for valid sequence and 0 for padding
        decoder_hidden : n_samples x decoder_sequence_len x state_dim 
        """
        a = torch.bmm(encoder_hidden, decoder_hidden.transpose(2, 1))
        enc_mask = ~enc_mask.unsqueeze(dim = -1).expand(*c.shape)
        enc_mask = enc_mask.masked_fill_(enc_mask, float("-inf")) 
        a = a + enc_mask
        a = torch.softmax(a, dim = 1)
        c = torch.bmm(a.transpose(2,1), encoder_hidden)
        return c