In [1]:
import sys
import os
import torch as T
import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
SEED = 42

device = T.device('cuda' if T.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


T.manual_seed(SEED)

if T.cuda.is_available():
    T.cuda.manual_seed_all(SEED)

Using device: cuda


## Dataset Loading

In [None]:
from torch.utils.data import DataLoader
from datasets.custom_dataset import CustomWCSTDataset

### 1. Dataset Hyperparameters

In [None]:
BATCH_SIZE = 64
TOTAL_BATCHES = 500
TRAIN_TEST_SPLIT_RATIO = 0.6
VALIDATION_TEST_SPLIT_RATIO = 0.5

### 2. Loading Dataset

In [None]:
train_size = int(TOTAL_BATCHES * TRAIN_TEST_SPLIT_RATIO)
validation_size = int((TOTAL_BATCHES - train_size) * VALIDATION_TEST_SPLIT_RATIO)
test_size = TOTAL_BATCHES - train_size - validation_size

train_datasets = {
    "color": CustomWCSTDataset(total_batches=train_size // 2, fixed_context=0, sample_batch_size=BATCH_SIZE, allow_switch=False),
    "shape": CustomWCSTDataset(total_batches=train_size // 2, fixed_context=1, sample_batch_size=BATCH_SIZE, allow_switch=False),
    "quantity": CustomWCSTDataset(total_batches=train_size // 2, fixed_context=2, sample_batch_size=BATCH_SIZE, allow_switch=False)
}

train_loaders = {
    ctx: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    for ctx, ds in train_datasets.items()
}

validation_dataset = CustomWCSTDataset(
        total_batches=validation_size, sample_batch_size=BATCH_SIZE
    )
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = CustomWCSTDataset(
        total_batches=validation_size, sample_batch_size=BATCH_SIZE
    )
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

## Transformer Model Creation

In [7]:
from src.transformer import Transformer

### 1. Transformer Hyperparameters

In [None]:
VOCABULARY_SIZE = 70        # 64 cards + 4 categories + SEP + EOS
EMBEDDING_SIZE = 216        # larger embedding to capture card features
N_ATTENTION_HEADS = 6       # more heads for better multi-feature attention
N_BLOCKS = 3                # same depth as before
MAX_SEQUENCE_LENGTH = 10    # longer max sequence to accommodate multiple past trials
FF_DIMS = 256               # larger feedforward layer for better representation
DROPOUT_PROB = 0.2        # reduce dropout slightly to retain signal in small batches
CARD_DIMS = (4, 4, 4)

### 2. Transformer Initialisation

In [9]:
transformer = Transformer(
    VOCABULARY_SIZE, VOCABULARY_SIZE, CARD_DIMS, EMBEDDING_SIZE, N_ATTENTION_HEADS,
    N_BLOCKS, MAX_SEQUENCE_LENGTH, FF_DIMS, DROPOUT_PROB, device=device
)

## Training Transformer

In [10]:
import itertools
import numpy as np
import torch as T
from torch.utils.data import DataLoader
from torch import nn, optim

### 1. Train, Validate, Evaluate Model Functions

In [11]:
def train_model(
    train_loader: DataLoader,
    validation_loader: DataLoader,
    model: Transformer,
    criterion: nn.CrossEntropyLoss,
    optimizer: optim.Optimizer,
    max_epochs: int = 20,
    device: str | T.device = "cpu",
    patience: int = 3,
):
    best_val_loss = np.inf
    patience_counter = 0

    train_losses, train_accs, train_perplexities = [], [], []
    val_losses, val_accs, val_perplexities = [], [], []
    best_model_state = model.state_dict()

    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch + 1}/{max_epochs}")
        print("-" * 40)

        # --- Training ---
        model.train()
        epoch_train_losses = []
        total_correct = 0
        total_samples = 0

        for batch_idx, (encoder_input, decoder_input, target) in enumerate(train_loader):
            encoder_input, decoder_input, target = encoder_input.to(device), decoder_input.to(device), target.view(-1).to(device)
    
            # Forward pass
            logits = model(encoder_input, decoder_input)  # [batch, seq_len, vocab]
            logits = logits[:, -1, :]  # only the final step prediction [batch, vocab]

            loss = criterion(logits, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            preds = logits.argmax(dim=1)
            total_correct += (preds == target).sum().item()
            total_samples += target.size(0)

            epoch_train_losses.append(loss.item())

            if batch_idx % 100 == 0 or batch_idx == len(train_loader) - 1:
                print(f"Train Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")

        train_loss = np.mean(epoch_train_losses)
        train_acc = total_correct / total_samples
        train_perplexity = np.exp(train_loss)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_perplexities.append(train_perplexity)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Perplexity: {train_perplexity:.4f}")

        # --- Validation ---
        model.eval()
        val_batch_losses = []
        val_correct = 0
        val_samples = 0

        with T.no_grad():
            for encoder_input, decoder_input, target in validation_loader:
                encoder_input, decoder_input, target = encoder_input.to(device), decoder_input.to(device), target.view(-1).to(device)

                logits = model(encoder_input, decoder_input) # [batch, seq_len, vocab]
                logits = logits[:, -1, :]  # only the final step prediction [batch, vocab]

                loss = criterion(logits, target)

                preds = logits.argmax(dim=1)
                val_correct += (preds == target).sum().item()
                val_samples += target.size(0)

                val_batch_losses.append(loss.item())

        val_loss = np.mean(val_batch_losses)
        val_acc = val_correct / val_samples
        val_perplexity = np.exp(val_loss)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_perplexities.append(val_perplexity)

        print(f"[Epoch {epoch+1}] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Perplexity: {val_perplexity:.4f}")
        print("-" * 40)

        # --- Early Stopping ---
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     patience_counter = 0
        #     best_model_state = model.state_dict()
        #     print(f"Validation loss improved — saving model (Loss: {val_loss:.4f})")
        # else:
        #     patience_counter += 1
        #     print(f"No improvement ({patience_counter}/{patience})")
        #     if patience_counter >= patience:
        #         print("\nEarly stopping triggered. Restoring best model.")
        #         model.load_state_dict(best_model_state)
        #         break

    print("\nTraining complete")

    return {
        "train_losses": train_losses,
        "train_accs": train_accs,
        "train_perplexities": train_perplexities,
        "val_losses": val_losses,
        "val_accs": val_accs,
        "val_perplexities": val_perplexities,
        "best_val_loss": best_val_loss,
    }


In [12]:
def train_model_round_robin(
    train_loaders: dict[str, DataLoader], validation_loader: DataLoader, model: nn.Module,
    criterion: nn.CrossEntropyLoss, optimizer: optim.Optimizer, max_epochs: int = 20,
    device: str | T.device = "cpu", patience: int = 3,
):
    best_val_loss = np.inf
    patience_counter = 0
    best_model_state = model.state_dict()

    history = {k: [] for k in [
        "train_losses", "train_accs", "train_perplexities",
        "val_losses", "val_accs", "val_perplexities"
    ]}

    for epoch in range(max_epochs):
        print(f"\n[Epoch {epoch + 1}/{max_epochs}]")
        print("-" * 60)

        model.train()
        epoch_losses, total_correct, total_samples = [], 0, 0

        # --- Build a round-robin iterator across all loaders ---
        loaders_cycle = itertools.cycle(train_loaders.items())
        active_iters = {ctx: iter(dl) for ctx, dl in train_loaders.items()}

        # Find smallest loader length to roughly balance epoch size
        min_len = min(len(dl) for dl in train_loaders.values())
        total_batches = min_len * len(train_loaders)

        for batch_idx in range(total_batches):
            context, _ = next(loaders_cycle)
            loader_iter = active_iters[context]

            try:
                encoder_input, decoder_input, target = next(loader_iter)
            except StopIteration:
                # Restart exhausted iterator
                active_iters[context] = iter(train_loaders[context])
                encoder_input, decoder_input, target = next(active_iters[context])

            encoder_input, decoder_input, target = (
                encoder_input.to(device),
                decoder_input.to(device),
                target.view(-1).to(device)
            )

            # Forward pass
            logits = model(encoder_input, decoder_input)
            logits = logits[:, -1, :]  # predict final token only
            loss = criterion(logits, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            preds = logits.argmax(dim=1)
            total_correct += (preds == target).sum().item()
            total_samples += target.size(0)
            epoch_losses.append(loss.item())

            if batch_idx % 50 == 0 or batch_idx == total_batches - 1:
                print(f"[{context}] Batch {batch_idx+1}/{total_batches} | Loss: {loss.item():.4f}")

        # --- Epoch stats ---
        train_loss = np.mean(epoch_losses)
        train_acc = total_correct / total_samples
        train_perplexity = np.exp(train_loss)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | Perplexity: {train_perplexity:.4f}")

        # --- Validation ---
        model.eval()
        val_losses, val_correct, val_samples = [], 0, 0

        with T.no_grad():
            for encoder_input, decoder_input, target in validation_loader:
                encoder_input, decoder_input, target = (
                    encoder_input.to(device),
                    decoder_input.to(device),
                    target.view(-1).to(device)
                )
                logits = model(encoder_input, decoder_input)
                logits = logits[:, -1, :]
                loss = criterion(logits, target)
                preds = logits.argmax(dim=1)
                val_correct += (preds == target).sum().item()
                val_samples += target.size(0)
                val_losses.append(loss.item())

        val_loss = np.mean(val_losses)
        val_acc = val_correct / val_samples
        val_perplexity = np.exp(val_loss)

        print(f"[Epoch {epoch+1}] Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | Perplexity: {val_perplexity:.4f}")
        print("-" * 60)

        # --- Logging ---
        history["train_losses"].append(train_loss)
        history["train_accs"].append(train_acc)
        history["train_perplexities"].append(train_perplexity)
        history["val_losses"].append(val_loss)
        history["val_accs"].append(val_acc)
        history["val_perplexities"].append(val_perplexity)

        # --- Early stopping ---
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     best_model_state = model.state_dict()
        #     patience_counter = 0
        # else:
        #     patience_counter += 1
        #     if patience_counter >= patience:
        #         print("Early stopping — restoring best model.")
        #         model.load_state_dict(best_model_state)
        #         break

    print("\n Training complete (Round-Robin Mode)")
    return history


In [13]:
def test_model(test_loader: DataLoader, model: Transformer, criterion: nn.CrossEntropyLoss, device: str | T.device = "cpu"):

    model.eval()
    test_batch_losses = []
    test_correct = 0
    test_tokens = 0

    with T.no_grad():
        for encoder_input, decoder_input, target in test_loader:
            encoder_input, decoder_input, target = encoder_input.to(device), decoder_input.to(device), target.view(-1).to(device)
            
            logits = model(encoder_input, decoder_input)[:, -1, :]
            loss = criterion(logits, target)

            preds = logits.argmax(dim=1)
            test_correct += (preds == target).sum().item()
            test_tokens += target.size(0)

            test_batch_losses.append(loss.item())

    test_loss = np.mean(test_batch_losses)
    test_acc = test_correct / test_tokens
    test_perplexity = np.exp(test_loss)

    return {
        "test_loss": test_loss,
        "test_acc": test_acc,
        "test_perplexity": test_perplexity
    }

### 2. Train Transformer Model

In [14]:
LEARNING_RATE = 3e-4
BATCH_SIZE = 64
WEIGHT_DECAY = 1e-2
WARMUP_STEPS = 400
LABEL_SMOOTHING = 0.1
MAX_EPOCHS = 100

In [15]:
criterion =  nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer = optim.AdamW(transformer.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=WARMUP_STEPS)  

results = train_model_round_robin(
    train_loaders, validation_loader, transformer, criterion, 
    optimizer, max_epochs=MAX_EPOCHS, device=device
)


[Epoch 1/100]
------------------------------------------------------------
[color] Batch 1/300 | Loss: 4.5288
[quantity] Batch 51/300 | Loss: 1.9377
[shape] Batch 101/300 | Loss: 1.9490
[color] Batch 151/300 | Loss: 2.0080
[quantity] Batch 201/300 | Loss: 1.9344
[shape] Batch 251/300 | Loss: 1.9633
[quantity] Batch 300/300 | Loss: 1.9583
[Epoch 1] Train Loss: 2.0070 | Acc: 0.2510 | Perplexity: 7.4411
[Epoch 1] Val Loss: 1.9669 | Acc: 0.2472 | Perplexity: 7.1483
------------------------------------------------------------

[Epoch 2/100]
------------------------------------------------------------
[color] Batch 1/300 | Loss: 1.9688
[quantity] Batch 51/300 | Loss: 2.0039
[shape] Batch 101/300 | Loss: 1.9685
[color] Batch 151/300 | Loss: 1.9543
[quantity] Batch 201/300 | Loss: 1.9805
[shape] Batch 251/300 | Loss: 1.9988
[quantity] Batch 300/300 | Loss: 1.9976
[Epoch 2] Train Loss: 1.9763 | Acc: 0.2509 | Perplexity: 7.2162
[Epoch 2] Val Loss: 1.9670 | Acc: 0.2564 | Perplexity: 7.1495
-----

### 3. Test Transformer Model

In [16]:
results = test_model(test_loader, transformer, criterion, device)
print(f"Test Loss: {results["test_loss"]:.4f} | Test Acc: {results["test_acc"]:.4f} | Test Perplexity: {results["test_perplexity"]:.4f}")

Test Loss: 2.9588 | Test Acc: 0.3195 | Test Perplexity: 19.2752


## Model Inference

In [17]:
from datasets.wcst import WCST
wcst = WCST(10)

In [18]:
def model_inference(model: Transformer, source_sequence, start_tokens):
    model.eval()
    generated = start_tokens
    
    with T.no_grad():
        logits = model(source_sequence, generated)
    
    # Greedy Selection
    next_token = T.argmax(logits[:, -1, :], dim=-1, keepdim=True)

    generated = T.cat([generated, next_token], dim=1)

    return generated


In [19]:
encoder_input, decoder_input, target = train_datasets["quantity"][:10]
encoder_input = encoder_input.to(device)
decoder_input = decoder_input.to(device)
target = target.to(device)

prediction = model_inference(transformer, encoder_input, decoder_input)

print("# Actual Trials")
test_batch = [np.asarray(item.cpu()) for item in [encoder_input, T.concatenate([decoder_input, target], dim=1)]]
output = wcst.visualise_batch(test_batch)

print("# Predicted Trials")
prediction_batch = [np.asarray(item.cpu()) for item in [encoder_input, T.concatenate([decoder_input, prediction], dim=1)]]
output = wcst.visualise_batch(prediction_batch)

# Actual Trials
[array(['green', 'cross', '2'], dtype='<U6'), array(['green', 'star', '3'], dtype='<U6'), array(['yellow', 'star', '4'], dtype='<U6'), array(['yellow', 'circle', '1'], dtype='<U6'), array(['blue', 'cross', '1'], dtype='<U6'), 'SEP', 'C4', 'EOS', array(['red', 'square', '1'], dtype='<U6'), 'SEP', 'C4']
[array(['blue', 'cross', '4'], dtype='<U6'), array(['blue', 'cross', '3'], dtype='<U6'), array(['green', 'square', '2'], dtype='<U6'), array(['yellow', 'square', '1'], dtype='<U6'), array(['yellow', 'star', '3'], dtype='<U6'), 'SEP', 'C2', 'EOS', array(['blue', 'circle', '4'], dtype='<U6'), 'SEP', 'C1']
[array(['green', 'cross', '4'], dtype='<U6'), array(['blue', 'square', '2'], dtype='<U6'), array(['green', 'star', '1'], dtype='<U6'), array(['red', 'circle', '3'], dtype='<U6'), array(['blue', 'circle', '2'], dtype='<U6'), 'SEP', 'C2', 'EOS', array(['green', 'cross', '1'], dtype='<U6'), 'SEP', 'C3']
[array(['yellow', 'square', '3'], dtype='<U6'), array(['yellow', 'cross', 