In [8]:

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from datasets import load_dataset
from tqdm import tqdm

In [9]:
# Setup the device (If mps available use MPS)
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")  # Apple Silicon
    num_workers = 0
    print("Using Apple MPS (Metal GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")  # NVIDIA GPU
    num_workers = 0
    print("Using CUDA GPU")
else:
    device = torch.device("cpu")  # fallback
    num_workers = 2
    print("Using CPU")

Using CUDA GPU


In [10]:
# Load the dataset from Hugging face
dataset = load_dataset("clane9/imagenet-100")

In [11]:
# Preprocessing — this matches ImageNet-style inputs
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# --------------------------
# Custom PyTorch Dataset
# --------------------------
class ImageNet100Dataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        img = example["image"]

        if img.mode != "RGB":
            img = img.convert("RGB")

        if self.transform:
            img = self.transform(img)

        label = example["label"]
        return img, label


# Create a PyTorch DataLoader
# --------------------------
train_dataset = ImageNet100Dataset(dataset["train"], transform=transform)
val_dataset   = ImageNet100Dataset(dataset["validation"], transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=num_workers
)


In [12]:
# Loading pretrained ViT-B/16

import timm

model = timm.create_model(
    "vit_base_patch16_224",
    pretrained=True,
    num_classes=100
).to(device=device)

In [13]:
# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Use AMP for NVIDIA; MPS does not fully support AMP
use_amp = device.type == "cuda"
scaler = torch.cuda.amp.GradScaler() if use_amp else None

  scaler = torch.cuda.amp.GradScaler() if use_amp else None


In [14]:
# Training loop
epochs = 3
os.makedirs("checkpoints", exist_ok=True)

for epoch in range(epochs):
    model.train()
    running_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]")

    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()

        if use_amp:
            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=running_loss / (loop.n + 1))

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished, Avg Loss: {avg_loss:.4f}")

    # ✅ SAVE CHECKPOINT
    checkpoint = {
        "epoch": epoch + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": avg_loss,
    }

    torch.save(
        checkpoint,
        f"checkpoints/vit_b16_epoch_{epoch+1}.pth"
    )


  with torch.cuda.amp.autocast():
Epoch [1/10]: 100%|██████████| 7919/7919 [33:45<00:00,  3.91it/s, loss=0.888]


Epoch 1 finished, Avg Loss: 0.8883


Epoch [2/10]: 100%|██████████| 7919/7919 [31:41<00:00,  4.17it/s, loss=0.595]


Epoch 2 finished, Avg Loss: 0.5954


Epoch [3/10]: 100%|██████████| 7919/7919 [30:59<00:00,  4.26it/s, loss=0.482]


Epoch 3 finished, Avg Loss: 0.4819


Epoch [4/10]: 100%|██████████| 7919/7919 [30:57<00:00,  4.26it/s, loss=0.407]


Epoch 4 finished, Avg Loss: 0.4074


Epoch [5/10]: 100%|██████████| 7919/7919 [30:59<00:00,  4.26it/s, loss=0.343]


Epoch 5 finished, Avg Loss: 0.3427


Epoch [6/10]: 100%|██████████| 7919/7919 [33:07<00:00,  3.98it/s, loss=0.302]


Epoch 6 finished, Avg Loss: 0.3020


Epoch [7/10]: 100%|██████████| 7919/7919 [33:52<00:00,  3.90it/s, loss=0.266]


Epoch 7 finished, Avg Loss: 0.2664


Epoch [8/10]: 100%|██████████| 7919/7919 [33:43<00:00,  3.91it/s, loss=0.238] 


Epoch 8 finished, Avg Loss: 0.2383


Epoch [9/10]: 100%|██████████| 7919/7919 [32:44<00:00,  4.03it/s, loss=0.214]


Epoch 9 finished, Avg Loss: 0.2139


Epoch [10/10]: 100%|██████████| 7919/7919 [33:00<00:00,  4.00it/s, loss=0.194]


Epoch 10 finished, Avg Loss: 0.1937


In [15]:
# Save trained model

torch.save(model.state_dict(), "vit_b16_imagenet100.pth")
print("Model saved as vit_b16_imagenet100.pth")


Model saved as vit_b16_imagenet100.pth


In [16]:
# Evaluation Function

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

baseline_acc = evaluate(model, val_loader)
print(f"Baseline Accuracy on ImageNet-100: {baseline_acc:.2f}%")

Baseline Accuracy on ImageNet-100: 77.86%
