# PROTOTYPE

## Cell 1: Environment Variables

In [None]:
# Minimal installs: only what we need
!pip install -q timm==0.9.12 sentence-transformers==2.7.0 wandb==0.17.0 torchmetrics==1.4.0.post0

# Import essentials
import os
import random
import numpy as np
import torch
import logging
from torch import nn
import wandb
# Environment tweaks
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Seeds for reproducibility (enough for CUB)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False  # Deterministic
torch.backends.cudnn.deterministic = True

# Logging: only to file + console, no spam
logging.basicConfig(
    level=logging.INFO,
    force=True,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(), logging.FileHandler("/kaggle/working/train.log", mode="w")]
)

# Device check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"PyTorch {torch.__version__} | Device: {device}")
if device.type == "cuda":
    logging.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logging.info(f"Initial VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
else:
    logging.warning("No GPU detected. Training will be slow.")

## Cell 2: Dataset Definition (Feature Engineering)

In [None]:
# ====================== FINAL DATASET & SPLIT CELL (VISION-ONLY) ======================
import os
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as T

# ---------------- CONFIG ----------------
SEED = 42
NUM_CLASSES = 200
IMG_SIZE = 224
BATCH_SIZE = 48

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# ---------------- PATHS ----------------
DATA_ROOT = "/kaggle/input/cub2002011/CUB_200_2011"
IMAGES_DIR = os.path.join(DATA_ROOT, "images")
IMAGES_FILE = os.path.join(DATA_ROOT, "images.txt")
LABELS_FILE = os.path.join(DATA_ROOT, "image_class_labels.txt")
BB_FILE = os.path.join(DATA_ROOT, "bounding_boxes.txt")
SPLIT_FILE = os.path.join(DATA_ROOT, "train_test_split.txt")

# ---------------- LOAD METADATA ----------------
images_df = pd.read_csv(IMAGES_FILE, sep=" ", header=None, names=["img_id", "path"])
labels_df = pd.read_csv(LABELS_FILE, sep=" ", header=None, names=["img_id", "class_id"])
bb_df = pd.read_csv(BB_FILE, sep=" ", header=None, names=["img_id", "x", "y", "width", "height"])
split_df = pd.read_csv(SPLIT_FILE, sep=" ", header=None, names=["img_id", "is_train"])

img_paths = dict(zip(images_df.img_id, images_df.path))
labels = {row.img_id: row.class_id - 1 for row in labels_df.itertuples()}  # 0–199
bboxes = {row.img_id: (row.x, row.y, row.width, row.height) for row in bb_df.itertuples()}

train_ids = split_df[split_df["is_train"] == 1]["img_id"].tolist()
test_ids  = split_df[split_df["is_train"] == 0]["img_id"].tolist()

print(f"Official train samples: {len(train_ids)}")
print(f"Official test samples:  {len(test_ids)}")

# ---------------- DATASET ----------------
class CUBVisionDataset(Dataset):
    def __init__(self, image_ids, train=False):
        self.ids = list(image_ids)
        self.train = train

        self.base_tf = T.Compose([
            T.Resize(256),
            T.CenterCrop(IMG_SIZE),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

        self.aug_tf = T.Compose([
            T.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(0.2, 0.2, 0.2, 0.1),
            T.RandomRotation(15),
        ])

    def crop_bbox(self, img, bbox):
        x, y, w, h = map(int, bbox)
        m = 0.15
        x = max(0, int(x - w * m))
        y = max(0, int(y - h * m))
        w = int(w * (1 + 2 * m))
        h = int(h * (1 + 2 * m))
        return img.crop((x, y, x + w, y + h))

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]

        path = os.path.join(IMAGES_DIR, img_paths[img_id])
        img = Image.open(path).convert("RGB")
        img = self.crop_bbox(img, bboxes[img_id])

        if self.train:
            img = self.aug_tf(img)
        img = self.base_tf(img)

        label = labels[img_id]

        return img, label

# ---------------- DATASETS ----------------
full_train_dataset = CUBVisionDataset(train_ids, train=True)
full_val_dataset   = CUBVisionDataset(train_ids, train=False)
test_dataset       = CUBVisionDataset(test_ids,  train=False)

# ---------------- STRATIFIED VAL SPLIT ----------------
val_size_per_class = 3

class_to_indices = {}
for idx, img_id in enumerate(train_ids):
    class_to_indices.setdefault(labels[img_id], []).append(idx)

train_indices, val_indices = [], []
for cls, indices in class_to_indices.items():
    random.shuffle(indices)
    val_indices.extend(indices[:val_size_per_class])
    train_indices.extend(indices[val_size_per_class:])

train_dataset = Subset(full_train_dataset, train_indices)
val_dataset   = Subset(full_val_dataset,   val_indices)

print(f"Final splits → Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# ---------------- LOADERS ----------------
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True,
                          num_workers=4, pin_memory=True, persistent_workers=True)
val_loader   = DataLoader(val_dataset,   BATCH_SIZE, shuffle=False,
                          num_workers=4, pin_memory=True, persistent_workers=True)
test_loader  = DataLoader(test_dataset,  BATCH_SIZE, shuffle=False,
                          num_workers=4, pin_memory=True, persistent_workers=True)

print("✓ Vision-only dataset ready")


## Cell 3: Model Definition (backbone)

In [None]:
# ====================== FULL VISION-ONLY PIPELINE ======================
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import timm
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast

# ---------------- CONFIG ----------------
NUM_CLASSES = 200
VISION_DIM = 1024        # Swin-B output dim
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================== MODEL ======================
class ViTMixSwin(nn.Module):
    """
    Vision-only fine-grained classifier using Swin-B.
    """

    def __init__(self, freeze_vision=True):
        super().__init__()

        # ---- Vision Encoder (Swin-B) ----
        self.vision_encoder = timm.create_model(
            "swin_base_patch4_window7_224",
            pretrained=True,
            num_classes=0,       # remove classifier
            global_pool="avg"    # output (B, 1024)
        )

        if freeze_vision:
            for p in self.vision_encoder.parameters():
                p.requires_grad = False

        # ---- Classification Head ----
        self.classifier = nn.Sequential(
            nn.Linear(VISION_DIM, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(512, NUM_CLASSES),
        )

    def forward(self, images):
        v_feat = self.vision_encoder(images)
        logits = self.classifier(v_feat)
        return logits



# ====================== USAGE EXAMPLE ======================
# model = ViTMixSwin(freeze_vision=True).to(DEVICE)
# optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-4)
# scheduler = CosineAnnealingLR(optimizer, T_max=20)
# criterion = nn.CrossEntropyLoss()
# history, best_state = train_model(model, train_loader, val_loader, optimizer, criterion, DEVICE,
#                                   epochs=20, scheduler=scheduler, amp=True, unfreeze_after=5)


## Cell 4 Training Utilities

In [None]:
import os
import torch
import pandas as pd
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

# ====================== METRICS ======================
@torch.no_grad()
def accuracy(logits, targets, topk=(1,)):
    maxk = max(topk)
    batch_size = targets.size(0)
    _, preds = logits.topk(maxk, dim=1, largest=True, sorted=True)
    preds = preds.t()
    correct = preds.eq(targets.view(1, -1).expand_as(preds))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter:
    def __init__(self, name):
        self.name = name
        self.reset()
    def reset(self):
        self.val = self.sum = self.count = self.avg = 0.0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# ====================== TRAIN / VALIDATION ======================
def train_one_epoch(model, loader, optimizer, criterion, device, scaler=None):
    model.train()
    loss_meter = AverageMeter("train_loss")
    acc1_meter = AverageMeter("train_acc1")

    for images, labels in tqdm(loader, desc="Train", leave=False, colour="green"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler:
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        acc1 = accuracy(outputs, labels, topk=(1,))[0]
        bs = images.size(0)
        loss_meter.update(loss.item(), bs)
        acc1_meter.update(acc1.item(), bs)

    return {"train_loss": loss_meter.avg, "train_acc1": acc1_meter.avg}

@torch.no_grad()
def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    loss_meter = AverageMeter("val_loss")
    acc1_meter = AverageMeter("val_acc1")
    acc5_meter = AverageMeter("val_acc5")

    for images, labels in tqdm(loader, desc="Validate", leave=False, colour="red"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        acc1, acc5 = accuracy(outputs, labels, topk=(1,5))
        bs = images.size(0)
        loss_meter.update(loss.item(), bs)
        acc1_meter.update(acc1.item(), bs)
        acc5_meter.update(acc5.item(), bs)

    return {"val_loss": loss_meter.avg, "val_acc1": acc1_meter.avg, "val_acc5": acc5_meter.avg}

# ====================== FULL TRAINING LOOP ======================
def train_model(model, train_loader, val_loader, optimizer, criterion,
                device, epochs=20, scheduler=None, amp=True, unfreeze_after=5,
                save_dir=".", best_model_name="best_model.pth", log_csv_name="training_log.csv"):
    """
    Full training with explicit saving of all metrics per epoch in CSV,
    and best model saved in PyTorch format.
    """
    os.makedirs(save_dir, exist_ok=True)
    scaler = GradScaler() if amp else None
    best_val_acc = 0.0
    best_state = None
    history = []

    # Initialize CSV file
    log_path = os.path.join(save_dir, log_csv_name)
    df_columns = ["epoch","train_loss","train_acc1","val_loss","val_acc1","val_acc5"]
    pd.DataFrame(columns=df_columns).to_csv(log_path, index=False)

    for epoch in range(1, epochs+1):
        # Partial unfreeze
        if epoch == unfreeze_after:
            print(">> Unfreezing last 2 Swin blocks for fine-tuning...")
            for name, param in model.vision_encoder.named_parameters():
                if "layers.2" in name or "layers.3" in name:
                    param.requires_grad = True

        train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_metrics = validate_one_epoch(model, val_loader, criterion, device)

        if scheduler:
            scheduler.step()

        metrics = {
            "epoch": epoch,
            "train_loss": train_metrics["train_loss"],
            "train_acc1": train_metrics["train_acc1"],
            "val_loss": val_metrics["val_loss"],
            "val_acc1": val_metrics["val_acc1"],
            "val_acc5": val_metrics["val_acc5"]
        }
        history.append(metrics)

        # Append metrics to CSV
        pd.DataFrame([metrics]).to_csv(log_path, mode='a', header=False, index=False)

        # Save best model explicitly
        if val_metrics["val_acc1"] > best_val_acc:
            best_val_acc = val_metrics["val_acc1"]
            best_state = {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "val_acc1": best_val_acc
            }
            torch.save(best_state, os.path.join(save_dir, best_model_name))

        print(f"Epoch [{epoch}/{epochs}] | "
              f"Train Loss: {train_metrics['train_loss']:.4f}, Train Acc: {train_metrics['train_acc1']:.2f}% | "
              f"Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_acc1']:.2f}%")

    return history, best_state

# STARTING TRAINING

In [None]:
model = ViTMixSwin(freeze_vision=True).to(DEVICE)

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss()

history, best_state = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=DEVICE,
    epochs=20,
    scheduler=scheduler,
    amp=True,
    unfreeze_after=5
)

# -------------------- 4️⃣ Print best validation accuracy --------------------
print(f"\n✓ Best validation accuracy: {best_state['val_acc1']:.2f}% at epoch {best_state['epoch']}")
test_metrics = validate_one_epoch(
    model=model, 
    loader=test_loader,
    criterion=criterion,
    device=DEVICE
)

print("\n=== TEST RESULTS ===")
print(f"Top-1 Accuracy: {test_metrics['val_acc1']:.2f}%")
print(f"Top-5 Accuracy: {test_metrics['val_acc5']:.2f}%")