# Appendix - LSTM + Attention from Scratch

In [9]:
import torchtext
import torch 
from torch import nn 
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print (device)

cpu


## LSTM

### 1.1 Vanilla LSTM from Scratch

In [10]:
class LSTM_cell(nn.Module):
    def __init__(self, input_dim:int, hidden_dim:int):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # initialise the trianable Parameters
        # for input gate
        self.U_i = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_i = nn.Parameter(torch.Tensor(hidden_dim))
        
        # for forget gate
        self.U_f = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_f = nn.Parameter(torch.Tensor(hidden_dim))
        
        # for input (tanh layer, updated state)
        self.U_g = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_g = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_g = nn.Parameter(torch.Tensor(hidden_dim))
        
        # for output gate
        self.U_o = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_o = nn.Parameter(torch.Tensor(hidden_dim))
        
        self.init_weights()
        
    def init_weights(self):
            
        # heuristic for weight initialization.
        stdv = 1.0 / math.sqrt(self.hidden_dim)
        for weight in self.parameters():
            # initialize its values from a uniform distribution with lower bound -stdv and upper bound stdv
            weight.data.uniform_(-stdv, stdv)
        '''
        This initialization method is known as "Xavier" or "Glorot" initialization, 
        and it helps in preventing the vanishing/exploding gradient problem during training. 
        It's commonly used for weights in neural networks, 
        especially in the context of tanh or sigmoid activation functions.
        '''    
        
    def forward(self, x, init_states = None):
        """
        x.shape = (bs, seq_len, input_dim)
        """
        bs, seq_len, _ = x.shape
        
        # to store the hidden list
        output = []
        
        # initialize the hidden state and cell state for the first time step
        if init_states is None:
            h_t = torch.zeros(bs, self.hidden_dim).to(x.device)
            c_t = torch.zeros(bs, self.hidden_dim).to(x.device)
        else:          
            h_t, c_t = init_states
            
        # for each time step of the input x, calculate LSTM gates (f_t, i_t, o_t, g_t) 
        # and updates the cell state (c_t) and hidden state (h_t) accordingly.
        for t in range(seq_len):
            x_t = x[:, t, :] #get x data of time step t (SHAPE: (batch_size, input_dim))
            
            f_t = torch.sigmoid(h_t @ self.W_f + x_t @ self.U_f + self.b_f)
            i_t = torch.sigmoid(h_t @ self.W_i + x_t @ self.U_i + self.b_i)
            o_t = torch.sigmoid(h_t @ self.W_o + x_t @ self.U_o + self.b_o)
            g_t = torch.sigmoid(h_t @ self.W_g + x_t @ self.U_g + self.b_g)
            c_t = (f_t * c_t) + (i_t * g_t)
            h_t = o_t * torch.tanh(c_t)
            
            output.append(h_t.unsqueeze(0))
            # reshape h_t to (1, batch_size, hidden_dim), then append to the list of hidden
        
        # The list is then concatenated to form a tensor of shape (seq_len, batch_size, hidden_dim).
        output = torch.cat(output, dim = 0) # concatenate h_t of all time steps into SHAPE: (seq_len, batch_size, hidden_dim)
        output = output.transpose(0,1).contiguous() # just transpose to SHAPE: (batch_size, seq_len, hidden_dim)
        
        return output, (h_t, c_t)
        # return ouput (hidden state tensor), final hidden state and final cell state
        
        # This forward pass is one step of an LSTM layer

In [11]:
input_dim  = 5000
hidden_dim = 256
embed_dim  = 300
output_dim = 1

batch_size = 32

In [None]:
# Checking LSTM cell can run

my_LSTM_cell = LSTM_cell(embed_dim, hidden_dim).to(device)
output, (h_t, c_t) = my_LSTM_cell(test data)

assert output.shape == torch.Size([32, 100, 256])
assert h_t.shape == torch.Size([32, 256])
assert c_t.shape == torch.Size ([32, 256])

### Vanilla / Peephole / Coupled LSTM from Scratch

