In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim


from training.train import train_model
from utils.transforms import get_transforms
from models.ResNet_SC import build_model

In [2]:
# Hyperparameters
batch_size = 128
num_epochs = 50
lr = 0.005
num_classes = 10


# Select device: prefer CUDA, then Apple MPS (for Apple Silicon), otherwise CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")



Using device: mps


## 1) Load data



In [3]:
train_transform = get_transforms(split='train')
test_transform = get_transforms(split='test')



train_dataset = datasets.CIFAR10(
    root="./data",          
    train=True,             
    download=True,          
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=test_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,      # shuffle training data each epoch
    num_workers=8      # adjust number of workers to your machine
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,     # no need to shuffle test/validation data
    num_workers=8
)

## 2) Create model

In [4]:
# (1) Plain ResNet-18 (no self-correction)
model_plain = build_model(
    num_classes=num_classes,
    use_adabn=False,
    use_cbam=False,
    use_proto=False
).to(device)

# (2) ResNet-18 + AdaBN
model_adabn = build_model(
    num_classes=num_classes,
    use_adabn=True,
    use_cbam=False,
    use_proto=False
).to(device)

# (3) ResNet-18 + CBAM
model_cbam = build_model(
    num_classes=num_classes,
    use_adabn=False,
    use_cbam=True,
    use_proto=False
).to(device)

# (4) ResNet-18 + Prototype Alignment
model_proto = build_model(
    num_classes=num_classes,
    use_adabn=False,
    use_cbam=False,
    use_proto=True
).to(device)




### 2.1) Model optimizers

In [5]:
criterion = nn.CrossEntropyLoss()

optim  = optim.SGD(model_plain.parameters(),  lr=lr, momentum=0.9, weight_decay=5e-4)


## 3) Train model

In [None]:
# A dictionary to store (model, optimizer) pairs for easy looping:
baseline_dict = {
    "Plain"   : (model_plain, optim),
    "AdaBN"   : (model_adabn, optim),
    "CBAM"    : (model_cbam,  optim),
    "Proto"   : (model_proto, optim),
}

# Track best validation accuracy for each baseline:
best_val_acc = {name: 0.0 for name in baseline_dict.keys()}

for name, (model, optim) in baseline_dict.items():
    train_model(name=name, model=model, optimizer=optim, train_loader=train_loader, test_loader=test_loader, 
                criterion=criterion, device=device, num_epochs=num_epochs)


=== Training baseline: Plain ===
[Plain][Epoch 1/50] train_loss=1.7903, train_acc=0.3327  val_loss=1.5769, val_acc=0.4448  (best=0.4448)  epoch_time=0h02m13s, ETA=1h48m58s
[Plain][Epoch 2/50] train_loss=1.4538, train_acc=0.4690  val_loss=1.3055, val_acc=0.5465  (best=0.5465)  epoch_time=0h02m14s, ETA=1h47m28s
