# ISIC 2024 — Skin Cancer Detection (Inference Only)

**This notebook is for Kaggle submission only.** It loads pre-trained checkpoints
from attached datasets, processes test images + tabular metadata, and generates
`submission.csv`.

In [None]:
# ============================================================
# CONFIGURATION — Edit these for your setup
# ============================================================
import os, sys, warnings
warnings.filterwarnings('ignore')

# Map of model_name -> Kaggle dataset slug where its checkpoints live.
# When you add a notebook output as a dataset, Kaggle mounts it at:
#   /kaggle/input/{dataset-slug}/
USERNAME = "rudrashivm"
MODEL_DATASETS = {
    "efficientnet_b0": f"{USERNAME}/efficientnet-b0",
    # "convnext_tiny":   "convnext-tiny",
    # "swin_large":      "swin-large",
}

ENSEMBLE_STRATEGY = "soft"   # "soft" or "hard_weighted"
BATCH_SIZE = 64

# Competition data
DATA_DIR = "/kaggle/input/isic-2024-challenge"
TEST_HDF5 = os.path.join(DATA_DIR, "test-image.hdf5")
TEST_META = os.path.join(DATA_DIR, "test-metadata.csv")
SAMPLE_SUB = os.path.join(DATA_DIR, "sample_submission.csv")
OUTPUT_DIR = "/kaggle/working"

# Verify data sources
print("Data sources:")
for path, label in [(TEST_HDF5, "test HDF5"), (TEST_META, "test metadata"), (SAMPLE_SUB, "sample submission")]:
    status = "\u2705" if os.path.exists(path) else "\u274c"
    print(f"  {status} {label}: {path}")

print("\nCheckpoint datasets:")
for model_name, dataset_slug in MODEL_DATASETS.items():
    base = f"/kaggle/input/datasets/{dataset_slug}"
    ckpt_dir = os.path.join(base, "checkpoints", model_name)
    status = "\u2705" if os.path.exists(ckpt_dir) else "\u274c"
    print(f"  {status} {model_name}: {ckpt_dir}")
    if os.path.exists(ckpt_dir):
        for fold_dir in sorted(os.listdir(ckpt_dir)):
            fold_path = os.path.join(ckpt_dir, fold_dir)
            if os.path.isdir(fold_path):
                ckpts = [f for f in os.listdir(fold_path) if f.endswith('.ckpt')]
                print(f"      {fold_dir}: {ckpts}")

In [None]:
# ============================================================
# MODEL, TRANSFORMS & TABULAR ENCODING (self-contained)
# ============================================================
# Everything needed to load checkpoints and run inference.
# Mirrors src/models/isic_module.py, src/data/components/transforms.py,
# and tabular encoding from src/data/isic_datamodule.py.
# Fully inline so the notebook works without internet.
# ============================================================

import torch
import torch.nn as nn
import timm
import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2

try:
    from lightning import LightningModule
except ImportError:
    from pytorch_lightning import LightningModule

from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification import BinaryAUROC, BinaryROC
from typing import Any, Dict, List, Tuple


def get_val_transforms(img_size: int = 384):
    """Validation/test transforms: resize + normalize."""
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


# === Tabular feature encoding (must match isic_datamodule.py exactly) ===
TABULAR_NUM_COLS = [
    'age_approx', 'clin_size_long_diam_mm',
    'tbp_lv_A', 'tbp_lv_Aext', 'tbp_lv_B', 'tbp_lv_Bext',
    'tbp_lv_C', 'tbp_lv_Cext', 'tbp_lv_H', 'tbp_lv_Hext',
    'tbp_lv_L', 'tbp_lv_Lext', 'tbp_lv_areaMM2',
    'tbp_lv_area_perim_ratio', 'tbp_lv_color_std_mean',
    'tbp_lv_deltaA', 'tbp_lv_deltaB', 'tbp_lv_deltaL',
    'tbp_lv_deltaLB', 'tbp_lv_deltaLBnorm', 'tbp_lv_eccentricity',
    'tbp_lv_minorAxisMM', 'tbp_lv_nevi_confidence',
    'tbp_lv_norm_border', 'tbp_lv_norm_color', 'tbp_lv_perimeterMM',
    'tbp_lv_radial_color_std_max', 'tbp_lv_stdL', 'tbp_lv_stdLExt',
    'tbp_lv_symm_2axis', 'tbp_lv_symm_2axis_angle',
    'tbp_lv_x', 'tbp_lv_y', 'tbp_lv_z',
    'tbp_lv_dnn_lesion_confidence',
]


