In [55]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import torch
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F 
import torch.backends.cudnn as cudnn 
import numpy as np
import sys

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import os 
import argparse
import pandas as pd 
import csv
import time
from randomaug import RandAugment 
from swin_transformer_pytorch import swin_t
import Utils

### Set Training Configuration: Device, Model, and Hyperparameters via ArgParser

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # parsers
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 
parser.add_argument('--dataset', default='CIFAR10') # options: CIFAR10, CIFAR100 
parser.add_argument('--dataset_classes', default='10') # options: 10 for CIFAR10, 100 fr CIFAR100
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # 
parser.add_argument('--opt', default="adam")
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--noaug', action='store_false', help='disable use randomaug') 
parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
parser.add_argument('--net', default='swin') # options: vit, swin, cait, twins
parser.add_argument('--heads', default='6') 
parser.add_argument('--layers', default='12')    # depth
parser.add_argument('--dp', action='store_true', help='use data parallel') 
parser.add_argument('--bs', default='64')    # was 512
parser.add_argument('--size', default="32") 
parser.add_argument('--n_epochs', type=int, default='100')
parser.add_argument('--patch', default='4', type=int, help="patch for ViT") 
parser.add_argument('--dimhead', default="420", type=int) # or 512

args = parser.parse_args(args=[
    '--dataset', 'CIFAR10',
    '--dataset_classes', '10',
    '--lr', '0.0001',
    '--opt', 'adam',
    '--net', 'swin',
    '--heads', '6',
    '--layers', '12',
    '--bs', '200',
    '--size', '32',
    '--n_epochs', '200',
    '--patch', '4',
    '--dimhead', '420',
    '--noaug',
    '--noamp',
    '--dp'
])

best_acc = 0

### Count number of trainable parameters in the model 

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Defining the main training pipeline

In [65]:
def main():
    bs = int(args.bs) 
    imsize = int(args.size)
    
    use_amp = not args.noamp 
    aug = args.noaug
    global net, testloader
    best_acc = 0    # best test accuracy
    start_epoch = 0    # start from epoch 0 or last checkpoint epoch
    
    # Data
    print('==> Preparing data..') 
    size = imsize
    trainloader, testloader = Utils.get_loaders_CIFAR10(size, bs) 
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    # Model factory..
    net = swin_t(window_size=int(args.patch),
                 num_classes=int(args.dataset_classes), downscaling_factors=(2,2,2,1)).to(device)
    pcount = count_parameters(net)
    print("count of parameters in the model = ", pcount/1e6, " million") 
    
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 
        checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net)) 
        net.load_state_dict(checkpoint['model'])
        best_acc = checkpoint['acc'] 
        start_epoch = checkpoint['epoch']
    
    # Loss is CE
    criterion = nn.CrossEntropyLoss()
    
    if args.opt == "adam":
        optimizer = optim.Adam(net.parameters(), lr=args.lr) 
    elif args.opt == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=args.lr)
    
    # use cosine scheduling
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs)
    
    ##### Training
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 
    list_loss = []
    list_acc = []
    list_train_acc = []
    list_train_loss = []
    
    
    net.cuda()
    for epoch in range(start_epoch, args.n_epochs): 
        start = time.time()
        train_loss, train_acc = train(epoch, net, trainloader, criterion,scaler,optimizer, use_amp)
     
    
        val_loss, val_acc = test(epoch, net, testloader, criterion, optimizer, scaler) 
        print(f"Epoch {epoch+1}/{args.n_epochs} | Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
        scheduler.step() # step cosine scheduling list_loss.append(val_loss)
        list_train_loss.append(train_loss)
        list_train_acc.append(train_acc)
        list_loss.append(val_loss)
        list_acc.append(val_acc)

    return testloader, list_loss, list_acc, list_train_loss, list_train_acc

### Training Function

In [66]:
def train(epoch, net, trainloader, criterion, scaler, optimizer, use_amp): 
    print('\nEpoch: %d' % (epoch+1))
    net.train() 
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader): 
        inputs, targets = inputs.to(device), targets.to(device) # Train with amp
        with torch.cuda.amp.autocast(enabled=use_amp): 
            outputs = net(inputs)
            loss = criterion(outputs, targets) 
        scaler.scale(loss).backward() 
        scaler.step(optimizer)
        scaler.update() 
        optimizer.zero_grad()
    
        train_loss += loss.item()
        _, predicted = outputs.max(1) 
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
     
    
    print(batch_idx, len(trainloader), 'TrainLoss: %.3f | TrainAcc: %.3f%% (%d/%d)'
          % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 
    return train_loss/(batch_idx+1), 100.*correct/total

### Testing Function

In [67]:
def test(epoch, net, testloader, criterion, optimizer, scaler): 
    global best_acc
    net.eval() 
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader): 
            inputs, targets = inputs.to(device), targets.to(device) 
            outputs = net(inputs)
            loss = criterion(outputs, targets) 
            test_loss += loss.item()
            _, predicted = outputs.max(1) 
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
        print(batch_idx, len(testloader), 'ValLoss: %.3f | Val Acc: %.3f%% (%d/%d)' 
              % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total 
    if acc > best_acc:
        print('Saving..')
        state = {"model": net.state_dict(), 
                 "optimizer": optimizer.state_dict(), 
                 "acc": acc, "epoch": epoch, 
                 "scaler": scaler.state_dict()} 
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch)) 
        best_acc = acc
    return test_loss/(batch_idx+1), acc

