In [1]:
import sys
import os
import torch as T
import numpy as np
import random

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

In [2]:
def set_global_seed(seed: int):
    T.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    if T.cuda.is_available():
        T.cuda.manual_seed_all(seed)

In [3]:
SEED = 42

device = T.device('cuda' if T.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

set_global_seed(SEED)

Using device: cuda


## Dataset Loading

In [4]:
from torch.utils.data import DataLoader
from datasets.custom_dataset import CustomWCSTDataset

### 1. Dataset Hyperparameters

In [None]:
BATCH_SIZE = 64
TOTAL_BATCHES = 750
TRAIN_TEST_SPLIT_RATIO = 0.6
VALIDATION_TEST_SPLIT_RATIO = 0.5

N_5_features = {
    'colour': ['Red', 'Blue', 'Green', 'Yellow', 'Purple'],
    'shape': ['Circle', 'Square', 'Star', 'Cross', 'Triangle'],
    'quantity': ['One', 'Two', 'Three', 'Four', 'Five']
}


N_6_features = {
    'colour': ['Red', 'Blue', 'Green', 'Yellow', 'Purple', 'Orange'],
    'shape': ['Circle', 'Square', 'Star', 'Cross', 'Triangle', 'Hexagon'],
    'quantity': ['One', 'Two', 'Three', 'Four', 'Five', 'Six']
}


N_7_features = {
    'colour': ['Red', 'Blue', 'Green', 'Yellow', 'Purple', 'Orange', 'Brown'],
    'shape': ['Circle', 'Square', 'Star', 'Cross', 'Triangle', 'Hexagon', 'Diamond'],
    'quantity': ['One', 'Two', 'Three', 'Four', 'Five', 'Six', 'Seven']
}

### 2. Loading Dataset

In [6]:
train_size = int(TOTAL_BATCHES * TRAIN_TEST_SPLIT_RATIO)
validation_size = int((TOTAL_BATCHES - train_size) * VALIDATION_TEST_SPLIT_RATIO)
test_size = TOTAL_BATCHES - train_size - validation_size
                                                     
train_dataset = CustomWCSTDataset(total_batches=train_size, sample_batch_size=BATCH_SIZE, features=N_7_features)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

validation_datasets = {
    "color": CustomWCSTDataset(
        total_batches=validation_size // 3, fixed_context=0, sample_batch_size=BATCH_SIZE, 
        allow_switch=False, features=N_7_features
        ),
    "shape": CustomWCSTDataset(
        total_batches=validation_size // 3, fixed_context=1, sample_batch_size=BATCH_SIZE, 
        allow_switch=False, features=N_7_features
        ),
    "quantity": CustomWCSTDataset(
        total_batches=validation_size // 3, fixed_context=2, sample_batch_size=BATCH_SIZE, 
        allow_switch=False, features=N_7_features
        )
}

validation_loaders = {
    ctx: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    for ctx, ds in validation_datasets.items()
}

test_dataset = CustomWCSTDataset(
        total_batches=test_size, sample_batch_size=BATCH_SIZE, features=N_7_features
    )
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[Dataset Init] Context switched -> 1
[Dataset Init] Context switched -> 0
[Dataset Init] Context switched -> 2
[

## Transformer Model Creation & Training

In [7]:
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader
from torch import nn, optim

from src.transformer import CETransformer

### 1. Train, Validate and Test Model Functions

In [8]:
def validate_per_context(
        model: nn.Module, criterion: nn.CrossEntropyLoss, loaders: dict[str, DataLoader], device: T.device | str ="cpu"
    ):
    model.eval()
    results = {}

    with T.no_grad():
        for ctx_name, loader in loaders.items():
            total_loss, total_correct, total_samples = 0.0, 0, 0
            for input_seq, target in loader:
                input_seq, target = input_seq.to(device), target.view(-1).to(device)

                logits = model(input_seq)
                logits = logits[:, -1, :]
                
                loss = criterion(logits, target)

                preds = logits.argmax(dim=1)
                total_loss += loss.item() * target.size(0)
                total_correct += (preds == target).sum().item()
                total_samples += target.size(0)

            avg_loss = total_loss / total_samples
            avg_acc = total_correct / total_samples

            results[ctx_name] = {
                "loss": avg_loss,
                "acc": avg_acc,
                "perplexity": np.exp(avg_loss)
            }
    return results

In [9]:
def train_model(
    train_loader: DataLoader, validation_loaders: dict[str, DataLoader], model: nn.Module, criterion: nn.CrossEntropyLoss, 
    optimizer: optim.Optimizer, max_epochs: int = 20, device: str | T.device = "cpu"
):
    history = defaultdict(list)

    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch + 1}/{max_epochs}")
        print("-" * 40)

        # --- Training ---
        model.train()
        epoch_train_losses = []
        total_correct = 0
        total_samples = 0

        for batch_idx, (input_seq, target) in enumerate(train_loader):
            input_seq, target = input_seq.to(device), target.view(-1).to(device)

            # Forward pass
            logits = model(input_seq)  # [batch, seq_len, vocab]
            logits = logits[:, -1, :]  # only the final step prediction [batch, vocab]

            loss = criterion(logits, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            preds = logits.argmax(dim=1)
            total_correct += (preds == target).sum().item()
            total_samples += target.size(0)

            epoch_train_losses.append(loss.item())

            if batch_idx % 100 == 0 or batch_idx == len(train_loader) - 1:
                print(f"Train Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")

        train_loss = np.mean(epoch_train_losses)
        train_acc = total_correct / total_samples
        train_perplexity = np.exp(train_loss)

        history["train_losses"].append(train_loss)
        history["train_accs"].append(train_acc)
        history["train_perplexities"].append(train_perplexity)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Perplexity: {train_perplexity:.4f}")

        # --- Validation ---
        context_results = validate_per_context(model, criterion, validation_loaders, device)
        val_loss = []
        val_acc = []
        val_perplexity = []
        for ctx, metrics in context_results.items():
            print(f"[Validation] {ctx}: Loss={metrics['loss']:.4f} | Acc={metrics['acc']:.4f} | Perplexity={metrics['perplexity']:.4f}")
            val_loss.append(metrics['loss'])
            val_acc.append(metrics['acc'])
            val_perplexity.append(metrics['perplexity'])

        val_loss = np.mean(val_loss)
        val_acc = np.mean(val_acc)
        val_perplexity = np.mean(val_perplexity)

        print(f"[Validation Summary] Val Loss: {np.mean(val_loss):.4f} | "
              f"Val Acc: {np.mean(val_acc):.4f} | Val Perplexity: {np.mean(val_perplexity):.4f}", end="\n\n")
        print("-" * 40)

        history["val_context"].append(context_results)
        history["val_losses"].append(val_loss)
        history["val_accs"].append(val_acc)
        history["val_perplexities"].append(val_perplexity)

    print("\nTraining complete")

    return history


In [10]:
def test_model(test_loader: DataLoader, model: nn.Module, criterion: nn.CrossEntropyLoss, device: str | T.device = "cpu"):

    model.eval()
    test_batch_losses = []
    test_correct = 0
    test_tokens = 0

    with T.no_grad():
        for input_seq, target in test_loader:
            input_seq, target = input_seq.to(device), target.view(-1).to(device)

            # Forward pass
            logits = model(input_seq)  # [batch, seq_len, vocab]
            logits = logits[:, -1, :]
            loss = criterion(logits, target)

            preds = logits.argmax(dim=1)
            test_correct += (preds == target).sum().item()
            test_tokens += target.size(0)

            test_batch_losses.append(loss.item())

    test_loss = np.mean(test_batch_losses)
    test_acc = test_correct / test_tokens
    test_perplexity = np.exp(test_loss)

    return {
        "test_loss": test_loss,
        "test_acc": test_acc,
        "test_perplexity": test_perplexity
    }

### 2. Create & Train Transformer Model

In [13]:
VOCABULARY_SIZE = train_dataset.vocabulary_size        # 64 cards + 4 categories + SEP + EOS
EMBEDDING_SIZE = 256        # larger embedding to capture card features
N_ATTENTION_HEADS = 8       # more heads for better multi-feature attention
N_BLOCKS = 6                # same depth as before
MAX_SEQUENCE_LENGTH = 13    # longer max sequence to accommodate multiple past trials
FF_DIMS = 1024               # larger feedforward layer for better representation
DROPOUT_PROB = 0.2        # reduce dropout slightly to retain signal in small batches
CARD_DIMS = (5, 5, 5)

LEARNING_RATE = 3e-4
BATCH_SIZE = 64
WEIGHT_DECAY = 1e-2
WARMUP_STEPS = 400
LABEL_SMOOTHING = 0.1
MAX_EPOCHS = 100
SEEDS = [31, 42, 45, 69, 420]

In [14]:
results = []
    
transformer = CETransformer(
    VOCABULARY_SIZE, CARD_DIMS, EMBEDDING_SIZE, N_ATTENTION_HEADS,
    N_BLOCKS, MAX_SEQUENCE_LENGTH, FF_DIMS, DROPOUT_PROB, device=device
)

criterion =  nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer = optim.AdamW(transformer.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=WARMUP_STEPS)  

result = train_model(
    train_loader, validation_loaders, transformer, criterion,
    optimizer, max_epochs=MAX_EPOCHS, device=device
)

test_results = test_model(test_loader, transformer, criterion, device)
print(f"[Test] Loss: {test_results["test_loss"]:.4f} | Acc: {test_results["test_acc"]:.4f} | Perplexity: {test_results["test_perplexity"]:.4f}")

results.append({**result, **test_results})


Epoch 1/100
----------------------------------------
Train Batch 1/450 | Loss: 153.6825
Train Batch 101/450 | Loss: 6.2566
Train Batch 201/450 | Loss: 4.8835
Train Batch 301/450 | Loss: 4.8185
Train Batch 401/450 | Loss: 4.4816
Train Batch 450/450 | Loss: 3.9415
[Epoch 1] Train Loss: 6.2888 | Train Acc: 0.1362 | Train Perplexity: 538.5137
[Validation] color: Loss=3.2958 | Acc=0.1397 | Perplexity=26.9978
[Validation] shape: Loss=3.2845 | Acc=0.1441 | Perplexity=26.6968
[Validation] quantity: Loss=3.2751 | Acc=0.1459 | Perplexity=26.4450
[Validation Summary] Val Loss: 3.2851 | Val Acc: 0.1432 | Val Perplexity: 26.7132

----------------------------------------

Epoch 2/100
----------------------------------------
Train Batch 1/450 | Loss: 4.0094
Train Batch 101/450 | Loss: 4.1517
Train Batch 201/450 | Loss: 3.9081
Train Batch 301/450 | Loss: 3.9609
Train Batch 401/450 | Loss: 3.6431
Train Batch 450/450 | Loss: 3.3915
[Epoch 2] Train Loss: 3.7663 | Train Acc: 0.1447 | Train Perplexity: 43

### 3. Save Training & Test Results 

In [15]:
import json

In [16]:
for i, result in enumerate(results):
    with open(f"./../../results/scaling_laws/ce_transformer_performace_N-7.json", "w") as file:
        json.dump(result, file, indent=4)