In [16]:
class new_LSTM_cell(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, lstm_type: str):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.lstm_type = lstm_type
        
        # initialise the trainable Parameters
        self.U_i = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_i = nn.Parameter(torch.Tensor(hidden_dim))
        
        self.U_f = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_f = nn.Parameter(torch.Tensor(hidden_dim))
        
        self.U_g = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_g = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_g = nn.Parameter(torch.Tensor(hidden_dim))
        
        self.U_o = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_o = nn.Parameter(torch.Tensor(hidden_dim))
        
        if self.lstm_type == 'peephole' :
            self.P_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
            self.P_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
            self.P_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
            
        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_dim)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
    
    def forward(self, x, init_states=None):
        bs, seq_len, _ = x.shape
        output = []
        
        # initialize the hidden state and cell state for the first time step 
        if init_states is None:
            h_t  = torch.zeros(bs, self.hidden_dim).to(x.device)
            c_t  = torch.zeros(bs, self.hidden_dim).to(x.device)
        else:
            h_t, c_t = init_states
        
        # For each time step of the input x, do ...
        for t in range(seq_len):
            x_t = x[:, t, :] # get x data of time step t (SHAPE: (batch_size, input_dim))
            
            if self.lstm_type in ['vanilla', 'coupled'] :
                f_t = torch.sigmoid(    h_t @ self.W_f  +  x_t @ self.U_f  +  self.b_f)
                o_t = torch.sigmoid(    h_t @ self.W_o  +  x_t @ self.U_o  +  self.b_o)
                if self.lstm_type == 'vanilla':
                    i_t = torch.sigmoid(    h_t @ self.W_i  +  x_t @ self.U_i  +  self.b_i)
                if self.lstm_type == 'coupled':
                    i_t = (1 - f_t)
            if self.lstm_type == 'peephole' :
                i_t = torch.sigmoid( h_t @ self.W_i + x_t @ self.U_i + c_t @ self.P_i + self.b_i) # SHAPE: (batch_size, hidden_dim)
                f_t = torch.sigmoid( h_t @ self.W_f + x_t @ self.U_f + c_t @ self.P_f + self.b_f) # SHAPE: (batch_size, hidden_dim)
                o_t = torch.sigmoid( h_t @ self.W_o + x_t @ self.U_o + c_t @ self.P_o + self.b_o) # SHAPE: (batch_size, hidden_dim)
            
            g_t = torch.tanh(       h_t @ self.W_g  +  x_t @ self.U_g   + self.b_g)
            c_t = (f_t * c_t) + (i_t * g_t)
            h_t = o_t * torch.tanh(c_t)
            
            output.append(h_t.unsqueeze(0)) # reshape h_t to (1, batch_size, hidden_dim), then append to the list of hidden states

        output = torch.cat(output, dim = 0) # concatenate h_t of all time steps into SHAPE :(seq_len, batch_size, hidden_dim)
        output = output.transpose(0, 1).contiguous() # just transpose to SHAPE :(seq_len, batch_size, hidden_dim)
        return output, (h_t, c_t)

### biLSTM from Scratch

#### biLSTM with vanilla

In [17]:
class BiLSTM_model(nn.Module):
    def __init__(self, input_dim:int, embed_dim:int, hidden_dim:int, output_dim:int):
        super().__init__()
        self.num_directions = 2
        
        self.embedding      = nn.Embedding(input_dim, embed_dim, padding_idx = pad_idx)
        self.hidden_dim     = hidden_dim
        
        # Define forward and backward LSTM cells
        self.forward_lstm   = new_LSTM_cell(embed_dim, hidden_dim, lstm_type = 'vanilla')
        self.backward_lstm  = new_LSTM_cell(embed_dim, hidden_dim, lstm_type = 'vanilla')
        
        # Learnable weights for combining the forward and backward hidden states
        self.W_h = nn.Parameter(torch.Tensor(hidden_dim * self.num_directions, hidden_dim * self.num_directions))
        self.b_h = nn.Parameter(torch.Tensor(hidden_dim * self.num_directions))
        
        # Fully connected layer for classification
        self.fc  = nn.Linear(hidden_dim * self.num_directions, output_dim)
        
        self.init_weights()
    
    def init_weights(self):
    
        stdv = 1.0 / math.sqrt(self.hidden_dim)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
        
    def forward(self, text, text_length):
        
        # Embedding Layer
        embedded = self.embedding(text)
        embedded_flip = torch.flip(embedded, [1])
        
        # Forward and backward LSTM pass
        output_forward, (hn_forward, cn_forward)    = self.forward_lstm (embedded, init_states = None)
        output_backward, (hn_backward, cn_backward) = self.backward_lstm(embedded_flip, init_states = None)
        
        # Concatenate the hidden states from the forward and backward LSTMs
        # concat the hidden state at the last word of the sentence and hidden state at the first sentence
        # hidden state at the last word = hs from forward lstm 
        # hidden state at the first word = hs from the backward lstm
        concat_hn = torch.cat((hn_forward, hn_backward), dim = 1)
        
        # Apply a linear transformation to combine the hidden states
        ht = torch.sigmoid(concat_hn @ self.W_h + self.b_n)
        
        return self.fc(ht) # Pass ht to another linear layer to get the output for binary classification

## Attention

