In [None]:
import sys
import os
import torch as T
import numpy as np
import random

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

In [None]:
def set_global_seed(seed: int):
    T.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    if T.cuda.is_available():
        T.cuda.manual_seed_all(seed)

In [None]:
SEED = 42

device = T.device('cuda' if T.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


set_global_seed(SEED)

Using device: cuda


## Dataset Loading

In [3]:
from torch.utils.data import DataLoader
from datasets.custom_dataset import CustomWCSTDataset

### 1. Dataset Hyperparameters

In [4]:
BATCH_SIZE = 64
TOTAL_BATCHES = 500
TRAIN_TEST_SPLIT_RATIO = 0.6
VALIDATION_TEST_SPLIT_RATIO = 0.5

### 2. Loading Dataset

In [None]:
train_size = int(TOTAL_BATCHES * TRAIN_TEST_SPLIT_RATIO)
validation_size = int((TOTAL_BATCHES - train_size) * VALIDATION_TEST_SPLIT_RATIO)
test_size = TOTAL_BATCHES - train_size - validation_size

train_dataset = CustomWCSTDataset(total_batches=train_size, sample_batch_size=BATCH_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

validation_datasets = {
    "color": CustomWCSTDataset(total_batches=validation_size // 3, fixed_context=0, sample_batch_size=BATCH_SIZE, allow_switch=False),
    "shape": CustomWCSTDataset(total_batches=validation_size // 3, fixed_context=1, sample_batch_size=BATCH_SIZE, allow_switch=False),
    "quantity": CustomWCSTDataset(total_batches=validation_size // 3, fixed_context=2, sample_batch_size=BATCH_SIZE, allow_switch=False)
}

validation_loaders = {
    ctx: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    for ctx, ds in validation_datasets.items()
}

test_dataset = CustomWCSTDataset(
        total_batches=test_size, sample_batch_size=BATCH_SIZE
    )
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

[Dataset Init] Fixed context: 0
[Dataset Init] Fixed context: 1
[Dataset Init] Fixed context: 2
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 1


## Transformer Model Creation

In [None]:
from src.transformer import BaselineTransformer

### 1. Transformer Hyperparameters

In [None]:
VOCABULARY_SIZE = 70        # 64 cards + 4 categories + SEP + EOS
EMBEDDING_SIZE = 256        # larger embedding to capture card features
N_ATTENTION_HEADS = 8       # more heads for better multi-feature attention
N_BLOCKS = 6                # same depth as before
MAX_SEQUENCE_LENGTH = 10    # longer max sequence to accommodate multiple past trials
FF_DIMS = 1024               # larger feedforward layer for better representation
DROPOUT_PROB = 0.2        # reduce dropout slightly to retain signal in small batches

### 2. Transformer Initialisation

In [None]:
transformer = BaselineTransformer(
    VOCABULARY_SIZE, EMBEDDING_SIZE, N_ATTENTION_HEADS,
    N_BLOCKS, MAX_SEQUENCE_LENGTH, FF_DIMS, DROPOUT_PROB, device=device
)

## Training Transformer

In [None]:
import numpy as np
import torch as T
from collections import defaultdict
from torch.utils.data import DataLoader
from torch import nn, optim

### 1. Train, Validate, Evaluate Model Functions

In [None]:
def validate_per_context(
        model: nn.Module, criterion: nn.CrossEntropyLoss, loaders: dict[str, DataLoader], device: T.device | str ="cpu"
    ):
    model.eval()
    results = {}

    with T.no_grad():
        for ctx_name, loader in loaders.items():
            total_loss, total_correct, total_samples = 0.0, 0, 0
            for input_seq, target in loader:
                input_seq, target = input_seq.to(device), target.view(-1).to(device)

                logits = model(input_seq)
                logits = logits[:, -1, :]
                
                loss = criterion(logits, target)

                preds = logits.argmax(dim=1)
                total_loss += loss.item() * target.size(0)
                total_correct += (preds == target).sum().item()
                total_samples += target.size(0)

            avg_loss = total_loss / total_samples
            avg_acc = total_correct / total_samples

            results[ctx_name] = {
                "loss": avg_loss,
                "acc": avg_acc,
                "perplexity": np.exp(avg_loss)
            }
    return results

In [None]:
def train_model(
    train_loader: DataLoader, validation_loaders: dict[str, DataLoader], model: nn.Module, criterion: nn.CrossEntropyLoss, 
    optimizer: optim.Optimizer, max_epochs: int = 20, device: str | T.device = "cpu"
):
    history = defaultdict(list)

    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch + 1}/{max_epochs}")
        print("-" * 40)

        # --- Training ---
        model.train()
        epoch_train_losses = []
        total_correct = 0
        total_samples = 0

        for batch_idx, (input_seq, target) in enumerate(train_loader):
            input_seq, target = input_seq.to(device), target.view(-1).to(device)

            # Forward pass
            logits = model(input_seq)  # [batch, seq_len, vocab]
            logits = logits[:, -1, :]  # only the final step prediction [batch, vocab]

            loss = criterion(logits, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            preds = logits.argmax(dim=1)
            total_correct += (preds == target).sum().item()
            total_samples += target.size(0)

            epoch_train_losses.append(loss.item())

            if batch_idx % 100 == 0 or batch_idx == len(train_loader) - 1:
                print(f"Train Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")

        train_loss = np.mean(epoch_train_losses)
        train_acc = total_correct / total_samples
        train_perplexity = np.exp(train_loss)

        history["train_losses"].append(train_loss)
        history["train_accs"].append(train_acc)
        history["train_perplexities"].append(train_perplexity)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Perplexity: {train_perplexity:.4f}")

        # --- Validation ---
        context_results = validate_per_context(model, criterion, validation_loaders, device)
        val_loss = []
        val_acc = []
        val_perplexity = []
        for ctx, metrics in context_results.items():
            print(f"[Validation] {ctx}: Loss={metrics['loss']:.4f} | Acc={metrics['acc']:.4f} | Perplexity={metrics['perplexity']:.4f}")
            val_loss.append(metrics['loss'])
            val_acc.append(metrics['acc'])
            val_perplexity.append(metrics['perplexity'])

        val_loss = np.mean(val_loss)
        val_acc = np.mean(val_acc)
        val_perplexity = np.mean(val_perplexity)

        print(f"[Validation Summary] Val Loss: {np.mean(val_loss):.4f} | "
              f"Val Acc: {np.mean(val_acc):.4f} | Val Perplexity: {np.mean(val_perplexity):.4f}", end="\n\n")
        print("-" * 40)

        history["val_context"].append(context_results)
        history["val_losses"].append(val_loss)
        history["val_accs"].append(val_acc)
        history["val_perplexities"].append(val_perplexity)

    print("\nTraining complete")

    return history


In [None]:
def test_model(test_loader: DataLoader, model: nn.Module, criterion: nn.CrossEntropyLoss, device: str | T.device = "cpu"):

    model.eval()
    test_batch_losses = []
    test_correct = 0
    test_tokens = 0

    with T.no_grad():
        for input_seq, target in test_loader:
            input_seq, target = input_seq.to(device), target.view(-1).to(device)

            # Forward pass
            logits = model(input_seq)  # [batch, seq_len, vocab]
            logits = logits[:, -1, :]
            loss = criterion(logits, target)

            preds = logits.argmax(dim=1)
            test_correct += (preds == target).sum().item()
            test_tokens += target.size(0)

            test_batch_losses.append(loss.item())

    test_loss = np.mean(test_batch_losses)
    test_acc = test_correct / test_tokens
    test_perplexity = np.exp(test_loss)

    return {
        "test_loss": test_loss,
        "test_acc": test_acc,
        "test_perplexity": test_perplexity
    }

### 2. Train Transformer Model

In [29]:
LEARNING_RATE = 3e-4
BATCH_SIZE = 64
WEIGHT_DECAY = 1e-2
WARMUP_STEPS = 400
LABEL_SMOOTHING = 0.1
MAX_EPOCHS = 100

In [None]:
criterion =  nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer = optim.AdamW(transformer.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=WARMUP_STEPS)  

results = train_model(
    train_loader, validation_loaders, transformer, criterion, 
    optimizer, max_epochs=MAX_EPOCHS, device=device
)


[Epoch 1/100]
------------------------------------------------------------
[color] Batch 1/450 | Loss: 163.1925
[quantity] Batch 51/450 | Loss: 4.9344
[shape] Batch 101/450 | Loss: 4.5522
[color] Batch 151/450 | Loss: 3.7767
[quantity] Batch 201/450 | Loss: 3.1851
[shape] Batch 251/450 | Loss: 3.3180
[color] Batch 301/450 | Loss: 3.0815
[quantity] Batch 351/450 | Loss: 3.0033
[shape] Batch 401/450 | Loss: 3.1230
[quantity] Batch 450/450 | Loss: 3.0852
[Epoch 1] Train Loss: 4.6574 | Acc: 0.2407 | Perplexity: 105.3670
[Epoch 1] Val Loss: 2.2154 | Acc: 0.2458 | Perplexity: 9.1649
------------------------------------------------------------

[Epoch 2/100]
------------------------------------------------------------
[color] Batch 1/450 | Loss: 2.8136
[quantity] Batch 51/450 | Loss: 2.8871
[shape] Batch 101/450 | Loss: 2.4255
[color] Batch 151/450 | Loss: 2.9158
[quantity] Batch 201/450 | Loss: 2.7198
[shape] Batch 251/450 | Loss: 2.9274
[color] Batch 301/450 | Loss: 2.6777
[quantity] Batch

### 3. Test Transformer Model

In [31]:
test_results = test_model(test_loader, transformer, criterion, device)
print(f"Test Loss: {test_results["test_loss"]:.4f} | Acc: {test_results["test_acc"]:.4f} | Perplexity: {test_results["test_perplexity"]:.4f}")

Test Loss: 1.1241 | Acc: 0.8445 | Perplexity: 3.0775


In [None]:
transformer.save(f"../models/baseline_transformer-{0.74}-{0.27}-{0.27}.pt")

### 4. Save Training & Test Results 

In [32]:
import json

In [None]:
final_results = {**results, **test_results}

with open(f"./../results/baseline_transformer_performace.json", "w") as file:
    json.dump(final_results, file, indent=4)

## Model Inference

In [34]:
from datasets.wcst import WCST
wcst = WCST(10)

In [35]:
def model_inference(model: nn.Module, source_sequence):
    model.eval()
    generated = source_sequence
    
    with T.no_grad():
        full_sequence = T.cat([generated], dim=1)
        logits = model(full_sequence)
    
    # Greedy Selection
    next_token = T.argmax(logits[:, -1, :], dim=-1, keepdim=True)

    generated = T.cat([generated, next_token], dim=1)

    return generated


In [36]:

input_seqs, targets = [], []
for input_seq, target in test_dataset[:10]:
    input_seqs.append(input_seq)
    targets.append(target)

input_seqs = T.stack(input_seqs).to(device)
targets = T.stack(targets).unsqueeze(1).to(device)

predictions = model_inference(transformer, input_seqs)

print("# Actual Trials")
test_batch = [np.asarray(item.cpu()) for item in [input_seqs[:, :-2], T.concatenate([input_seqs[:, -2:], targets], dim=1)]]
output = wcst.visualise_batch(test_batch)

print("# Predicted Trials")
prediction_batch = [np.asarray(item.cpu()) for item in [predictions[:, :-2], predictions[:, -2:]]]
output = wcst.visualise_batch(prediction_batch)

# Actual Trials
[array(['green', 'star', '1'], dtype='<U6'), array(['red', 'square', '1'], dtype='<U6'), array(['yellow', 'cross', '2'], dtype='<U6'), array(['red', 'circle', '3'], dtype='<U6'), array(['blue', 'circle', '3'], dtype='<U6'), 'SEP', 'C4', 'EOS', array(['green', 'cross', '4'], dtype='<U6'), 'SEP', 'C3']
[array(['green', 'cross', '3'], dtype='<U6'), array(['blue', 'square', '1'], dtype='<U6'), array(['red', 'star', '2'], dtype='<U6'), array(['red', 'circle', '2'], dtype='<U6'), array(['yellow', 'square', '4'], dtype='<U6'), 'SEP', 'C2', 'EOS', array(['blue', 'circle', '2'], dtype='<U6'), 'SEP', 'C4']
[array(['blue', 'circle', '4'], dtype='<U6'), array(['red', 'star', '2'], dtype='<U6'), array(['green', 'square', '3'], dtype='<U6'), array(['red', 'cross', '4'], dtype='<U6'), array(['yellow', 'star', '4'], dtype='<U6'), 'SEP', 'C2', 'EOS', array(['blue', 'circle', '3'], dtype='<U6'), 'SEP', 'C1']
[array(['yellow', 'star', '4'], dtype='<U6'), array(['red', 'cross', '1'], dtype