In [4]:
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from collections import Counter
import random
import time
import re
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.optim as optim

SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

imdb = load_dataset('imdb')
train_test_split = imdb['train'].train_test_split(test_size=0.2, seed=SEED)
train_data, valid_data = train_test_split['train'], train_test_split['test']
test_data = imdb['test']

MAX_VOCAB_SIZE = 25_000
PAD_IDX, UNK_IDX = 0, 1

def tokenize(text):
    # Basic tokenization using regex to split by whitespace and punctuation
    return re.findall(r'\b\w+\b', text.lower())

def build_vocab(dataset, max_size):
    counter = Counter()
    for example in dataset:
        tokens = tokenize(example['text'])
        counter.update(tokens)
    vocab = ['<pad>', '<unk>'] + [word for word, _ in counter.most_common(max_size-2)]
    return {word: idx for idx, word in enumerate(vocab)}

word2idx = build_vocab(train_data, MAX_VOCAB_SIZE)

class IMDBDataset(Dataset):
    def __init__(self, data, word2idx):
        self.data = data
        self.word2idx = word2idx

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        label = 1.0 if self.data[idx]['label'] else 0.0  # Convert to float

        # Tokenize and numericalize
        tokens = tokenize(text)
        indices = [self.word2idx.get(token, UNK_IDX) for token in tokens]

        return torch.tensor(indices, dtype=torch.long), torch.tensor(label, dtype=torch.float)

train_dataset = IMDBDataset(train_data, word2idx)
valid_dataset = IMDBDataset(valid_data, word2idx)
test_dataset = IMDBDataset(test_data, word2idx)

def collate_batch(batch):
    texts, labels = zip(*batch)
    padded_texts = pad_sequence(texts, padding_value=PAD_IDX)
    lengths = torch.tensor([len(text) for text in texts])
    return padded_texts, torch.tensor(labels)

BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_batch)

class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=PAD_IDX)
        self.rnn = nn.RNN(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        output, hidden = self.rnn(embedded)
        return self.fc(hidden.squeeze(0))

INPUT_DIM = len(word2idx)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1
model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

def binary_accuracy(preds, y):
    # Round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()  # Convert boolean to float for summation
    acc = correct.sum() / len(correct)
    return acc

def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for batch in iterator:
        optimizer.zero_grad()

        # Unpack the batch
        texts, labels = batch

        predictions = model(texts).squeeze(1)

        loss = criterion(predictions, labels)
        acc = binary_accuracy(predictions, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):
    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():
        for batch in iterator:
            # Unpack the batch
            texts, labels = batch

            predictions = model(texts).squeeze(1)

            loss = criterion(predictions, labels)
            acc = binary_accuracy(predictions, labels)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

N_EPOCHS = 2

best_valid_loss = float('inf')
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')


Epoch: 01 | Epoch Time: 1m 13s
	Train Loss: 0.694 | Train Acc: 50.02%
	 Val. Loss: 0.694 |  Val. Acc: 48.85%
Epoch: 02 | Epoch Time: 1m 26s
	Train Loss: 0.699 | Train Acc: 50.19%
	 Val. Loss: 0.695 |  Val. Acc: 49.29%
