# Finetuning ViT on CIFAR

In [15]:
import os

import torch
import torchvision.transforms as T
import torch.nn as nn
import yaml
from argparse import Namespace
from tqdm import tqdm
from ffcv.fields import BytesField, IntField, RGBImageField
from ffcv.writer import DatasetWriter

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset
from PIL import Image

import timm
from torchsummary import summary

In [3]:
model = timm.create_model("vit_small_patch16_224", pretrained=True)
for param in model.parameters():
    param.requires_grad = False
outputs_attrs = 100
num_inputs = model.head.in_features
last_layer = nn.Linear(num_inputs, outputs_attrs)
model.head = last_layer
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 384, 14, 14]         295,296
          Identity-2             [-1, 196, 384]               0
        PatchEmbed-3             [-1, 196, 384]               0
           Dropout-4             [-1, 197, 384]               0
          Identity-5             [-1, 197, 384]               0
         LayerNorm-6             [-1, 197, 384]             768
            Linear-7            [-1, 197, 1152]         443,520
          Identity-8           [-1, 6, 197, 64]               0
          Identity-9           [-1, 6, 197, 64]               0
           Linear-10             [-1, 197, 384]         147,840
          Dropout-11             [-1, 197, 384]               0
        Attention-12             [-1, 197, 384]               0
         Identity-13             [-1, 197, 384]               0
         Identity-14             [-1, 1

In [4]:
img = torch.randn(1, 3, 64, 64)
img = T.functional.resize(img, size=(224, 224))
output = model(img)



In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_function = nn.CrossEntropyLoss()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

loader = get_loader(
    "cifar100",
    bs=128,
    mode="train",
    augment=True,
    dev=device,
    mixup=0.0,
    data_path='/scratch/ffcv/',
    data_resolution=32,
    crop_resolution=32,
)

Loading /scratch/ffcv/cifar100/train_32.beton


In [None]:
for ims, targs in tqdm(loader, desc="Training"):
    ims = T.functional.resize(ims, size=(224, 224))
    optimizer.zero_grad()
    outputs = model(ims)
    loss = loss_function(outputs, targs)
    loss.backward()
    optimizer.step()

In [10]:
torch.save(model, 'vit_small_patch16_224_cifar100.pth')

In [11]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader, extractor = None):
    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims = T.functional.resize(ims, size=(224, 224))
        preds = model(ims)
            
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [8]:
data_loader = get_loader(
    "cifar10",
    bs=128,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path='/scratch/ffcv/',
    data_resolution=32,
    crop_resolution=32,
)
test_acc, test_top5 = test(model, data_loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Loading /scratch/ffcv/cifar10/val_32.beton


Evaluation: 100%|██████████| 79/79 [27:33<00:00, 20.93s/it]

Test Accuracy         92.2800
Top 5 Test Accuracy           99.7600





In [12]:
data_loader = get_loader(
    "cifar100",
    bs=128,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path='/scratch/ffcv/',
    data_resolution=32,
    crop_resolution=32,
)
test_acc, test_top5 = test(model, data_loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Loading /scratch/ffcv/cifar100/test_32.beton


Evaluation: 100%|██████████| 79/79 [27:47<00:00, 21.11s/it]

Test Accuracy         63.1300
Top 5 Test Accuracy           85.2500



