In [None]:
# ================================================================
#  SETUP — Kaggle Notebook Version (NO COLAB, NO kaggle.json)
# ================================================================

import os, json, random
from pathlib import Path
import copy  # <-- NEW: for safe rollback on bad pruning

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
import torch.nn.utils.prune as prune

# -------------------------
#  Dataset paths (Kaggle)
# -------------------------
DATA_ROOT = "/kaggle/input/paribahan-bd"
LOCAL_VEHICLES_DIR = os.path.join(DATA_ROOT, "Local-Vehicles", "Local-Vehicles")

print("DATA ROOT CONTENTS:", os.listdir(DATA_ROOT))
print("VEHICLE CLASS FOLDERS:", os.listdir(LOCAL_VEHICLES_DIR))

# -------------------------
#  Reproducibility
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# -------------------------
#  Device
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ================================================================
#  TRANSFORMS
# ================================================================
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225],
    ),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225],
    ),
])

# ================================================================
#  LOAD DATA
# ================================================================
base_dataset = datasets.ImageFolder(root=LOCAL_VEHICLES_DIR, transform=None)
print("Found classes:", base_dataset.classes)
NUM_CLASSES = len(base_dataset.classes)

indices = np.arange(len(base_dataset))
labels = np.array(base_dataset.targets)

train_idx, tmp_idx, y_train, y_tmp = train_test_split(
    indices, labels, test_size=0.3, random_state=SEED, stratify=labels
)
val_idx, test_idx, y_val, y_test = train_test_split(
    tmp_idx, y_tmp, test_size=0.5, random_state=SEED, stratify=y_tmp
)

print(f"TOTAL SAMPLES: {len(base_dataset)}")
print(f"TRAIN: {len(train_idx)} | VAL: {len(val_idx)} | TEST: {len(test_idx)}")

# ================================================================
#  CUSTOM DATASET
# ================================================================
class IndexedImageFolder(Dataset):
    def __init__(self, base_ds, indices, transform=None):
        self.base = base_ds
        self.indices = list(indices)
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        img, label = self.base[self.indices[i]]
        img = img.convert("RGB")  # ensure RGB
        if self.transform is not None:
            img = self.transform(img)
        return img, label

train_dataset_global = IndexedImageFolder(base_dataset, train_idx, transform=train_transform)
val_dataset_global   = IndexedImageFolder(base_dataset, val_idx,   transform=eval_transform)
test_dataset_global  = IndexedImageFolder(base_dataset, test_idx,  transform=eval_transform)

