# Finetuning ViT-S/16 on CIFAR-10 and CIFAR-100

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary
import torchvision.transforms as T
import timm
from tqdm import tqdm

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, AverageMeter

### Retrain classifier

In [None]:
# This code works for either cifar10 or cifar100
dataset_name = "cifar10"
data_path = '/scratch/data/ffcv/'

# Import pretrained ViT model from timm
model = timm.create_model("vit_small_patch16_224", pretrained=True)

# Freeze all model weights
for param in model.parameters():
    param.requires_grad = False

# Replace the last layer with a new layer that has the correct number of outputs
outputs_attrs = int(dataset_name.split("r")[1])
num_inputs = model.head.in_features
last_layer = nn.Linear(num_inputs, outputs_attrs)
model.head = last_layer

summary(model, (3, 224, 224))

In [None]:
optimizer = torch.optim.AdamW(model.parameters())
loss_function = nn.CrossEntropyLoss()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Get the dataloader for the dataset
loader = get_loader(
    dataset_name,
    bs=1024,
    mode="train",
    augment=True,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=32,
    crop_resolution=32,
)

# Train weights for the linear layer for 7 epochs, resizing the input images
num_epochs = 7

for _ in range(num_epochs):
    for ims, targs in tqdm(loader, desc="Training"):
        ims = T.functional.resize(ims, size=(224, 224))
        optimizer.zero_grad()
        outputs = model(ims)
        loss = loss_function(outputs, targs)
        loss.backward()
        optimizer.step()

torch.save(model, f'vit_models/vit_small_patch16_224_{dataset_name}_{num_epochs}.pth')

### Evaluate model accuracy

In [None]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):
    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims = T.functional.resize(ims, size=(224, 224))
        preds = model(ims)
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [None]:
model = torch.load(f'vit_models/vit_small_patch16_224_{dataset_name}_{num_epochs}.pth')

data_loader = get_loader(
    dataset_name,
    bs=128,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=32,
    crop_resolution=32,
)
test_acc, test_top5 = test(model, data_loader)

# Print all the results
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))