In [1]:
%matplotlib inline
import torch
from torch.utils.data import DataLoader, Subset, default_collate
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision.transforms import v2
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import timm
from timm.models.swin_transformer import SwinTransformer
import numpy as np
from torch.utils.tensorboard import SummaryWriter

import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import logging

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Define transformations

In [2]:
cutmix = v2.CutMix(num_classes=100)
random_seed = 42
np.random.seed(random_seed)
torch.manual_seed(random_seed)


def collate_fn_cutmix(batch):
    return cutmix(*default_collate(batch))


def prepare_datasets(train_transform, test_transform, cutmix_, batch_size):
    trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
    testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
    
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

    if cutmix_ == 0:
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)
    elif cutmix_ == 1:
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn_cutmix)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn_cutmix)

    return trainloader, testloader


# Define transformations
test_transform = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
])

train_transform_basic = v2.Compose([
    v2.PILToTensor(),
    v2.RandomResizedCrop(size=(32, 32), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
])

## Training function

In [3]:
def setup_logger(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    logging.basicConfig(
        filename=os.path.join(log_dir, 'training.log'),
        level=logging.INFO,
        format='%(asctime)s %(levelname)s: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        filemode='w'
    )
    logger = logging.getLogger('train_model_logger')
    return logger

def train_model(model, trainloader, testloader, optimizer, num_epochs, scheduler=None, warmup_epochs=10, log_dir='./logs'):
    writer = SummaryWriter(log_dir=log_dir)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    save_dir = log_dir

    # Set up logging with timestamp
    logger = setup_logger(log_dir)

    model = model.to(device)
    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_acc = 0.0

    pbar_epochs = tqdm(total=num_epochs, desc='Overall Progress', position=0)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        if epoch < warmup_epochs:
            lr_scale = min(1., float(epoch + 1) / warmup_epochs)
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * pg['initial_lr']
        
        pbar_batches = tqdm(total=len(trainloader), desc=f'Epoch {epoch+1}/{num_epochs}', position=1, leave=False)
        
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.3)
            
            
            optimizer.step()

            # Track statistics
            running_loss += loss.item() * inputs.size(0)
            pbar_batches.set_postfix(loss=loss.item())
            pbar_batches.update()

        pbar_batches.close()
        epoch_train_loss = running_loss / len(trainloader.dataset)
        train_losses.append(epoch_train_loss)

        # Write to TensorBoard
        writer.add_scalar('train_loss', epoch_train_loss, epoch)
        
        if scheduler is not None:
            scheduler.step()
        
        # Validation loop
        model.eval()
        correct_val = 0
        total_val = 0
        running_val_loss = 0.0

        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                if labels.ndim == 2:
                    labels = labels.argmax(dim=1)
                correct_val += (predicted == labels).sum().item()

        # Calculate average validation loss and accuracy
        epoch_val_loss = running_val_loss / len(testloader.dataset)
        epoch_val_acc = correct_val / total_val
        val_losses.append(epoch_val_loss)
        val_accuracies.append(epoch_val_acc)

        # Write to TensorBoard
        writer.add_scalar('validation_loss', epoch_val_loss, epoch)
        writer.add_scalar('validation_accuracy', epoch_val_acc, epoch)

        pbar_epochs.set_postfix(
                train_loss=f"{epoch_train_loss:.4f}", 
                val_loss=f"{epoch_val_loss:.4f}", 
                val_acc=f"{100*epoch_val_acc:.2f}%"
            )
        pbar_epochs.update()
        
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            print("Saving...")
            torch.save(model.state_dict(), os.path.join(save_dir, f'best_model_ckpt.pth'))
            logger.info(f"Epoch [{epoch+1}/{num_epochs}], New best validation accuracy: {100*epoch_val_acc:.2f}% - Model saved.")

        # Logging epoch information
        logger.info(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_train_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {100*epoch_val_acc:.2f}%")

    pbar_epochs.close()
    writer.close()
    print('Training complete!')
    print(f'Best Validation Accuracy: {best_val_acc:.4f}')
    logger.info(f'Training complete! Best Validation Accuracy: {100*best_val_acc:.2f}%')

    return train_losses, val_losses, val_accuracies


## Model definition

In [4]:
class ResNet50_CIFAR100(torch.nn.Module):
    def __init__(self):
        super(ResNet50_CIFAR100, self).__init__()
        self.model = models.resnet50(weights=None, num_classes=100)
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)
    
class SwinViT_T_CIFAR100(nn.Module):
    def __init__(self, num_classes=100):
        super(SwinViT_T_CIFAR100, self).__init__()
        self.swin = SwinTransformer(
            img_size=32,
            patch_size=2,
            num_classes=num_classes,
            embed_dim=96,
            depths=(2, 2, 6, 2),
            num_heads=(3, 6, 12, 24),
            window_size=4,
        )

    def forward(self, x):
        x = self.swin(x)
        return x

In [32]:
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [33]:
count_parameters(SwinViT_T_CIFAR100())

27574318

In [28]:
count_parameters(ResNet50_CIFAR100())

23705252

## Try different configurations

In [21]:
model_name = "ResNet50_CIFAR100"
num_epochs = 20
b_lr_wd = [
    (128, 0.005, 1e-5),
    (128, 0.005, 1e-4),
    (128, 0.005, 1e-3),
    (128, 0.001, 1e-3),
    (128, 0.001, 1e-4),
    (128, 0.001, 1e-5),
    (256, 0.005, 1e-5),
    (256, 0.005, 1e-4),
    (256, 0.005, 1e-3),
    (256, 0.001, 1e-5),
    (256, 0.001, 1e-4),
    (256, 0.001, 1e-3),
]
optimizer_name = 'SGD'

for batch_size, lr, weight_decay in b_lr_wd:
    trainloader_cutmix, testloader_cutmix = prepare_datasets(train_transform_basic, test_transform, cutmix_=1, batch_size=batch_size)
    model = ResNet50_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_weightdecay={weight_decay}")
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    log_dir = f'./official_models_tuning/{model_name}_cutmix/{optimizer_name}-{batch_size}_lr{lr}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_cutmix, testloader_cutmix, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [07:19<00:00, 21.96s/it, train_loss=3.8248, val_acc=16.56%, val_loss=3.8279]


Training complete!
Best Validation Accuracy: 0.1786
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [07:20<00:00, 22.04s/it, train_loss=3.8188, val_acc=18.84%, val_loss=3.7020]


Training complete!
Best Validation Accuracy: 0.1884
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [07:19<00:00, 21.98s/it, train_loss=3.7795, val_acc=17.72%, val_loss=3.7301]


Training complete!
Best Validation Accuracy: 0.1862
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [07:19<00:00, 21.97s/it, train_loss=4.0773, val_acc=12.35%, val_loss=4.0010]


Training complete!
Best Validation Accuracy: 0.1235
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [07:18<00:00, 21.95s/it, train_loss=4.0899, val_acc=11.63%, val_loss=4.0144]


Training complete!
Best Validation Accuracy: 0.1163
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [07:18<00:00, 21.95s/it, train_loss=4.0970, val_acc=11.34%, val_loss=4.0602]


Training complete!
Best Validation Accuracy: 0.1134
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:32<00:00, 19.60s/it, train_loss=3.9509, val_acc=15.68%, val_loss=3.7996]


