# ISIC 2024 — Skin Cancer Detection (Inference Only)

**This notebook is for Kaggle submission only.** It loads pre-trained checkpoints
from attached datasets and generates `submission.csv`.

### Required Data Sources
1. **Competition data**: `isic-2024-challenge` (provides `test-image.hdf5`, `sample_submission.csv`)
2. **Checkpoint datasets**: Your training notebook outputs (e.g., `efficientnet-b0`)

### How checkpoint paths work
When you add a Kaggle notebook output as a dataset named e.g. `efficientnet-b0`,
it appears at: `/kaggle/input/efficientnet-b0/checkpoints/efficientnet_b0/fold_0/...`

In [None]:
# ============================================================
# CONFIGURATION — Edit these for your setup
# ============================================================
import os, sys, warnings
warnings.filterwarnings('ignore')

# Map of model_name -> Kaggle dataset slug where its checkpoints live.
# When you add a notebook output as a dataset, Kaggle mounts it at:
#   /kaggle/input/{dataset-slug}/
#
# Example: If your training notebook output is named "efficientnet-b0",
# and checkpoints are at /kaggle/input/efficientnet-b0/checkpoints/efficientnet_b0/fold_0/
MODEL_DATASETS = {
    "efficientnet_b0": "rudrashivm/efficientnet-b0",
    # "convnext_tiny":   "convnext-tiny",
    # "swin_large":      "swin-large",
}

# Ensemble strategy: "soft" or "hard_weighted"
ENSEMBLE_STRATEGY = "soft"

# Inference batch size (increase if GPU memory allows)
BATCH_SIZE = 64

# Competition data
DATA_DIR = "/kaggle/input/isic-2024-challenge"
TEST_HDF5 = os.path.join(DATA_DIR, "test-image.hdf5")
SAMPLE_SUB = os.path.join(DATA_DIR, "sample_submission.csv")

# Output
OUTPUT_DIR = "/kaggle/working"

# Verify data sources
print("Data sources:")
for path, label in [(TEST_HDF5, "test HDF5"), (SAMPLE_SUB, "sample submission")]:
    status = "✅" if os.path.exists(path) else "❌"
    print(f"  {status} {label}: {path}")

print("\nCheckpoint datasets:")
for model_name, dataset_slug in MODEL_DATASETS.items():
    base = f"/kaggle/input/{dataset_slug}"
    ckpt_dir = os.path.join(base, "checkpoints", model_name)
    status = "✅" if os.path.exists(ckpt_dir) else "❌"
    print(f"  {status} {model_name}: {ckpt_dir}")
    if os.path.exists(ckpt_dir):
        for fold_dir in sorted(os.listdir(ckpt_dir)):
            fold_path = os.path.join(ckpt_dir, fold_dir)
            if os.path.isdir(fold_path):
                ckpts = [f for f in os.listdir(fold_path) if f.endswith('.ckpt')]
                print(f"      {fold_dir}: {ckpts}")

In [None]:
# ============================================================
# MODEL & TRANSFORMS (self-contained — no repo clone needed)
# ============================================================
# This cell defines everything needed to load checkpoints and
# run inference. It mirrors src/models/isic_module.py and
# src/data/components/transforms.py but is fully inline so the
# notebook works in Kaggle's no-internet submission environment.
# ============================================================

import torch
import torch.nn as nn
import timm
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification import BinaryAUROC, BinaryROC
from typing import Any, Dict, List, Optional, Tuple


def get_val_transforms(img_size: int = 224):
    \"\"\"Validation/test transforms: resize + normalize.\"\"\"
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ])


class ISICLitModule(LightningModule):
    \"\"\"Minimal ISICLitModule for inference — matches training checkpoint structure.\"\"\"

    def __init__(
        self,
        name: str = "",
        backbone: str = "tf_efficientnet_b0_ns",
        num_classes: int = 1,
        pretrained: bool = False,  # No need for pretrained weights at inference
        lr: float = 1e-4,
        weight_decay: float = 1e-2,
        max_epochs: int = 20,
        dropout: float = 0.0,
        pos_weight: float = 1.0,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False)

        self.model = timm.create_model(
            backbone,
            pretrained=False,  # Weights come from checkpoint
            num_classes=num_classes,
            drop_rate=dropout,
        )

        self.criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor([pos_weight])
        )

        self.train_auroc = BinaryAUROC()
        self.val_auroc = BinaryAUROC()
        self.test_auroc = BinaryAUROC()
        self.val_roc = BinaryROC()
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()
        self.val_auroc_best = MaxMetric()

        self.register_buffer("best_threshold", torch.tensor(0.5))
        self.register_buffer("best_auroc", torch.tensor(0.0))

        self._val_preds: List[torch.Tensor] = []
        self._val_targets: List[torch.Tensor] = []

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)


def get_model_img_size(model):
    \"\"\"Auto-detect image size from TIMM backbone config.\"\"\"
    try:
        data_config = timm.data.resolve_data_config(model.model.pretrained_cfg)
        return data_config.get("input_size", (3, 224, 224))[-1]
    except Exception:
        return 224


print("✅ Model and transforms defined")

In [None]:
# ============================================================
# INFERENCE — Load checkpoints, predict on test-image.hdf5
# ============================================================

import h5py, cv2, glob
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load test ISIC IDs from sample submission
test_df = pd.read_csv(SAMPLE_SUB)
isic_ids = test_df["isic_id"].tolist()
print(f"Test samples: {len(isic_ids)}")

