In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from models.ctm import ContinuousThoughtMachine
from utils.losses import image_classification_loss

# Hyperparameters (as defined in train_distributed.py)
d_model = 512
dropout = 0.0
backbone_type = 'resnet18-4'
d_input = 128
heads = 4
iterations = 50  # internal ticks for CTM
positional_embedding_type = 'none'
synapse_depth = 4
n_synch_out = 32
n_synch_action = 32
neuron_select_type = 'first-last'
n_random_pairing_self = 256
memory_length = 25
deep_memory = True
memory_hidden_dims = 4
dropout_nlm = None  # if None, use same dropout as rest of model
do_normalisation = False  # apply layernorm in NLMs
batch_size = 32
batch_size_test = 32
lr = 1e-3
training_iterations = 100001
warmup_steps = 5000
use_scheduler = True
scheduler_type = 'cosine'  # 'cosine' or 'multistep'
milestones = [8000, 15000, 20000]  # for multistep scheduler
gamma = 0.1
weight_decay = 0.0
gradient_clipping = -1
num_workers_train = 0
seed = 412
track_every = 1000
n_test_batches = 20  # evaluate on this many batches for metrics (-1 for full eval)

# Set random seeds for reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data loading (CIFAR-10) with standard normalization and augmentation for training
dataset_mean = [0.4914, 0.4822, 0.4465]
dataset_std = [0.2470, 0.2435, 0.2616]
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(dataset_mean, dataset_std),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(dataset_mean, dataset_std),
])
train_data = datasets.CIFAR10(root='data/', train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root='data/', train=False, download=True, transform=test_transform)
class_labels = train_data.classes
out_dims = len(class_labels)  # number of classes, should be 10 for CIFAR-10

# DataLoaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                                           num_workers=num_workers_train, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size_test, shuffle=False, num_workers=0)
# Separate loader for evaluating training metrics (no shuffle for consistency)
train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size_test, shuffle=False, num_workers=0)

# Model initialization
model = ContinuousThoughtMachine(
    iterations=iterations,
    d_model=d_model,
    d_input=d_input,
    heads=heads,
    n_synch_out=n_synch_out,
    n_synch_action=n_synch_action,
    synapse_depth=synapse_depth,
    memory_length=memory_length,
    deep_nlms=deep_memory,
    memory_hidden_dims=memory_hidden_dims,
    do_layernorm_nlm=do_normalisation,
    backbone_type=backbone_type,
    positional_embedding_type=positional_embedding_type,
    out_dims=out_dims,
    prediction_reshaper=[-1],  # task-specific reshaping ([-1] for image classification)
    dropout=dropout,
    dropout_nlm=dropout_nlm,
    neuron_select_type=neuron_select_type,
    n_random_pairing_self=n_random_pairing_self
).to(device)

# Optimizer (AdamW) and initial learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8)

# Lists for tracking metrics
iteration_points = []
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