Training complete!
Best Validation Accuracy: 0.1568
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:31<00:00, 19.60s/it, train_loss=3.9368, val_acc=15.57%, val_loss=3.8697]


Training complete!
Best Validation Accuracy: 0.1590
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:32<00:00, 19.61s/it, train_loss=3.9222, val_acc=14.05%, val_loss=3.9381]


Training complete!
Best Validation Accuracy: 0.1452
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:32<00:00, 19.63s/it, train_loss=4.1786, val_acc=8.91%, val_loss=4.1399]


Training complete!
Best Validation Accuracy: 0.0898
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:32<00:00, 19.61s/it, train_loss=4.1877, val_acc=9.62%, val_loss=4.1134]


Training complete!
Best Validation Accuracy: 0.0962
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:32<00:00, 19.60s/it, train_loss=4.1708, val_acc=9.63%, val_loss=4.1282]

Training complete!
Best Validation Accuracy: 0.0963





In [22]:
model_name = "SwinViT_T_CIFAR100"
num_epochs = 20
b_lr_wd = [
    (128, 0.005, 1e-5),
    (128, 0.005, 1e-4),
    (128, 0.005, 1e-3),
    (128, 0.001, 1e-3),
    (128, 0.001, 1e-4),
    (128, 0.001, 1e-5),
    (256, 0.005, 1e-5),
    (256, 0.005, 1e-4),
    (256, 0.005, 1e-3),
    (256, 0.001, 1e-5),
    (256, 0.001, 1e-4),
    (256, 0.001, 1e-3),
]
optimizer_name = 'Adam'

for batch_size, lr, weight_decay in b_lr_wd:
    trainloader_cutmix, testloader_cutmix = prepare_datasets(train_transform_basic, test_transform, cutmix_=1, batch_size=batch_size)
    model = SwinViT_T_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_weightdecay={weight_decay}")
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    log_dir = f'./official_models_tuning/{model_name}_cutmix/{optimizer_name}-{batch_size}_lr{lr}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_cutmix, testloader_cutmix, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:33<00:00, 19.69s/it, train_loss=3.8504, val_acc=16.87%, val_loss=3.7502]


Training complete!
Best Validation Accuracy: 0.1704
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:35<00:00, 19.78s/it, train_loss=4.0520, val_acc=11.80%, val_loss=3.9698]


Training complete!
Best Validation Accuracy: 0.1210
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:36<00:00, 19.81s/it, train_loss=4.2236, val_acc=7.21%, val_loss=4.2237]


Training complete!
Best Validation Accuracy: 0.0791
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:33<00:00, 19.68s/it, train_loss=4.0701, val_acc=10.72%, val_loss=4.0481]


Training complete!
Best Validation Accuracy: 0.1078
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:35<00:00, 19.76s/it, train_loss=3.7398, val_acc=18.56%, val_loss=3.6752]


