# PROTOTYPE

## ü§ç Cell 1: Environment Variables

In [1]:
# Minimal installs: only what we need
!pip install -q timm==0.9.12 sentence-transformers==2.7.0 wandb==0.17.0 torchmetrics==1.4.0.post0

# Import essentials
import os
import random
import numpy as np
import torch
import logging
from torch import nn
import wandb
# Environment tweaks
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Seeds for reproducibility (enough for CUB)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False  # Deterministic
torch.backends.cudnn.deterministic = True

# Logging: only to file + console, no spam
logging.basicConfig(
    level=logging.INFO,
    force=True,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(), logging.FileHandler("/kaggle/working/train.log", mode="w")]
)

# Device check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"PyTorch {torch.__version__} | Device: {device}")
if device.type == "cuda":
    logging.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logging.info(f"Initial VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
else:
    logging.warning("No GPU detected. Training will be slow.")

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m664.8/664.8 MB[0m [31m792.4 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m211.5/211.5 MB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m56.3/56.3 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m127.9/127.9 MB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚

2026-02-15 13:35:26,449 - INFO - PyTorch 2.6.0+cu124 | Device: cuda
2026-02-15 13:35:26,487 - INFO - GPU: Tesla T4
2026-02-15 13:35:26,488 - INFO - Initial VRAM: 0.00 MB


## ü§ç Cell 2: Dataset Definition (Feature Engineering)

In [2]:
# ====================== FINAL DATASET & SPLIT CELL (VISION-ONLY) ======================
import os
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as T

# ---------------- CONFIG ----------------
SEED = 42
NUM_CLASSES = 200
IMG_SIZE = 224
BATCH_SIZE = 48

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# ---------------- PATHS ----------------
DATA_ROOT = "/kaggle/input/cub2002011/CUB_200_2011"
IMAGES_DIR = os.path.join(DATA_ROOT, "images")
IMAGES_FILE = os.path.join(DATA_ROOT, "images.txt")
LABELS_FILE = os.path.join(DATA_ROOT, "image_class_labels.txt")
BB_FILE = os.path.join(DATA_ROOT, "bounding_boxes.txt")
SPLIT_FILE = os.path.join(DATA_ROOT, "train_test_split.txt")

# ---------------- LOAD METADATA ----------------
images_df = pd.read_csv(IMAGES_FILE, sep=" ", header=None, names=["img_id", "path"])
labels_df = pd.read_csv(LABELS_FILE, sep=" ", header=None, names=["img_id", "class_id"])
bb_df = pd.read_csv(BB_FILE, sep=" ", header=None, names=["img_id", "x", "y", "width", "height"])
split_df = pd.read_csv(SPLIT_FILE, sep=" ", header=None, names=["img_id", "is_train"])

img_paths = dict(zip(images_df.img_id, images_df.path))
labels = {row.img_id: row.class_id - 1 for row in labels_df.itertuples()}  # 0‚Äì199
bboxes = {row.img_id: (row.x, row.y, row.width, row.height) for row in bb_df.itertuples()}

train_ids = split_df[split_df["is_train"] == 1]["img_id"].tolist()
test_ids  = split_df[split_df["is_train"] == 0]["img_id"].tolist()

print(f"Official train samples: {len(train_ids)}")
print(f"Official test samples:  {len(test_ids)}")

# ---------------- DATASET ----------------
class CUBVisionDataset(Dataset):
    def __init__(self, image_ids, train=False):
        self.ids = list(image_ids)
        self.train = train

        self.base_tf = T.Compose([
            T.Resize(256),
            T.CenterCrop(IMG_SIZE),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

        self.aug_tf = T.Compose([
            T.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(0.2, 0.2, 0.2, 0.1),
            T.RandomRotation(15),
        ])

    def crop_bbox(self, img, bbox):
        x, y, w, h = map(int, bbox)
        m = 0.15
        x = max(0, int(x - w * m))
        y = max(0, int(y - h * m))
        w = int(w * (1 + 2 * m))
        h = int(h * (1 + 2 * m))
        return img.crop((x, y, x + w, y + h))

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]

        path = os.path.join(IMAGES_DIR, img_paths[img_id])
        img = Image.open(path).convert("RGB")
        img = self.crop_bbox(img, bboxes[img_id])

        if self.train:
            img = self.aug_tf(img)
        img = self.base_tf(img)

        label = labels[img_id]

        return img, label

# ---------------- DATASETS ----------------
full_train_dataset = CUBVisionDataset(train_ids, train=True)
full_val_dataset   = CUBVisionDataset(train_ids, train=False)
test_dataset       = CUBVisionDataset(test_ids,  train=False)

# ---------------- STRATIFIED VAL SPLIT ----------------
val_size_per_class = 3

class_to_indices = {}
for idx, img_id in enumerate(train_ids):
    class_to_indices.setdefault(labels[img_id], []).append(idx)

train_indices, val_indices = [], []
for cls, indices in class_to_indices.items():
    random.shuffle(indices)
    val_indices.extend(indices[:val_size_per_class])
    train_indices.extend(indices[val_size_per_class:])

train_dataset = Subset(full_train_dataset, train_indices)
val_dataset   = Subset(full_val_dataset,   val_indices)

print(f"Final splits ‚Üí Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# ---------------- LOADERS ----------------
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True,
                          num_workers=4, pin_memory=True, persistent_workers=True)
val_loader   = DataLoader(val_dataset,   BATCH_SIZE, shuffle=False,
                          num_workers=4, pin_memory=True, persistent_workers=True)
test_loader  = DataLoader(test_dataset,  BATCH_SIZE, shuffle=False,
                          num_workers=4, pin_memory=True, persistent_workers=True)

print("‚úì Vision-only dataset ready")


2026-02-15 13:35:26,734 - INFO - NumExpr defaulting to 4 threads.


Official train samples: 5994
Official test samples:  5794
Final splits ‚Üí Train: 5394, Val: 600, Test: 5794
‚úì Vision-only dataset ready


## ü§ç Cell 3: Model Definition (backbone)

In [3]:
# ====================== FULL VISION-ONLY PIPELINE ======================
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import timm
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast

# ---------------- CONFIG ----------------
NUM_CLASSES = 200
VISION_DIM = 1024        # Swin-B output dim
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================== MODEL ======================
class ViTMixSwin(nn.Module):
    """
    Vision-only fine-grained classifier using Swin-B.
    """

    def __init__(self, freeze_vision=True):
        super().__init__()

        # ---- Vision Encoder (Swin-B) ----
        self.vision_encoder = timm.create_model(
            "swin_base_patch4_window7_224",
            pretrained=True,
            num_classes=0,       # remove classifier
            global_pool="avg"    # output (B, 1024)
        )

        if freeze_vision:
            for p in self.vision_encoder.parameters():
                p.requires_grad = False

        # ---- Classification Head ----
        self.classifier = nn.Sequential(
            nn.Linear(VISION_DIM, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(512, NUM_CLASSES),
        )

    def forward(self, images):
        v_feat = self.vision_encoder(images)
        logits = self.classifier(v_feat)
        return logits



# ====================== USAGE EXAMPLE ======================
# model = ViTMixSwin(freeze_vision=True).to(DEVICE)
# optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-4)
# scheduler = CosineAnnealingLR(optimizer, T_max=20)
# criterion = nn.CrossEntropyLoss()
# history, best_state = train_model(model, train_loader, val_loader, optimizer, criterion, DEVICE,
#                                   epochs=20, scheduler=scheduler, amp=True, unfreeze_after=5)


## ## ü§ç Cell 4 Training Utilities

In [4]:
# ====================== METRICS ======================
@torch.no_grad()
def accuracy(logits, targets, topk=(1,)):
    maxk = max(topk)
    batch_size = targets.size(0)
    _, preds = logits.topk(maxk, dim=1, largest=True, sorted=True)
    preds = preds.t()
    correct = preds.eq(targets.view(1, -1).expand_as(preds))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter:
    def __init__(self, name):
        self.name = name
        self.reset()
    def reset(self):
        self.val = self.sum = self.count = self.avg = 0.0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# ====================== TRAIN / VALIDATION ======================
def train_one_epoch(model, loader, optimizer, criterion, device, scaler=None):
    model.train()
    loss_meter = AverageMeter("train_loss")
    acc_meter = AverageMeter("train_acc1")

    for images, labels in tqdm(loader, desc="Train", leave=False, colour="green"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler:
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        acc1 = accuracy(outputs, labels, topk=(1,))[0]
        bs = images.size(0)
        loss_meter.update(loss.item(), bs)
        acc_meter.update(acc1.item(), bs)

    return {"train_loss": loss_meter.avg, "train_acc1": acc_meter.avg}

@torch.no_grad()
def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    loss_meter = AverageMeter("val_loss")
    acc1_meter = AverageMeter("val_acc1")
    acc5_meter = AverageMeter("val_acc5")

    for images, labels in tqdm(loader, desc="Validate", leave=False, colour="red"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        acc1, acc5 = accuracy(outputs, labels, topk=(1,5))
        bs = images.size(0)
        loss_meter.update(loss.item(), bs)
        acc1_meter.update(acc1.item(), bs)
        acc5_meter.update(acc5.item(), bs)

    return {"val_loss": loss_meter.avg, "val_acc1": acc1_meter.avg, "val_acc5": acc5_meter.avg}

# ====================== FULL TRAINING LOOP ======================
def train_model(model, train_loader, val_loader, optimizer, criterion,
                device, epochs=20, scheduler=None, amp=True, unfreeze_after=5):
    """
    - amp: Use mixed precision
    - unfreeze_after: epoch number to unfreeze last 2-3 Swin blocks for fine-tuning
    """
    scaler = GradScaler() if amp else None
    best_val_acc = 0.0
    best_state = None
    history = []

    for epoch in range(1, epochs+1):
        # Partial unfreeze
        if epoch == unfreeze_after:
            print(">> Unfreezing last 2 Swin blocks for fine-tuning...")
            # Unfreeze last 2 layers
            for name, param in model.vision_encoder.named_parameters():
                if "layers.2" in name or "layers.3" in name:
                    param.requires_grad = True

        train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_metrics = validate_one_epoch(model, val_loader, criterion, device)

        if scheduler:
            scheduler.step()

        metrics = {**train_metrics, **val_metrics}
        history.append(metrics)

        if val_metrics["val_acc1"] > best_val_acc:
            best_val_acc = val_metrics["val_acc1"]
            best_state = {"model": model.state_dict(), "optimizer": optimizer.state_dict(),
                          "epoch": epoch, "val_acc1": best_val_acc}

        print(f"Epoch [{epoch}/{epochs}] | "
              f"Train Loss: {train_metrics['train_loss']:.4f}, Train Acc: {train_metrics['train_acc1']:.2f}% | "
              f"Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_acc1']:.2f}%")

    return history, best_state

# STARTING TRAINING

In [5]:
model = ViTMixSwin(freeze_vision=True).to(DEVICE)

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss()

history, best_state = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=DEVICE,
    epochs=20,
    scheduler=scheduler,
    amp=True,
    unfreeze_after=5
)

# -------------------- 4Ô∏è‚É£ Print best validation accuracy --------------------
print(f"\n‚úì Best validation accuracy: {best_state['val_acc1']:.2f}% at epoch {best_state['epoch']}")
test_metrics = validate_one_epoch(
    model=model, 
    loader=test_loader,
    criterion=criterion,
    device=DEVICE
)

print("\n=== TEST RESULTS ===")
print(f"Top-1 Accuracy: {test_metrics['val_acc1']:.2f}%")
print(f"Top-5 Accuracy: {test_metrics['val_acc5']:.2f}%")


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
2026-02-15 13:35:33,788 - INFO - Loading pretrained weights from Hugging Face hub (timm/swin_base_patch4_window7_224.ms_in22k_ft_in1k)


model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

2026-02-15 13:35:36,078 - INFO - [timm/swin_base_patch4_window7_224.ms_in22k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
  scaler = GradScaler() if amp else None


Train:   0%|          | 0/113 [00:00<?, ?it/s]

  with autocast():


Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [1/20] | Train Loss: 4.5727, Train Acc: 15.29% | Val Loss: 3.2585, Val Acc: 47.50%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [2/20] | Train Loss: 2.6757, Train Acc: 43.29% | Val Loss: 1.8326, Val Acc: 59.33%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [3/20] | Train Loss: 1.7754, Train Acc: 57.12% | Val Loss: 1.3305, Val Acc: 66.00%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [4/20] | Train Loss: 1.3967, Train Acc: 63.44% | Val Loss: 1.0900, Val Acc: 72.33%
>> Unfreezing last 2 Swin blocks for fine-tuning...


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [5/20] | Train Loss: 1.1722, Train Acc: 68.58% | Val Loss: 0.9543, Val Acc: 74.50%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [6/20] | Train Loss: 1.0218, Train Acc: 71.82% | Val Loss: 0.8629, Val Acc: 76.00%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [7/20] | Train Loss: 0.9128, Train Acc: 74.10% | Val Loss: 0.8010, Val Acc: 78.50%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [8/20] | Train Loss: 0.8333, Train Acc: 77.51% | Val Loss: 0.7477, Val Acc: 79.00%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [9/20] | Train Loss: 0.7752, Train Acc: 78.70% | Val Loss: 0.7120, Val Acc: 81.00%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [10/20] | Train Loss: 0.7326, Train Acc: 79.37% | Val Loss: 0.7011, Val Acc: 80.67%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [11/20] | Train Loss: 0.6942, Train Acc: 81.48% | Val Loss: 0.6741, Val Acc: 82.33%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [12/20] | Train Loss: 0.6595, Train Acc: 82.00% | Val Loss: 0.6683, Val Acc: 81.83%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [13/20] | Train Loss: 0.6310, Train Acc: 82.93% | Val Loss: 0.6454, Val Acc: 82.67%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [14/20] | Train Loss: 0.6199, Train Acc: 83.50% | Val Loss: 0.6335, Val Acc: 82.83%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [15/20] | Train Loss: 0.6060, Train Acc: 83.93% | Val Loss: 0.6264, Val Acc: 82.33%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [16/20] | Train Loss: 0.5878, Train Acc: 83.83% | Val Loss: 0.6223, Val Acc: 83.17%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [17/20] | Train Loss: 0.5861, Train Acc: 84.15% | Val Loss: 0.6226, Val Acc: 83.50%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [18/20] | Train Loss: 0.5855, Train Acc: 84.39% | Val Loss: 0.6200, Val Acc: 83.67%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [19/20] | Train Loss: 0.5710, Train Acc: 84.56% | Val Loss: 0.6195, Val Acc: 83.67%


Train:   0%|          | 0/113 [00:00<?, ?it/s]

Validate:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch [20/20] | Train Loss: 0.5688, Train Acc: 84.72% | Val Loss: 0.6195, Val Acc: 83.50%

‚úì Best validation accuracy: 83.67% at epoch 19


Validate:   0%|          | 0/121 [00:00<?, ?it/s]


=== TEST RESULTS ===
Top-1 Accuracy: 84.48%
Top-5 Accuracy: 98.38%