def encode_tabular(df: pd.DataFrame) -> np.ndarray:
    """Encode tabular features identically to ISICDataModule._encode_tabular."""
    parts = []
    # Numeric
    num_data = df[TABULAR_NUM_COLS].fillna(0).values.astype(np.float32)
    parts.append(num_data)
    # Sex -> binary
    sex_map = {'male': 1.0, 'female': 0.0}
    sex_vals = df['sex'].map(sex_map).fillna(0.5).values.astype(np.float32).reshape(-1, 1)
    parts.append(sex_vals)
    # Anatomical site -> one-hot
    site_categories = [
        'head/neck', 'upper extremity', 'lower extremity',
        'anterior torso', 'posterior torso', 'lateral torso', 'palms/soles'
    ]
    site_encoded = np.zeros((len(df), len(site_categories)), dtype=np.float32)
    for i, cat in enumerate(site_categories):
        site_encoded[:, i] = (df['anatom_site_general'] == cat).astype(np.float32)
    parts.append(site_encoded)
    return np.hstack(parts)


class ISICLitModule(LightningModule):
    """ISICLitModule matching the training checkpoint structure.
    Supports both image-only and image+tabular fusion modes."""

    def __init__(
        self,
        name: str = "",
        backbone: str = "tf_efficientnet_b0_ns",
        num_classes: int = 1,
        pretrained: bool = False,
        lr: float = 1e-4,
        weight_decay: float = 1e-2,
        max_epochs: int = 20,
        dropout: float = 0.0,
        pos_weight: float = 1.0,
        n_tabular_features: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False)

        if n_tabular_features > 0:
            self.model = timm.create_model(
                backbone, pretrained=False, num_classes=0, drop_rate=dropout,
            )
            img_feat_dim = self.model.num_features
            fusion_dim = img_feat_dim + n_tabular_features
            self.fusion_head = nn.Sequential(
                nn.Linear(fusion_dim, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(128, 1),
            )
        else:
            self.model = timm.create_model(
                backbone, pretrained=False, num_classes=num_classes, drop_rate=dropout,
            )
            self.fusion_head = None

        self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
        self.train_auroc = BinaryAUROC()
        self.val_auroc = BinaryAUROC()
        self.test_auroc = BinaryAUROC()
        self.val_roc = BinaryROC()
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()
        self.val_auroc_best = MaxMetric()
        self.register_buffer("best_threshold", torch.tensor(0.5))
        self.register_buffer("best_auroc", torch.tensor(0.0))
        self._val_preds = []
        self._val_targets = []

    def forward(self, x, tabular=None):
        if self.fusion_head is not None and tabular is not None:
            img_features = self.model(x)
            combined = torch.cat([img_features, tabular], dim=1)
            return self.fusion_head(combined)
        else:
            return self.model(x)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)


def get_model_img_size(model):
    """Auto-detect image size from TIMM backbone config."""
    try:
        data_config = timm.data.resolve_data_config(model.model.pretrained_cfg)
        return data_config.get("input_size", (3, 224, 224))[-1]
    except Exception:
        return 384


print("\u2705 Model, transforms, and tabular encoding defined")
print(f"   Tabular features: {len(TABULAR_NUM_COLS)} numeric + 1 sex + 7 site = {len(TABULAR_NUM_COLS) + 8} total")

In [None]:
# ============================================================
# INFERENCE — Load checkpoints, predict on test-image.hdf5
# ============================================================

import h5py, cv2, glob
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load test metadata
test_meta_df = pd.read_csv(TEST_META, low_memory=False)
sample_sub = pd.read_csv(SAMPLE_SUB)
isic_ids = sample_sub["isic_id"].tolist()
print(f"Test samples: {len(isic_ids)}")

# Encode tabular features for test set
# Reorder test_meta_df to match isic_ids order from sample_submission
test_meta_df = test_meta_df.set_index('isic_id').loc[isic_ids].reset_index()
test_tabular_raw = encode_tabular(test_meta_df)
print(f"Tabular features shape: {test_tabular_raw.shape}")

# Note: tabular standardization uses approximate stats.
# For best results, save training stats during training and load them here.
# For now, we self-standardize the test set (acceptable approximation).
tab_mean = test_tabular_raw.mean(axis=0)
tab_std = test_tabular_raw.std(axis=0)
tab_std[tab_std < 1e-7] = 1.0
test_tabular = (test_tabular_raw - tab_mean) / tab_std

# ---- Discover and load all model checkpoints ----
model_configs = []  # List of (model, img_size, threshold, auroc, uses_tabular)

