# Classifying text with LSTM!
In this Notebook we'll first introduce techniques around processing text data. The field of Natural Language Processing (NLP) has many techniques that we have not yet looked into in this series, so we'll use this notebook to introduce them so that we can look at using text data in the future! 

[<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/9/93/LSTM_Cell.svg/2880px-LSTM_Cell.svg.png">](LSTM)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F

# We'll be using Pytorch's text library called torchtext! 
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T

from tqdm.notebook import trange, tqdm

In [None]:
# Define the hyperparameters
learning_rate = 1e-4

nepochs = 20

batch_size = 32

max_len = 128
data_set_root = "../../datasets"

# We'll be using the AG News Dataset
# Which contains a short news article and a single label to classify the "type" of article
# Note that for torchtext these datasets are NOT Pytorch dataset classes "AG_NEWS" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58
dataset_train = AG_NEWS(root=data_set_root, split="train")
dataset_test = AG_NEWS(root=data_set_root, split="test")

In [None]:
# The tokenizer is the method by which we split the sentance into "chunks" or "tokens"
tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

# The vocab is all the unique tokens contained within our dataset
# and provides each token with it's own integer index.

# We will also add "special" tokens that we'll use to signal something to our model
# <pad> is a padding token that is added to the end of a sentance to ensure 
# the length of all sequences in a batch is the same
# <sos> signals the "Start-Of-Sentence" aka the start of the sequence
# <eos> signal the "End-Of-Sentence" aka the end of the sequence
# <unk> "unknown" token is used if a token is not contained in the vocab
vocab = build_vocab_from_iterator(
    yield_tokens(dataset_train),
    min_freq=2, # Only include a token if it appears more than 2 times in the dataset
    specials= ['<pad>', '<sos>', '<eos>', '<unk>'], # special case tokens
    special_first=True
)

# Set the <unk> "unknown" token as the default token
vocab.set_default_index(vocab['<unk>'])

In [None]:
# Lets have a look at the vocab!
vocab.get_itos()

In [None]:
# We cab define 
text_tranform = T.Sequential(
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)

In [None]:
text_tokenizer = lambda batch: [tokenizer(x) for x in batch]
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
class LSTM(nn.Module):
    def __init__(self, num_emb, output_size, num_layers=1, hidden_size=128):
        super(LSTM, self).__init__()
        
        # Create an embedding for each token
        self.embedding = nn.Embedding(num_emb, hidden_size)
        
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, 
                            num_layers=num_layers, batch_first=True, dropout=0.5)
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, input_seq, hidden_in, mem_in):
        input_embs = self.embedding(input_seq)

        output, (hidden_out, mem_out) = self.lstm(input_embs, (hidden_in, mem_in))
                
        return self.fc_out(output), hidden_out, mem_out

In [None]:
device = torch.device(1 if torch.cuda.is_available() else 'cpu')

In [None]:
hidden_size = 64
num_layers = 3

# Create model
lstm_classifier = LSTM(num_emb=len(vocab), output_size=4, 
                       num_layers=num_layers, hidden_size=hidden_size).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(lstm_classifier.parameters(), lr=learning_rate)

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in lstm_classifier.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
training_loss_logger = []
test_loss_logger = []

training_acc_logger = []
test_acc_logger = []

In [None]:
pbar = trange(0, nepochs, leave=False, desc="Epoch")    
train_acc = 0
test_acc = 0
for epoch in pbar:
    pbar.set_postfix_str('Accuracy: Train %.2f%%, Test %.2f%%' % (train_acc * 100, test_acc * 100))
    
    lstm_classifier.train()
    steps = 0
    for label, text in tqdm(data_loader_train, desc="Training", leave=False):
        bs = label.shape[0]
        text_tokens = text_tranform(text_tokenizer(text)).to(device)
        label = (label - 1).to(device)
        
        hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
        memory = torch.zeros(num_layers, bs, hidden_size, device=device)
        pred, hidden, memory = lstm_classifier(text_tokens, hidden, memory)

        loss = loss_fn(pred[:, -1, :], label)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_loss_logger.append(loss.item())
        
        train_acc += (pred[:, -1, :].argmax(1) == label).sum()
        steps += bs
        
    train_acc = (train_acc/steps).item()
    training_acc_logger.append(train_acc)
    
    lstm_classifier.eval()
    steps = 0
    with torch.no_grad():
        for label, text in tqdm(data_loader_test, desc="Testing", leave=False):
            bs = label.shape[0]
            text_tokens = text_tranform(text_tokenizer(text)).to(device)
            label = (label - 1).to(device)

            hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
            memory = torch.zeros(num_layers, bs, hidden_size, device=device)
            pred, hidden, memory = lstm_classifier(text_tokens, hidden, memory)

            loss = loss_fn(pred[:, -1, :], label)
            test_loss_logger.append(loss.item())

            test_acc += (pred[:, -1, :].argmax(1) == label).sum()
            steps += bs

        test_acc = (test_acc/steps).item()
        test_acc_logger.append(test_acc)

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(np.linspace(0, nepochs, len(training_loss_logger)), training_loss_logger)
_ = plt.plot(np.linspace(0, nepochs, len(test_loss_logger)), test_loss_logger)

_ = plt.legend(["Train", "Test"])
_ = plt.title("Training Vs Test Loss")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Loss")

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(np.linspace(0, nepochs, len(training_acc_logger)), training_acc_logger)
_ = plt.plot(np.linspace(0, nepochs, len(test_acc_logger)), test_acc_logger)

_ = plt.legend(["Train", "Test"])
_ = plt.title("Training Vs Test Accuracy")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Accuracy")