Training complete!
Best Validation Accuracy: 0.1929
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:38<00:00, 19.93s/it, train_loss=3.6287, val_acc=22.92%, val_loss=3.5197]


Training complete!
Best Validation Accuracy: 0.2297
Files already downloaded and verified
Files already downloaded and verified


Overall Progress:   0%|          | 0/10 [2:23:57<?, ?it/s]
Overall Progress:  85%|████████▌ | 17/20 [2:13:14<23:30, 470.25s/it, train_loss=3.8968, val_acc=15.02%, val_loss=3.8624]
Overall Progress:  20%|██        | 4/20 [2:05:22<8:21:28, 1880.56s/it, train_loss=4.2075, val_acc=8.56%, val_loss=4.1389]




Overall Progress: 100%|██████████| 20/20 [04:17<00:00, 12.86s/it, train_loss=3.8361, val_acc=16.00%, val_loss=3.7768]


Training complete!
Best Validation Accuracy: 0.1669
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [04:18<00:00, 12.93s/it, train_loss=4.0280, val_acc=12.37%, val_loss=3.9655]


Training complete!
Best Validation Accuracy: 0.1237
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [04:17<00:00, 12.88s/it, train_loss=4.1803, val_acc=8.60%, val_loss=4.1455]


Training complete!
Best Validation Accuracy: 0.0860
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [04:17<00:00, 12.89s/it, train_loss=3.5905, val_acc=22.19%, val_loss=3.5735]


Training complete!
Best Validation Accuracy: 0.2298
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [04:17<00:00, 12.87s/it, train_loss=3.7421, val_acc=20.27%, val_loss=3.6271]


Training complete!
Best Validation Accuracy: 0.2027
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [04:17<00:00, 12.86s/it, train_loss=4.0224, val_acc=11.94%, val_loss=3.9782]

Training complete!
Best Validation Accuracy: 0.1194





change batch size

In [23]:
model_name = "ResNet50_CIFAR100"
num_epochs = 20
b_lr_wd = [
    (512, 0.005, 1e-4),
    (1024, 0.005, 1e-4),
]
optimizer_name = 'SGD'

for batch_size, lr, weight_decay in b_lr_wd:
    trainloader_cutmix, testloader_cutmix = prepare_datasets(train_transform_basic, test_transform, cutmix_=1, batch_size=batch_size)
    model = ResNet50_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_weightdecay={weight_decay}")
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    log_dir = f'./official_models_tuning/{model_name}_cutmix/{optimizer_name}-{batch_size}_lr{lr}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_cutmix, testloader_cutmix, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:07<00:00, 18.35s/it, train_loss=4.0131, val_acc=11.76%, val_loss=3.9988]


Training complete!
Best Validation Accuracy: 0.1363
Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [06:13<00:00, 18.65s/it, train_loss=4.1852, val_acc=11.35%, val_loss=4.0276]

Training complete!
Best Validation Accuracy: 0.1135





In [24]:
model_name = "SwinViT_T_CIFAR100"
num_epochs = 20
b_lr_wd = [
    (512, 0.001, 1e-5),
    (1024, 0.001, 1e-5),
]
optimizer_name = 'Adam'

for batch_size, lr, weight_decay in b_lr_wd:
    trainloader_cutmix, testloader_cutmix = prepare_datasets(train_transform_basic, test_transform, cutmix_=1, batch_size=batch_size)
    model = SwinViT_T_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_weightdecay={weight_decay}")
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    log_dir = f'./official_models_tuning/{model_name}_cutmix/{optimizer_name}-{batch_size}_lr{lr}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_cutmix, testloader_cutmix, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 20/20 [03:51<00:00, 11.59s/it, train_loss=3.6287, val_acc=21.39%, val_loss=3.5375]


Training complete!
Best Validation Accuracy: 0.2140
Files already downloaded and verified
Files already downloaded and verified


  return F.conv2d(input, weight, bias, self.stride,
Overall Progress: 100%|██████████| 20/20 [03:47<00:00, 11.39s/it, train_loss=3.7537, val_acc=19.40%, val_loss=3.6858]

Training complete!
Best Validation Accuracy: 0.2082





Rest are some training records.

In [7]:
model_name = "SwinViT_T_CIFAR100"
num_epochs = 200
lr = 0.001
weight_decay = 1e-5
optimizer_name = 'AdamW'
batch_size = 128
# lr_schedules = ["StepLR", "CosineAnnealingLR", "ReduceLROnPlateau"]
lr_schedules = ["CosineAnnealingLR"]


for lr_schedule in lr_schedules:
    trainloader_cutmix, testloader_cutmix = prepare_datasets(train_transform_basic, test_transform, cutmix_=1, batch_size=batch_size)
    model = SwinViT_T_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_{lr_schedule}_weightdecay={weight_decay}")
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader_cutmix), eta_min=1e-6)
    log_dir = f'./official_models_tuning/{model_name}_cutmix/{optimizer_name}-{batch_size}_lr{lr}_{lr_schedule}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_cutmix, testloader_cutmix, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress: 100%|██████████| 200/200 [1:22:46<00:00, 24.83s/it, train_loss=2.2727, val_acc=46.32%, val_loss=2.6705]