for model_name, dataset_slug in MODEL_DATASETS.items():
    ckpt_base = f"/kaggle/input/datasets/{dataset_slug}/checkpoints/{model_name}"
    if not os.path.isdir(ckpt_base):
        print(f"\u26a0\ufe0f  Skipping {model_name}: {ckpt_base} not found")
        continue
    
    fold_dirs = sorted([
        d for d in os.listdir(ckpt_base)
        if d.startswith("fold_") and os.path.isdir(os.path.join(ckpt_base, d))
    ])
    
    print(f"\n{model_name}: {len(fold_dirs)} folds")
    
    for fold_dir in fold_dirs:
        fold_path = os.path.join(ckpt_base, fold_dir)
        auroc_ckpts = sorted(glob.glob(os.path.join(fold_path, "epoch_*_auroc_*.ckpt")))
        if auroc_ckpts:
            def get_auroc(p):
                try: return float(os.path.basename(p).split("auroc_")[-1].replace(".ckpt",""))
                except: return 0.0
            ckpt_path = max(auroc_ckpts, key=get_auroc)
        elif os.path.exists(os.path.join(fold_path, "last.ckpt")):
            ckpt_path = os.path.join(fold_path, "last.ckpt")
        else:
            any_ckpts = glob.glob(os.path.join(fold_path, "*.ckpt"))
            ckpt_path = any_ckpts[0] if any_ckpts else None
        
        if ckpt_path is None:
            print(f"  {fold_dir}: no checkpoint found")
            continue
        
        print(f"  {fold_dir}: {os.path.basename(ckpt_path)}")
        model = ISICLitModule.load_from_checkpoint(ckpt_path, map_location=device, strict=False)
        model.eval()
        model.to(device)
        
        img_size = get_model_img_size(model)
        threshold = model.best_threshold.item()
        auroc = model.best_auroc.item() if hasattr(model, "best_auroc") else 0.5
        uses_tabular = model.fusion_head is not None
        print(f"    img_size={img_size}, threshold={threshold:.4f}, auroc={auroc:.4f}, tabular={uses_tabular}")
        
        model_configs.append((model, img_size, threshold, auroc, uses_tabular))

print(f"\nTotal model-folds loaded: {len(model_configs)}")
assert len(model_configs) > 0, "No checkpoints loaded! Check MODEL_DATASETS paths."

# ---- Run inference ----
hdf5 = h5py.File(TEST_HDF5, "r")

all_probs = []
all_thresholds = []
all_aurocs = []

for model_idx, (model, img_size, threshold, auroc, uses_tabular) in enumerate(model_configs):
    print(f"\nModel {model_idx+1}/{len(model_configs)} (img_size={img_size}, tabular={uses_tabular})...")
    transform = get_val_transforms(img_size)
    probs = []
    
    for batch_start in tqdm(range(0, len(isic_ids), BATCH_SIZE), desc="Batches"):
        batch_ids = isic_ids[batch_start : batch_start + BATCH_SIZE]
        batch_images = []
        
        for isic_id in batch_ids:
            img_bytes = hdf5[isic_id][()]
            img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = transform(image=img)["image"]
            batch_images.append(img)
        
        batch_tensor = torch.stack(batch_images).to(device)
        
        # Prepare tabular features for this batch if model uses them
        if uses_tabular:
            batch_tab = torch.tensor(
                test_tabular[batch_start : batch_start + len(batch_ids)],
                dtype=torch.float32
            ).to(device)
        else:
            batch_tab = None
        
        with torch.no_grad():
            logits = model(batch_tensor, batch_tab).squeeze(1)
            batch_probs = torch.sigmoid(logits).cpu().numpy()
        probs.extend(batch_probs.tolist())
    
    all_probs.append(np.array(probs))
    all_thresholds.append(threshold)
    all_aurocs.append(auroc)
    
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

hdf5.close()

# ---- Ensemble ----
all_probs = np.array(all_probs)

if ENSEMBLE_STRATEGY == "hard_weighted":
    weights = np.array(all_aurocs)
    weights = weights / weights.sum() if weights.sum() > 0 else np.ones_like(weights) / len(weights)
    hard = np.array([(p >= t).astype(float) for p, t in zip(all_probs, all_thresholds)])
    final_probs = np.average(hard, axis=0, weights=weights)
    final_threshold = 0.5
    print(f"\nHard-weighted ensemble | weights: {[f'{w:.3f}' for w in weights]}")
else:
    final_probs = np.mean(all_probs, axis=0)
    final_threshold = np.mean(all_thresholds)
    print(f"\nSoft ensemble | threshold: {final_threshold:.4f}")

final_preds = (final_probs >= final_threshold).astype(int)

# ---- Generate submission.csv ----
submission = pd.DataFrame({"isic_id": isic_ids, "target": final_preds})
output_path = os.path.join(OUTPUT_DIR, "submission.csv")
submission.to_csv(output_path, index=False)

print(f"\n{'='*60}")
print(f"SUBMISSION GENERATED: {output_path}")
print(f"{'='*60}")
print(f"  Samples: {len(submission)}")
print(f"  Malignant: {final_preds.sum()} ({final_preds.mean()*100:.2f}%)")
print(f"  Benign: {len(final_preds) - final_preds.sum()}")
print(f"  Models\u00d7folds: {len(model_configs)}")
print(f"  Strategy: {ENSEMBLE_STRATEGY}")
print(f"\nFirst 10 rows:")
print(submission.head(10).to_string(index=False))
print(f"\n\u2705 Ready for submission!")