Implement vision trasnformers on CIFAR 10 dataset using pretrained checkpoints.

Exponential LR: $\eta_t=\eta_0 . \gamma^t$, where $\eta_t$ is learning rate in iteration $t$. $\gamma$ is decay factor like $\gamma=0.95$.  

In [None]:
import time

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
#import torchmetrics
from torchvision import transforms
from torchvision.models import vit_b_16
from torchvision.models import ViT_B_16_Weights

In [None]:
!pip install torchmetrics



In [None]:
import torchmetrics
import torchvision
import torchvision.transforms as transforms
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.utils.data import Subset
# Define transformations (normalization, augmentation, etc.)
transform = transforms.Compose([transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Split training set into training (45,000) and validation (5,000)
# train_size = int(0.9 * len(trainset))  # 90% for training
# val_size = len(trainset) - train_size  # 10% for validation
# train_subset, val_subset = random_split(trainset, [train_size, val_size])
subset_size = 20000
train_subset = Subset(trainset, list(range(subset_size)))

val_subset = Subset(trainset, list(range(subset_size, subset_size + 5000)))



# Create DataLoaders
train_loader = DataLoader(train_subset, batch_size=16, shuffle=True, num_workers=0) #Uses 4 separate worker threads to load batches faster.
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False, num_workers=0)

test_loader = DataLoader(testset, batch_size=16, shuffle=False, num_workers=0)


In [None]:
def train(num_epochs, model, optimizer, train_loader, val_loader, device, scheduler):

    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)  #to track training accuracy

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):   #iterate over small batches
            model.train()

            ### FORWARD AND BACK PROP
            features, targets = features.to(device), targets.to(device)
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            loss.backward()

            ### UPDATE MODEL PARAMETERS
            optimizer.step()
            optimizer.zero_grad() #reset the gradients of all model parameters

            ### LOGGING
            if not batch_idx % 300:
                print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}")

            model.eval()
            with torch.no_grad():
                predicted_labels = torch.argmax(logits, 1)
                train_acc.update(predicted_labels, targets)
        scheduler.step()

        ### MORE LOGGING
        model.eval()
        with torch.no_grad():
            val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)

            for (features, targets) in val_loader:
                features, targets = features.to(device), targets.to(device)
                outputs = model(features)
                predicted_labels = torch.argmax(outputs, 1)
                val_acc.update(predicted_labels, targets)

            print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%")
            train_acc.reset(), val_acc.reset()


if __name__ == "__main__":

    print("PyTorch:", torch.__version__)
    torch.set_float32_matmul_precision("medium")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


PyTorch: 2.9.0+cu128


  _C._set_float32_matmul_precision(precision)


In [None]:
 #########################################
    ### 2 Initializing the Model
    #########################################

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) #16 (Patch Size) → The input image is divided into 16×16 pixel patches.


In [None]:
   # replace output layer
model.heads.head = torch.nn.Linear(in_features=768, out_features=10)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)  #too large learning rate destroys the pretrained features
scheduler = ExponentialLR(optimizer, gamma=0.9)  #after every epoch, learning rate decays by 10%.


In [None]:
    #########################################
    ### 3 Finetuning
    #########################################
torch.cuda.empty_cache()
start = time.time() #measure total training time
train(
    num_epochs=3,
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    scheduler=scheduler
)

end = time.time()
elapsed = end-start
print(f"Time elapsed {elapsed/60:.2f} min")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")


Epoch: 0001/0003 | Batch 0000/1250 | Loss: 2.2799
Epoch: 0001/0003 | Batch 0300/1250 | Loss: 0.3184
Epoch: 0001/0003 | Batch 0600/1250 | Loss: 0.1237
Epoch: 0001/0003 | Batch 0900/1250 | Loss: 0.1196
Epoch: 0001/0003 | Batch 1200/1250 | Loss: 0.0695
Epoch: 0001/0003 | Train acc.: 91.65% | Val acc.: 92.60%
Epoch: 0002/0003 | Batch 0000/1250 | Loss: 0.0629
Epoch: 0002/0003 | Batch 0300/1250 | Loss: 0.0073
Epoch: 0002/0003 | Batch 0600/1250 | Loss: 0.7606
Epoch: 0002/0003 | Batch 0900/1250 | Loss: 0.2015
Epoch: 0002/0003 | Batch 1200/1250 | Loss: 0.1817
Epoch: 0002/0003 | Train acc.: 96.61% | Val acc.: 94.10%
Epoch: 0003/0003 | Batch 0000/1250 | Loss: 0.0058
Epoch: 0003/0003 | Batch 0300/1250 | Loss: 0.0166
Epoch: 0003/0003 | Batch 0600/1250 | Loss: 0.0373
Epoch: 0003/0003 | Batch 0900/1250 | Loss: 0.0544
Epoch: 0003/0003 | Batch 1200/1250 | Loss: 0.1013
Epoch: 0003/0003 | Train acc.: 97.97% | Val acc.: 95.16%
Time elapsed 40.23 min
Memory used: 3.29 GB


In [None]:
#########################################
    ### 4 Evaluation
    #########################################

with torch.no_grad():
    model.eval()
    test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)

    for (features, targets) in test_loader:
        features, targets = features.to(device), targets.to(device)
        outputs = model(features)
        predicted_labels = torch.argmax(outputs, 1)
        test_acc.update(predicted_labels, targets)

print(f"Test accuracy {test_acc.compute()*100:.2f}%")

Test accuracy 95.03%
