In [None]:
import os

import torch
from torch.utils.data import DataLoader, random_split
from torch.nn import Embedding

from utils.Dataset import IMDBDataset
from utils.Embedding import GloVeEmbedding

In [None]:
# Hyperparameters

seed = 123
torch.manual_seed(seed)

imdb_dir = './resources/aclImdb'
glove_dir = './resources/glove.6B'

seq_len = 100  # Max len of a seq
vocab_size = 10000  # Size of the tokenizer vocabulary
embedding_dim = 100  # Embedding layer dimension, one of {50, 100, 200, 300}

training_samples = 200  # Thanks to GloVe, a fewer samples are enough
validation_samples = 10000
batch_size = 32

epochs = 10
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Prepare train/valid dataset

train_dir = os.path.join(imdb_dir, 'train')

dataset = IMDBDataset(train_dir, seq_len, vocab_size)
tokenizer = dataset.tokenizer  # Save the tokenizer for reuse
rest_samples = len(dataset) - training_samples - validation_samples
train, valid, _ = random_split(
    dataset, [training_samples, validation_samples, rest_samples])
train_dataloader = DataLoader(train, batch_size, True)
valid_dataloader = DataLoader(valid, batch_size, True)

print(f'Length of train: {len(train)}')
print(f'Length of valid: {len(valid)}')

In [None]:
# Model definition

class SimpleModel(torch.nn.Module):

    def __init__(self, embedding=None):
        super().__init__()
        self.embedding = embedding or torch.nn.Embedding(vocab_size+1, embedding_dim, 0)
        self.flatten = torch.nn.Flatten()
        self.linear_stack = torch.nn.Sequential(
            torch.nn.Linear(seq_len*embedding_dim, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, X):
        X = self.embedding(X)
        X = self.flatten(X)
        X = self.linear_stack(X)
        return X

glove_embedding = GloVeEmbedding(
    glove_dir, vocab_size, embedding_dim, 0, tokenizer.word_index)
model = SimpleModel(embedding=glove_embedding).to(device)
print(model)

In [None]:
# %% Training and validation

loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.RMSprop(model.parameters(), 1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    current = 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)

        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss, current = loss.item(), current+len(X)
        print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')

def validate(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss += loss_fn(pred, y).item()
            correct += ((pred>0.5).long() == y.long()).sum().item()
    
    loss /= num_batches
    correct /= size

    print(f'Validation error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {loss:>8f} \n')
    return loss, correct

for epoch in range(epochs):
    print(f'Epoch {epoch+1}\n-------------------------------')
    train(train_dataloader, model, loss_fn, optimizer)
    validate(valid_dataloader, model, loss_fn)

print('Training done!')

In [None]:
# Test