In [1]:
import torch as T
import matplotlib.pyplot as plt
import numpy as np

In [2]:
device = T.device('cuda' if T.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Dataset Creation

In [3]:
from torch.utils.data import TensorDataset, DataLoader, random_split

from wcst import WCST

[[17. 21. 42. 46. 20. 68. 65. 69.]
 [18. 20. 59. 62. 41. 68. 66. 69.]
 [27.  6. 17. 63. 38. 68. 65. 69.]] [[ 4. 68. 65.]
 [ 4. 68. 65.]
 [58. 68. 64.]]


### 1. Dataset Hyperparameters

In [4]:
BATCH_SIZE = 32
BATCHES_PER_CONTEXT = 20
N_CONTEXT_SWITCHES = 1

### 2. Creating Dataset

In [5]:
wcst = WCST(BATCH_SIZE)
data = []
targets = []

for i in range(N_CONTEXT_SWITCHES):
    for j in range(BATCHES_PER_CONTEXT):
        X, t = next(wcst.gen_batch())
        for k in range(BATCH_SIZE):
            data.append(X[k])
            targets.append(t[k])
    wcst.context_switch()

data = T.tensor(np.array(data), dtype=T.long).to(device=device)
targets = T.tensor(np.array(targets), dtype=T.long).to(device=device)
dataset = TensorDataset(data, targets)

train_dataset, validation_dataset, test_dataset = random_split(
    dataset, [0.6, 0.2, 0.2]
)

train_dataset_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataset_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE)
test_dataset_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE)

## Transformer Model Creation

In [6]:
from toy_transformer import Transformer

### 1. Transformer Hyperparameters

In [7]:
VOCABULARY_SIZE = len(wcst.card_indices) + len(wcst.categories) + 2 # Cards + Categories + 'SEP' + 'EOS'
EMBEDDING_SIZE = 32
N_ATTENTION_HEADS = 4
N_BLOCKS = 3
MAX_SEQUENCE_LENGTH = 10
FF_DIMS = 64
DROPOUT_PROB = 0.1

### 2. Transformer Initialisation

In [8]:
transformer = Transformer(
    VOCABULARY_SIZE, VOCABULARY_SIZE, EMBEDDING_SIZE, N_ATTENTION_HEADS,
    N_BLOCKS, MAX_SEQUENCE_LENGTH, FF_DIMS, DROPOUT_PROB, device=device
)

## Training Transformer

In [9]:
import torch.nn as nn
import torch.optim as optim

### 1. Train, Validate, Evaluate Model Functions