Training complete!
Best Validation Accuracy: 0.4634





In [19]:
model_name = "SwinViT_T_CIFAR100"
num_epochs = 300
lr = 0.001
weight_decay = 1e-5
optimizer_name = 'AdamW'
batch_size = 128
# lr_schedules = ["StepLR", "CosineAnnealingLR", "ReduceLROnPlateau"]
lr_schedules = ["CosineAnnealingLR"]


for lr_schedule in lr_schedules:
    trainloader_cutmix, testloader_cutmix = prepare_datasets(train_transform_basic, test_transform, cutmix_=1, batch_size=batch_size)
    model = SwinViT_T_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_{lr_schedule}_weightdecay={weight_decay}")
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader_cutmix), eta_min=1e-6)
    log_dir = f'./official_models_tuning/{model_name}_cutmix/{optimizer_name}-{batch_size}_lr{lr}_{lr_schedule}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_cutmix, testloader_cutmix, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress:   0%|          | 1/300 [00:25<2:05:49, 25.25s/it, train_loss=4.3955, val_acc=8.14%, val_loss=4.2611]

Saving...


Overall Progress:   1%|          | 2/300 [00:50<2:05:30, 25.27s/it, train_loss=4.3098, val_acc=8.15%, val_loss=4.2415]

Saving...


Overall Progress:   1%|          | 3/300 [01:16<2:05:37, 25.38s/it, train_loss=4.2745, val_acc=9.57%, val_loss=4.1890]

Saving...


Overall Progress:   1%|▏         | 4/300 [01:41<2:05:06, 25.36s/it, train_loss=4.2406, val_acc=10.67%, val_loss=4.1441]

Saving...


Overall Progress:   2%|▏         | 5/300 [02:06<2:04:44, 25.37s/it, train_loss=4.1932, val_acc=11.24%, val_loss=4.1513]

Saving...


Overall Progress:   2%|▏         | 6/300 [02:31<2:03:51, 25.28s/it, train_loss=4.1657, val_acc=11.26%, val_loss=4.1142]

Saving...


Overall Progress:   3%|▎         | 8/300 [03:21<2:02:19, 25.14s/it, train_loss=4.1364, val_acc=13.49%, val_loss=4.0361]

Saving...


Overall Progress:   3%|▎         | 9/300 [03:47<2:02:00, 25.16s/it, train_loss=4.1199, val_acc=13.84%, val_loss=4.0555]

Saving...


Overall Progress:   5%|▍         | 14/300 [05:51<1:58:46, 24.92s/it, train_loss=4.0535, val_acc=14.40%, val_loss=4.0139]

Saving...


Overall Progress:   5%|▌         | 15/300 [06:16<1:58:46, 25.00s/it, train_loss=4.0447, val_acc=16.25%, val_loss=3.9172]

Saving...


Overall Progress:   6%|▌         | 18/300 [07:32<1:58:10, 25.14s/it, train_loss=3.9740, val_acc=16.65%, val_loss=3.9128]

Saving...


Overall Progress:   6%|▋         | 19/300 [07:57<1:57:44, 25.14s/it, train_loss=3.9931, val_acc=18.14%, val_loss=3.8971]

Saving...


Overall Progress:   7%|▋         | 20/300 [08:23<1:57:53, 25.26s/it, train_loss=3.9591, val_acc=18.87%, val_loss=3.8464]

Saving...


Overall Progress:   8%|▊         | 23/300 [09:38<1:55:58, 25.12s/it, train_loss=3.9042, val_acc=19.63%, val_loss=3.8312]

Saving...


Overall Progress:   8%|▊         | 25/300 [10:28<1:54:59, 25.09s/it, train_loss=3.8801, val_acc=20.66%, val_loss=3.7770]

Saving...


Overall Progress:   9%|▊         | 26/300 [10:54<1:54:51, 25.15s/it, train_loss=3.8631, val_acc=20.85%, val_loss=3.7383]

Saving...


Overall Progress:  10%|▉         | 29/300 [12:09<1:52:49, 24.98s/it, train_loss=3.8482, val_acc=21.34%, val_loss=3.7752]

Saving...


Overall Progress:  10%|█         | 31/300 [12:59<1:52:56, 25.19s/it, train_loss=3.8237, val_acc=21.72%, val_loss=3.7439]

Saving...


Overall Progress:  11%|█         | 32/300 [13:24<1:52:20, 25.15s/it, train_loss=3.8144, val_acc=21.84%, val_loss=3.7481]

Saving...


Overall Progress:  11%|█         | 33/300 [13:50<1:51:57, 25.16s/it, train_loss=3.7868, val_acc=22.33%, val_loss=3.7283]

Saving...


Overall Progress:  11%|█▏        | 34/300 [14:15<1:51:45, 25.21s/it, train_loss=3.7605, val_acc=23.48%, val_loss=3.6693]

Saving...


Overall Progress:  13%|█▎        | 38/300 [15:56<1:50:11, 25.24s/it, train_loss=3.7549, val_acc=24.57%, val_loss=3.6484]