train_loader_global = DataLoader(train_dataset_global, batch_size=BATCH_SIZE,
                                 shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader_global   = DataLoader(val_dataset_global, batch_size=BATCH_SIZE,
                                 shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader_global  = DataLoader(test_dataset_global, batch_size=BATCH_SIZE,
                                 shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ================================================================
#  EXPERT MODELS (EfficientNet B0/B1/B2)
# ================================================================
EXPERT_BACKBONES = ["efficientnet_b0", "efficientnet_b1", "efficientnet_b2"]
NUM_EXPERTS = len(EXPERT_BACKBONES)
BOOTSTRAP_RATIO = 0.7

TRIAL_EPOCHS = 1
FULL_EPOCHS  = 2

LAMBDA_GL = 1e-5
PRUNE_AMOUNT = 0.3
PRUNE_FT_EPOCHS = 1

KD_TEMPERATURE = 4.0
KD_ALPHA = 0.3
KD_EPOCHS = 3

LR_EXPERT = 1e-4
LR_STUDENT = 1e-3

criterion_ce = nn.CrossEntropyLoss()

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def group_lasso_penalty(model):
    gl = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            w = m.weight
            w_norm = w.pow(2).sum(dim=(1, 2, 3)).sqrt()
            gl += w_norm.sum()
    return gl

def build_efficientnet_expert(name, num_classes):
    if name == "efficientnet_b0":
        m = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
    elif name == "efficientnet_b1":
        m = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.IMAGENET1K_V1)
    elif name == "efficientnet_b2":
        m = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.IMAGENET1K_V1)
    else:
        raise ValueError("Unknown EfficientNet variant")

    in_features = m.classifier[-1].in_features
    m.classifier[-1] = nn.Linear(in_features, num_classes)
    return m.to(device)


experts = []
for name in EXPERT_BACKBONES:
    model = build_efficientnet_expert(name, NUM_CLASSES)
    print(f"Built expert {name} with {count_params(model)/1e6:.2f}M params")
    experts.append(model)

# ================================================================
#  TRAINING FUNCTIONS
# ================================================================
def train_one_epoch_single_model(model, dataloader, optimizer, criterion, device, lambda_gl=0.0):
    model.train()
    total_loss = 0
    total_correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        out = model(images)
        ce_loss = criterion(out, labels)

        loss = ce_loss + lambda_gl * group_lasso_penalty(model)
        loss.backward()
        optimizer.step()

        total_loss += ce_loss.item() * images.size(0)
        total_correct += (out.argmax(1) == labels).sum().item()
        total += labels.size(0)

    return total_loss / total, total_correct / total


def evaluate_single_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            out = model(images)
            loss = criterion(out, labels)

            total_loss += loss.item() * images.size(0)
            total_correct += (out.argmax(1) == labels).sum().item()
            total += labels.size(0)

    return total_loss / total, total_correct / total

def structured_channel_prune(model, amount):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            prune.ln_structured(m, name="weight", amount=amount, n=2, dim=0)
            prune.remove(m, "weight")

# ================================================================
#  TRAIN EXPERT MODELS
# ================================================================
expert_val_scores = []

for i, expert in enumerate(experts):
    print(f"\n=== Training Expert {i+1}/{NUM_EXPERTS} ===")

    n_train = len(train_dataset_global)
    sub_size = int(BOOTSTRAP_RATIO * n_train)
    sub_idx = np.random.choice(np.arange(n_train), size=sub_size, replace=True)
    train_subset = torch.utils.data.Subset(train_dataset_global, sub_idx)

    train_loader_expert = DataLoader(
        train_subset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True
    )

    optimizer = torch.optim.AdamW(expert.parameters(), lr=LR_EXPERT, weight_decay=1e-4)

    # Trial training
    for epoch in range(1, TRIAL_EPOCHS + 1):
        t_loss, t_acc = train_one_epoch_single_model(expert, train_loader_expert, optimizer, criterion_ce, device, LAMBDA_GL)
        v_loss, v_acc = evaluate_single_model(expert, val_loader_global, criterion_ce, device)
        print(f"[Trial {epoch}/{TRIAL_EPOCHS}] Train={t_acc:.3f}, Val={v_acc:.3f}")

    # Full training
    for epoch in range(1, FULL_EPOCHS + 1):
        t_loss, t_acc = train_one_epoch_single_model(expert, train_loader_expert, optimizer, criterion_ce, device, LAMBDA_GL)
        v_loss, v_acc = evaluate_single_model(expert, val_loader_global, criterion_ce, device)
        print(f"[Full {epoch}/{FULL_EPOCHS}] Train={t_acc:.3f}, Val={v_acc:.3f}")

    expert_val_scores.append(v_acc)

print("\nValidation scores:", expert_val_scores)

# ================================================================
#  PRUNE EXPERTS (with rollback if accuracy collapses)
# ================================================================
for i, expert in enumerate(experts):
    print(f"\n=== Pruning Expert {i+1} ===")
    before_params = count_params(expert)
    pre_val = expert_val_scores[i]
    before_state = copy.deepcopy(expert.state_dict())  # save pre-prune weights

    structured_channel_prune(expert, PRUNE_AMOUNT)
    after_params = count_params(expert)
    print(f"Params: {before_params/1e6:.2f}M → {after_params/1e6:.2f}M")

    optimizer = torch.optim.AdamW(expert.parameters(), lr=LR_EXPERT)
    t_loss, t_acc = train_one_epoch_single_model(
        expert, train_loader_global, optimizer, criterion_ce, device, lambda_gl=0.0
    )
    v_loss, v_acc = evaluate_single_model(expert, val_loader_global, criterion_ce, device)
    print(f"After prune FT: Train={t_acc:.3f}, Val={v_acc:.3f}")

    # If pruning hurts too much, revert
    if v_acc < pre_val * 0.95:
        print(f"Pruning degraded Expert {i+1} too much (pre_val={pre_val:.3f}), reverting weights.")
        expert.load_state_dict(before_state)
        # keep original expert_val_scores[i]
    else:
        print(f"Pruning accepted for Expert {i+1}.")
        expert_val_scores[i] = v_acc  # update to post-prune val if good

# ================================================================
#  FOREST WEIGHTS
# ================================================================
val_scores = np.array(expert_val_scores)
forest_weights = val_scores / val_scores.sum() if val_scores.sum() > 0 else np.ones_like(val_scores) / len(val_scores)
forest_weights = torch.tensor(forest_weights, dtype=torch.float32, device=device)
print("\nForest Weights:", forest_weights.cpu().tolist())

def forest_logits(experts, images, weights=None):
    for m in experts:
        m.eval()
    with torch.no_grad():
        logits = torch.stack([m(images) for m in experts])
    if weights is None:
        return logits.mean(0)
    return (logits * weights.view(-1, 1, 1)).sum(0)

# ================================================================
#  STUDENT MODEL (MobileNetV3)
# ================================================================
def build_student_model(num_classes):
    m = models.mobilenet_v3_small(
        weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
    )
    in_features = m.classifier[-1].in_features
    m.classifier[-1] = nn.Linear(in_features, num_classes)
    return m.to(device)

student = build_student_model(NUM_CLASSES)
print("Student params:", count_params(student)/1e6, "M")

def kd_loss_fn(s_logits, t_logits, labels, T, alpha):
    ce = criterion_ce(s_logits, labels)
    log_p_s = F.log_softmax(s_logits / T, dim=1)
    p_t = F.softmax(t_logits / T, dim=1)
    kd = F.kl_div(log_p_s, p_t, reduction="batchmean") * (T**2)
    return alpha * ce + (1 - alpha) * kd, ce, kd

optimizer_student = torch.optim.AdamW(student.parameters(), lr=LR_STUDENT)

print("\n=== TRAINING STUDENT WITH KD ===")
for epoch in range(1, KD_EPOCHS + 1):
    student.train()
    total_correct = 0
    total = 0
    total_loss = 0

    for images, labels in train_loader_global:
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            t_logits = forest_logits(experts, images, weights=forest_weights)

        optimizer_student.zero_grad()
        s_logits = student(images)
        loss, ce, kd = kd_loss_fn(s_logits, t_logits, labels, KD_TEMPERATURE, KD_ALPHA)
        loss.backward()
        optimizer_student.step()

        total_loss += loss.item() * images.size(0)
        total_correct += (s_logits.argmax(1) == labels).sum().item()
        total += labels.size(0)

    train_acc = total_correct / total
    val_acc = evaluate_single_model(student, val_loader_global, criterion_ce, device)[1]

    print(f"[KD {epoch}/{KD_EPOCHS}] Train Acc={train_acc:.3f}, Val Acc={val_acc:.3f}")

# ================================================================
#  TEST ACCURACY
# ================================================================
test_acc = evaluate_single_model(student, test_loader_global, criterion_ce, device)[1]
print("\nFinal Student Test Accuracy:", test_acc)
print("Forest params:", sum(count_params(e) for e in experts)/1e6, "M")
print("Student params:", count_params(student)/1e6, "M")
