In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms 


import warnings
warnings.filterwarnings('ignore')

In [2]:
import timm
import torch
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_models(dataset, model_name, key):
    if dataset == 'cifar10':
        # For Inception, resize to 299x299; otherwise 32x32 is fine
        if any(x in key for x in ['inc', 'bit']):
            transform_resize = transforms.Resize((299, 299))
            norm = transforms.Normalize((0.5,), (0.5,)) 
        elif any(x in key for x in ['vit']):
            transform_resize = transforms.Resize((224, 224))
            norm = transforms.Normalize((0.5,), (0.5,)) 
        else:
            transform_resize = transforms.Resize((224, 224)) # do nothing
            norm = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                        (0.2023, 0.1994, 0.2010)) 
    
        # Create model
        model = timm.create_model(model_name, pretrained=True, num_classes=10).to(device)
        model.eval()
        
        # Wrap resize + normalization + model in Sequential
        return torch.nn.Sequential(
            transform_resize,   # resize if Inception
            norm,               # normalization
            model
        )

In [3]:
timm.list_models('*regnet*')

['haloregnetz_b',
 'nf_regnet_b0',
 'nf_regnet_b1',
 'nf_regnet_b2',
 'nf_regnet_b3',
 'nf_regnet_b4',
 'nf_regnet_b5',
 'regnetv_040',
 'regnetv_064',
 'regnetx_002',
 'regnetx_004',
 'regnetx_006',
 'regnetx_008',
 'regnetx_016',
 'regnetx_032',
 'regnetx_040',
 'regnetx_064',
 'regnetx_080',
 'regnetx_120',
 'regnetx_160',
 'regnetx_320',
 'regnety_002',
 'regnety_004',
 'regnety_006',
 'regnety_008',
 'regnety_016',
 'regnety_032',
 'regnety_040',
 'regnety_040s_gn',
 'regnety_064',
 'regnety_080',
 'regnety_120',
 'regnety_160',
 'regnety_320',
 'regnetz_005',
 'regnetz_040',
 'regnetz_040h',
 'regnetz_b16',
 'regnetz_b16_evos',
 'regnetz_c16',
 'regnetz_c16_evos',
 'regnetz_d8',
 'regnetz_d8_evos',
 'regnetz_d32',
 'regnetz_e8']

In [4]:
import torchvision
from torch.utils.data import random_split, DataLoader

# CIFAR-10 normalization for training and test
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# Split train into train/val
train_size = int(0.9 * len(trainset))
val_size = len(trainset) - train_size
train_subset, val_subset = random_split(trainset, [train_size, val_size], generator=torch.Generator().manual_seed(56))

# DataLoaders
trainloader = DataLoader(train_subset, batch_size=48, shuffle=True, num_workers=4)
valloader = DataLoader(val_subset, batch_size=48, shuffle=False, num_workers=4)
testloader = DataLoader(testset, batch_size=48, shuffle=False, num_workers=4)

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(testset)}")

Train: 45000, Val: 5000, Test: 10000


### Train 

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def train_model(
    model,
    trainloader,
    valloader,
    epochs=10,
    lr=1e-3,
    checkpoint_path="best_model.pth",
    early_stop_patience=None,
):

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler()     # Mixed precision

    best_val_acc = 0
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss, total, correct = 0, 0, 0

        pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():   # FP16 forward
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix(loss=loss.item())

        train_acc = 100 * correct / total
        val_acc = evaluate(model, valloader)

        print(
            f"Epoch {epoch+1}: "
            f"Loss={running_loss/total:.4f}, "
            f"Train Acc={train_acc:.2f}%, "
            f"Val Acc={val_acc:.2f}%"
        )

        # -------------------------------
        # Save Best Model Checkpoint
        # -------------------------------
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "val_acc": best_val_acc,
                },
                checkpoint_path,
            )
            print(f"✨ Saved Best Model (Val Acc: {best_val_acc:.2f}%)")

            patience_counter = 0
        else:
            patience_counter += 1

        # -------------------------------
        # Early Stopping
        # -------------------------------
        if early_stop_patience is not None:
            if patience_counter >= early_stop_patience:
                print("⛔ Early stopping triggered.")
                break

    print(f"Training completed. Best Val Acc: {best_val_acc:.2f}%")
    return best_val_acc


In [6]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    return 100 * correct / total

In [None]:
model_name = "gcvit_tiny" 

model = get_models('cifar10', model_name, "vit") 

train_model(model, trainloader, valloader, epochs=50, lr=1e-3, checkpoint_path=f"checkpoints/{model_name}_cifar10.pth", 
           early_stop_patience=None)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth" to /home/firuz/.cache/torch/hub/checkpoints/gcvit_tiny_224_nvidia-ac783954.pth


Epoch 1/50: 100%|███████████████████████████████| 938/938 [03:00<00:00,  5.21it/s, loss=0.033]


Epoch 1: Loss=0.3747, Train Acc=87.50%, Val Acc=90.18%
✨ Saved Best Model (Val Acc: 90.18%)


Epoch 2/50: 100%|███████████████████████████████| 938/938 [02:53<00:00,  5.40it/s, loss=0.163]


Epoch 2: Loss=0.2300, Train Acc=92.39%, Val Acc=92.60%
✨ Saved Best Model (Val Acc: 92.60%)


Epoch 3/50: 100%|███████████████████████████████| 938/938 [02:56<00:00,  5.30it/s, loss=0.182]


Epoch 3: Loss=0.1931, Train Acc=93.63%, Val Acc=93.06%
✨ Saved Best Model (Val Acc: 93.06%)


Epoch 4/50: 100%|███████████████████████████████| 938/938 [02:53<00:00,  5.40it/s, loss=0.134]


Epoch 4: Loss=0.1675, Train Acc=94.45%, Val Acc=91.48%


Epoch 5/50: 100%|███████████████████████████████| 938/938 [03:15<00:00,  4.79it/s, loss=0.144]


Epoch 5: Loss=0.1548, Train Acc=94.86%, Val Acc=91.28%


Epoch 6/50: 100%|███████████████████████████████| 938/938 [02:53<00:00,  5.39it/s, loss=0.233]


Epoch 6: Loss=0.1463, Train Acc=95.13%, Val Acc=93.02%


Epoch 7/50: 100%|███████████████████████████████| 938/938 [02:54<00:00,  5.39it/s, loss=0.135]


Epoch 7: Loss=0.1190, Train Acc=96.02%, Val Acc=92.70%


Epoch 8/50:  51%|████████████████▏               | 476/938 [01:28<01:25,  5.41it/s, loss=0.25]