In [1]:
from collections import Counter

import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchtext.vocab import Vocab
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer

# Prepare Data

In [2]:
tokenizer = get_tokenizer('basic_english')

train_iter, test_iter = AG_NEWS()
train_dataset, test_dataset = list(train_iter), list(test_iter)

counter = Counter()
counter.update(['DUMMY'])
for (label, line) in train_dataset:
    counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)

In [3]:
embedding_dim = 32
num_timesteps = 50
batch_size = 128

In [4]:
embedding_pipeline = lambda x: [vocab[token] for token in x]
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    label_list, idx_list = [], []
    
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        
        tokens = tokenizer(_text)
        if len(tokens) >= num_timesteps:
            tokens = tokens[:num_timesteps]
        else:
            for _ in range(num_timesteps-len(tokens)):
                tokens.append('DUMMY')

        idx_list.append(embedding_pipeline(tokens))
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    idx_list = np.array(idx_list)

    return label_list, idx_list

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

# Define LSTM Model

In [5]:
class LSTMModel(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LSTMModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        
        self.embedding = nn.Embedding(len(vocab), embedding_dim)

        # Parameters for Forget Gate
        self.f_x, self.f_h, self.f_b = self._get_params()

        # Parameters for Output Gate
        self.o_x, self.o_h, self.o_b = self._get_params()

        # Parameters for Input Gate
        self.i_x, self.i_h, self.i_b = self._get_params()
        self.i2_x, self.i2_h, self.i2_b = self._get_params()
        
        # Dense Layer for final output
#         self.fc = nn.Linear(hidden_dim, 32)
#         self.dropout = nn.Dropout(p=0.5)
#         self.logits = nn.Linear(32, output_dim)

        self.logits = nn.Linear(hidden_dim, output_dim)
        
    def _get_params(self):
        
        x = nn.Parameter(torch.randn([self.input_dim, self.hidden_dim], requires_grad=True, dtype=torch.float32))
        h = nn.Parameter(torch.randn([self.hidden_dim, self.hidden_dim], requires_grad=True, dtype=torch.float32))
        b = nn.Parameter(torch.randn([1, self.hidden_dim], requires_grad=True, dtype=torch.float32))

        return x, h, b
    
    def _lstm_cell(self, embedded_input, h, state):
                
        # Forget Gate Calculation
        forget_gate = torch.sigmoid(torch.matmul(embedded_input, self.f_x) + torch.matmul(h, self.f_h) + self.f_b)  
        
        # Output Gate Calculation
        output_gate = torch.sigmoid(torch.matmul(embedded_input, self.o_x) + torch.matmul(h, self.o_h) + self.o_b)

        # Input Gate Calculation
        input_gate = torch.sigmoid(torch.matmul(embedded_input, self.i_x) + torch.matmul(h, self.i_h) + self.i_b) 
        input2_state = torch.tanh(torch.matmul(embedded_input, self.i2_x) + torch.matmul(h, self.i2_h) + self.i2_b)

        # New State after the LSTM Cell
        state = input2_state * input_gate + state * forget_gate
        
        # New Output from the LSTM Cell
        h = output_gate * torch.tanh(state)
        
        return h, state
    
    def forward(self, x):
        
        # x.shape -> [batch_size, 50, 32] 

        x = self.embedding(torch.from_numpy(x))
 
        h = torch.randn([x.shape[0], hidden_dim], dtype=torch.float32)
        state = torch.randn([x.shape[0], hidden_dim], dtype=torch.float32)
        
        for seq in range(x.shape[1]):
            h, state = self._lstm_cell(x[:,seq,:], h, state) 

        # Flatten the Last Output of the LSTM
        flatten_h = h.squeeze()
        
#         fc1 = torch.relu(self.fc(flatten_h))
#         fc2 = self.dropout(fc1)
#         logits = self.logits(fc2)

        logits = self.logits(flatten_h)
    
        return logits

# Train and Test

In [6]:
hidden_dim = 32
output_dim = 10
 
model = LSTMModel(embedding_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

In [7]:
def evaluate():
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, embeds) in enumerate(test_dataloader):
            predited_label = model(embeds)
            loss = criterion(predited_label, label)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [8]:
def train(num_epoch, iter_debug=True):
    model.train()

    for epoch in range(num_epoch):
        iter = 0

        for labels, embeds in train_dataloader:

            optimizer.zero_grad()

            outputs = model(embeds)
            loss = criterion(outputs, labels)
            loss.backward()

            #torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()

            iter += 1
            if iter % 200 == 0 and iter_debug:
                print('Epoc {} - Iteration {} - Loss: {}'.format(epoch, iter, loss.item()))   

        accuracy = evaluate()
        print('Done Training with Epoch {} - Loss: {}. Test Accuracy: {}'.format(epoch, loss.item(), accuracy))  

In [9]:
train(10, iter_debug=False)

Done Training with Epoch 0 - Loss: 1.4032599925994873. Test Accuracy: 0.2988157894736842
Done Training with Epoch 1 - Loss: 1.2664052248001099. Test Accuracy: 0.41736842105263156
Done Training with Epoch 2 - Loss: 0.9979702234268188. Test Accuracy: 0.5348684210526315
Done Training with Epoch 3 - Loss: 0.826845645904541. Test Accuracy: 0.6551315789473684
Done Training with Epoch 4 - Loss: 0.6566196084022522. Test Accuracy: 0.703421052631579
Done Training with Epoch 5 - Loss: 0.4938162863254547. Test Accuracy: 0.7343421052631579
Done Training with Epoch 6 - Loss: 0.4425247311592102. Test Accuracy: 0.7527631578947368
Done Training with Epoch 7 - Loss: 0.37164077162742615. Test Accuracy: 0.7646052631578948
Done Training with Epoch 8 - Loss: 0.33144453167915344. Test Accuracy: 0.7735526315789474
Done Training with Epoch 9 - Loss: 0.30023258924484253. Test Accuracy: 0.7806578947368421
