In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms 

In [3]:
import timm
import torch
import torchvision.transforms as transforms

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

def get_models(dataset, model_name, key):
    """
    Creates a model for CIFAR-10 or ImageNet dataset.
    For CIFAR-10: uses pretrained ImageNet weights, replaces classifier head with num_classes=10.
    """
    if dataset == 'imagenet':
        model = timm.create_model(model_name, pretrained=True, num_classes=1000).to(device)
        model.eval()
        if any(x in key for x in ['inc', 'vit', 'bit']):
            norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        else:
            norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        return torch.nn.Sequential(norm, model)
    
    if dataset == 'cifar10':
        # For Inception, resize to 299x299; otherwise 32x32 is fine
        if any(x in key for x in ['inc', 'vit', 'bit']):
            transform_resize = transforms.Resize((299, 299))
            norm = transforms.Normalize((0.5,), (0.5,)) 
            
        else:
            transform_resize = transforms.Resize((224, 224)) # do nothing
            # Standard CIFAR-10 normalization
            norm = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                        (0.2023, 0.1994, 0.2010)) 
    
        # Create model
        model = timm.create_model(model_name, pretrained=True, num_classes=10).to(device)
        model.eval()
        
        # Wrap resize + normalization + model in Sequential
        return torch.nn.Sequential(
            transform_resize,   # resize if Inception
            norm,               # normalization
            model
        )

In [4]:
import torchvision
from torch.utils.data import random_split, DataLoader

# CIFAR-10 normalization for training and test
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# Split train into train/val
train_size = int(0.9 * len(trainset))
val_size = len(trainset) - train_size
train_subset, val_subset = random_split(trainset, [train_size, val_size], generator=torch.Generator().manual_seed(56))

# DataLoaders
trainloader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=4)
valloader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=4)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=4)

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(testset)}")

Train: 45000, Val: 5000, Test: 10000


### Train 

In [5]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def train_model(model, trainloader, valloader, epochs=10, lr=1e-3):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    # optimizer = optim.SGD(model.parameters(), lr=0.01,
    #                   momentum=0.9, weight_decay=5e-4)

    for epoch in range(epochs):
        model.train()
        running_loss, total, correct = 0, 0, 0

        for images, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_acc = 100 * correct / total
        val_acc = evaluate(model, valloader)
        print(f"Epoch {epoch+1}: Loss={running_loss/total:.4f}, Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")


In [6]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    return 100 * correct / total

In [7]:
target_models = [
    # ("resnet50", "resnet50"), 
    ("convmixer_768_32", "vit_t"),
    
] 

for model_name, key  in target_models:
    model = get_models('cifar10', model_name, key) 
    train_model(model, trainloader, valloader, epochs=50, lr=1e-3)
    torch.save(model.state_dict(), f"checkpoints/{key}_cifar10.pth")

Epoch 1/50:   0%|                                                                                                                                             | 0/1407 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 166.00 MiB. GPU 2 has a total capacity of 10.75 GiB of which 74.62 MiB is free. Including non-PyTorch memory, this process has 10.67 GiB memory in use. Of the allocated memory 10.48 GiB is allocated by PyTorch, and 10.66 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

### Test data 

In [8]:
model = get_models('cifar10', "resnet152", "resnet152") 

state_dict = torch.load("checkpoints/resnet152_cifar10.pth")
model_state = model.state_dict()

# If model expects "1." but checkpoint doesn't have it
if list(model_state.keys())[0].startswith("1.") and not list(state_dict.keys())[0].startswith("1."):
    state_dict = {f"1.{k}": v for k, v in state_dict.items()}

# If checkpoint has "1." but model doesn’t
elif not list(model_state.keys())[0].startswith("1.") and list(state_dict.keys())[0].startswith("1."):
    state_dict = {k.replace("1.", "", 1): v for k, v in state_dict.items()}

model.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
import torch
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

# Initialize metrics
num_classes = 10
accuracy = MulticlassAccuracy(num_classes=num_classes).to(device)
precision = MulticlassPrecision(num_classes=num_classes, average='macro').to(device)
recall = MulticlassRecall(num_classes=num_classes, average='macro').to(device)
f1 = MulticlassF1Score(num_classes=num_classes, average='macro').to(device)

# Evaluate
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        accuracy.update(preds, labels)
        precision.update(preds, labels)
        recall.update(preds, labels)
        f1.update(preds, labels)

acc = accuracy.compute().item() * 100
prec = precision.compute().item() * 100
rec = recall.compute().item() * 100
f1_score = f1.compute().item() * 100

print(f"✅ Test Results:")
print(f"Accuracy:  {acc:.2f}%")
print(f"Precision: {prec:.2f}%")
print(f"Recall:    {rec:.2f}%")
print(f"F1-score:  {f1_score:.2f}%")


✅ Test Results:
Accuracy:  90.77%
Precision: 91.24%
Recall:    90.77%
F1-score:  90.76%