### LSTM + General Attention

In [19]:
import torch.nn as nn
from torch.nn import functional as F

In [None]:
class LSTM_GAtt(nn.Module):
    
    def __init__(self, input_dim:int, embed_dim:int, hidden_dim:int, output_dim:int):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, embed_dim, padding_idx=pad_idx)
        
        # Use pytorch LSTM
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers = num_layers, bidirectional = bidirectional, dropout = dropout, batch_first = True )
        
        # linear layer for binary classification
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
    def attention_net(self, lstm_output, hn):
        h_t      = hn.unsqueeze(2)
        H_keys   = torch.clone(lstm_output)
        H_values = torch.clone(lstm_output)
        
        alignment_score = torch.bmm(H_keys, h_t).squeeze(2) # shape : (bs, seq_len, 1)
        
        soft_attn_weights = F.softmax(alignment_score, 1) # shape : (bs, seq_len, 1)
        
        context = torch.bmm(H_values.transpose(1,2), soft_attn_weights.unsqueeze(2)).squeeze(2) # shape : (bs, hid_dim)
        
        return context
        
    def forward(self, text, text_lengths):
        
        embedded = self.embedding(text) # shape (bs, seq_len, emb_dim)
        
        lstm_output, (hn, cn) = self.lstm(embedded)
        
        # this is how we concatenate the forward hidden and backward hidden from Pytorch's BiLSTM
        hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim = 1)
        
        attn_output = self.attention_net(lstm_output, hn)
        
        return self.fc(attn_output)
        

### LSTM + Self Attention

In [None]:
# This attention mask will be apply after Q @ K^T thus the shape will be batch, seq_len, seq_len
def get_pad_mask(text):  #[batch, seq_len]
    batch_size, seq_len = text.size()
    # eq(zero) is lstm output over PAD token
    pad_mask = text.data.eq(0).unsqueeze(1)  # torch.eq Computes element-wise equality # batch_size x 1 x seq_len; we unsqueeze so we can make expansion below
    return pad_mask.expand(batch_size, seq_len, seq_len)  # batch_size x seq_len x seq_len

class LSTM_SelfAtt(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, output_dim, len_reduction):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embed_dim, padding_idx=pad_idx)
        
        # let's use pytorch's LSTM
        self.lstm = nn.LSTM(embed_dim, 
                           hidden_dim, 
                           num_layers=num_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout,
                           batch_first=True)
        
        # Long Softmax Layer for Classification
        self.softmax       = nn.LogSoftmax(dim=1)

        # initialize three linear layers for Q, K, V
        self.lin_Q = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.lin_K = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.lin_V = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        
        self.len_reduction = len_reduction
        
        # Linear Layer for binary classification 
        self.fc    = nn.Linear(hidden_dim * 2, output_dim)
        
    def self_attention_net(self, lstm_output, pad_mask):
        
        # create three copies of output (H)
        Q = self.lin_Q(torch.clone(lstm_output)) # SHAPE : (bs, seq_len, n_hidden * num_directions)
        K = self.lin_K(torch.clone(lstm_output)) # SHAPE : (bs, seq_len, n_hidden * num_directions)
        V = self.lin_V(torch.clone(lstm_output)) # SHAPE : (bs, seq_len, n_hidden * num_directions)
        
        # attention score        
        alignment_score = torch.matmul(Q, K.transpose(1, 2)) # SHAPE : (bs, seq_len, seq_len)
                
        # Apply padding mask 
        if self.mask:
            alignment_score.masked_fill_(pad_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        
        # Softmax to get attention weights
        soft_attn_weights = self.softmax(alignment_score)
        
        # Weighted sum to get context
        context = torch.matmul(soft_attn_weights, V) # SHAPE : (bs, seq_len, hidden_dim * num_directions)
        
        # Length reduction options: mean, sum, last
        if self.len_reduction == "mean":
            return torch.mean(context, dim=1)
        elif self.len_reduction == "sum":
            return torch.sum(context, dim=1)
        elif self.len_reduction == "last":
            return context[:, -1, :]
        
    def forward(self, text, text_lengths, mask=True):
        self.mask = mask
        pad_mask = get_pad_mask(text)
        
        embedded = self.embedding(text) # SHAPE : (batch_size, seq_len, embed_dim)

        lstm_output, (hn, cn) = self.lstm(embedded)
        
        # Concatenating the forward and backward hidden states
        # This is how we concatenate the forward hidden and backward hidden from Pytorch's BiLSTM
        hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim = 1)
        
        # Self-attention
        attn_output = self.self_attention_net(lstm_output, pad_mask)
        
        return self.fc(attn_output) # Classification using a linear layer