# ---- Discover and load all model checkpoints ----
model_configs = []  # List of (model, img_size, threshold, auroc)

for model_name, dataset_slug in MODEL_DATASETS.items():
    ckpt_base = f"/kaggle/input/{dataset_slug}/checkpoints/{model_name}"
    
    if not os.path.isdir(ckpt_base):
        print(f"⚠️  Skipping {model_name}: {ckpt_base} not found")
        continue
    
    # Discover fold directories
    fold_dirs = sorted([
        d for d in os.listdir(ckpt_base)
        if d.startswith("fold_") and os.path.isdir(os.path.join(ckpt_base, d))
    ])
    
    print(f"\n{model_name}: {len(fold_dirs)} folds")
    
    for fold_dir in fold_dirs:
        fold_path = os.path.join(ckpt_base, fold_dir)
        
        # Find best checkpoint (prefer auroc-named, fallback to last.ckpt)
        auroc_ckpts = sorted(glob.glob(os.path.join(fold_path, "epoch_*_auroc_*.ckpt")))
        if auroc_ckpts:
            # Pick highest AUROC from filename
            def get_auroc(p):
                try: return float(os.path.basename(p).split("auroc_")[-1].replace(".ckpt",""))
                except: return 0.0
            ckpt_path = max(auroc_ckpts, key=get_auroc)
        elif os.path.exists(os.path.join(fold_path, "last.ckpt")):
            ckpt_path = os.path.join(fold_path, "last.ckpt")
        else:
            any_ckpts = glob.glob(os.path.join(fold_path, "*.ckpt"))
            ckpt_path = any_ckpts[0] if any_ckpts else None
        
        if ckpt_path is None:
            print(f"  {fold_dir}: no checkpoint found")
            continue
        
        print(f"  {fold_dir}: {os.path.basename(ckpt_path)}")
        
        model = ISICLitModule.load_from_checkpoint(ckpt_path, map_location=device)
        model.eval()
        model.to(device)
        
        img_size = get_model_img_size(model)
        threshold = model.best_threshold.item()
        auroc = model.best_auroc.item() if hasattr(model, "best_auroc") else 0.5
        print(f"    img_size={img_size}, threshold={threshold:.4f}, auroc={auroc:.4f}")
        
        model_configs.append((model, img_size, threshold, auroc))

print(f"\nTotal model-folds loaded: {len(model_configs)}")
assert len(model_configs) > 0, "No checkpoints loaded! Check MODEL_DATASETS paths."

# ---- Run inference ----
hdf5 = h5py.File(TEST_HDF5, "r")

all_probs = []
all_thresholds = []
all_aurocs = []

for model_idx, (model, img_size, threshold, auroc) in enumerate(model_configs):
    print(f"\nModel {model_idx+1}/{len(model_configs)} (img_size={img_size})...")
    transform = get_val_transforms(img_size)
    probs = []
    
    for batch_start in tqdm(range(0, len(isic_ids), BATCH_SIZE), desc="Batches"):
        batch_ids = isic_ids[batch_start : batch_start + BATCH_SIZE]
        batch_images = []
        
        for isic_id in batch_ids:
            img_bytes = hdf5[isic_id][()]
            img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = transform(image=img)["image"]
            batch_images.append(img)
        
        batch_tensor = torch.stack(batch_images).to(device)
        with torch.no_grad():
            logits = model(batch_tensor).squeeze(1)
            batch_probs = torch.sigmoid(logits).cpu().numpy()
        probs.extend(batch_probs.tolist())
    
    all_probs.append(np.array(probs))
    all_thresholds.append(threshold)
    all_aurocs.append(auroc)
    
    # Free GPU memory
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

hdf5.close()

# ---- Ensemble ----
all_probs = np.array(all_probs)

if ENSEMBLE_STRATEGY == "hard_weighted":
    weights = np.array(all_aurocs)
    weights = weights / weights.sum() if weights.sum() > 0 else np.ones_like(weights) / len(weights)
    hard = np.array([(p >= t).astype(float) for p, t in zip(all_probs, all_thresholds)])
    final_probs = np.average(hard, axis=0, weights=weights)
    final_threshold = 0.5
    print(f"\nHard-weighted ensemble | weights: {[f'{w:.3f}' for w in weights]}")
else:
    final_probs = np.mean(all_probs, axis=0)
    final_threshold = np.mean(all_thresholds)
    print(f"\nSoft ensemble | threshold: {final_threshold:.4f}")

final_preds = (final_probs >= final_threshold).astype(int)

# ---- Generate submission.csv ----
submission = pd.DataFrame({
    "isic_id": isic_ids,
    "target": final_preds,
})

output_path = os.path.join(OUTPUT_DIR, "submission.csv")
submission.to_csv(output_path, index=False)

print(f"\n{'='*60}")
print(f"SUBMISSION GENERATED: {output_path}")
print(f"{'='*60}")
print(f"  Samples: {len(submission)}")
print(f"  Malignant: {final_preds.sum()} ({final_preds.mean()*100:.2f}%)")
print(f"  Benign: {len(final_preds) - final_preds.sum()}")
print(f"  Models×folds: {len(model_configs)}")
print(f"  Strategy: {ENSEMBLE_STRATEGY}")
print(f"\nFirst 10 rows:")
print(submission.head(10).to_string(index=False))
print(f"\n✅ Ready for submission!")