In [62]:
import sys
import os
import torch as T
import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from datasets.wcst import WCST

In [63]:
wcst = WCST(10)

device = T.device('cuda' if T.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Dataset Loading

In [64]:
from torch.utils.data import TensorDataset, DataLoader

### 1. Dataset Hyperparameters

In [65]:
BATCH_SIZE = 32

### 2. Loading Dataset

In [66]:
train_data, train_targets = T.load('../datasets/train_dataset.pt')
train_dataset_loader = DataLoader(TensorDataset(train_data, train_targets), batch_size=BATCH_SIZE, shuffle=True)

validation_data, validation_targets = T.load('../datasets/validation_dataset.pt')
validation_dataset_loader = DataLoader(TensorDataset(validation_data, validation_targets), batch_size=BATCH_SIZE)

test_data, test_targets = T.load('../datasets/test_dataset.pt')
test_dataset_loader  = DataLoader(TensorDataset(test_data, test_targets), batch_size=BATCH_SIZE)

## Transformer Model Creation

In [67]:
from src.transformer import Transformer

### 1. Transformer Hyperparameters

In [86]:
VOCABULARY_SIZE = 70 # Cards + Categories + 'SEP' + 'EOS'
EMBEDDING_SIZE = 32
N_ATTENTION_HEADS = 4
N_BLOCKS = 1
MAX_SEQUENCE_LENGTH = 10
FF_DIMS = 64
DROPOUT_PROB = 0.1

### 2. Transformer Initialisation

In [87]:
transformer = Transformer(
    VOCABULARY_SIZE, VOCABULARY_SIZE, EMBEDDING_SIZE, N_ATTENTION_HEADS,
    N_BLOCKS, MAX_SEQUENCE_LENGTH, FF_DIMS, DROPOUT_PROB, device=device
)

## Training Transformer

In [88]:
import torch.nn as nn
import torch.optim as optim

### 1. Train, Validate, Evaluate Model Functions

In [89]:
def train_model(
    train_loader: DataLoader, validation_loader: DataLoader, model: Transformer, criterion: nn.CrossEntropyLoss, 
    optimizer: optim.Optimizer, max_epochs: int = 20, device: str | T.device = "cpu", patience: int = 3
    ):
    
    best_val_loss = np.inf
    patience_counter = 0

    train_losses, train_accs, train_perplexities = [], [], []
    val_losses, val_accs, val_perplexities = [], [], []
    best_model_state = model.state_dict()

    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch + 1}/{max_epochs}")
        print("-" * 40)

        # --- Training ---
        model.train()
        epoch_train_losses = []
        total_correct = 0
        total_tokens = 0

        for batch_idx, (X, target) in enumerate(train_loader):
            encoder_input, target = X.to(device), target.to(device)

            # Decoder inputs/targets (shifted)
            decoder_input = target[:, :-1]
            decoder_target = target[:, 1:].reshape(-1)

            # Forward pass
            logits = model(encoder_input, decoder_input).reshape(-1, VOCABULARY_SIZE)
            loss = criterion(logits, decoder_target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            preds = logits.argmax(dim=1)
            total_correct += (preds == decoder_target).sum().item()
            total_tokens += decoder_target.size(0)

            epoch_train_losses.append(loss.item())

            if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
                print(f"Train Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")

        train_loss = np.mean(epoch_train_losses)
        train_acc = total_correct / total_tokens
        train_perplexity = np.exp(train_loss)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_perplexities.append(train_perplexities)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Perplexity: {train_perplexity:.4f}")

        # --- Validation ---
        model.eval()
        val_batch_losses = []
        test_correct = 0
        test_tokens = 0

        with T.no_grad():
            for X, target in validation_loader:
                encoder_input, target = X.to(device), target.to(device)
                decoder_input = target[:, :-1]
                decoder_target = target[:, 1:].reshape(-1)

                logits = model(encoder_input, decoder_input).reshape(-1, VOCABULARY_SIZE)
                loss = criterion(logits, decoder_target)

                preds = logits.argmax(dim=1)
                test_correct += (preds == decoder_target).sum().item()
                test_tokens += decoder_target.size(0)

                val_batch_losses.append(loss.item())

        val_loss = np.mean(val_batch_losses)
        val_acc = test_correct / test_tokens
        val_perplexity = np.exp(val_loss)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_perplexities.append(val_perplexities)

        print(f"[Epoch {epoch+1}] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Perplexity: {val_perplexity:.4f}")
        print("-" * 40)

        # --- Early Stopping ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
            print(f"Validation loss improved — saving model (Loss: {val_loss:.4f})")
        # else:
        #     patience_counter += 1
        #     print(f"No improvement ({patience_counter}/{patience})")
        #     if patience_counter >= patience:
        #         print("\nEarly stopping triggered. Restoring best model.")
        #         model.load_state_dict(best_model_state)
        #         break

    print("\nTraining complete")

    return {
        "train_losses": train_losses,
        "train_accs": train_accs,
        "train_perplexities": train_perplexities,
        "val_losses": val_losses,
        "val_accs": val_accs,
        "val_perplexities": val_perplexities,
        "best_val_loss": best_val_loss,
    }


In [94]:
def test_model(test_loader: DataLoader, model: Transformer, criterion: nn.CrossEntropyLoss, device: str | T.device = "cpu"):

    model.eval()
    test_batch_losses = []
    test_correct = 0
    test_tokens = 0

    with T.no_grad():
        for X, target in test_loader:
            encoder_input, target = X.to(device), target.to(device)
            decoder_input = target[:, :-1]
            decoder_target = target[:, 1:].reshape(-1)

            logits = model(encoder_input, decoder_input).reshape(-1, VOCABULARY_SIZE)
            loss = criterion(logits, decoder_target)

            preds = logits.argmax(dim=1)
            test_correct += (preds == decoder_target).sum().item()
            test_tokens += decoder_target.size(0)

            test_batch_losses.append(loss.item())

    test_loss = np.mean(test_batch_losses)
    test_acc = test_correct / test_tokens
    test_perplexity = np.exp(test_loss)

    return {
        "test_loss": test_loss,
        "test_acc": test_acc,
        "test_perplexity": test_perplexity
    }

### 2. Train Transformer Model

In [95]:
# Training Hyperparameters
MAX_EPOCHS = 40
LEARNING_RATE = 1e-3
BETAS = (0.9, 0.98)
EPSILON = 1e-9

In [96]:
criterion =  nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer.parameters(), lr=LEARNING_RATE, betas=BETAS, eps=EPSILON)

