In [38]:
from models import BaselineClassifier
from data import load_imdb, load_imdb_synth, load_xor
import torch
import torch.nn as nn
from torch import device
from q1 import pad_batch
import random
import torch.nn.functional as F
import itertools

In [39]:
(x_train_1, y_train_1), (x_val_1, y_val_1), (i2w_1, w2i_1), numcls_1 = load_imdb(final=False)
train_data1 = (x_train_1, y_train_1)
val_data1   = (x_val_1, y_val_1)

In [40]:
(x_train_2, y_train_2), (x_val_2, y_val_2), (i2w_2, w2i_2), numcls_2 = load_imdb_synth()
train_data2 = (x_train_2, y_train_2)
val_data2   = (x_val_2, y_val_2)

In [41]:
(x_train_3, y_train_3), (x_val_3, y_val_3), (i2w_3, w2i_3), numcls_3 = load_xor()
train_data3 = (x_train_3, y_train_3)
val_data3   = (x_val_3, y_val_3)

In [42]:
def iterate_batches(dataset, batch_size, pad_idx, shuffle=True):
    """
    dataset: (x_list, y_list)
    returns a list of (x_batch, y_batch) tuples
    """
    x_data, y_data = dataset
    indices = list(range(len(x_data)))

    batches = []
    for start in range(0, len(indices), batch_size):
        batch_idx = indices[start:start + batch_size]
        x_seqs = [x_data[j] for j in batch_idx]
        y_labels = [y_data[j] for j in batch_idx]

        x = pad_batch(x_seqs, pad_idx)              # (B, T)
        y = torch.tensor(y_labels, dtype=torch.long)  # (B,)
        batches.append((x, y))
    return batches


In [43]:
def train_epochs(model, train_data, batch_size, pad_idx, optimizer, num_epochs=5):

    for epoch in range(1, num_epochs + 1):
        total_loss = 0.0       
        total_correct = 0
        total_examples = 0

        print(f"\nEpoch {epoch}/{num_epochs}")

        for x, y in iterate_batches(train_data, batch_size, pad_idx, shuffle=True):

            optimizer.zero_grad()
            output = model(x)
            loss = F.cross_entropy(output, y)

            loss.backward()
            optimizer.step()

            # stats for this batch
            batch_size_actual = x.size(0)
            total_loss += loss.item() * batch_size_actual
            preds = output.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_examples += batch_size_actual

        # epoch metrics
        avg_loss = total_loss / total_examples
        acc = total_correct / total_examples
        
        print(f"Training loss: {avg_loss:.4f}  |  accuracy: {acc:.4f}")

    return avg_loss, acc

In [44]:
def evaluate(model, val_data, batch_size, pad_idx):
    total_loss = 0.0
    total_correct = 0
    total_examples = 0

    with torch.no_grad():
        for x, y in iterate_batches(val_data, batch_size, pad_idx, shuffle=False):
            output = model(x)
            loss = F.cross_entropy(output, y)

            batch_size_actual = x.size(0)
            total_loss += loss.item() * batch_size_actual

            preds = output.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_examples += batch_size_actual

    avg_loss = total_loss / total_examples
    acc = total_correct / total_examples
    return avg_loss, acc

In [45]:
baseline = BaselineClassifier(vocab_size=len(i2w_1))
optimizer = torch.optim.Adam(baseline.parameters(), lr=0.001)
batch_size = 64
pad_idx1 = w2i_1['.pad']

In [46]:
def grid_search(train_data, val_data, vocab_size, num_classes, pad_idx):
    pools = ['mean', 'max', 'first']
    lrs = [1e-3, 3e-4]
    batch_sizes = [64, 128]

    results = []

    for batch_size in batch_sizes:
        for lr, pool in itertools.product(lrs, pools):
            # fresh model + optimizer for each run
            model = BaselineClassifier(vocab_size=vocab_size, emb_dim=300, num_classes=num_classes, pool=pool)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)

            train_loss, train_acc = train_epochs(model, train_data, batch_size, pad_idx, optimizer)
            val_loss, val_acc = evaluate(model, val_data, batch_size, pad_idx)

            print(f'pool={pool}, lr={lr}, batch={batch_size} | train_acc={train_acc:.3f}, val_acc={val_acc:.3f}')
            results.append((pool, lr, batch_size, train_acc, val_acc))

    return results

In [None]:
results1 = grid_search(train_data1, val_data1, vocab_size=len(i2w_1), num_classes=numcls_1, pad_idx=pad_idx1)



Epoch 1/5
Training loss: 0.6114  |  accuracy: 0.7201

Epoch 2/5
Training loss: 0.4179  |  accuracy: 0.8486

Epoch 3/5
Training loss: 0.3012  |  accuracy: 0.8946

Epoch 4/5
Training loss: 0.2366  |  accuracy: 0.9212

Epoch 5/5
Training loss: 0.1922  |  accuracy: 0.9390
pool=mean, lr=0.001, batch=64 | train_acc=0.939, val_acc=0.885

Epoch 1/5
Training loss: 0.6754  |  accuracy: 0.5863

Epoch 2/5
Training loss: 0.5780  |  accuracy: 0.7216

Epoch 3/5
Training loss: 0.4780  |  accuracy: 0.8062

Epoch 4/5
Training loss: 0.3926  |  accuracy: 0.8539

Epoch 5/5
Training loss: 0.3298  |  accuracy: 0.8795
pool=max, lr=0.001, batch=64 | train_acc=0.879, val_acc=0.827

Epoch 1/5
Training loss: 0.7072  |  accuracy: 0.5228

Epoch 2/5
Training loss: 0.6827  |  accuracy: 0.5545

Epoch 3/5
Training loss: 0.6682  |  accuracy: 0.5706

Epoch 4/5
Training loss: 0.6538  |  accuracy: 0.5842

Epoch 5/5
Training loss: 0.6406  |  accuracy: 0.5917
pool=first, lr=0.001, batch=64 | train_acc=0.592, val_acc=0.548



In [None]:
results2 = grid_search(train_data2, val_data2, vocab_size=len(i2w_2), num_classes=numcls_2, pad_idx=w2i_2['.pad'])


In [None]:
results3 = grid_search(train_data3, val_data3, vocab_size=len(i2w_3), num_classes=numcls_3, pad_idx=w2i_3['.pad'])