In [1]:
!pip install timm torchvision
!pip install ray[tune]

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->timm)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->timm)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->tim

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader, random_split
import timm
import data_preprocessing
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.optim as optim

In [3]:
BATCH_SIZE = 64
EPOCHS = 20
VAL_SPLIT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

In [4]:
# Reproducibility
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f884c09c430>

In [5]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

100%|██████████| 169M/169M [00:06<00:00, 26.6MB/s]


In [6]:
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

In [7]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=True):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True

    return model

model = create_dino_vit_s16_for_cifar100(False).to(DEVICE)

  model = create_fn(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/86.7M [00:00<?, ?B/s]

In [8]:
print(next(model.parameters()).device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

torch.backends.cudnn.benchmark = True

cuda:0
Trainable params: 21,704,164 / 21,704,164


In [None]:
# Loss, optimizer, scheduler, and scaler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.head.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()

# Early stopping parameters
patience = 2
best_val_acc = 0.0
epochs_no_improve = 0

# Training loop
for epoch in range(EPOCHS):
    model.train()
    correct, total, train_loss = 0, 0, 0.0

    for x, y in trainloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        with autocast():
            outputs = model(x)
            loss = criterion(outputs, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * y.size(0)
        _, pred = torch.max(outputs, 1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    scheduler.step()

    train_acc = correct / total
    train_loss /= total

    # Validation
    model.eval()
    correct, total, val_loss = 0, 0, 0.0
    with torch.no_grad():
        for x, y in valloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            outputs = model(x)
            loss = criterion(outputs, y)

            val_loss += loss.item() * y.size(0)
            _, pred = torch.max(outputs, 1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    val_acc = correct / total
    val_loss /= total

    print(f"Epoch {epoch+1:02d}/{EPOCHS} — Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    # Early stopping logic
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        epochs_no_improve = 0
        best_model_state = model.state_dict()  # save best model
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

model.load_state_dict(best_model_state)

  scaler = GradScaler()
  with autocast():


Epoch 01/20 — Train Acc: 0.5744 | Val Acc: 0.6188
Epoch 02/20 — Train Acc: 0.6757 | Val Acc: 0.6490
Epoch 03/20 — Train Acc: 0.7086 | Val Acc: 0.6684
Epoch 04/20 — Train Acc: 0.7326 | Val Acc: 0.6772
Epoch 05/20 — Train Acc: 0.7441 | Val Acc: 0.6836
Epoch 06/20 — Train Acc: 0.7604 | Val Acc: 0.6832
Epoch 07/20 — Train Acc: 0.7741 | Val Acc: 0.6960
Epoch 08/20 — Train Acc: 0.7904 | Val Acc: 0.6936
Epoch 09/20 — Train Acc: 0.8032 | Val Acc: 0.6988
Epoch 10/20 — Train Acc: 0.8192 | Val Acc: 0.7054
Epoch 11/20 — Train Acc: 0.8332 | Val Acc: 0.7184
Epoch 12/20 — Train Acc: 0.8461 | Val Acc: 0.7222
Epoch 13/20 — Train Acc: 0.8620 | Val Acc: 0.7230
Epoch 14/20 — Train Acc: 0.8764 | Val Acc: 0.7254
Epoch 15/20 — Train Acc: 0.8918 | Val Acc: 0.7286
Epoch 16/20 — Train Acc: 0.9058 | Val Acc: 0.7316
Epoch 17/20 — Train Acc: 0.9164 | Val Acc: 0.7350
Epoch 18/20 — Train Acc: 0.9243 | Val Acc: 0.7300
Epoch 19/20 — Train Acc: 0.9308 | Val Acc: 0.7318
Early stopping triggered at epoch 19


<All keys matched successfully>

In [None]:
#  test
model.eval()
correct, total, test_loss = 0, 0, 0.0
with torch.no_grad():
    for x, y in testloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        outputs = model(x)
        loss = criterion(outputs, y)

        test_loss += loss.item() * y.size(0)
        _, pred = torch.max(outputs, 1)
        correct += (pred == y).sum().item()
        total += y.size(0)

test_acc = correct / total
test_loss /= total

print(f"\n Final Test Accuracy: {test_acc:.4f} | Test Loss: {test_loss:.4f}")


 Final Test Accuracy: 0.7395 | Test Loss: 1.6905


In [10]:
from ray import tune

search_space = {
    "lr": tune.loguniform(1e-5, 1e-2),
    "momentum": tune.uniform(0.7, 0.99),
    "weight_decay": tune.loguniform(1e-6, 1e-3),
    "batch_size": tune.choice([32, 64, 128]),
}

In [11]:
from torchvision import transforms

def get_cifar_transform() -> transforms.Compose:
    return transforms.Compose([
        transforms.Resize((224, 224)),  # Resize CIFAR images to 224x224
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet means
            std=[0.229, 0.224, 0.225]    # ImageNet stds
        )
    ])


In [24]:
from ray.train import report
def train_vit(config):
    # Create model
    model = create_dino_vit_s16_for_cifar100().to(DEVICE)
    criterion = nn.CrossEntropyLoss()

    # Optimizer, scheduler, AMP
    optimizer = optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"]
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    scaler = torch.cuda.amp.GradScaler()

    # Load data using your custom pipeline (w/ or w/o augmentation)
    pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
    trainset, valset, testset = pipeline.run_pipeline()

    trainloader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True)
    valloader = DataLoader(valset, batch_size=config["batch_size"])
    testloader = DataLoader(testset, batch_size=config["batch_size"])

    best_val_acc = 0.0

    for epoch in range(20):
        model.train()
        correct, total, train_loss = 0, 0, 0.0

        for x, y in trainloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(x)
                loss = criterion(outputs, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item() * y.size(0)
            _, pred = outputs.max(1)
            correct += (pred == y).sum().item()
            total += y.size(0)

        scheduler.step()
        train_acc = correct / total

        # Validation
        model.eval()
        correct, total, val_loss = 0, 0, 0.0
        with torch.no_grad():
            for x, y in valloader:
                x, y = x.to(DEVICE), y.to(DEVICE)
                outputs = model(x)
                loss = criterion(outputs, y)
                val_loss += loss.item() * y.size(0)
                _, pred = outputs.max(1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        val_acc = correct / total

        if val_acc > best_val_acc:
            best_val_acc = val_acc

    report({"val_accuracy": val_acc, "train_accuracy": train_acc})

In [25]:
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.basic_variant import BasicVariantGenerator
import os

results_dir = os.path.abspath("ray_results")
storage_uri = f"file://{results_dir}"

analysis = tune.run(
    train_vit,
    config=search_space,
    storage_path=storage_uri,
    search_alg=BasicVariantGenerator(),
    num_samples=10,
    resources_per_trial={"cpu": 2, "gpu": 1},
    scheduler=ASHAScheduler(metric="val_accuracy", mode="max"),
    name="vit_hyperparam_search"
)


+----------------------------------------------------------+
| Configuration for experiment     vit_hyperparam_search   |
+----------------------------------------------------------+
| Search algorithm                 BasicVariantGenerator   |
| Scheduler                        AsyncHyperBandScheduler |
| Number of trials                 10                      |
+----------------------------------------------------------+

View detailed results here: /content/ray_results/vit_hyperparam_search
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-04-17_08-25-35_121040_820/artifacts/2025-04-17_08-36-11/vit_hyperparam_search/driver_artifacts`

Trial status: 10 PENDING
Current time: 2025-04-17 08:36:11. Total running time: 0s
Logical resource usage: 0/2 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:T4)
+---------------------------------------------------------------------------------------------+
| Trial name              status              lr     moment

[36m(train_vit pid=5427)[0m   model = create_fn(
  0%|          | 0.00/169M [00:00<?, ?B/s]
  0%|          | 65.5k/169M [00:00<08:00, 352kB/s]
  0%|          | 229k/169M [00:00<04:15, 660kB/s] 
  0%|          | 754k/169M [00:00<01:20, 2.08MB/s]
  1%|          | 1.41M/169M [00:00<00:48, 3.45MB/s]
  2%|▏         | 3.24M/169M [00:00<00:20, 7.98MB/s]
  3%|▎         | 5.47M/169M [00:00<00:13, 12.3MB/s]
  5%|▌         | 8.68M/169M [00:00<00:08, 18.3MB/s]
  7%|▋         | 12.4M/169M [00:00<00:06, 23.8MB/s]
  9%|▉         | 15.7M/169M [00:01<00:06, 22.7MB/s]
 11%|█▏        | 19.1M/169M [00:01<00:05, 25.8MB/s]
 13%|█▎        | 22.6M/169M [00:01<00:05, 28.3MB/s]
 15%|█▌        | 26.1M/169M [00:01<00:04, 30.2MB/s]
 17%|█▋        | 29.6M/169M [00:01<00:04, 31.4MB/s]
 20%|█▉        | 33.1M/169M [00:01<00:04, 31.6MB/s]
 22%|██▏       | 36.5M/169M [00:01<00:04, 32.5MB/s]
 24%|██▎       | 40.0M/169M [00:01<00:03, 33.0MB/s]
 26%|██▌       | 43.5M/169M [00:01<00:03, 33.5MB/s]
 28%|██▊       | 47.0M/16


Trial status: 1 RUNNING | 9 PENDING
Current time: 2025-04-17 08:36:41. Total running time: 30s
Logical resource usage: 2.0/2 CPUs, 0.25/1 GPUs (0.0/1.0 accelerator_type:T4)
+---------------------------------------------------------------------------------------------+
| Trial name              status              lr     momentum     weight_decay     batch_size |
+---------------------------------------------------------------------------------------------+
| train_vit_045fe_00000   RUNNING    2.5067e-05      0.715299      0.000359066             64 |
| train_vit_045fe_00001   PENDING    0.00132097      0.704371      5.66508e-05             32 |
| train_vit_045fe_00002   PENDING    1.09955e-05     0.900448      7.24943e-06            128 |
| train_vit_045fe_00003   PENDING    5.79904e-05     0.865917      1.18338e-05            128 |
| train_vit_045fe_00004   PENDING    0.00585614      0.781448      0.000790279            128 |
| train_vit_045fe_00005   PENDING    0.00235607      0.755

2025-04-17 08:37:17,727	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/content/ray_results/vit_hyperparam_search' in 0.0062s.


Trial status: 1 RUNNING | 9 PENDING
Current time: 2025-04-17 08:37:17. Total running time: 1min 5s
Logical resource usage: 2.0/2 CPUs, 0.25/1 GPUs (0.0/1.0 accelerator_type:T4)
+---------------------------------------------------------------------------------------------+
| Trial name              status              lr     momentum     weight_decay     batch_size |
+---------------------------------------------------------------------------------------------+
| train_vit_045fe_00000   RUNNING    2.5067e-05      0.715299      0.000359066             64 |
| train_vit_045fe_00001   PENDING    0.00132097      0.704371      5.66508e-05             32 |
| train_vit_045fe_00002   PENDING    1.09955e-05     0.900448      7.24943e-06            128 |
| train_vit_045fe_00003   PENDING    5.79904e-05     0.865917      1.18338e-05            128 |
| train_vit_045fe_00004   PENDING    0.00585614      0.781448      0.000790279            128 |
| train_vit_045fe_00005   PENDING    0.00235607      0.

KeyboardInterrupt: 