In [1]:
!pip install torch torchvision wandb transformers huggingface_hub datasets gradio




In [12]:
import os, random, zipfile, urllib.request
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
import numpy as np
from collections import Counter
from tqdm import tqdm
from PIL import Image
print("PyTorch version:", torch.__version__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


PyTorch version: 2.8.0+cu126


In [11]:

if not os.path.exists("tiny-imagenet-200"):
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    urllib.request.urlretrieve(url, "tiny-imagenet-200.zip")
    with zipfile.ZipFile("tiny-imagenet-200.zip", "r") as zip_ref:
        zip_ref.extractall(".")
    print("Downloaded and extracted Tiny ImageNet.")
else:
    print("Dataset already available.")



Dataset already available.


In [33]:
import os
import random
import numpy as np
import pandas as pd
from torchvision import datasets, transforms
from torchvision.datasets.folder import default_loader
from torch.utils.data import DataLoader, Subset, Dataset


data_dir = "tiny-imagenet-200"
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    normalize
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

train_data_full = datasets.ImageFolder(train_dir, transform=train_transforms)


val_annotations = os.path.join(val_dir, "val_annotations.txt")
df_val = pd.read_csv(val_annotations, sep="\t", header=None,
                     names=["file", "class", "x1", "y1", "x2", "y2"])
val_class_map = dict(zip(df_val["file"], df_val["class"]))


val_images = []
for fname in os.listdir(os.path.join(val_dir, "images")):
    if fname in val_class_map:
        wnid = val_class_map[fname]
        if wnid in train_data_full.class_to_idx:
            label = train_data_full.class_to_idx[wnid]
            img_path = os.path.join(val_dir, "images", fname)
            val_images.append((img_path, label))

print(f" Validation images remapped: {len(val_images)} samples across {len(set(val_class_map.values()))} classes.")


class TinyImageNetValDataset(Dataset):
    def __init__(self, samples, transform=None, loader=default_loader):
        self.samples = samples
        self.transform = transform
        self.loader = loader
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = self.loader(path)
        if self.transform:
            img = self.transform(img)
        return img, label

val_data_full = TinyImageNetValDataset(val_images, transform=val_transforms)


def balanced_subset(dataset, n_per_class):
    """
    Returns a subset with up to n_per_class samples per class.
    Works for both train and val datasets.
    """
    if hasattr(dataset, "samples"):
        labels = [lbl for _, lbl in dataset.samples]
    elif hasattr(dataset, "samples_list"):
        labels = [lbl for _, lbl in dataset.samples_list]
    else:

        labels = [lbl for _, lbl in dataset.samples]

    unique_labels = sorted(set(labels))
    indices = []
    for c in unique_labels:
        c_indices = [i for i, l in enumerate(labels) if l == c]
        choose = min(n_per_class, len(c_indices))
        indices.extend(random.sample(c_indices, choose))
    return Subset(dataset, indices)



n_per_class_train = 100  # 200 * 200 = 40,000 training samples
n_per_class_val   = 7     # 7 * 200 = 1,400 validation samples

train_data = balanced_subset(train_data_full, n_per_class_train)
val_data   = balanced_subset(val_data_full, n_per_class_val)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_data, batch_size=64, shuffle=False, num_workers=2)

print(f"Using balanced subset: {len(train_data)} train samples, {len(val_data)} val samples.")


 Validation images remapped: 10000 samples across 200 classes.
Using balanced subset: 20000 train samples, 1400 val samples.


In [34]:
# Map wnid -> human name via words.txt
words_path = os.path.join(data_dir, "words.txt")
id_to_name = {}
with open(words_path) as f:
    for line in f:
        wnid, name = line.strip().split("\t")
        id_to_name[wnid] = name.split(",")[0]

class_ids = train_data_full.classes
# Count occurrences in balanced subset
subset_labels = [train_data_full.samples[i][1] for i in train_data.indices]
from collections import Counter
cnt = Counter(subset_labels)
top5 = cnt.most_common(5)
print("Top 5 classes in subset (readable name, wnid, count):")
for idx, c in top5:
    wnid = class_ids[idx]
    print(f"{id_to_name.get(wnid, wnid)} ({wnid}): {c}")


Top 5 classes in subset (readable name, wnid, count):
goldfish (n01443537): 100
European fire salamander (n01629819): 100
bullfrog (n01641577): 100
tailed frog (n01644900): 100
American alligator (n01698640): 100


In [53]:
import wandb