# Evaluation function for training & test datasets (uses at most n_test_batches batches)
def evaluate():
    model.eval()
    total_train_loss = 0.0
    total_train_correct = 0.0
    total_train_samples = 0
    for batch_idx, (inputs, targets) in enumerate(train_eval_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        with torch.inference_mode():
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                # CTM/LSTM outputs: predictions, certainties, (maybe other)
                predictions, certainties, _ = outputs
                loss_val, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
                preds = predictions.argmax(dim=1)  # shape [batch, iterations]
                idx = torch.arange(predictions.size(0), device=device)
                preds = preds[idx, where_most_certain]  # pick class at most certain step for each sample
            else:
                # FF baseline output
                predictions = outputs
                loss_val = nn.CrossEntropyLoss()(predictions, targets)
                preds = predictions.argmax(dim=1)
        total_train_loss += loss_val.item() * inputs.size(0)
        total_train_correct += (preds == targets).sum().item()
        total_train_samples += inputs.size(0)
        if n_test_batches != -1 and batch_idx >= n_test_batches - 1:
            break
    avg_train_loss = total_train_loss / total_train_samples if total_train_samples > 0 else 0.0
    train_acc = total_train_correct / total_train_samples if total_train_samples > 0 else 0.0

    total_test_loss = 0.0
    total_test_correct = 0.0
    total_test_samples = 0
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        with torch.inference_mode():
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                predictions, certainties, _ = outputs
                loss_val, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
                preds = predictions.argmax(dim=1)
                idx = torch.arange(predictions.size(0), device=device)
                preds = preds[idx, where_most_certain]
            else:
                predictions = outputs
                loss_val = nn.CrossEntropyLoss()(predictions, targets)
                preds = predictions.argmax(dim=1)
        total_test_loss += loss_val.item() * inputs.size(0)
        total_test_correct += (preds == targets).sum().item()
        total_test_samples += inputs.size(0)
        if n_test_batches != -1 and batch_idx >= n_test_batches - 1:
            break
    avg_test_loss = total_test_loss / total_test_samples if total_test_samples > 0 else 0.0
    test_acc = total_test_correct / total_test_samples if total_test_samples > 0 else 0.0
    model.train()
    return avg_train_loss, train_acc, avg_test_loss, test_acc

# Training loop (single device, no distributed training)
model.train()
train_iter = iter(train_loader)  # iterator for continuous sampling
for it in range(training_iterations):
    # Fetch the next training batch (restart if at end of epoch)
    try:
        inputs, targets = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        inputs, targets = next(train_iter)
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Adjust learning rate (warmup and scheduler)
    if it < warmup_steps:
        current_lr = lr * float(it + 1) / float(warmup_steps)
    elif use_scheduler and scheduler_type == 'cosine':
        # Cosine annealing schedule after warmup
        progress = float(it - warmup_steps) / float(training_iterations - warmup_steps)
        cos_factor = 0.5 * (1 + math.cos(math.pi * progress))
        eta_min = 1e-7
        current_lr = eta_min + cos_factor * (lr - eta_min)
    elif use_scheduler and scheduler_type == 'multistep':
        milestones_passed = sum(1 for m in milestones if it >= m)
        current_lr = lr * (gamma ** milestones_passed)
    else:
        current_lr = lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    # Forward pass and loss computation
    outputs = model(inputs)
    if isinstance(outputs, tuple):
        predictions, certainties, _ = outputs
        loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
    else:
        predictions = outputs
        loss = nn.CrossEntropyLoss()(predictions, targets)

    # Backward pass and optimization step
    optimizer.zero_grad()
    loss.backward()
    if gradient_clipping != -1:
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping)
    optimizer.step()

    # Periodic evaluation and logging
    if it % track_every == 0 and it != 0:
        avg_train_loss, train_acc, avg_test_loss, test_acc = evaluate()
        iteration_points.append(it)
        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)
        print(f"Iteration {it}: Train Loss={avg_train_loss:.4f}, Train Acc={train_acc:.4f}, "
              f"Test Loss={avg_test_loss:.4f}, Test Acc={test_acc:.4f}")

# Final evaluation at the end of training (if not already done on last iteration)
if iteration_points and iteration_points[-1] != training_iterations - 1:
    avg_train_loss, train_acc, avg_test_loss, test_acc = evaluate()
    iteration_points.append(training_iterations - 1)
    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)
    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)
    print(f"Final Iteration: Train Loss={avg_train_loss:.4f}, Train Acc={train_acc:.4f}, "
          f"Test Loss={avg_test_loss:.4f}, Test Acc={test_acc:.4f}")

# Plot training and test accuracy and loss
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 5))
plt.plot(iteration_points, train_losses, label='Train Loss')
plt.plot(iteration_points, test_losses, label='Test Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training vs Test Loss')
plt.legend()

plt.figure(figsize=(8, 5))
plt.plot(iteration_points, train_accuracies, label='Train Accuracy')
plt.plot(iteration_points, test_accuracies, label='Test Accuracy')
plt.xlabel('Iteration')
plt.ylabel('Accuracy')
plt.title('Training vs Test Accuracy')
plt.legend()

plt.show()


Using neuron select type: first-last
Synch representation size action: 528
Synch representation size out: 528


KeyboardInterrupt: 

In [2]:
# save the model state
torch.save(model.state_dict(), './logs/ctm_cifar10.pth')