Saving...


Overall Progress:  13%|█▎        | 39/300 [16:21<1:49:55, 25.27s/it, train_loss=3.7250, val_acc=25.90%, val_loss=3.6180]

Saving...


Overall Progress:  16%|█▌        | 47/300 [19:41<1:45:28, 25.01s/it, train_loss=3.6397, val_acc=27.05%, val_loss=3.5882]

Saving...


Overall Progress:  17%|█▋        | 50/300 [20:56<1:44:14, 25.02s/it, train_loss=3.6520, val_acc=27.26%, val_loss=3.5928]

Saving...


Overall Progress:  17%|█▋        | 51/300 [21:22<1:44:11, 25.10s/it, train_loss=3.6273, val_acc=27.68%, val_loss=3.5739]

Saving...


Overall Progress:  18%|█▊        | 53/300 [22:12<1:43:14, 25.08s/it, train_loss=3.6156, val_acc=29.93%, val_loss=3.4950]

Saving...


Overall Progress:  19%|█▉        | 57/300 [23:52<1:41:25, 25.04s/it, train_loss=3.6127, val_acc=30.26%, val_loss=3.4817]

Saving...


Overall Progress:  19%|█▉        | 58/300 [24:18<1:41:36, 25.19s/it, train_loss=3.5496, val_acc=30.81%, val_loss=3.4527]

Saving...


Overall Progress:  21%|██        | 62/300 [25:58<1:39:23, 25.06s/it, train_loss=3.5448, val_acc=31.50%, val_loss=3.4059]

Saving...


Overall Progress:  22%|██▏       | 66/300 [27:38<1:37:49, 25.08s/it, train_loss=3.5160, val_acc=33.25%, val_loss=3.3716]

Saving...


Overall Progress:  24%|██▍       | 73/300 [30:33<1:34:11, 24.89s/it, train_loss=3.4564, val_acc=35.39%, val_loss=3.2920]

Saving...


Overall Progress:  28%|██▊       | 84/300 [35:09<1:30:11, 25.05s/it, train_loss=3.3638, val_acc=36.07%, val_loss=3.3180]

Saving...


Overall Progress:  31%|███▏      | 94/300 [39:19<1:25:54, 25.02s/it, train_loss=3.2692, val_acc=36.21%, val_loss=3.3106]

Saving...


Overall Progress:  33%|███▎      | 98/300 [40:59<1:24:06, 24.98s/it, train_loss=3.2851, val_acc=37.35%, val_loss=3.2452]

Saving...


Overall Progress:  34%|███▍      | 102/300 [42:40<1:22:54, 25.12s/it, train_loss=3.2347, val_acc=38.82%, val_loss=3.2535]

Saving...


Overall Progress:  36%|███▌      | 108/300 [45:10<1:20:01, 25.01s/it, train_loss=3.2224, val_acc=38.91%, val_loss=3.2355]

Saving...


Overall Progress:  36%|███▋      | 109/300 [45:35<1:19:50, 25.08s/it, train_loss=3.1989, val_acc=39.07%, val_loss=3.1777]

Saving...


Overall Progress:  40%|███▉      | 119/300 [49:47<1:15:28, 25.02s/it, train_loss=3.1455, val_acc=40.74%, val_loss=3.1632]

Saving...


Overall Progress:  42%|████▏     | 127/300 [53:08<1:12:40, 25.20s/it, train_loss=3.1173, val_acc=41.32%, val_loss=3.1635]

Saving...


Overall Progress:  44%|████▍     | 133/300 [55:39<1:09:56, 25.13s/it, train_loss=3.0970, val_acc=41.76%, val_loss=3.1224]

Saving...


Overall Progress:  49%|████▊     | 146/300 [1:01:04<1:04:18, 25.05s/it, train_loss=3.0134, val_acc=41.88%, val_loss=3.1427]

Saving...


Overall Progress:  50%|█████     | 150/300 [1:02:44<1:02:32, 25.01s/it, train_loss=2.9664, val_acc=43.82%, val_loss=3.1126]

Saving...


Overall Progress:  52%|█████▏    | 156/300 [1:05:15<1:00:25, 25.18s/it, train_loss=2.9614, val_acc=43.95%, val_loss=3.0829]

Saving...


Overall Progress:  59%|█████▊    | 176/300 [1:13:37<51:54, 25.11s/it, train_loss=2.8639, val_acc=44.44%, val_loss=3.1006]  

Saving...


Overall Progress:  63%|██████▎   | 190/300 [1:19:28<45:48, 24.99s/it, train_loss=2.8336, val_acc=45.42%, val_loss=3.0279]

Saving...


Overall Progress:  64%|██████▎   | 191/300 [1:19:54<45:43, 25.17s/it, train_loss=2.8180, val_acc=46.12%, val_loss=2.9918]

Saving...


Overall Progress:  83%|████████▎ | 250/300 [1:44:33<20:49, 24.98s/it, train_loss=2.5838, val_acc=46.64%, val_loss=2.9899]

Saving...