wandb.init(project="TinyImageNet-ResNet", name="resnet34_finetune_balanced_augmented", reinit=True)
wandb.config.update({
    "model": "resnet34-pretrained",
    "dataset": "tiny-imagenet-200 (balanced subset, strong aug)",
    "train_samples": len(train_data),
    "val_samples": len(val_data),
    "batch_size": 64,
    "n_per_class_train": n_per_class_train,
    "n_per_class_val": n_per_class_val,
    "optimizer": "AdamW",
    "lr_warmup": 1e-3,
    "lr_finetune": 5e-5,
    "weight_decay": 1e-4,
    "epochs": 7
})


0,1
baseline_accuracy,▁
drifted_accuracy,▁

0,1
baseline_accuracy,0.69357
drifted_accuracy,0.2147


[34m[1mwandb[0m: Detected [huggingface_hub.inference, mcp] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


In [54]:
config_text = "\n".join([f"{k}: {v}" for k, v in wandb.config.items()])
wandb.log({"training_config": wandb.Html(f"<pre>{config_text}</pre>")})


In [23]:
from torchvision import models

model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)


in_feats = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(in_feats, 200)
)
model = model.to(device)


criterion = nn.CrossEntropyLoss(label_smoothing=0.1)


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 174MB/s]


In [24]:

for name, param in model.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
warmup_epochs = 3


In [26]:
print("Warmup training (classifier head only)...")

for epoch in range(1, warmup_epochs + 1):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in tqdm(train_loader, desc=f"Warmup Epoch {epoch}/{warmup_epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    train_loss = running_loss / total
    wandb.log({"warmup_epoch": epoch, "train_acc": train_acc, "train_loss": train_loss})
    print(f"Warmup Epoch {epoch}: Train Acc {train_acc:.3f}, Train Loss {train_loss:.3f}")


Warmup training (classifier head only)...


Warmup Epoch 1/3: 100%|██████████| 625/625 [01:46<00:00,  5.87it/s]


Warmup Epoch 1: Train Acc 0.143, Train Loss 4.429


Warmup Epoch 2/3: 100%|██████████| 625/625 [01:38<00:00,  6.34it/s]


Warmup Epoch 2: Train Acc 0.231, Train Loss 3.958


Warmup Epoch 3/3: 100%|██████████| 625/625 [01:54<00:00,  5.44it/s]

Warmup Epoch 3: Train Acc 0.250, Train Loss 3.902





In [36]:

for param in model.parameters():
    param.requires_grad = True


optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-5,
    steps_per_epoch=len(train_loader),
    epochs=12
)

print("Starting full fine-tuning of entire ResNet34...")


Starting full fine-tuning of entire ResNet34...


In [37]:
best_val_acc = 0.0
epochs = 7

for epoch in range(1, epochs + 1):
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0
    for imgs, labels in tqdm(train_loader, desc=f"Fine-tune Epoch {epoch}/{epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        train_correct += preds.eq(labels).sum().item()
        train_total += labels.size(0)

    train_acc = train_correct / train_total
    train_loss /= train_total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            val_correct += preds.eq(labels).sum().item()
            val_total += labels.size(0)

    val_acc = val_correct / val_total
    val_loss /= val_total

    print(f"Epoch {epoch}: Train Acc {train_acc:.3f} | Val Acc {val_acc:.3f}")
    wandb.log({
        "epoch": epoch + warmup_epochs,
        "train_acc": train_acc,
        "train_loss": train_loss,
        "val_acc": val_acc,
        "val_loss": val_loss,
        "lr": scheduler.get_last_lr()[0]
    })

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "resnet34_tinyimagenet_best.pth")
        artifact = wandb.Artifact("resnet34_tinyimagenet", type="model")
        artifact.add_file("resnet34_tinyimagenet_best.pth")
        wandb.log_artifact(artifact)


Fine-tune Epoch 1/7: 100%|██████████| 313/313 [01:54<00:00,  2.73it/s]


Epoch 1: Train Acc 0.272 | Val Acc 0.570


Fine-tune Epoch 2/7: 100%|██████████| 313/313 [01:53<00:00,  2.75it/s]


Epoch 2: Train Acc 0.353 | Val Acc 0.638


Fine-tune Epoch 3/7: 100%|██████████| 313/313 [01:55<00:00,  2.72it/s]


Epoch 3: Train Acc 0.415 | Val Acc 0.651


Fine-tune Epoch 4/7: 100%|██████████| 313/313 [01:57<00:00,  2.66it/s]


Epoch 4: Train Acc 0.455 | Val Acc 0.661


Fine-tune Epoch 5/7: 100%|██████████| 313/313 [01:52<00:00,  2.79it/s]


Epoch 5: Train Acc 0.498 | Val Acc 0.682


Fine-tune Epoch 6/7: 100%|██████████| 313/313 [01:53<00:00,  2.77it/s]


