In [1]:
import torch
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import math

In [2]:
class LSTM(torch.nn.Module):
    def __init__(self,hidden_dim,embedding_dim, vocab_size):
        super().__init__()
        self.input_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)


        self.Ui = torch.nn.Parameter(torch.Tensor(embedding_dim,hidden_dim))
        self.Vi = torch.nn.Parameter(torch.Tensor(hidden_dim,hidden_dim))
        self.bi = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.Uf = torch.nn.Parameter(torch.Tensor(embedding_dim,hidden_dim))
        self.Vf = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.bf = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.Uc = torch.nn.Parameter(torch.Tensor(embedding_dim,hidden_dim))
        self.Vc = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.bc = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.Uo = torch.nn.Parameter(torch.Tensor(embedding_dim,hidden_dim))
        self.Vo = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.bo = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.linear = torch.nn.Linear(hidden_dim, 2)


        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_dim)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, init_states=None):
        bs, seq_sz = x.size()
        hidden_seq = []
        x = self.embedding(x)

        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_dim).to(x.device),
                torch.zeros(bs, self.hidden_dim).to(x.device),
            )
        else:
            h_t, c_t = init_states

        for t in range(seq_sz):
            x_t = x[:,t,:]

            i_t = torch.sigmoid(x_t @ self.Ui + h_t @ self.Vi + self.bi)
            f_t = torch.sigmoid(x_t @ self.Uf + h_t @ self.Vf + self.bf)
            o_t = torch.sigmoid(x_t @ self.Uo + h_t @ self.Vo + self.bo)
            g_t = torch.tanh(x_t @ self.Uc + h_t @ self.Vc + self.bc)

            c_t = f_t * c_t + i_t * g_t 
            h_t = o_t * torch.tanh(c_t )

            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim = 0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        logits = self.linear(h_t)  # last timestep hidden state

        return logits, hidden_seq, (h_t, c_t)










In [3]:
import torchtext
from torchtext.experimental.datasets import IMDB

In [4]:
import os

# Make sure download directory exists
os.makedirs(".data", exist_ok=True)

from torchtext.experimental.datasets import IMDB

train_dataset, test_dataset = IMDB(root=".data")


25000lines [00:02, 9480.63lines/s]


In [5]:
train_list = list(train_dataset)
test_list = list(test_dataset)

print("Train size:", len(train_list))
print("Test size:", len(test_list))
print("Example:", train_list[0])

Train size: 25000
Test size: 25000
Example: (tensor(0), tensor([   13,  1568,    13,   246, 35468,    43,    64,   398,  1135,    92,
            7,    37,     2,  7126,    15,  3363,    11,    60,    11,    17,
           94,   629,    12,  6921,     3,    13,    87,   553,    15,    38,
           94,    11,    17, 20193,    40,  1225,     3,    16,     3,  9263,
           51,    11,   131,   780,     8,  2480,    14,   682,     4,  1575,
          118,     6,   342,     7,   114,  1160,  3052,    13,    72,    75,
            8,    74,    14,    19,   537,     3,     2,   121,    10,  5959,
          194,     6,   191,  3862,   474,  1424,   766,  4314,    42,   489,
            8,   834,   287,    61,    58,    50,   127,     3,    12,   826,
           61,   489,     8,  1132,    47, 11859,     8,   257,    56,   441,
            7,   669,    28,    54,     2,   863, 29737,   209,    50,   781,
         1001,  1304,   147,    18,     2,  2675,   337,     5,  1510,  1304,
        

In [6]:
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_data, val_data = torch.utils.data.random_split(train_dataset, [train_size, val_size])

print("Train size:", len(train_data))
print("Validation size:", len(val_data))


Train size: 20000
Validation size: 5000


In [7]:
from torch.nn.utils.rnn import pad_sequence
def collate_batch(batch):
    """
    batch: list of tuples (label_tensor, sequence_tensor)
    """
    labels = torch.tensor([entry[0].item() for entry in batch])  # shape (batch_size,)
    sequences = [entry[1] for entry in batch]  # list of tensors
    
    # pad sequences to max length in this batch
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
    
    return padded_sequences, labels


In [8]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
)

val_loader = torch.utils.data.DataLoader(
    val_data, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
)

test_loader = torch.utils.data.DataLoader(
    test_list, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
)


In [9]:
vocab_size = max(max(seq[1].tolist()) for seq in train_list) + 1

In [10]:
vocab_size

100684

In [11]:
model = LSTM(32,64,vocab_size)
num_epochs = 10
batch_size = 64
loss_fn = torch.nn.CrossEntropyLoss()
learning_rate=1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [16]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    total_loss = 0

    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)[0]
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

        if batch % 50 == 0:
            current = batch * X.size(0)
            print(f"loss: {loss.item():>7f}  [{current:>5d}/{size:>5d}]")
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

def val_loop(dataloader, model, loss_fn):
    model.eval()
    total_loss, correct = 0, 0
    size = len(dataloader.dataset)

    with torch.no_grad():
        for X, y in dataloader:
            pred, _, _ = model(X)
            total_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / size
    return avg_loss, accuracy

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train_loop(train_loader, model, loss_fn, optimizer)
    val_loss, val_acc = val_loop(val_loader, model, loss_fn)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {100*val_acc:.2f}%\n")
print("Done!") 

Epoch 1
-------------------------------
loss: 0.690272  [    0/20000]
loss: 0.705965  [ 1600/20000]
loss: 0.698567  [ 3200/20000]
loss: 0.689851  [ 4800/20000]
loss: 0.693281  [ 6400/20000]
loss: 0.700279  [ 8000/20000]
loss: 0.696541  [ 9600/20000]
loss: 0.684629  [11200/20000]
loss: 0.698404  [12800/20000]
loss: 0.699753  [14400/20000]
loss: 0.697020  [16000/20000]
loss: 0.681342  [17600/20000]
loss: 0.686454  [19200/20000]
Train Loss: 0.6929 | Val Loss: 0.6938 | Val Acc: 50.92%

Epoch 2
-------------------------------
loss: 0.681326  [    0/20000]
loss: 0.676287  [ 1600/20000]
loss: 0.684601  [ 3200/20000]
loss: 0.691303  [ 4800/20000]
loss: 0.701291  [ 6400/20000]
loss: 0.694792  [ 8000/20000]
loss: 0.695791  [ 9600/20000]
loss: 0.696257  [11200/20000]
loss: 0.697408  [12800/20000]
loss: 0.687193  [14400/20000]
loss: 0.685558  [16000/20000]
loss: 0.706288  [17600/20000]
loss: 0.698996  [19200/20000]
Train Loss: 0.6897 | Val Loss: 0.6935 | Val Acc: 51.10%

Epoch 3
------------------