Overall Progress: 100%|█████████▉| 299/300 [2:05:01<00:25, 25.19s/it, train_loss=2.5089, val_acc=47.39%, val_loss=3.0177]

Saving...


Overall Progress: 100%|██████████| 300/300 [2:05:27<00:00, 25.09s/it, train_loss=2.5881, val_acc=46.26%, val_loss=3.0105]

Training complete!
Best Validation Accuracy: 0.4739





In [5]:
model_name = "SwinViT_T_CIFAR100"
num_epochs = 300
lr = 0.001
weight_decay = 1e-5
optimizer_name = 'AdamW'
batch_size = 128
# lr_schedules = ["StepLR", "CosineAnnealingLR", "ReduceLROnPlateau"]
lr_schedules = ["CosineAnnealingLR"]


for lr_schedule in lr_schedules:
    trainloader_basic, testloader_basic = prepare_datasets(train_transform_basic, test_transform, cutmix_=0, batch_size=batch_size)
    model = SwinViT_T_CIFAR100()
    model = model.to(device)
    print(f" ======> Training {model_name}: {optimizer_name}-{batch_size}_lr={lr}_{lr_schedule}_weightdecay={weight_decay}")
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.97)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader_basic), eta_min=1e-6)
    log_dir = f'./official_models_tuning/{model_name}_basic/{optimizer_name}-{batch_size}_lr{lr}_{lr_schedule}_wd-{weight_decay}_e{num_epochs}'
    train_model(model, trainloader_basic, testloader_basic, optimizer, num_epochs, scheduler=scheduler, log_dir=log_dir)

Files already downloaded and verified
Files already downloaded and verified


Overall Progress:   0%|          | 1/300 [00:26<2:12:33, 26.60s/it, train_loss=4.1455, val_acc=13.27%, val_loss=3.8608]

Saving...


Overall Progress:   1%|          | 2/300 [00:51<2:06:48, 25.53s/it, train_loss=3.8934, val_acc=16.74%, val_loss=3.7033]

Saving...


Overall Progress:   1%|          | 3/300 [01:16<2:05:14, 25.30s/it, train_loss=3.7731, val_acc=19.08%, val_loss=3.6190]

Saving...


Overall Progress:   1%|▏         | 4/300 [01:41<2:04:01, 25.14s/it, train_loss=3.6738, val_acc=21.85%, val_loss=3.4871]

Saving...


Overall Progress:   2%|▏         | 5/300 [02:06<2:04:01, 25.23s/it, train_loss=3.6060, val_acc=22.08%, val_loss=3.4964]

Saving...


Overall Progress:   2%|▏         | 6/300 [02:32<2:04:05, 25.33s/it, train_loss=3.5548, val_acc=23.10%, val_loss=3.4041]

Saving...


Overall Progress:   2%|▏         | 7/300 [02:57<2:03:11, 25.23s/it, train_loss=3.5094, val_acc=23.81%, val_loss=3.3951]

Saving...


Overall Progress:   3%|▎         | 8/300 [03:22<2:02:29, 25.17s/it, train_loss=3.4693, val_acc=26.26%, val_loss=3.2814]

Saving...


Overall Progress:   3%|▎         | 10/300 [04:12<2:01:18, 25.10s/it, train_loss=3.4213, val_acc=26.65%, val_loss=3.2845]

Saving...


Overall Progress:   4%|▎         | 11/300 [04:37<2:00:57, 25.11s/it, train_loss=3.3739, val_acc=26.86%, val_loss=3.2760]

Saving...


Overall Progress:   4%|▍         | 12/300 [05:02<2:00:28, 25.10s/it, train_loss=3.3264, val_acc=28.92%, val_loss=3.1869]

Saving...


Overall Progress:   5%|▍         | 14/300 [05:52<1:59:06, 24.99s/it, train_loss=3.2724, val_acc=29.65%, val_loss=3.1846]

Saving...


Overall Progress:   5%|▌         | 15/300 [06:17<1:58:27, 24.94s/it, train_loss=3.2367, val_acc=30.46%, val_loss=3.1336]

Saving...


Overall Progress:   5%|▌         | 16/300 [06:42<1:58:32, 25.05s/it, train_loss=3.2000, val_acc=30.66%, val_loss=3.1107]

Saving...


Overall Progress:   6%|▌         | 17/300 [07:07<1:58:01, 25.02s/it, train_loss=3.1793, val_acc=31.45%, val_loss=3.1076]

Saving...


Overall Progress:   6%|▌         | 18/300 [07:32<1:57:20, 24.97s/it, train_loss=3.1670, val_acc=32.52%, val_loss=3.0447]

Saving...


Overall Progress:   6%|▋         | 19/300 [07:57<1:57:00, 24.99s/it, train_loss=3.1473, val_acc=33.10%, val_loss=3.0453]

Saving...


Overall Progress:   7%|▋         | 20/300 [08:22<1:57:03, 25.08s/it, train_loss=3.1154, val_acc=34.08%, val_loss=2.9797]

Saving...


Overall Progress:   7%|▋         | 21/300 [08:47<1:56:44, 25.10s/it, train_loss=3.0849, val_acc=34.42%, val_loss=2.9847]