Epoch 6: Train Acc 0.527 | Val Acc 0.682


Fine-tune Epoch 7/7: 100%|██████████| 313/313 [01:53<00:00,  2.77it/s]


Epoch 7: Train Acc 0.554 | Val Acc 0.694


In [38]:
print("Training complete. Best Validation Accuracy:", best_val_acc)
wandb.finish()


Training complete. Best Validation Accuracy: 0.6935714285714286


0,1
epoch,▁▂▃▅▆▇█
lr,▁▅▇█▇▆▅
train_acc,▁▃▅▆▇▇█
train_loss,█▆▄▃▂▂▁
val_acc,▁▅▆▆▇▇█
val_loss,█▄▃▂▂▂▁

0,1
epoch,10.0
lr,3e-05
train_acc,0.55385
train_loss,2.57283
val_acc,0.69357
val_loss,2.01699


In [41]:
from torchvision import models
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet34(weights=None)
in_feats = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(in_feats, 200)
)
model = model.to(device)

model.load_state_dict(torch.load("resnet34_tinyimagenet_best.pth", map_location=device))
model.eval()
print("Loaded best trained model for evaluation.")


Loaded best trained model for evaluation.


In [45]:
import wandb

wandb.init(project="TinyImageNet-ResNet", name="resnet34_drift_eval", reinit=True)
wandb.config.update({
    "phase": "drift_evaluation",
    "model": "resnet34_finetuned",
    "dataset": "tiny-imagenet-200",
})
print("W&B reinitialized for drift evaluation logging.")


W&B reinitialized for drift evaluation logging.


In [46]:
from torchvision import transforms

drift_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ColorJitter(brightness=0.8, contrast=0.8),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x + 0.05 * torch.randn_like(x)),
    normalize
])

val_data_drifted = TinyImageNetValDataset(val_images, transform=drift_transforms)
val_loader_drifted = DataLoader(val_data_drifted, batch_size=64, shuffle=False, num_workers=2)


In [47]:
def evaluate_model(model, dataloader):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total

baseline_acc = evaluate_model(model, val_loader)
drifted_acc = evaluate_model(model, val_loader_drifted)

print(f"Baseline Accuracy: {baseline_acc*100:.2f}%")
print(f"Drifted Accuracy:  {drifted_acc*100:.2f}%")

wandb.log({"baseline_accuracy": baseline_acc, "drifted_accuracy": drifted_acc})

if drifted_acc < 0.8 * baseline_acc:
    wandb.alert(
        title="Model Accuracy Drift Detected",
        text=f"Accuracy dropped from {baseline_acc*100:.2f}% to {drifted_acc*100:.2f}%"
    )


Baseline Accuracy: 69.36%
Drifted Accuracy:  21.47%


In [58]:
%%writefile app.py
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import sys

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model():
    model = models.resnet34(weights=None)
    in_feats = model.fc.in_features
    model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(in_feats, 200))
    if os.path.exists("resnet34_tinyimagenet_best.pth"):
        model.load_state_dict(torch.load("resnet34_tinyimagenet_best.pth", map_location=device))
    else:
        print(" Model weights not found!", file=sys.stderr)
    model.to(device)
    model.eval()
    return model

model = load_model()

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])


id_to_name = {}
if os.path.exists("words.txt"):
    with open("words.txt", "r") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                wnid, name = parts
                id_to_name[wnid] = name.split(",")[0].strip()


if os.path.exists("wnids.txt"):
    with open("wnids.txt", "r") as f:
        class_ids = [line.strip() for line in f if line.strip()]
else:
    print("wnids.txt not found! Using alphabetical order of words.txt.", file=sys.stderr)
    class_ids = sorted(list(id_to_name.keys()))


idx_to_label = [id_to_name.get(wnid, wnid) for wnid in class_ids]

def predict(image):
    image = val_transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(image)
        probs = torch.nn.functional.softmax(outputs[0], dim=0)
        top5 = torch.topk(probs, 5)

    results = {}
    for idx, prob in zip(top5.indices, top5.values):
        wnid = class_ids[idx.item()] if idx.item() < len(class_ids) else f"class_{idx.item()}"
        readable = id_to_name.get(wnid, wnid)
        results[f"{readable} ({wnid})"] = round(float(prob.item()), 4)
    return results

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Label(num_top_classes=5, label="Top-5 Predictions"),
    title="Tiny ImageNet Classifier (ResNet-34)",
    description="Upload an image to see Top-5 predicted Tiny ImageNet classes with WNIDs."
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)


Overwriting app.py


In [51]:
%%writefile requirements.txt
torch
torchvision
gradio
Pillow
numpy


Writing requirements.txt
