# Building and Improving on LSTMs
Lets build a long short term memory unit from scratch! For a clear and concise tutorial of LSTMs that heavily influenced this work, see [this writeup](http://colah.github.io/posts/2015-08-Understanding-LSTMs/). The inner workings of LSTMs will also be explained throughout this notebook.

Once we implement an a simple recurrent neural network (RNN) that makes use of LSTMs and test its performance, we will implement improvements based on the paper below titled "Mogrifier LSTM" and compare our results.


## References
[PyTorch LSTM Tutorial](https://mlexplained.com/2019/02/15/building-an-lstm-from-scratch-in-pytorch-lstms-in-depth-part-1/)

[Understanding LSTMs writeup](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)

[Video lesson on LSTMs from Andrew Ng](https://www.coursera.org/lecture/nlp-sequence-models/long-short-term-memory-lstm-KXoay)

[Mogrifier LSTM](https://arxiv.org/abs/1909.01792)

[LSTM Performance](https://github.com/sebastianruder/NLP-progress/blob/master/english/language_modeling.md)

# RNNs
In recurrent neural networks (RNNs), we generally want to maintain some memory of sequential data in order to make predictions. 
>![Image from wikipedia](https://upload.wikimedia.org/wikipedia/commons/b/b5/Recurrent_neural_network_unfold.svg)
>*By François Deloche - Own work, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=60109157*

In the above diagram, the main thing to keep in mind is how memory is maintained. Inputs $x_t$ feed into hidden states $h_t$. Each cell (one set of red blue and green shapes) in the RNN (1) receives information from some part in a sequence of data, (2) reads it into the hidden state, and (3) transforms that information and passes it on to the next hidden state in the chain. In this manner, past data is fed through the entire network in case it might be useful to some cell later in the chain.

The problem with this approach to RNNs is that we generally run into the problem of vanishing gradients. What this means is that our data from past hidden states eventually gets multiplied to be so small as to be neglible. Our ability to remember things from even a few states ago is not so great.

# Enter LSTMs
To solve the problem of vanishing gradients, the idea of the long short term memory (LSTM) cell was created. This cell replaces the RNN cell referred to above in order to alleviate some issues with RNNs and acheive better performance on sequential data.

## But how do LSTMs work?

Essentially what LSTMs do is expand on the basic RNN cell by adding a parallel branch that tracks data that happened further in the past. In this way, we can avoid the vanishing gradient issue of basic RNN cells.
The below diagram shows the flow of information in an LSTM cell.

>![image](https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/The_LSTM_cell.png/1920px-The_LSTM_cell.png)
>By Guillaume Chevalier - Own work, CC BY 4.0, https://commons.wikimedia.org/w/index.php?curid=71836793



The equations for the LSTM cell look like this (taken from [here](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)):

$\begin{array}{ll} \\
  f_t = \sigma(W_{f} [x_t,h_{t-1}] + b_{f}) \\
  i_t = \sigma(W_{i} [x_t,h_{t-1}] + b_{i}) \\
  \tilde{C_t} = \tanh(W_{ig} [x_t,h_{t-1}] + b_{ig} + W_{hg} h_{(t-1)} + b_{hg}) \\
  o_t = \sigma(W_{o} [x_t,h_{t-1}] + b_{o}) \\
  C_t = f_t * C_{t-1} + i_t * \tilde{C_t} \\
  h_t = o_t * \tanh(C_t) \\
\end{array}$

*Note that $[x_t,h_{t-1}]$ refers to the concatenation of the $x_t$ and $h_{t-1}$ matrices. Some variations of LSTM implementations split these instead. Depending on the implementation, this may lead to some lost information.*

$f_t$ is the function for the forget gate.

$i_t$ is the function for the input gate.

$\tilde{C_t}$ is the function for the candidate memory cell update.

$o_t$ is an intermediate calculation for determining the hidden state $h_t$.

$C_t$ is the function for the updated memory cell.

$h_t$ is the function for the update hidden state.

**While initially daunting**, what is happening here is not too complicated. Essentially, we have two gates, one for determining whether we will forget ($f_t$) some part of our memory cell $C$ and an input/update gate for determining whether we will add new information into our memory cell ($i_t$). Finally, we update our hidden state $h$ based on whatever information is currently in the memory cell.  

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.optim as optim

from typing import *
from pathlib import Path
DATA_ROOT = Path("../data/brown")
N_EPOCHS = 210
from enum import IntEnum
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# Implementing the LSTM

Given what we've learned, lets implement our own LSTM in PyTorch!

In [None]:
#only run this cell if running the notebook from Google Colaboratory
!pip install allennlp==0.8.0
%load_ext tensorboard

In [2]:
class NaiveLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        #Define/initialize all tensors   
        # forget gate
        self.Wf = Parameter(torch.Tensor(input_sz+hidden_sz, hidden_sz))
        self.bf = Parameter(torch.Tensor(hidden_sz))
        # input gate
        self.Wi = Parameter(torch.Tensor(input_sz+hidden_sz, hidden_sz))
        self.bi = Parameter(torch.Tensor(hidden_sz))
        # Candidate memory cell
        self.Wc = Parameter(torch.Tensor(input_sz+hidden_sz, hidden_sz))
        self.bc = Parameter(torch.Tensor(hidden_sz))
        # output gate
        self.Wo = Parameter(torch.Tensor(input_sz+hidden_sz, hidden_sz))
        self.bo = Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)
        
    #Define forward pass through all LSTM cells across all timesteps.
    #By using PyTorch functions, we get backpropagation for free.
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor, torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        batch_sz, seq_sz, _ = x.size()
        hidden_seq = []
        #ht and Ct start as the previous states and end as the output states in each loop bellow
        if init_states is None:
            ht = torch.zeros((batch_sz,self.hidden_size)).to(x.device)
            Ct = torch.zeros((batch_sz,self.hidden_size)).to(x.device)
        else:
            ht, Ct = init_states
        for t in range(seq_sz): # iterate over the time steps
            xt = x[:, t, :]
            hx_concat = torch.cat((ht,xt),dim=1)

            ### The LSTM Cell!
            ft = torch.sigmoid(hx_concat @ self.Wf + self.bf)
            it = torch.sigmoid(hx_concat @ self.Wi + self.bi)
            Ct_candidate = torch.tanh(hx_concat @ self.Wc + self.bc)
            ot = torch.sigmoid(hx_concat @ self.Wo + self.bo)
            #outputs
            Ct = ft * Ct + it * Ct_candidate
            ht = ot * torch.tanh(Ct)
            ###

            hidden_seq.append(ht.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (ht, Ct)

#sanity testing
#note that our hidden_sz is also our defined output size for each LSTM cell.
batch_sz, seq_len, feat_sz, hidden_sz = 5, 10, 32, 16
arr = torch.randn(batch_sz, seq_len, feat_sz)
lstm = NaiveLSTM(feat_sz, hidden_sz)
ht, (hn, cn) = lstm(arr)
ht.shape #shape should be batch_sz x seq_len x hidden_sz = 5x10x16

torch.Size([5, 10, 16])

# Testing the Implementation
Now, that we've covered the basics and have a minimally working LSTM, we'll put our model into action. Our testbed will be a character-level language modeling task. We'll be using the Brown Corpus which you can get via the commands below.

>More information on the Brown corpus can be found [here](https://en.wikipedia.org/wiki/Brown_Corpus).

>"The Brown University Standard Corpus of Present-Day American English (or just Brown Corpus) was compiled in the 1960s by Henry Kučera and W. Nelson Francis at Brown University, Providence, Rhode Island as a general corpus (text collection) in the field of corpus linguistics. It contains 500 samples of English-language text, totaling roughly one million words, compiled from works published in the United States in 1961."

In [3]:
!mkdir -p {DATA_ROOT}
!curl http://www.sls.hawaii.edu/bley-vroman/brown.txt -o {DATA_ROOT / "brown.txt"}

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 6040k  100 6040k    0     0  5687k      0  0:00:01  0:00:01 --:--:-- 5687k


We'll let AllenNLP--an NLP library made to simplify training in PyTorch--handle the complexity of training the language model and building up the datasets. What's happening below is we are tokenizing the characters in the dataset and then splitting the data into training and validation sets.

In [3]:
from allennlp.data.dataset_readers import LanguageModelingReader
from allennlp.data.tokenizers import CharacterTokenizer
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data import Vocabulary
from allennlp.data.iterators import BasicIterator
from allennlp.training import Trainer
from sklearn.model_selection import train_test_split

char_tokenizer = CharacterTokenizer(lowercase_characters=True)

reader = LanguageModelingReader(
    tokens_per_instance=500,
    tokenizer=char_tokenizer,
    token_indexers = {"tokens": SingleIdTokenIndexer()},
)

train_ds = reader.read(DATA_ROOT / "brown.txt")
train_ds, val_ds = train_test_split(train_ds, random_state=0, test_size=0.1)

vocab = Vocabulary.from_instances(train_ds)

iterator = BasicIterator(batch_size=32)
iterator.index_with(vocab)

0it [00:00, ?it/s]
  0%|          | 0/11994 [00:00<?, ?it/s][A
100%|██████████| 11994/11994 [00:00<00:00, 55196.20it/s][A
11994it [00:07, 1706.61it/s]
100%|██████████| 10794/10794 [00:03<00:00, 3145.77it/s]


In [4]:
def train(model: nn.Module, epochs: int,log_dir):
    trainer = Trainer( patience=7,
        histogram_interval=10,
        summary_interval= 10,
        serialization_dir=log_dir,
        model=model.cuda() if torch.cuda.is_available() else model,
        optimizer=optim.Adam(model.parameters()),
        iterator=iterator, train_dataset=train_ds, 
        validation_dataset=val_ds, num_epochs=epochs,
        cuda_device=0 if torch.cuda.is_available() else -1
    )
    return trainer

We build our NLP neural network below using 3 layers:

1.   An embedding layer
2.   An encoding layer (a set of our LSTM cells based on the sequence size)
3.   A projection layer (to convert the text back out)



In [5]:
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.models import Model
from allennlp.nn.util import get_text_field_mask

class LanguageModel(Model):
    def __init__(self, encoder: nn.RNN, vocab: Vocabulary,
                 embedding_dim: int=50):
        super().__init__(vocab=vocab)
        # char embedding
        self.vocab_size = vocab.get_vocab_size()
        self.padding_idx = vocab.get_token_index("@@PADDING@@")
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size(),
            embedding_dim=embedding_dim,
            padding_index=self.padding_idx,
        )
        self.embedding = BasicTextFieldEmbedder({"tokens": token_embedding})
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.hidden_size, self.vocab_size)
        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
    
    def forward(self, input_tokens: Dict[str, torch.Tensor],
                output_tokens: Dict[str, torch.Tensor]):
        embs = self.embedding(input_tokens)
        x, _ = self.encoder(embs)
        x = self.projection(x)
        if output_tokens is not None:
            loss = self.loss(x.view((-1, self.vocab_size)), output_tokens["tokens"].flatten())
        else:
            loss = None
        return {"loss": loss, "logits": x}

Now, let's try training. If you only want to verify this works, change N_EPOCHS to some small number (e.g. 1).

In [6]:
lm_naive = LanguageModel(NaiveLSTM(50, 125), vocab)
LSTM_trainer = train(lm_naive,N_EPOCHS,"./run/lstm")
LSTM_trainer.train()

loss: 2.7291 ||: 100%|██████████| 338/338 [02:12<00:00,  2.54it/s]
loss: 2.3597 ||: 100%|██████████| 38/38 [00:05<00:00,  7.49it/s]
loss: 2.2287 ||: 100%|██████████| 338/338 [02:05<00:00,  2.70it/s]
loss: 2.1205 ||: 100%|██████████| 38/38 [00:04<00:00,  8.63it/s]
loss: 2.0543 ||: 100%|██████████| 338/338 [02:05<00:00,  2.70it/s]
loss: 1.9936 ||: 100%|██████████| 38/38 [00:04<00:00,  8.63it/s]
loss: 1.9513 ||: 100%|██████████| 338/338 [02:05<00:00,  2.70it/s]
loss: 1.9082 ||: 100%|██████████| 38/38 [00:04<00:00,  8.59it/s]
loss: 1.8780 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.8450 ||: 100%|██████████| 38/38 [00:04<00:00,  8.74it/s]
loss: 1.8224 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.7966 ||: 100%|██████████| 38/38 [00:04<00:00,  8.65it/s]
loss: 1.7784 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.7575 ||: 100%|██████████| 38/38 [00:04<00:00,  8.71it/s]
loss: 1.7425 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.7

loss: 1.4335 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4438 ||: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]
loss: 1.4327 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4434 ||: 100%|██████████| 38/38 [00:04<00:00,  8.64it/s]
loss: 1.4317 ||: 100%|██████████| 338/338 [02:05<00:00,  2.70it/s]
loss: 1.4434 ||: 100%|██████████| 38/38 [00:04<00:00,  8.73it/s]
loss: 1.4306 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4414 ||: 100%|██████████| 38/38 [00:04<00:00,  8.67it/s]
loss: 1.4296 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4414 ||: 100%|██████████| 38/38 [00:04<00:00,  8.71it/s]
loss: 1.4287 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4402 ||: 100%|██████████| 38/38 [00:04<00:00,  8.64it/s]
loss: 1.4280 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4389 ||: 100%|██████████| 38/38 [00:04<00:00,  8.71it/s]
loss: 1.4271 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4

loss: 1.3973 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4143 ||: 100%|██████████| 38/38 [00:04<00:00,  8.74it/s]
loss: 1.3969 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4141 ||: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]
loss: 1.3964 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4139 ||: 100%|██████████| 38/38 [00:04<00:00,  8.73it/s]
loss: 1.3964 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4133 ||: 100%|██████████| 38/38 [00:04<00:00,  8.64it/s]
loss: 1.3958 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4134 ||: 100%|██████████| 38/38 [00:04<00:00,  8.67it/s]
loss: 1.3955 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4132 ||: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]
loss: 1.3952 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4131 ||: 100%|██████████| 38/38 [00:04<00:00,  8.69it/s]
loss: 1.3950 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4

loss: 1.3820 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4041 ||: 100%|██████████| 38/38 [00:04<00:00,  8.73it/s]
loss: 1.3818 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4034 ||: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]
loss: 1.3817 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4034 ||: 100%|██████████| 38/38 [00:04<00:00,  8.75it/s]
loss: 1.3814 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4033 ||: 100%|██████████| 38/38 [00:04<00:00,  8.70it/s]
loss: 1.3812 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4025 ||: 100%|██████████| 38/38 [00:04<00:00,  8.72it/s]
loss: 1.3811 ||: 100%|██████████| 338/338 [02:04<00:00,  2.72it/s]
loss: 1.4028 ||: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]
loss: 1.3810 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4026 ||: 100%|██████████| 38/38 [00:04<00:00,  8.72it/s]
loss: 1.3806 ||: 100%|██████████| 338/338 [02:04<00:00,  2.71it/s]
loss: 1.4

{'best_epoch': 207,
 'peak_cpu_memory_MB': 2750.236,
 'peak_gpu_0_memory_MB': 970,
 'training_duration': '7:31:32.654401',
 'training_start_epoch': 0,
 'training_epochs': 209,
 'epoch': 209,
 'training_loss': 1.3781935724280996,
 'training_cpu_memory_MB': 2750.236,
 'training_gpu_0_memory_MB': 952,
 'validation_loss': 1.401858257619958,
 'best_validation_loss': 1.4012448003417568}

Now, let's compare with the official LSTM. We'll do this one until it looks like the loss is no longer decreasing so we can see what kind of accuracy this model is capable of.

In [7]:
lm_comparison = LanguageModel(nn.LSTM(50, 125, batch_first=True), vocab)
official_LSTM = train(lm_comparison, N_EPOCHS,"./run/officiallstm")
official_LSTM.train()

loss: 2.6807 ||: 100%|██████████| 338/338 [00:13<00:00, 25.51it/s]
loss: 2.3096 ||: 100%|██████████| 38/38 [00:00<00:00, 51.11it/s]
loss: 2.1862 ||: 100%|██████████| 338/338 [00:13<00:00, 25.53it/s]
loss: 2.0887 ||: 100%|██████████| 38/38 [00:00<00:00, 51.24it/s]
loss: 2.0264 ||: 100%|██████████| 338/338 [00:13<00:00, 25.56it/s]
loss: 1.9687 ||: 100%|██████████| 38/38 [00:00<00:00, 51.09it/s]
loss: 1.9279 ||: 100%|██████████| 338/338 [00:13<00:00, 24.90it/s]
loss: 1.8885 ||: 100%|██████████| 38/38 [00:00<00:00, 51.02it/s]
loss: 1.8581 ||: 100%|██████████| 338/338 [00:13<00:00, 25.45it/s]
loss: 1.8262 ||: 100%|██████████| 38/38 [00:00<00:00, 50.78it/s]
loss: 1.8030 ||: 100%|██████████| 338/338 [00:13<00:00, 25.32it/s]
loss: 1.7782 ||: 100%|██████████| 38/38 [00:00<00:00, 50.99it/s]
loss: 1.7596 ||: 100%|██████████| 338/338 [00:13<00:00, 25.14it/s]
loss: 1.7394 ||: 100%|██████████| 38/38 [00:00<00:00, 51.49it/s]
loss: 1.7240 ||: 100%|██████████| 338/338 [00:13<00:00, 25.43it/s]
loss: 1.7

loss: 1.4306 ||: 100%|██████████| 338/338 [00:13<00:00, 25.02it/s]
loss: 1.4393 ||: 100%|██████████| 38/38 [00:00<00:00, 51.41it/s]
loss: 1.4295 ||: 100%|██████████| 338/338 [00:13<00:00, 25.47it/s]
loss: 1.4385 ||: 100%|██████████| 38/38 [00:00<00:00, 51.20it/s]
loss: 1.4286 ||: 100%|██████████| 338/338 [00:13<00:00, 25.48it/s]
loss: 1.4378 ||: 100%|██████████| 38/38 [00:00<00:00, 51.27it/s]
loss: 1.4277 ||: 100%|██████████| 338/338 [00:13<00:00, 25.59it/s]
loss: 1.4377 ||: 100%|██████████| 38/38 [00:00<00:00, 51.18it/s]
loss: 1.4269 ||: 100%|██████████| 338/338 [00:13<00:00, 25.36it/s]
loss: 1.4359 ||: 100%|██████████| 38/38 [00:00<00:00, 51.48it/s]
loss: 1.4260 ||: 100%|██████████| 338/338 [00:13<00:00, 25.41it/s]
loss: 1.4358 ||: 100%|██████████| 38/38 [00:00<00:00, 51.17it/s]
loss: 1.4251 ||: 100%|██████████| 338/338 [00:13<00:00, 25.17it/s]
loss: 1.4348 ||: 100%|██████████| 38/38 [00:00<00:00, 51.25it/s]
loss: 1.4244 ||: 100%|██████████| 338/338 [00:13<00:00, 25.09it/s]
loss: 1.4

loss: 1.3963 ||: 100%|██████████| 338/338 [00:13<00:00, 25.24it/s]
loss: 1.4119 ||: 100%|██████████| 38/38 [00:00<00:00, 50.97it/s]
loss: 1.3960 ||: 100%|██████████| 338/338 [00:13<00:00, 25.01it/s]
loss: 1.4117 ||: 100%|██████████| 38/38 [00:00<00:00, 51.30it/s]
loss: 1.3958 ||: 100%|██████████| 338/338 [00:13<00:00, 25.58it/s]
loss: 1.4124 ||: 100%|██████████| 38/38 [00:00<00:00, 51.27it/s]
loss: 1.3951 ||: 100%|██████████| 338/338 [00:13<00:00, 25.51it/s]
loss: 1.4110 ||: 100%|██████████| 38/38 [00:00<00:00, 51.20it/s]
loss: 1.3949 ||: 100%|██████████| 338/338 [00:13<00:00, 25.49it/s]
loss: 1.4121 ||: 100%|██████████| 38/38 [00:00<00:00, 51.13it/s]
loss: 1.3946 ||: 100%|██████████| 338/338 [00:13<00:00, 25.40it/s]
loss: 1.4115 ||: 100%|██████████| 38/38 [00:00<00:00, 51.20it/s]
loss: 1.3945 ||: 100%|██████████| 338/338 [00:13<00:00, 25.49it/s]
loss: 1.4107 ||: 100%|██████████| 38/38 [00:00<00:00, 51.48it/s]
loss: 1.3941 ||: 100%|██████████| 338/338 [00:13<00:00, 25.17it/s]
loss: 1.4

loss: 1.3818 ||: 100%|██████████| 338/338 [00:13<00:00, 25.44it/s]
loss: 1.4023 ||: 100%|██████████| 38/38 [00:00<00:00, 51.80it/s]
loss: 1.3817 ||: 100%|██████████| 338/338 [00:13<00:00, 25.12it/s]
loss: 1.4021 ||: 100%|██████████| 38/38 [00:00<00:00, 51.29it/s]
loss: 1.3813 ||: 100%|██████████| 338/338 [00:13<00:00, 25.08it/s]
loss: 1.4013 ||: 100%|██████████| 38/38 [00:00<00:00, 51.20it/s]
loss: 1.3811 ||: 100%|██████████| 338/338 [00:13<00:00, 25.49it/s]
loss: 1.4014 ||: 100%|██████████| 38/38 [00:00<00:00, 51.20it/s]
loss: 1.3809 ||: 100%|██████████| 338/338 [00:13<00:00, 25.50it/s]
loss: 1.4020 ||: 100%|██████████| 38/38 [00:00<00:00, 51.26it/s]
loss: 1.3809 ||: 100%|██████████| 338/338 [00:13<00:00, 25.46it/s]
loss: 1.4018 ||: 100%|██████████| 38/38 [00:00<00:00, 51.34it/s]
loss: 1.3805 ||: 100%|██████████| 338/338 [00:13<00:00, 25.34it/s]
loss: 1.4017 ||: 100%|██████████| 38/38 [00:00<00:00, 51.06it/s]
loss: 1.3805 ||: 100%|██████████| 338/338 [00:13<00:00, 25.46it/s]
loss: 1.4

{'best_epoch': 197,
 'peak_cpu_memory_MB': 2751.964,
 'peak_gpu_0_memory_MB': 1017,
 'training_duration': '0:48:11.254805',
 'training_start_epoch': 0,
 'training_epochs': 203,
 'epoch': 203,
 'training_loss': 1.3790437656746815,
 'training_cpu_memory_MB': 2751.964,
 'training_gpu_0_memory_MB': 1014,
 'validation_loss': 1.4007186230860258,
 'best_validation_loss': 1.4002491518070823}

It looks like our basic implementation without any optimizations is a little over 100x slower than the built in one (185 seconds vs 15 seconds in my testing) while the number of iterations it takes to acheive similar accuracy is about the same. We could investigate making further improvements to ours (more efficient batches, fewer calculations, etc.), but lets leave it for now.

# Mogrifier LSTM!

Lets implement a version of the Mogrifier LSTM based on [this paper](https://arxiv.org/abs/1909.01792).

>![mogrifier](https://drive.google.com/uc?id=1yMgPjXW_SV29Y-uuJsKNdknhm5CylWAu)

Essentially what this paper does is provide another gate prior to the input into each LSTM cell that is entirely based on the interaction between the hidden state and the input. Read the paper if you'd like the intuition behind it.

Lets look at the equations we need to implement:

$x^i = 2\sigma(Q^ih^{i-1}_{prev}) * x^{i-2}$ for odd $i \in [1...r]$

$h^i_{prev} = 2\sigma(R^ix^{i-1}) * h^{i-2}_{prev}$ for even $i \in [1...r]$

So all we have to do is add some randomly initialized weights $Q$ and $R$ that will gate the inputs $x_t$ and $h_{t-1}$ in alternating fashion. Lets see how it compares!

In [8]:
class MogLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int, mog_iterations: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        self.mog_iterations = mog_iterations
        #Define/initialize all tensors   
        self.Wih = Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.Whh = Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bih = Parameter(torch.Tensor(hidden_sz * 4))
        self.bhh = Parameter(torch.Tensor(hidden_sz * 4))
        #Mogrifiers
        self.Q = Parameter(torch.Tensor(hidden_sz,input_sz))
        self.R = Parameter(torch.Tensor(input_sz,hidden_sz))

        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)

    def mogrify(self,xt,ht):
      for i in range(1,self.mog_iterations+1):
        if (i % 2 == 0):
          ht = (2*torch.sigmoid(xt @ self.R)) * ht
        else:
          xt = (2*torch.sigmoid(ht @ self.Q)) * xt
      return xt, ht

    
    #Define forward pass through all LSTM cells across all timesteps.
    #By using PyTorch functions, we get backpropagation for free.
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor, torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        batch_sz, seq_sz, _ = x.size()
        hidden_seq = []
        #ht and Ct start as the previous states and end as the output states in each loop below
        if init_states is None:
            ht = torch.zeros((batch_sz,self.hidden_size)).to(x.device)
            Ct = torch.zeros((batch_sz,self.hidden_size)).to(x.device)
        else:
            ht, Ct = init_states
        for t in range(seq_sz): # iterate over the time steps
            xt = x[:, t, :]
            xt, ht = self.mogrify(xt,ht) #mogrification
            gates = (xt @ self.Wih + self.bih) + (ht @ self.Whh + self.bhh)
            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

            ### The LSTM Cell!
            ft = torch.sigmoid(forgetgate)
            it = torch.sigmoid(ingate)
            Ct_candidate = torch.tanh(cellgate)
            ot = torch.sigmoid(outgate)
            #outputs
            Ct = (ft * Ct) + (it * Ct_candidate)
            ht = ot * torch.tanh(Ct)
            ###

            hidden_seq.append(ht.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (ht, Ct)

#sanity testing
#note that our hidden_sz is also our defined output size for each LSTM cell.
batch_sz, seq_len, feat_sz, hidden_sz = 5, 10, 32, 16
arr = torch.randn(batch_sz, seq_len, feat_sz)
lstm = NaiveLSTM(feat_sz, hidden_sz)
ht, (hn, cn) = lstm(arr)
ht.shape #shape should be batch_sz x seq_len x hidden_sz = 5x10x16

torch.Size([5, 10, 16])

In [9]:
lm_mog = LanguageModel(MogLSTM(50, 125,5), vocab)
mog_LSTM = train(lm_mog, N_EPOCHS, "./run/mog2")
mog_LSTM.train()

loss: 2.4673 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 2.0708 ||: 100%|██████████| 38/38 [00:07<00:00,  4.86it/s]
loss: 1.9501 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.8558 ||: 100%|██████████| 38/38 [00:07<00:00,  4.90it/s]
loss: 1.8036 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.7528 ||: 100%|██████████| 38/38 [00:07<00:00,  4.91it/s]
loss: 1.7202 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.6866 ||: 100%|██████████| 38/38 [00:07<00:00,  4.88it/s]
loss: 1.6657 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.6422 ||: 100%|██████████| 38/38 [00:07<00:00,  4.86it/s]
loss: 1.6272 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.6103 ||: 100%|██████████| 38/38 [00:07<00:00,  4.90it/s]
loss: 1.5986 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.5860 ||: 100%|██████████| 38/38 [00:07<00:00,  4.86it/s]
loss: 1.5762 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.5

loss: 1.4000 ||: 100%|██████████| 338/338 [03:42<00:00,  1.52it/s]
loss: 1.4157 ||: 100%|██████████| 38/38 [00:07<00:00,  4.76it/s]
loss: 1.3992 ||: 100%|██████████| 338/338 [03:43<00:00,  1.51it/s]
loss: 1.4150 ||: 100%|██████████| 38/38 [00:07<00:00,  4.90it/s]
loss: 1.3992 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.4141 ||: 100%|██████████| 38/38 [00:07<00:00,  4.84it/s]
loss: 1.3983 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.4135 ||: 100%|██████████| 38/38 [00:07<00:00,  4.85it/s]
loss: 1.3975 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.4135 ||: 100%|██████████| 38/38 [00:07<00:00,  4.84it/s]
loss: 1.3970 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.4134 ||: 100%|██████████| 38/38 [00:07<00:00,  4.85it/s]
loss: 1.3966 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.4132 ||: 100%|██████████| 38/38 [00:07<00:00,  4.88it/s]
loss: 1.3959 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.4

loss: 1.3781 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.3993 ||: 100%|██████████| 38/38 [00:07<00:00,  4.84it/s]
loss: 1.3783 ||: 100%|██████████| 338/338 [03:41<00:00,  1.52it/s]
loss: 1.3997 ||: 100%|██████████| 38/38 [00:07<00:00,  4.90it/s]
loss: 1.3781 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.3986 ||: 100%|██████████| 38/38 [00:07<00:00,  4.87it/s]
loss: 1.3778 ||: 100%|██████████| 338/338 [03:41<00:00,  1.52it/s]
loss: 1.3988 ||: 100%|██████████| 38/38 [00:07<00:00,  4.82it/s]
loss: 1.3776 ||: 100%|██████████| 338/338 [03:41<00:00,  1.52it/s]
loss: 1.3983 ||: 100%|██████████| 38/38 [00:07<00:00,  4.90it/s]
loss: 1.3778 ||: 100%|██████████| 338/338 [03:41<00:00,  1.52it/s]
loss: 1.3998 ||: 100%|██████████| 38/38 [00:07<00:00,  4.88it/s]
loss: 1.3776 ||: 100%|██████████| 338/338 [03:41<00:00,  1.53it/s]
loss: 1.3980 ||: 100%|██████████| 38/38 [00:07<00:00,  4.87it/s]
loss: 1.3770 ||: 100%|██████████| 338/338 [03:41<00:00,  1.52it/s]
loss: 1.3

{'best_epoch': 141,
 'peak_cpu_memory_MB': 2771.636,
 'peak_gpu_0_memory_MB': 1115,
 'training_duration': '9:25:44.141752',
 'training_start_epoch': 0,
 'training_epochs': 147,
 'epoch': 147,
 'training_loss': 1.3751222530060265,
 'training_cpu_memory_MB': 2771.636,
 'training_gpu_0_memory_MB': 1115,
 'validation_loss': 1.401227615381542,
 'best_validation_loss': 1.3973440904366343}


Based on our loss values and best epoch, it looks like our validation accuracy may be marginally better than the standard LSTM and it took about 60 fewer epochs to reach convergence. The Mogrifier LSTM makes no strong claims regarding increased speed (though they do claim lower complexity, which probably means they optimized much more than me), but they did indicate it acheives better accuracy. While promising, these results need to be tested with different datasets and neural networks to see how performance truly compares.

The writers of the Mogrifier LSTM paper have stated they will release their code, but so far this has not ocurred. We can try a few tweaks (different randomization, different test model, other datasets, different test methods, more training) to better verify what's going on, but we'll leave it at this for now and look for the code release to see if there's anything else that can be improved.

# Reloading Models

To reload models from checkpoints that are automatically saved (and either continue training or run inference) copy the following commands into a new cell and run (in this example we reload the first LSTM I created above). 


Note that your checkpoint may be automatically recognized and started from without needing to load it manually if it is located in the folder you are saving your training data to (in all the models above that's "./run/modelName").
<pre><code>
model2 = LanguageModel(NaiveLSTM(50, 125), vocab)
with open("run/lstm/model_state_epoch_2.th", 'rb') as f:
    model2.load_state_dict(torch.load(f))
</code></pre>

# Visualizing It

Now that we've implemented a few things, lets go ahead and visualize it all. We're using tensorboard since it does a bunch of plotting work for us. The one downside at the moment is that AllenNLP's training class does not give us the option of comparing graphs from different models, so for now we'll look at them separately. 

**Note: This visualization requires tensorboard, which you may not have installed. It is intended for use on Google Colaboratory**

In [None]:
#import matplotlib.pyplot as plt
#%matplotlib inline
%tensorboard --logdir "./run/mog2"

If on Google Colaboratory you can save and export all the data we just created:

In [None]:
#zip results
!zip -r all.zip ./run

In [None]:
#download/upload Colaboratory files to google drive
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!cp all.zip /content/drive/My\ Drive/all.zip