# DyT CIFAR-10 Training

Train a DynamicTanh Transformer (DyT) on CIFAR-10 with mixup and RandAugment. This notebook is self-contained and saves checkpoints, logs, and plots locally.


In [None]:
import csv
import random
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.cuda import amp

from dyt import DyT
from randomaug import RandAugment
from utils import progress_bar

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
cudnn.benchmark = True


In [None]:
# Hyperparameters tuned to stay within ~8-10GB of VRAM.
DATA_DIR = Path('./data')
CHECKPOINT_DIR = Path('checkpoint')
LOG_DIR = Path('log')
PLOT_DIR = Path('plots')

for directory in (CHECKPOINT_DIR, LOG_DIR, PLOT_DIR):
    directory.mkdir(exist_ok=True)

TOTAL_EPOCHS = 200
BATCH_SIZE = 256
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05
NUM_CLASSES = 10
IMAGE_SIZE = 32
PATCH_SIZE = 4
DIM = 384
DEPTH = 6
HEADS = 6
MLP_DIM = 1536
DROPOUT = 0.1
EMB_DROPOUT = 0.1
MIXUP_ALPHA = 0.2
USE_MIXUP = True
USE_RANDAUG = True
USE_AMP = True
NUM_WORKERS = 8

CHECKPOINT_PATH = CHECKPOINT_DIR / 'dyt_cifar10_latest.pth'
BEST_CHECKPOINT_PATH = CHECKPOINT_DIR / 'dyt_cifar10_best.pth'

print('Configuration ready for CIFAR-10 DyT training.')


In [None]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2023, 0.1994, 0.2010)

train_transforms = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
]

test_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

if USE_RANDAUG:
    train_transforms.insert(0, RandAugment(2, 14))

transform_train = transforms.Compose(train_transforms)

trainset = torchvision.datasets.CIFAR10(
    root=str(DATA_DIR),
    train=True,
    download=True,
    transform=transform_train,
)

testset = torchvision.datasets.CIFAR10(
    root=str(DATA_DIR),
    train=False,
    download=True,
    transform=test_transform,
)

trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=256,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

classes = trainset.classes
print(f'Train batches: {len(trainloader)}, Test batches: {len(testloader)}')


In [None]:
def mixup_data(inputs, targets, alpha):
    if not USE_MIXUP or alpha <= 0:
        return inputs, targets, targets, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = inputs.size(0)
    index = torch.randperm(batch_size, device=inputs.device)
    mixed_inputs = lam * inputs + (1 - lam) * inputs[index]
    targets_a = targets
    targets_b = targets[index]
    return mixed_inputs, targets_a, targets_b, lam

def mixup_criterion(criterion, predictions, targets_a, targets_b, lam):
    if not USE_MIXUP:
        return criterion(predictions, targets_a)
    return lam * criterion(predictions, targets_a) + (1 - lam) * criterion(predictions, targets_b)

def plot_metrics(train_losses, val_losses, val_accs):
    if not train_losses:
        return
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss')

    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_accs, label='Val Acc (%)', color='tab:green')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy')

    plot_path = PLOT_DIR / 'dyt_cifar10_training.png'
    plt.suptitle('DyT on CIFAR-10')
    plt.tight_layout()
    plt.savefig(plot_path)
    plt.show()
    print(f'Saved training plot to {plot_path}')


In [None]:
model = DyT(
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    dim=DIM,
    depth=DEPTH,
    heads=HEADS,
    mlp_dim=MLP_DIM,
    dropout=DROPOUT,
    emb_dropout=EMB_DROPOUT,
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS)
scaler = amp.GradScaler(enabled=USE_AMP)

start_epoch = 0
best_acc = 0.0
train_losses = []
val_losses = []
val_accs = []

if CHECKPOINT_PATH.exists():
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    scaler.load_state_dict(checkpoint['scaler'])
    start_epoch = checkpoint['epoch'] + 1
    best_acc = checkpoint.get('best_acc', 0.0)
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    val_accs = checkpoint.get('val_accs', [])
    print(f'Resumed from epoch {start_epoch} with best acc {best_acc:.2f}%')
else:
    print('Starting fresh training run.')


In [None]:
def save_checkpoint(epoch, best=False):
    state = {
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'scaler': scaler.state_dict(),
        'best_acc': best_acc,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accs': val_accs,
    }
    path = BEST_CHECKPOINT_PATH if best else CHECKPOINT_PATH
    torch.save(state, path)
    label = 'best' if best else 'latest'
    print(f'Checkpoint ({label}) saved to {path}')

def log_epoch():
    csv_path = LOG_DIR / 'dyt_cifar10_metrics.csv'
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['train_loss'] + train_losses)
        writer.writerow(['val_loss'] + val_losses)
        writer.writerow(['val_acc'] + val_accs)
    print(f'Metrics updated at {csv_path}')

def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        mixed_inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, MIXUP_ALPHA)

        with amp.autocast(enabled=USE_AMP):
            outputs = model(mixed_inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(
            batch_idx,
            len(trainloader),
            'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
                running_loss / (batch_idx + 1),
                100.0 * correct / total,
                correct,
                total,
            ),
        )
    return running_loss / len(trainloader)

@torch.no_grad()
def evaluate(epoch):
    model.eval()
    running_loss = 0.0
    total = 0
    correct = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(
            batch_idx,
            len(testloader),
            'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
                running_loss / (batch_idx + 1),
                100.0 * correct / total,
                correct,
                total,
            ),
        )
    avg_loss = running_loss / len(testloader)
    acc = 100.0 * correct / total
    return avg_loss, acc


In [None]:
if start_epoch >= TOTAL_EPOCHS:
    print('Training already completed for the configured number of epochs.')
else:
    for epoch in range(start_epoch, TOTAL_EPOCHS):
        print(f'\nEpoch {epoch + 1}/{TOTAL_EPOCHS}')
        train_loss = train_one_epoch(epoch)
        val_loss, val_acc = evaluate(epoch)

        scheduler.step()

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        current_lr = optimizer.param_groups[0]['lr']
        log_line = (
            f"{time.ctime()} | Epoch {epoch + 1}/{TOTAL_EPOCHS} | lr: {current_lr:.6f} | "
            f"train loss: {train_loss:.4f} | val loss: {val_loss:.4f} | val acc: {val_acc:.2f}%"
        )

        log_epoch()
        with open(LOG_DIR / 'dyt_cifar10.log', 'a') as logfile:
            logfile.write(log_line + '\n')

        save_checkpoint(epoch, best=False)
        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint(epoch, best=True)

        print(log_line)


In [None]:
plot_metrics(train_losses, val_losses, val_accs)
