# CE + CL VIT-TINY-Patch SIZE 2

In [2]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from transformers import (
    ViTConfig,
    ViTPreTrainedModel,
    ViTModel,
    get_cosine_schedule_with_warmup
)
from tqdm.auto import tqdm
import random

# Setup & Seeding for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_dp = torch.cuda.device_count() > 1
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

# Initial transform to split dataset
init_tf = transforms.ToTensor()
dataset_full = datasets.CIFAR10('./data', train=True, download=True, transform=init_tf)
val_size = int(0.1 * len(dataset_full))
train_size = len(dataset_full) - val_size
train_ds, val_ds = torch.utils.data.random_split(
    dataset_full,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

batch_size = 128

# Precompute DataLoaders for each resolution
stages = [(r) for r in [(12, 10), (16, 10), (20, 10), (24, 10), (28, 10), (32, 10)]]
dataloader_dict = {}

for resolution, _ in stages:
    train_tf = transforms.Compose([
        transforms.RandomCrop(resolution, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    val_tf = transforms.Compose([
        transforms.Resize(resolution),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_set = datasets.CIFAR10('./data', train=True, download=False, transform=train_tf)
    val_set   = datasets.CIFAR10('./data', train=True, download=False, transform=val_tf)

    train_loader = DataLoader(Subset(train_set, train_ds.indices), batch_size=batch_size,
                              shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(Subset(val_set, val_ds.indices), batch_size=batch_size,
                            shuffle=False, num_workers=0, pin_memory=True)

    dataloader_dict[resolution] = {
        'train': train_loader,
        'val': val_loader
    }

# Test loader (fixed resolution)
test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
test_ds = datasets.CIFAR10('./data', train=False, transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# # ViT-Tiny Classifier
class ViTTinyClassifier(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig):
        super().__init__(config)
        self.vit = ViTModel(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self, pixel_values, labels=None):
        outputs = self.vit(
            pixel_values,
            interpolate_pos_encoding=True,
            return_dict=True
        )
        cls_output = outputs.last_hidden_state[:, 0]
        logits = self.classifier(cls_output)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        return {"logits": logits, "loss": loss}

# # Instantiate model
tiny_config = ViTConfig(
    hidden_size=192,
    num_hidden_layers=12,
    num_attention_heads=3,
    intermediate_size=768,
    patch_size=2,
    image_size=32,
    num_labels=10,
    hidden_dropout_prob=0.1,
    classifier_dropout_prob=0.1
)
student = ViTTinyClassifier(tiny_config).to(device)
student = nn.DataParallel(student)

no_decay = ["bias", "LayerNorm.weight"]
decay_params = [p for n,p in student.named_parameters() if not any(nd in n for nd in no_decay)]
nodecay_params = [p for n,p in student.named_parameters() if     any(nd in n for nd in no_decay)]

optimizer = optim.AdamW([
    {"params": decay_params,   "weight_decay": 1e-4},
    {"params": nodecay_params, "weight_decay": 0.0},
], lr=3e-4)


# Evaluation function
def eval_model(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(pixel_values=imgs)
            preds = outputs["logits"].argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

# Curriculum Training
for resolution, epochs in stages:
    print(f"\n=== Training at resolution {resolution}px ===")
    tr_loader = dataloader_dict[resolution]['train']
    vl_loader = dataloader_dict[resolution]['val']

    best_val_acc = 0
    no_improve = 0
    patience = 3

    for epoch in range(1, epochs + 1):
        student.train()
        train_correct = train_total = 0
        for imgs, labels in tqdm(tr_loader, desc=f"Epoch {epoch}/{epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = student(pixel_values=imgs, labels=labels,)
            loss = outputs["loss"]
            if loss is not None and loss.dim() > 0:
                loss = loss.mean()
            loss.backward()
            optimizer.step()

            preds = outputs["logits"].argmax(1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

        train_acc = 100 * train_correct / train_total
        val_acc = eval_model(student, vl_loader)
        print(f"{resolution}px Ep{epoch}: Train {train_acc:.2f}%  Val {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            no_improve = 0
            torch.save(student.state_dict(), f'best_tiny_{resolution}px.pth')
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"→ Early stopping at resolution {resolution}px")
                student.load_state_dict(torch.load(f'best_tiny_{resolution}px.pth'))
                break

    print(f"=> Stage {resolution}px best Val = {best_val_acc:.2f}%")

# Final test evaluation
student.load_state_dict(torch.load('best_tiny_32px.pth'))
final_test_acc = eval_model(student, test_loader)
print(f"\n🌟 Final ViT-Tiny Test Acc: {final_test_acc:.2f}%")

Files already downloaded and verified

=== Training at resolution 12px ===


Epoch 1/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep1: Train 20.76%  Val 26.08%


Epoch 2/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep2: Train 24.38%  Val 30.32%


Epoch 3/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep3: Train 26.45%  Val 33.10%


Epoch 4/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep4: Train 27.92%  Val 30.68%


Epoch 5/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep5: Train 29.02%  Val 33.74%


Epoch 6/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep6: Train 29.65%  Val 33.66%


Epoch 7/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep7: Train 30.59%  Val 35.32%


Epoch 8/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep8: Train 30.79%  Val 36.14%


Epoch 9/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep9: Train 31.70%  Val 34.30%


Epoch 10/10:   0%|          | 0/352 [00:00<?, ?it/s]

12px Ep10: Train 32.14%  Val 36.10%
=> Stage 12px best Val = 36.14%

=== Training at resolution 16px ===


Epoch 1/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep1: Train 36.76%  Val 40.26%


Epoch 2/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep2: Train 38.17%  Val 40.64%


Epoch 3/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep3: Train 38.74%  Val 42.20%


Epoch 4/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep4: Train 39.77%  Val 42.76%


Epoch 5/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep5: Train 39.79%  Val 41.50%


Epoch 6/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep6: Train 40.56%  Val 43.24%


Epoch 7/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep7: Train 40.74%  Val 42.30%


Epoch 8/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep8: Train 41.55%  Val 43.98%


Epoch 9/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep9: Train 41.48%  Val 43.10%


Epoch 10/10:   0%|          | 0/352 [00:00<?, ?it/s]

16px Ep10: Train 42.06%  Val 44.74%
=> Stage 16px best Val = 44.74%

=== Training at resolution 20px ===


Epoch 1/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep1: Train 47.34%  Val 49.72%


Epoch 2/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep2: Train 47.79%  Val 48.58%


Epoch 3/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep3: Train 48.46%  Val 51.46%


Epoch 4/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep4: Train 48.85%  Val 50.04%


Epoch 5/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep5: Train 49.42%  Val 50.54%


Epoch 6/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep6: Train 49.96%  Val 51.62%


Epoch 7/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep7: Train 50.62%  Val 50.72%


Epoch 8/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep8: Train 50.83%  Val 52.74%


Epoch 9/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep9: Train 51.58%  Val 53.30%


Epoch 10/10:   0%|          | 0/352 [00:00<?, ?it/s]

20px Ep10: Train 51.99%  Val 54.84%
=> Stage 20px best Val = 54.84%

=== Training at resolution 24px ===


Epoch 1/10:   0%|          | 0/352 [00:00<?, ?it/s]

24px Ep1: Train 56.94%  Val 55.76%


Epoch 2/10:   0%|          | 0/352 [00:00<?, ?it/s]

24px Ep2: Train 57.81%  Val 58.18%


Epoch 3/10:   0%|          | 0/352 [00:00<?, ?it/s]

24px Ep3: Train 58.10%  Val 55.82%


Epoch 4/10:   0%|          | 0/352 [00:00<?, ?it/s]

24px Ep4: Train 59.03%  Val 57.92%


Epoch 5/10:   0%|          | 0/352 [00:00<?, ?it/s]

24px Ep5: Train 59.75%  Val 57.22%
→ Early stopping at resolution 24px
=> Stage 24px best Val = 58.18%

=== Training at resolution 28px ===


  student.load_state_dict(torch.load(f'best_tiny_{resolution}px.pth'))


Epoch 1/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep1: Train 62.01%  Val 60.92%


Epoch 2/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep2: Train 62.74%  Val 62.42%


Epoch 3/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep3: Train 63.77%  Val 61.96%


Epoch 4/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep4: Train 64.76%  Val 63.60%


Epoch 5/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep5: Train 65.08%  Val 64.08%


Epoch 6/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep6: Train 65.85%  Val 64.42%


Epoch 7/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep7: Train 66.26%  Val 64.90%


Epoch 8/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep8: Train 67.28%  Val 65.90%


Epoch 9/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep9: Train 67.68%  Val 64.32%


Epoch 10/10:   0%|          | 0/352 [00:00<?, ?it/s]

28px Ep10: Train 68.29%  Val 65.88%
=> Stage 28px best Val = 65.90%

=== Training at resolution 32px ===


Epoch 1/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep1: Train 71.29%  Val 71.56%


Epoch 2/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep2: Train 71.92%  Val 71.64%


Epoch 3/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep3: Train 72.91%  Val 71.18%


Epoch 4/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep4: Train 73.25%  Val 70.38%


Epoch 5/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep5: Train 74.30%  Val 71.88%


Epoch 6/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep6: Train 74.39%  Val 71.88%


Epoch 7/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep7: Train 75.53%  Val 71.36%


Epoch 8/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep8: Train 75.81%  Val 72.04%


Epoch 9/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep9: Train 76.63%  Val 71.88%


Epoch 10/10:   0%|          | 0/352 [00:00<?, ?it/s]

32px Ep10: Train 76.82%  Val 72.54%
=> Stage 32px best Val = 72.54%


  student.load_state_dict(torch.load('best_tiny_32px.pth'))



🌟 Final ViT-Tiny Test Acc: 71.11%
