In [1]:
!pip install timm



In [10]:
import torch
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torch import nn
import torch.optim as optim
from torchvision.datasets import CIFAR100
import timm
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
import os

In [3]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 100
learning_rate = 1e-3

In [4]:
# Data augmentation and normalization for training
transform_train = transforms.Compose([
    transforms.Resize(224),  # Resize the image to 224x224
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

# Normalization for validation and testing
transform_test = transforms.Compose([
    transforms.Resize(224),  # Resize the image to 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

In [5]:
# Load datasets
train_val_set = CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_set = CIFAR100(root='./data', train=False, download=True, transform=transform_test)

# Split training and validation
num_train = len(train_val_set)
indices = list(range(num_train))
split = int(np.floor(0.1 * num_train))

train_set, val_set = random_split(train_val_set, [num_train - split, split])

# Data loaders
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# Define the model
class ViTClassifier(nn.Module):
    def __init__(self, num_classes=100):
        super(ViTClassifier, self).__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.vit(x)

model = ViTClassifier(num_classes=100).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

# TensorBoard
writer = SummaryWriter('runs/CIFAR100_ViT_experiment')

# Early stopping
early_stopping_patience = 10
best_val_loss = float('inf')
counter = 0

In [11]:
def train(epoch):
    model.train()
    running_loss = 0.0
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
    for i, (inputs, labels) in loop:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_description(f'Epoch [{epoch+1}/{epochs}]')
        loop.set_postfix(loss=running_loss / (i + 1))

def validate(epoch):
    model.eval()
    val_loss = 0.0
    loop = tqdm(enumerate(val_loader), total=len(val_loader), leave=False)
    with torch.no_grad():
        for i, (inputs, labels) in loop:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            loop.set_postfix(val_loss=val_loss / (i + 1))
    val_loss /= len(val_loader)
    return val_loss

In [None]:
# Main training loop
for epoch in range(epochs):
    train(epoch)
    current_val_loss = validate(epoch)
    scheduler.step()
    # Check for early stopping
    if current_val_loss < best_val_loss:
        best_val_loss = current_val_loss
        counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        counter += 1
    if counter >= early_stopping_patience:
        print("Early stopping triggered")
        break

Epoch [1/10]:  35%|███▍      | 245/704 [09:09<17:30,  2.29s/it, loss=4.26]

In [None]:
# Testing
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

writer.close()