# Long short term memory

We previously explored RNNs, neural networks that are able to propagate some hidden state through a rolled out version of itself. A major problem with RNNs is exploding or vanishing gradients. Gradient clipping solves the exploding gradient problem, but the vanishing gradient problem is harder to solve. LSTMs propose a different architecture which benefits from a hidden state like RNNs, but mitigates the vanishing gradient problem. Another issue RNNs have is that the hidden state often forgets information from a while ago in the sequence and is more biased towards more recent tokens. LSTMs also address this issue with their gated structure. 


In [None]:
from torch import nn
import torch 
import torch.nn.functional as F

In [None]:
class lstm(nn.Module): 
    def __init__(self, input_size, hidden_dim) -> None:
        super().__init__()
        self.forget_gate = nn.Sequential(
            nn.Linear(input+hidden_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.input_gate = nn.Sequential(
            nn.Linear(input+hidden_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.input_node = nn.Sequential(
            nn.Linear(input+hidden_dim, hidden_dim),
            nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Linear(input+hidden_dim, hidden_dim),
            nn.Sigmoid()
        )

    def forward(self, x, h_in, c_in):
        x_h = torch.cat((x, h_in), 2)
        i_gate_output = self.input_gate(x_h)
        i_node_output = self.input_node(x_h)
        o_gate_output = self.output_gate(x_h)
        f_gate_output = self.forget_gate(x_h)

        c_out = (f_gate_output * c_in) + (i_node_output * i_gate_output)

        h_out = nn.Tanh(c_out) * o_gate_output

        return h_out, c_out
    
    def init_h(self):
        return torch.zeros(1, self.hidden_size)


In [None]:
class newsLSTM(nn.Module): 
    def __init__(self, vocab_size, embed_size, hidden_size) -> None:
        super(newsLSTM, self).__init__()
        
        self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.hidden_size = hidden_size 
        
        self.lstm = lstm(input_size=embed_size, hidden_size=hidden_size)
        
        self.hidden2label = nn.Linear(2*hidden_size, 4)
        self.softmax = nn.LogSoftmax(dim=1)
        self.dropoutLayer = nn.Dropout(p=0.5)

    def forward(self, x, x_len):
        embedded = self.encoder(x)
        x_packed = nn.utils.rnn.pack_padded_sequence(embedded, x_len, batch_first=True, enforce_sorted=False)
        h_t, c_t = self.lstm(x_packed)  # Pass the initial hidden state 'h' to the RNN
        
        
        hidden = self.dropoutLayer(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        
        # Linear layer and softmax
        label_space = self.hidden2label(hidden)
        
        return label_space

    def init_h(self):
        return torch.zeros(1, self.hidden_size)

In [None]:
train_iter = AG_NEWS(split='train')

# Convert to list to enable random splitting
train_dataset = list(train_iter)

#80-20 train-val split 
train_size = int(len(train_dataset) * 0.8)  
val_size = len(train_dataset) - train_size  
train_data, val_data = random_split(train_dataset, [train_size, val_size])

tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

VOCAB_SIZE = 5000

# Build vocab based on the train_data
train_data_iter = (text for _, text in train_data)
vocab = build_vocab_from_iterator(yield_tokens(train_data_iter), specials=["<unk>"], max_tokens=VOCAB_SIZE)
vocab.set_default_index(vocab["<unk>"])