results = train_model(
    train_dataset_loader, validation_dataset_loader, transformer, criterion, 
    optimizer, max_epochs=MAX_EPOCHS, device=device
)


Epoch 1/40
----------------------------------------
Train Batch 1/60 | Loss: 0.5473
Train Batch 11/60 | Loss: 0.5271
Train Batch 21/60 | Loss: 0.5323
Train Batch 31/60 | Loss: 0.5177
Train Batch 41/60 | Loss: 0.6228
Train Batch 51/60 | Loss: 0.5423
Train Batch 60/60 | Loss: 0.4891
[Epoch 1] Train Loss: 0.5446 | Train Acc: 0.7654 | Train Perplexity: 1.7240
[Epoch 1] Val Loss: 0.8443 | Val Acc: 0.6367 | Val Perplexity: 2.3262
----------------------------------------
Validation loss improved — saving model (Loss: 0.8443)

Epoch 2/40
----------------------------------------
Train Batch 1/60 | Loss: 0.5009
Train Batch 11/60 | Loss: 0.5317
Train Batch 21/60 | Loss: 0.5372
Train Batch 31/60 | Loss: 0.5341
Train Batch 41/60 | Loss: 0.4763
Train Batch 51/60 | Loss: 0.4659
Train Batch 60/60 | Loss: 0.5578
[Epoch 2] Train Loss: 0.5359 | Train Acc: 0.7701 | Train Perplexity: 1.7090
[Epoch 2] Val Loss: 0.8643 | Val Acc: 0.6344 | Val Perplexity: 2.3732
----------------------------------------

Epoc

### 3. Test Transformer Model

In [99]:
results = test_model(test_dataset_loader, transformer, criterion, device)
print(f"Train Loss: {results["test_loss"]:.4f} | Train Acc: {results["test_acc"]:.4f} | Train Perplexity: {results["test_perplexity"]:.4f}")

Train Loss: 1.0766 | Train Acc: 0.6258 | Train Perplexity: 2.9346


## Model Inference

In [100]:
def model_inference(model: Transformer, source_sequence, start_tokens):
    model.eval()
    generated = start_tokens
    
    with T.no_grad():
        logits = model(source_sequence, generated)
    
    # Greedy Selection
    next_token = T.argmax(logits[:, -1, :], dim=-1, keepdim=True)

    generated = T.cat([generated, next_token], dim=1)

    return generated


In [101]:
x, target = test_data[:5].to(device), test_targets[:5].to(device)
prediction = model_inference(transformer, x, target[:, : -1])

print("# Actual Trials")
test_batch = [np.asarray(item.cpu()) for item in [x, target]]
output = wcst.visualise_batch(test_batch)

print("# Predicted Trials")
prediction_batch = [np.asarray(item.cpu()) for item in [x, prediction]]
output = wcst.visualise_batch(prediction_batch)

# Actual Trials
[array(['yellow', 'cross', '2'], dtype='<U6'), array(['yellow', 'square', '2'], dtype='<U6'), array(['green', 'star', '1'], dtype='<U6'), array(['yellow', 'circle', '4'], dtype='<U6'), array(['yellow', 'star', '3'], dtype='<U6'), 'SEP', 'C3', 'EOS', array(['yellow', 'circle', '1'], dtype='<U6'), 'SEP', 'C4']
[array(['yellow', 'square', '2'], dtype='<U6'), array(['red', 'cross', '1'], dtype='<U6'), array(['blue', 'circle', '4'], dtype='<U6'), array(['blue', 'star', '4'], dtype='<U6'), array(['red', 'star', '3'], dtype='<U6'), 'SEP', 'C4', 'EOS', array(['blue', 'star', '2'], dtype='<U6'), 'SEP', 'C4']
[array(['red', 'circle', '1'], dtype='<U6'), array(['red', 'square', '4'], dtype='<U6'), array(['green', 'star', '3'], dtype='<U6'), array(['yellow', 'cross', '1'], dtype='<U6'), array(['blue', 'circle', '2'], dtype='<U6'), 'SEP', 'C1', 'EOS', array(['yellow', 'square', '1'], dtype='<U6'), 'SEP', 'C2']
[array(['green', 'circle', '4'], dtype='<U6'), array(['green', 'cross', '