### Run Main Function to Train Model and Collect Metrics

In [68]:
testloader, list_loss, list_acc, list_train_loss, list_train_acc  = main()

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
count of parameters in the model =  26.598646  million

Epoch: 1
249 250 TrainLoss: 2.276 | TrainAcc: 14.108% (7054/50000)
99 100 ValLoss: 2.011 | Val Acc: 24.830% (2483/10000)
Saving..
Epoch 1/200 | Val Loss: 2.0113 | Val Accuracy: 24.83%

Epoch: 2
249 250 TrainLoss: 2.118 | TrainAcc: 20.026% (10013/50000)
99 100 ValLoss: 1.797 | Val Acc: 32.600% (3260/10000)
Saving..
Epoch 2/200 | Val Loss: 1.7968 | Val Accuracy: 32.60%

Epoch: 3
249 250 TrainLoss: 1.999 | TrainAcc: 24.864% (12432/50000)
99 100 ValLoss: 1.595 | Val Acc: 41.020% (4102/10000)
Saving..
Epoch 3/200 | Val Loss: 1.5951 | Val Accuracy: 41.02%

Epoch: 4
249 250 TrainLoss: 1.930 | TrainAcc: 27.654% (13827/50000)
99 100 ValLoss: 1.544 | Val Acc: 41.630% (4163/10000)
Saving..
Epoch 4/200 | Val Loss: 1.5443 | Val Accuracy: 41.63%

Epoch: 5
249 250 TrainLoss: 1.865 | TrainAcc: 30.486% (15243/50000)
99 100 ValLoss: 1.439 | Val Acc: 47.

### Save the training and Testing metrics locally

In [54]:
import pandas as pd

# Create a DataFrame from the lists
df_metrics = pd.DataFrame({
    'Epoch': list(range(1, len(list_acc) + 1)),
    'Train_Loss': list_train_loss,
    'Train_Acc': list_train_acc,
    'Test_Loss': list_loss,
    'Test_Acc': list_acc
})

# Save to CSV
df_metrics.to_csv("swin_cifar10_2211_adam_metrics.csv", index=False)

print("Metrics saved to 'swin_cifar10_metrics.csv'")

Metrics saved to 'swin_cifar10_metrics.csv'


### Load Best Checkpoint and Visualize Predictions on Sample Test Images

In [None]:
# Load best model from checkpoint
from swin_transformer_pytorch import swin_t

net = swin_t(window_size=int(args.patch),
             num_classes=int(args.dataset_classes),
             downscaling_factors=(2, 2, 2, 1)).to(device)

# Instantiate the model using parsed args
checkpoint = torch.load('./checkpoint/swin-4-ckpt.t7', map_location=device)  # use correct patch size
net.load_state_dict(checkpoint['model'])
net.eval()

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Show 5 test samples with predictions
import matplotlib.pyplot as plt
dataiter = iter(testloader)
images, labels = next(dataiter)

images, labels = images[:5], labels[:5]
outputs = net(images.cuda())
_, preds = torch.max(outputs, 1)

# Unnormalize and display
mean = torch.tensor((0.4914, 0.4822, 0.4465)).view(3,1,1)
std = torch.tensor((0.2023, 0.1994, 0.2010)).view(3,1,1)

fig, axes = plt.subplots(1, 5, figsize=(15,3))
for idx in range(5):
    img = images[idx].cpu() * std + mean
    img = img.permute(1,2,0).clamp(0,1).numpy()
    axes[idx].imshow(img)
    axes[idx].set_title(f"Pred: {classes[preds[idx]]}\nTrue: {classes[labels[idx]]}")
    axes[idx].axis('off')
plt.tight_layout()
plt.show()

  checkpoint = torch.load('./checkpoint/swin-4-ckpt.t7', map_location=device)  # use correct patch size


### Visualize Training and Validation Accuracy and Loss Over Epochs

In [None]:
import matplotlib.pyplot as plt

epochs = range(1, len(list_acc) + 1)

# Accuracy Plot
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, list_train_acc, label='Train Accuracy', marker='o')
plt.plot(epochs, list_acc, label='Validation Accuracy', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Train vs Validation Accuracy')
plt.legend()

# Loss Plot
plt.subplot(1, 2, 2)
plt.plot(epochs, list_train_loss, label='Train Loss', marker='o')
plt.plot(epochs, list_loss, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train vs Validation Loss')
plt.legend()

plt.tight_layout()
plt.show()