Saving...


Overall Progress:   7%|▋         | 22/300 [09:12<1:55:51, 25.00s/it, train_loss=3.0598, val_acc=34.67%, val_loss=2.9787]

Saving...


Overall Progress:   8%|▊         | 23/300 [09:37<1:55:13, 24.96s/it, train_loss=3.0275, val_acc=35.77%, val_loss=2.9509]

Saving...


Overall Progress:   8%|▊         | 24/300 [10:02<1:54:49, 24.96s/it, train_loss=3.0199, val_acc=36.03%, val_loss=2.9254]

Saving...


Overall Progress:   8%|▊         | 25/300 [10:27<1:54:24, 24.96s/it, train_loss=2.9910, val_acc=36.80%, val_loss=2.8925]

Saving...


Overall Progress:   9%|▉         | 28/300 [11:42<1:53:43, 25.09s/it, train_loss=2.9410, val_acc=37.39%, val_loss=2.8610]

Saving...


Overall Progress:  10%|▉         | 29/300 [12:08<1:53:33, 25.14s/it, train_loss=2.9229, val_acc=39.34%, val_loss=2.7912]

Saving...


Overall Progress:  11%|█         | 33/300 [13:47<1:50:27, 24.82s/it, train_loss=2.8315, val_acc=40.09%, val_loss=2.7863]

Saving...


Overall Progress:  12%|█▏        | 36/300 [15:02<1:49:36, 24.91s/it, train_loss=2.7756, val_acc=40.21%, val_loss=2.7953]

Saving...


Overall Progress:  12%|█▏        | 37/300 [15:27<1:49:17, 24.94s/it, train_loss=2.7502, val_acc=41.33%, val_loss=2.7498]

Saving...


Overall Progress:  13%|█▎        | 39/300 [16:17<1:49:04, 25.08s/it, train_loss=2.7146, val_acc=41.91%, val_loss=2.7442]

Saving...


Overall Progress:  13%|█▎        | 40/300 [16:42<1:48:45, 25.10s/it, train_loss=2.7117, val_acc=42.55%, val_loss=2.7188]

Saving...


Overall Progress:  14%|█▍        | 42/300 [17:32<1:47:15, 24.94s/it, train_loss=2.6506, val_acc=42.60%, val_loss=2.6872]

Saving...


Overall Progress:  14%|█▍        | 43/300 [17:57<1:47:06, 25.01s/it, train_loss=2.6271, val_acc=43.73%, val_loss=2.6700]

Saving...


Overall Progress:  16%|█▌        | 48/300 [20:01<1:43:41, 24.69s/it, train_loss=2.5468, val_acc=44.55%, val_loss=2.6617]

Saving...


Overall Progress:  18%|█▊        | 54/300 [22:30<1:41:34, 24.77s/it, train_loss=2.4454, val_acc=44.86%, val_loss=2.6707]

Saving...


Overall Progress:  18%|█▊        | 55/300 [22:55<1:41:24, 24.84s/it, train_loss=2.4257, val_acc=45.48%, val_loss=2.6529]

Saving...


Overall Progress:  19%|█▉        | 57/300 [23:45<1:40:44, 24.87s/it, train_loss=2.3929, val_acc=45.74%, val_loss=2.6777]

Saving...


Overall Progress:  21%|██▏       | 64/300 [26:40<1:38:26, 25.03s/it, train_loss=2.2897, val_acc=46.53%, val_loss=2.6518]

Saving...


Overall Progress:  23%|██▎       | 69/300 [28:45<1:36:06, 24.97s/it, train_loss=2.2288, val_acc=46.56%, val_loss=2.6533]

Saving...


Overall Progress:  26%|██▋       | 79/300 [32:52<1:30:48, 24.66s/it, train_loss=2.0925, val_acc=47.01%, val_loss=2.6972]

Saving...


Overall Progress:  28%|██▊       | 84/300 [34:56<1:29:06, 24.75s/it, train_loss=2.0348, val_acc=47.79%, val_loss=2.6718]

Saving...


Overall Progress:  33%|███▎      | 99/300 [41:10<1:23:19, 24.87s/it, train_loss=1.8838, val_acc=48.57%, val_loss=2.7274]

Saving...


Overall Progress:  35%|███▍      | 104/300 [43:15<1:21:12, 24.86s/it, train_loss=1.8470, val_acc=48.71%, val_loss=2.7065]

Saving...


Overall Progress:  39%|███▊      | 116/300 [48:14<1:16:17, 24.88s/it, train_loss=1.7556, val_acc=48.95%, val_loss=2.7105]

Saving...


Overall Progress:  41%|████      | 122/300 [50:42<1:13:02, 24.62s/it, train_loss=1.7123, val_acc=48.96%, val_loss=2.7190]

Saving...


Overall Progress:  41%|████▏     | 124/300 [51:32<1:12:19, 24.65s/it, train_loss=1.7097, val_acc=49.46%, val_loss=2.6997]

Saving...


Overall Progress:  45%|████▌     | 135/300 [56:04<1:08:03, 24.75s/it, train_loss=1.6520, val_acc=49.56%, val_loss=2.7191]