In [10]:
def train_model(
    train_loader: DataLoader, validation_loader: DataLoader, model: Transformer, criterion: nn.CrossEntropyLoss, 
    optimizer: optim.Optimizer, max_epochs: int = 20, device: str | T.device = "cpu", patience: int = 3
    ):
    
    best_val_loss = np.inf
    patience_counter = 0

    train_losses, train_accs = [], []
    test_losses, test_accs = [], []
    best_model_state = model.state_dict()

    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch + 1}/{max_epochs}")
        print("-" * 40)

        # --- Training ---
        model.train()
        epoch_train_losses = []
        total_correct = 0
        total_tokens = 0

        for batch_idx, (X, target) in enumerate(train_loader):
            encoder_input, target = X.to(device), target.to(device)

            # Decoder inputs/targets (shifted)
            decoder_input = target[:, :-1]
            decoder_target = target[:, 1:].reshape(-1)

            # Forward pass
            logits = model(encoder_input, decoder_input).reshape(-1, VOCABULARY_SIZE)
            loss = criterion(logits, decoder_target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            preds = logits.argmax(dim=1)
            total_correct += (preds == decoder_target).sum().item()
            total_tokens += decoder_target.size(0)

            epoch_train_losses.append(loss.item())

            if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
                print(f"Train Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")

        train_loss = np.mean(epoch_train_losses)
        train_acc = total_correct / total_tokens
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")

        # --- Validation ---
        model.eval()
        test_batch_losses = []
        test_correct = 0
        test_tokens = 0

        with T.no_grad():
            for X, target in validation_loader:
                encoder_input, target = X.to(device), target.to(device)
                decoder_input = target[:, :-1]
                decoder_target = target[:, 1:].reshape(-1)

                logits = model(encoder_input, decoder_input).reshape(-1, VOCABULARY_SIZE)
                loss = criterion(logits, decoder_target)

                preds = logits.argmax(dim=1)
                test_correct += (preds == decoder_target).sum().item()
                test_tokens += decoder_target.size(0)

                test_batch_losses.append(loss.item())

        val_loss = np.mean(test_batch_losses)
        val_acc = test_correct / test_tokens
        test_losses.append(val_loss)
        test_accs.append(val_acc)

        print(f"[Epoch {epoch+1}] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print("-" * 40)

        # --- Early Stopping ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
            print(f"Validation loss improved — saving model (Loss: {val_loss:.4f})")
        # else:
        #     patience_counter += 1
        #     print(f"No improvement ({patience_counter}/{patience})")
        #     if patience_counter >= patience:
        #         print("\nEarly stopping triggered. Restoring best model.")
        #         model.load_state_dict(best_model_state)
        #         break

    print("\nTraining complete")

    return {
        "train_losses": train_losses,
        "train_accs": train_accs,
        "test_losses": test_losses,
        "test_accs": test_accs,
        "best_val_loss": best_val_loss,
    }


In [11]:
def test_model(test_loader: DataLoader, model: Transformer, criterion: nn.CrossEntropyLoss, device: str | T.device = "cpu"):
    test_losses, test_accs = [], []

    model.eval()
    test_batch_losses = []
    test_correct = 0
    test_tokens = 0

    with T.no_grad():
        for X, target in test_loader:
            encoder_input, target = X.to(device), target.to(device)
            decoder_input = target[:, :-1]
            decoder_target = target[:, 1:].reshape(-1)

            logits = model(encoder_input, decoder_input).reshape(-1, VOCABULARY_SIZE)
            loss = criterion(logits, decoder_target)

            preds = logits.argmax(dim=1)
            test_correct += (preds == decoder_target).sum().item()
            test_tokens += decoder_target.size(0)

            test_batch_losses.append(loss.item())

    test_loss = np.mean(test_batch_losses)
    test_acc = test_correct / test_tokens
    test_losses.append(test_loss)
    test_accs.append(test_acc)

    return {
        "test_losses": test_losses,
        "test_accs": test_accs,
    }

### 2. Train Transformer Model

In [12]:
# Training Hyperparameters
MAX_EPOCHS = 60
LEARNING_RATE = 1e-3
BETAS = (0.9, 0.98)
EPSILON = 1e-9

In [13]:
criterion =  nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer.parameters(), lr=LEARNING_RATE, betas=BETAS, eps=EPSILON)

results = train_model(
    train_dataset_loader, validation_dataset_loader, transformer, criterion, 
    optimizer, max_epochs=MAX_EPOCHS, device=device
)


Epoch 1/60
----------------------------------------
Train Batch 1/12 | Loss: 4.5562
Train Batch 11/12 | Loss: 2.6112
Train Batch 12/12 | Loss: 2.4872
[Epoch 1] Train Loss: 3.2559 | Train Acc: 0.3802
[Epoch 1] Val Loss: 2.4023 | Val Acc: 0.5000
----------------------------------------
Validation loss improved — saving model (Loss: 2.4023)

Epoch 2/60
----------------------------------------
Train Batch 1/12 | Loss: 2.5367
Train Batch 11/12 | Loss: 1.9600
Train Batch 12/12 | Loss: 1.9604
[Epoch 2] Train Loss: 2.1991 | Train Acc: 0.5013
[Epoch 2] Val Loss: 1.8479 | Val Acc: 0.5000
----------------------------------------
Validation loss improved — saving model (Loss: 1.8479)

Epoch 3/60
----------------------------------------
Train Batch 1/12 | Loss: 1.9235
Train Batch 11/12 | Loss: 1.4828
Train Batch 12/12 | Loss: 1.4303
[Epoch 3] Train Loss: 1.6638 | Train Acc: 0.5768
[Epoch 3] Val Loss: 1.3714 | Val Acc: 0.6172
----------------------------------------
Validation loss improved — savin

### 3. Test Transformer Model

In [14]:
results = test_model(test_dataset_loader, transformer, criterion, device)

## Model Inference

In [15]:
def model_inference(model: Transformer, source_sequence, start_tokens):
    model.eval()
    generated = start_tokens
    
    with T.no_grad():
        logits = model(source_sequence, generated)
    
    next_token = T.argmax(logits[:, -1, :], dim=-1, keepdim=True)

    generated = T.cat([generated, next_token], dim=1)

    return generated


In [16]:
x, target = train_dataset[:5]
prediction = model_inference(transformer, x, target[:, : -1])

print("# Actual Trials")
test_batch = [np.asarray(item.cpu()) for item in test_dataset[:5]]
output = wcst.visualise_batch(test_batch)

print("# Predicted Trials")
prediction_batch = [np.asarray(item.cpu()) for item in [x, prediction]]
output = wcst.visualise_batch(prediction_batch)

# Actual Trials
[array(['green', 'cross', '4'], dtype='<U6'), array(['blue', 'cross', '2'], dtype='<U6'), array(['red', 'star', '4'], dtype='<U6'), array(['yellow', 'circle', '4'], dtype='<U6'), array(['yellow', 'cross', '2'], dtype='<U6'), 'SEP', 'C4', 'EOS', array(['red', 'square', '2'], dtype='<U6'), 'SEP', 'C3']
[array(['green', 'square', '1'], dtype='<U6'), array(['red', 'star', '2'], dtype='<U6'), array(['blue', 'star', '4'], dtype='<U6'), array(['yellow', 'star', '4'], dtype='<U6'), array(['blue', 'cross', '3'], dtype='<U6'), 'SEP', 'C3', 'EOS', array(['green', 'cross', '3'], dtype='<U6'), 'SEP', 'C1']
[array(['red', 'star', '4'], dtype='<U6'), array(['blue', 'star', '3'], dtype='<U6'), array(['yellow', 'circle', '4'], dtype='<U6'), array(['green', 'circle', '1'], dtype='<U6'), array(['yellow', 'circle', '2'], dtype='<U6'), 'SEP', 'C3', 'EOS', array(['yellow', 'cross', '1'], dtype='<U6'), 'SEP', 'C3']
[array(['blue', 'square', '2'], dtype='<U6'), array(['green', 'star', '1'], dt