Saving...


Overall Progress:  46%|████▌     | 138/300 [57:20<1:07:25, 24.97s/it, train_loss=1.6349, val_acc=50.34%, val_loss=2.6781]

Saving...


Overall Progress:  48%|████▊     | 144/300 [59:50<1:05:04, 25.03s/it, train_loss=1.6148, val_acc=50.47%, val_loss=2.6809]

Saving...


Overall Progress:  53%|█████▎    | 160/300 [1:06:34<58:45, 25.18s/it, train_loss=1.5388, val_acc=50.57%, val_loss=2.6565]  

Saving...


Overall Progress:  55%|█████▍    | 164/300 [1:08:15<57:08, 25.21s/it, train_loss=1.5340, val_acc=50.61%, val_loss=2.6482]

Saving...


Overall Progress:  56%|█████▋    | 169/300 [1:10:21<54:55, 25.16s/it, train_loss=1.5055, val_acc=50.94%, val_loss=2.6408]

Saving...


Overall Progress:  57%|█████▋    | 170/300 [1:10:46<54:34, 25.19s/it, train_loss=1.5066, val_acc=51.25%, val_loss=2.6145]

Saving...


Overall Progress:  65%|██████▍   | 194/300 [1:20:44<43:44, 24.76s/it, train_loss=1.4350, val_acc=51.40%, val_loss=2.6157]

Saving...


Overall Progress:  66%|██████▋   | 199/300 [1:22:48<41:39, 24.75s/it, train_loss=1.4177, val_acc=51.53%, val_loss=2.6009]

Saving...


Overall Progress:  67%|██████▋   | 202/300 [1:24:03<40:52, 25.03s/it, train_loss=1.4051, val_acc=51.66%, val_loss=2.5892]

Saving...


Overall Progress:  70%|███████   | 210/300 [1:27:23<37:20, 24.90s/it, train_loss=1.3800, val_acc=51.67%, val_loss=2.5931]

Saving...


Overall Progress:  71%|███████▏  | 214/300 [1:29:02<35:35, 24.83s/it, train_loss=1.3805, val_acc=51.76%, val_loss=2.5828]

Saving...


Overall Progress:  72%|███████▏  | 216/300 [1:29:52<34:59, 24.99s/it, train_loss=1.3747, val_acc=51.83%, val_loss=2.5796]

Saving...


Overall Progress:  73%|███████▎  | 218/300 [1:30:42<34:08, 24.98s/it, train_loss=1.3714, val_acc=52.06%, val_loss=2.5563]

Saving...


Overall Progress:  73%|███████▎  | 220/300 [1:31:32<33:14, 24.94s/it, train_loss=1.3705, val_acc=52.32%, val_loss=2.5731]

Saving...


Overall Progress:  77%|███████▋  | 232/300 [1:36:30<28:06, 24.80s/it, train_loss=1.3405, val_acc=52.44%, val_loss=2.5579]

Saving...


Overall Progress:  78%|███████▊  | 235/300 [1:37:46<27:08, 25.05s/it, train_loss=1.3286, val_acc=52.51%, val_loss=2.5599]

Saving...


Overall Progress:  79%|███████▊  | 236/300 [1:38:11<26:48, 25.14s/it, train_loss=1.3239, val_acc=52.53%, val_loss=2.5610]

Saving...


Overall Progress:  79%|███████▉  | 238/300 [1:39:00<25:46, 24.94s/it, train_loss=1.3221, val_acc=52.55%, val_loss=2.5515]

Saving...


Overall Progress:  81%|████████  | 243/300 [1:41:05<23:39, 24.90s/it, train_loss=1.3028, val_acc=52.69%, val_loss=2.5462]

Saving...


Overall Progress:  82%|████████▏ | 247/300 [1:42:45<22:02, 24.96s/it, train_loss=1.3116, val_acc=52.99%, val_loss=2.5488]

Saving...


Overall Progress:  85%|████████▍ | 254/300 [1:45:40<19:07, 24.95s/it, train_loss=1.2981, val_acc=53.23%, val_loss=2.5367]

Saving...


Overall Progress:  87%|████████▋ | 261/300 [1:48:35<16:14, 24.98s/it, train_loss=1.2875, val_acc=53.31%, val_loss=2.5265]

Saving...


Overall Progress:  92%|█████████▏| 276/300 [1:54:49<10:00, 25.01s/it, train_loss=1.2614, val_acc=53.60%, val_loss=2.5160]

Saving...


Overall Progress:  96%|█████████▌| 287/300 [1:59:24<05:25, 25.07s/it, train_loss=1.2442, val_acc=53.68%, val_loss=2.5065]

Saving...


Overall Progress:  98%|█████████▊| 294/300 [2:02:19<02:30, 25.02s/it, train_loss=1.2362, val_acc=53.82%, val_loss=2.5000]

Saving...


Overall Progress: 100%|██████████| 300/300 [2:04:49<00:00, 24.96s/it, train_loss=1.2290, val_acc=52.97%, val_loss=2.5184]

Training complete!
Best Validation Accuracy: 0.5382



