# ISIC 2024 Skin Cancer Detection - MedGemma QLoRA Fine-Tuning

Fine-tune **MedGemma 1.5 4B IT** for binary skin lesion classification (benign/malignant).

**Data:** `train-metadata.csv` (401K rows) + `train-image.hdf5` (JPEG byte strings)

**3-Layer Class Imbalance Strategy** (from KerasCV starter):
- Layer 1: Sampling — downsample benign 1%, oversample malignant 5x
- Layer 2: sklearn class weights — `compute_class_weight('balanced')`
- Layer 3: Focal loss — gamma=2.0, alpha=0.75, per-sample weighted

**Features:**
- 4-bit QLoRA on Kaggle T4 GPU (16GB VRAM)
- Tabular features as text in prompt (age, sex, site, size, DNN confidence)
- StratifiedGroupKFold by patient_id (no data leakage)
- Frozen SigLIP vision encoder
- Structured JSON output with confidence scores

**Requirements:** Attach `isic-2024-challenge` dataset and enable GPU T4 accelerator.

## 1. Install Dependencies

In [None]:
%%capture
!pip install -U transformers accelerate peft bitsandbytes datasets huggingface_hub
!pip install trl scikit-learn matplotlib tqdm torchvision h5py

## 2. Setup & Authentication

In [None]:
import os, gc, io, json, re, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, balanced_accuracy_score
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

def clear_memory():
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

# Auth
from huggingface_hub import login
try:
    from kaggle_secrets import UserSecretsClient
    HF_TOKEN = UserSecretsClient().get_secret("HF_TOKEN")
except:
    HF_TOKEN = os.getenv("HF_TOKEN") or input("HF token: ")
login(token=HF_TOKEN)

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## 3. Configuration

In [None]:
# === PATHS — auto-detect Kaggle vs local ===
KAGGLE_BASE = "/kaggle/input/isic-2024-challenge"
LOCAL_BASE = os.path.join(os.path.expanduser("~"), "medagents", "isic-2024-challenge")

IS_KAGGLE = os.path.exists(KAGGLE_BASE)
DATA_DIR = KAGGLE_BASE if IS_KAGGLE else LOCAL_BASE

TRAIN_CSV = os.path.join(DATA_DIR, "train-metadata.csv")
TRAIN_HDF5 = os.path.join(DATA_DIR, "train-image.hdf5")
TEST_CSV = os.path.join(DATA_DIR, "test-metadata.csv")
TEST_HDF5 = os.path.join(DATA_DIR, "test-image.hdf5")

OUTPUT_DIR = "/kaggle/working/isic-medgemma" if IS_KAGGLE else "./output"
IMAGES_DIR = os.path.join(OUTPUT_DIR, "images")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(IMAGES_DIR, exist_ok=True)

# === MODEL ===
MODEL_NAME = "google/medgemma-1.5-4b-it"
MAX_SEQ_LENGTH = 512

# === QLoRA ===
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05

# === TRAINING ===
BATCH_SIZE = 1
GRADIENT_ACCUMULATION = 16
LEARNING_RATE = 2e-4
MAX_STEPS = 500
WARMUP_STEPS = 50

# === CLASS IMBALANCE (from KerasCV starter) ===
NEG_SAMPLE_FRAC = 0.01    # Layer 1: downsample benign to 1%
POS_SAMPLE_FRAC = 5.0     # Layer 1: oversample malignant 5x
FOCAL_GAMMA = 2.0         # Layer 3: focal loss gamma
FOCAL_ALPHA = 0.75        # Layer 3: focal loss alpha

# === SPLIT ===
N_FOLDS = 5
VAL_FOLD = 0

# === PROMPT TEMPLATE ===
METADATA_TEMPLATE = (
    "Patient: {age} year old {sex}\n"
    "Lesion site: {site}, size: {size_mm} mm\n"
    "DNN confidence: {dnn_confidence}, Nevi confidence: {nevi_confidence}\n"
)

print(f"Data: {DATA_DIR}")
print(f"Output: {OUTPUT_DIR}")
print(f"Kaggle: {IS_KAGGLE}")

## 4. Load Data & Class Balancing (Layer 1)

In [None]:
# Load train-metadata.csv
df = pd.read_csv(TRAIN_CSV)
print(f"Loaded: {len(df)} rows, {df.shape[1]} columns")
print(f"Benign: {(df['target']==0).sum()}, Malignant: {(df['target']==1).sum()}")
print(f"Malignant rate: {df['target'].mean()*100:.2f}%")

# Layer 1: Sampling (from KerasCV starter)
benign = df[df['target'] == 0]
malignant = df[df['target'] == 1]

benign_sampled = benign.sample(frac=NEG_SAMPLE_FRAC, random_state=SEED)
malignant_sampled = malignant.sample(frac=POS_SAMPLE_FRAC, replace=True, random_state=SEED)

sampled = pd.concat([benign_sampled, malignant_sampled]).sample(
    frac=1.0, random_state=SEED
).reset_index(drop=True)

print(f"\nLayer 1 - Sampling:")
print(f"  Benign:    {len(benign)} -> {len(benign_sampled)} (frac={NEG_SAMPLE_FRAC})")
print(f"  Malignant: {len(malignant)} -> {len(malignant_sampled)} (frac={POS_SAMPLE_FRAC})")
print(f"  Total:     {len(sampled)}")

In [None]:
# Layer 2: Compute sklearn class weights
y = sampled['target'].values
classes = np.unique(y)
weights = compute_class_weight('balanced', classes=classes, y=y)
class_weights = {int(c): float(w) for c, w in zip(classes, weights)}

print(f"Layer 2 - Class weights:")
print(f"  Benign (0):    {class_weights[0]:.4f}")
print(f"  Malignant (1): {class_weights[1]:.4f}")

## 5. Train/Val Split (StratifiedGroupKFold by patient_id)

In [None]:
# Fill missing patient_id with isic_id
sampled['patient_id'] = sampled['patient_id'].fillna(sampled['isic_id'])

sgkf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
splits = list(sgkf.split(sampled, sampled['target'], groups=sampled['patient_id']))

train_idx, val_idx = splits[VAL_FOLD]
train_df = sampled.iloc[train_idx].reset_index(drop=True)
val_df = sampled.iloc[val_idx].reset_index(drop=True)

# Verify no patient leakage
train_patients = set(train_df['patient_id'].unique())
val_patients = set(val_df['patient_id'].unique())
overlap = train_patients & val_patients

print(f"Train: {len(train_df)} (mal: {(train_df['target']==1).sum()}, ben: {(train_df['target']==0).sum()})")
print(f"Val:   {len(val_df)} (mal: {(val_df['target']==1).sum()}, ben: {(val_df['target']==0).sum()})")
print(f"Patient leakage: {'YES - ' + str(len(overlap)) + ' patients!' if overlap else 'None (verified)'}")

## 6. Extract Images from HDF5

In [None]:
# Collect all unique isic_ids needed
all_ids = set(train_df['isic_id'].unique()) | set(val_df['isic_id'].unique())
print(f"Need {len(all_ids)} unique images")

hdf5 = h5py.File(TRAIN_HDF5, "r")
extracted = 0
failed = 0

for isic_id in tqdm(all_ids, desc="Extracting images"):
    out_path = os.path.join(IMAGES_DIR, f"{isic_id}.jpg")
    if os.path.exists(out_path):
        extracted += 1
        continue
    try:
        byte_string = hdf5[isic_id][()]
        img = Image.open(io.BytesIO(byte_string)).convert("RGB")
        img.save(out_path, "JPEG", quality=95)
        extracted += 1
    except Exception as e:
        failed += 1

hdf5.close()
print(f"Extracted: {extracted}, Failed: {failed}")

## 7. Build Prompts & Format Conversations

In [None]:
def format_metadata(row):
    """Format tabular features into prompt text."""
    age = row.get('age_approx')
    sex = row.get('sex', 'unknown')
    site = row.get('anatom_site_general', 'unknown')
    size_mm = row.get('clin_size_long_diam_mm')
    dnn_val = row.get('tbp_lv_dnn_lesion_confidence')
    nevi_val = row.get('tbp_lv_nevi_confidence')
    
    return METADATA_TEMPLATE.format(
        age=str(int(age)) if pd.notna(age) else 'unknown',
        sex=str(sex) if pd.notna(sex) else 'unknown',
        site=str(site) if pd.notna(site) else 'unknown',
        size_mm=f'{size_mm:.1f}' if pd.notna(size_mm) else 'unknown',
        dnn_confidence=f'{dnn_val:.3f}' if pd.notna(dnn_val) else 'N/A',
        nevi_confidence=f'{nevi_val:.3f}' if pd.notna(nevi_val) else 'N/A',
    )


def make_response(is_malignant):
    """Generate training response JSON."""
    if is_malignant:
        conf = round(random.uniform(0.75, 0.95), 2)
        reasons = [
            "Irregular morphology with asymmetric features suggesting potential malignancy.",
            "Atypical dermoscopic pattern with concerning structural features.",
            "Lesion shows irregular pigment distribution and border irregularity.",
            "Morphological features are atypical, warranting biopsy and further evaluation.",
            "Heterogeneous pattern with multiple dermoscopic criteria for malignancy.",
        ]
    else:
        conf = round(random.uniform(0.80, 0.99), 2)
        reasons = [
            "Regular symmetric pattern consistent with benign melanocytic proliferation.",
            "Typical dermoscopic features of a benign nevus with uniform pigmentation.",
            "Homogeneous pattern with regular borders consistent with benign lesion.",
            "Symmetric structure with regular pigment network, no concerning features.",
            "Benign morphology with uniform color and well-defined borders.",
        ]
    return json.dumps({
        "classification": "malignant" if is_malignant else "benign",
        "confidence": conf,
        "reasoning": random.choice(reasons),
    })


def row_to_conversation(row):
    """Convert a row to VLM conversation dict."""
    meta = format_metadata(row)
    user_text = (
        "Analyze this dermoscopic skin lesion image for signs of malignancy.\n"
        f"{meta}"
        'Provide your assessment as JSON: '
        '{"classification": "benign" or "malignant", '
        '"confidence": 0.0-1.0, '
        '"reasoning": "brief explanation"}'
    )
    is_mal = row['target'] == 1
    image_path = os.path.join(IMAGES_DIR, f"{row['isic_id']}.jpg")
    
    return {
        "messages": [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": user_text},
            ]},
            {"role": "assistant", "content": [
                {"type": "text", "text": make_response(is_mal)},
            ]},
        ],
        "image_path": image_path,
        "isic_id": row['isic_id'],
        "malignant": int(is_mal),
    }


# Build conversations (skip samples with missing images)
train_convos = []
for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Train convos"):
    img_path = os.path.join(IMAGES_DIR, f"{row['isic_id']}.jpg")
    if os.path.exists(img_path):
        train_convos.append(row_to_conversation(row))

val_convos = []
for _, row in tqdm(val_df.iterrows(), total=len(val_df), desc="Val convos"):
    img_path = os.path.join(IMAGES_DIR, f"{row['isic_id']}.jpg")
    if os.path.exists(img_path):
        val_convos.append(row_to_conversation(row))

print(f"Conversations: train={len(train_convos)}, val={len(val_convos)}")

## 8. Load MedGemma with QLoRA

In [None]:
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

clear_memory()

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {MODEL_NAME}...")
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
    trust_remote_code=True,
    torch_dtype=torch.float16,
)
processor = AutoProcessor.from_pretrained(MODEL_NAME, token=HF_TOKEN, trust_remote_code=True)

if processor.tokenizer.pad_token is None:
    processor.tokenizer.pad_token = processor.tokenizer.eos_token
    model.config.pad_token_id = processor.tokenizer.eos_token_id

print("Model loaded!")

In [None]:
# Prepare for QLoRA
model = prepare_model_for_kbit_training(model)

# Freeze vision encoder
frozen = 0
for name, param in model.named_parameters():
    if "vision" in name.lower() or "siglip" in name.lower() or "image_encoder" in name.lower():
        param.requires_grad = False
        frozen += param.numel()
print(f"Froze vision encoder: {frozen:,} params")

# Apply LoRA
lora_config = LoraConfig(
    r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
    bias="none", task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"LoRA: {trainable:,} trainable / {total:,} total ({100*trainable/total:.2f}%)")

## 9. Dataset & Training Classes

In [None]:
class ISICDataset(Dataset):
    def __init__(self, conversations, processor, max_length=MAX_SEQ_LENGTH):
        self.data = conversations
        self.processor = processor
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        image = Image.open(sample["image_path"]).convert("RGB")
        
        # Tokenize conversation
        text = self.processor.apply_chat_template(
            sample["messages"], tokenize=False, add_generation_prompt=False,
        )
        inputs = self.processor(
            text=text, images=[image], return_tensors="pt",
            padding="max_length", max_length=self.max_length, truncation=True,
        )
        
        input_ids = inputs["input_ids"].squeeze(0)
        attention_mask = inputs["attention_mask"].squeeze(0)
        pixel_values = inputs.get("pixel_values")
        if pixel_values is not None:
            pixel_values = pixel_values.squeeze(0)
        
        # Create labels (mask non-response tokens)
        labels = input_ids.clone()
        pad_id = self.processor.tokenizer.pad_token_id
        if pad_id is not None:
            labels[input_ids == pad_id] = -100
        
        # Find model turn and mask everything before it
        marker = self.processor.tokenizer.encode("<start_of_turn>model\n", add_special_tokens=False)
        token_list = input_ids.tolist()
        for i in range(len(token_list) - len(marker) + 1):
            if token_list[i:i+len(marker)] == marker:
                labels[:i+len(marker)] = -100
                break
        
        result = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "is_malignant": sample["malignant"],
        }
        if pixel_values is not None:
            result["pixel_values"] = pixel_values
        return result


class ISICCollator:
    def __call__(self, features):
        batch = {
            "input_ids": torch.stack([f["input_ids"] for f in features]),
            "attention_mask": torch.stack([f["attention_mask"] for f in features]),
            "labels": torch.stack([f["labels"] for f in features]),
            "is_malignant": torch.tensor([f["is_malignant"] for f in features], dtype=torch.float32),
        }
        if "pixel_values" in features[0] and features[0]["pixel_values"] is not None:
            batch["pixel_values"] = torch.stack([f["pixel_values"] for f in features])
        return batch


class FocalLoss(nn.Module):
    """Layer 3: Focal loss with per-sample class weights."""
    def __init__(self, gamma=FOCAL_GAMMA, alpha=FOCAL_ALPHA):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
    
    def forward(self, logits, labels, sample_weights=None):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        flat_logits = shift_logits.view(-1, shift_logits.shape[-1])
        flat_labels = shift_labels.view(-1)
        
        valid = flat_labels != -100
        if not valid.any():
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        ce = F.cross_entropy(flat_logits[valid], flat_labels[valid], reduction="none")
        pt = torch.exp(-ce)
        focal = self.alpha * (1 - pt) ** self.gamma * ce
        
        if sample_weights is not None:
            seq_len = shift_labels.shape[1]
            w = sample_weights.unsqueeze(1).expand(-1, seq_len).contiguous().view(-1)
            focal = focal * w[valid]
        
        return focal.mean()


from transformers import Trainer, TrainingArguments

class ISICTrainer(Trainer):
    """Custom trainer with focal loss + sklearn class weights."""
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.focal_loss = FocalLoss()
        self.class_weights = class_weights or {0: 1.0, 1: 1.0}
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        is_mal = inputs.pop("is_malignant", None)
        outputs = model(**inputs)
        
        # Layer 2: Per-sample weights from sklearn
        weights = None
        if is_mal is not None:
            weights = torch.where(is_mal.bool(),
                torch.tensor(self.class_weights.get(1, 1.0), device=outputs.logits.device),
                torch.tensor(self.class_weights.get(0, 1.0), device=outputs.logits.device))
        
        # Layer 3: Focal loss
        loss = self.focal_loss(outputs.logits, inputs["labels"], weights)
        return (loss, outputs) if return_outputs else loss

print("Classes defined!")

## 10. Train

In [None]:
train_ds = ISICDataset(train_convos, processor)
val_ds = ISICDataset(val_convos, processor)
collator = ISICCollator()

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    max_steps=MAX_STEPS,
    warmup_steps=WARMUP_STEPS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    logging_steps=10,
    eval_steps=100,
    eval_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    seed=SEED,
    dataloader_pin_memory=False,
    remove_unused_columns=False,
)

trainer = ISICTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
    class_weights=class_weights,
)

print(f"Training: {MAX_STEPS} steps, effective batch={BATCH_SIZE*GRADIENT_ACCUMULATION}")
print(f"Layer 2 weights: benign={class_weights[0]:.4f}, malignant={class_weights[1]:.4f}")
print(f"Layer 3 focal: gamma={FOCAL_GAMMA}, alpha={FOCAL_ALPHA}")

In [None]:
clear_memory()
print("Training started...")
trainer.train()
print("Training complete!")

## 11. Evaluate on Validation Set

In [None]:
def parse_output(text):
    """Parse model JSON output with regex fallback."""
    try:
        m = re.search(r'\{[^{}]*\}', text)
        if m:
            parsed = json.loads(m.group())
            if 'classification' in parsed and 'confidence' in parsed:
                return parsed
    except:
        pass
    result = {"classification": "benign", "confidence": 0.5}
    cm = re.search(r'"classification"\s*:\s*"(malignant|benign)"', text, re.I)
    if cm: result["classification"] = cm.group(1).lower()
    cf = re.search(r'"confidence"\s*:\s*(0?\.\d+|1\.0|1)', text)
    if cf: result["confidence"] = float(cf.group(1))
    return result

# Run inference on validation set
model.eval()
predictions = []
device = "cuda" if torch.cuda.is_available() else "cpu"

for sample in tqdm(val_convos, desc="Evaluating"):
    image = Image.open(sample["image_path"]).convert("RGB")
    msgs = [sample["messages"][0]]  # user only
    
    text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=text, images=[image], return_tensors="pt",
                       padding=True, max_length=MAX_SEQ_LENGTH, truncation=True).to(device)
    
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=150, temperature=0.1,
                             do_sample=False, pad_token_id=processor.tokenizer.eos_token_id)
    
    gen = out[0][inputs["input_ids"].shape[1]:]
    resp = processor.tokenizer.decode(gen, skip_special_tokens=True)
    parsed = parse_output(resp)
    
    predictions.append({
        "isic_id": sample["isic_id"],
        "true_label": sample["malignant"],
        "pred_class": parsed["classification"],
        "pred_conf": parsed["confidence"],
    })

pred_df = pd.DataFrame(predictions)
print(f"Predictions: {len(pred_df)}")

In [None]:
# Compute metrics
y_true = pred_df["true_label"].values.astype(int)
y_pred = (pred_df["pred_class"] == "malignant").astype(int).values
y_scores = np.where(pred_df["pred_class"] == "malignant", pred_df["pred_conf"], 1 - pred_df["pred_conf"])

# pAUC (ISIC 2024 competition metric)
fpr, tpr, _ = roc_curve(y_true, y_scores)
mask = tpr >= 0.80
if mask.any() and len(fpr[mask]) >= 2:
    pauc_raw = np.trapz(tpr[mask], fpr[mask])
    # Rescale pAUC
    max_fpr_range = fpr[mask][-1] - fpr[mask][0]
    if max_fpr_range > 0:
        min_pauc = 0.5 * max_fpr_range * (1 + 0.80)
        max_pauc = max_fpr_range * 1.0
        pauc = max(0.0, min(1.0, (pauc_raw - min_pauc) / (max_pauc - min_pauc))) if max_pauc > min_pauc else 0.0
    else:
        pauc = 0.0
else:
    pauc = 0.0

auc = roc_auc_score(y_true, y_scores) if len(np.unique(y_true)) > 1 else 0.0
tp = ((y_pred == 1) & (y_true == 1)).sum()
fn = ((y_pred == 0) & (y_true == 1)).sum()
tn = ((y_pred == 0) & (y_true == 0)).sum()
fp = ((y_pred == 1) & (y_true == 0)).sum()

sens = tp / max(tp + fn, 1)
spec = tn / max(tn + fp, 1)
bal_acc = balanced_accuracy_score(y_true, y_pred)

print("=" * 50)
print("EVALUATION RESULTS")
print("=" * 50)
print(f"AUC-ROC:          {auc:.4f}")
print(f"pAUC (TPR>=0.80): {pauc:.4f}")
print(f"Sensitivity:      {sens:.4f}")
print(f"Specificity:      {spec:.4f}")
print(f"Balanced Acc:     {bal_acc:.4f}")
print(f"\nConfusion Matrix:")
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
print(f"  True Benign:    {cm[0][0]:>5d} correct, {cm[0][1]:>5d} false alarm")
print(f"  True Malignant: {cm[1][1]:>5d} detected, {cm[1][0]:>5d} missed")

In [None]:
# ROC Curve + Confusion Matrix plots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(fpr, tpr, 'b-', lw=2, label=f'ROC (AUC={auc:.4f})')
axes[0].plot([0,1], [0,1], 'k--', alpha=0.5)
if mask.any():
    axes[0].fill_between(fpr[mask], 0.8, tpr[mask], alpha=0.2, color='green', label='pAUC region')
axes[0].axhline(y=0.8, color='r', linestyle=':', alpha=0.5)
axes[0].set_xlabel('FPR'); axes[0].set_ylabel('TPR')
axes[0].set_title('ROC Curve'); axes[0].legend()
axes[0].grid(True, alpha=0.3)

im = axes[1].imshow(cm, cmap='Blues')
plt.colorbar(im, ax=axes[1])
for i in range(2):
    for j in range(2):
        c = 'white' if cm[i,j] > cm.max()/2 else 'black'
        axes[1].text(j, i, str(cm[i,j]), ha='center', va='center', color=c, fontsize=14)
axes[1].set_xticks([0,1]); axes[1].set_yticks([0,1])
axes[1].set_xticklabels(['Benign','Malignant']); axes[1].set_yticklabels(['Benign','Malignant'])
axes[1].set_xlabel('Predicted'); axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'eval_plots.png'), dpi=150)
plt.show()

## 12. Save Model

In [None]:
# Save LoRA adapter
adapter_dir = os.path.join(OUTPUT_DIR, "lora_adapter")
model.save_pretrained(adapter_dir)
processor.save_pretrained(adapter_dir)

# Save config + metrics
config = {
    "model": MODEL_NAME, "lora_r": LORA_R, "lora_alpha": LORA_ALPHA,
    "focal_gamma": FOCAL_GAMMA, "focal_alpha": FOCAL_ALPHA,
    "class_weights": class_weights,
    "neg_sample_frac": NEG_SAMPLE_FRAC, "pos_sample_frac": POS_SAMPLE_FRAC,
    "max_steps": MAX_STEPS, "n_folds": N_FOLDS, "val_fold": VAL_FOLD,
    "auc_roc": float(auc), "pauc": float(pauc),
    "sensitivity": float(sens), "specificity": float(spec),
}
with open(os.path.join(adapter_dir, "training_config.json"), "w") as f:
    json.dump(config, f, indent=2)

# Save predictions
pred_df.to_csv(os.path.join(OUTPUT_DIR, "predictions.csv"), index=False)

# Zip for download
import subprocess
subprocess.run(["zip", "-r", os.path.join(OUTPUT_DIR, "lora_adapter.zip"),
                adapter_dir], capture_output=True)

print(f"Saved to: {adapter_dir}")
print(f"Download: {OUTPUT_DIR}/lora_adapter.zip")

## 13. Merge LoRA into Base Model (for GGUF Export)

In [None]:
# Free the 4-bit training model to make room for float16 merge
del model, trainer
clear_memory()

from transformers import AutoModelForImageTextToText, AutoProcessor
from peft import PeftModel

# Load base model in float16 (NOT 4-bit — need full precision for clean merge)
print("Loading base model in float16 for LoRA merge...")
base_model = AutoModelForImageTextToText.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    token=HF_TOKEN,
    trust_remote_code=True,
)
merge_processor = AutoProcessor.from_pretrained(MODEL_NAME, token=HF_TOKEN, trust_remote_code=True)

# Load and merge LoRA adapter
adapter_dir = os.path.join(OUTPUT_DIR, "lora_adapter")
print(f"Loading LoRA adapter from {adapter_dir}...")
merged_model = PeftModel.from_pretrained(base_model, adapter_dir)
merged_model = merged_model.merge_and_unload()
print("LoRA merged into base model!")

# Save merged model as safetensors
merged_dir = os.path.join(OUTPUT_DIR, "merged_model")
os.makedirs(merged_dir, exist_ok=True)
merged_model.save_pretrained(merged_dir, safe_serialization=True)
merge_processor.save_pretrained(merged_dir)

# Free GPU memory
del merged_model, base_model
clear_memory()

print(f"Merged model saved to {merged_dir}")
!du -sh {merged_dir}

## 14. Convert to GGUF (for Ollama / llama.cpp CPU Serving)

In [None]:
%%bash -e
# Clone llama.cpp (shallow clone for speed) and install conversion deps
echo "=== Cloning llama.cpp ==="
git clone --depth 1 https://github.com/ggml-org/llama.cpp /kaggle/working/llama.cpp 2>&1 | tail -3

echo "=== Installing conversion requirements ==="
pip install -q -r /kaggle/working/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt 2>&1 | tail -3

echo "=== Building llama-quantize ==="
cd /kaggle/working/llama.cpp
cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_CUDA=OFF 2>&1 | tail -3
cmake --build build --config Release -j$(nproc) --target llama-quantize 2>&1 | tail -3

echo "=== Done ==="
ls -la build/bin/llama-quantize

In [None]:
import subprocess

merged_dir = os.path.join(OUTPUT_DIR, "merged_model")
gguf_dir = os.path.join(OUTPUT_DIR, "gguf")
os.makedirs(gguf_dir, exist_ok=True)

f16_gguf = os.path.join(gguf_dir, "medgemma-isic-f16.gguf")
mmproj_gguf = os.path.join(gguf_dir, "mmproj-medgemma-isic-f16.gguf")
q4_gguf = os.path.join(gguf_dir, "medgemma-isic-Q4_K_M.gguf")

# Step 1: Convert text model to GGUF F16
print("=" * 50)
print("Step 1/3: Converting text model to GGUF F16...")
print("=" * 50)
subprocess.run([
    "python", "/kaggle/working/llama.cpp/convert_hf_to_gguf.py",
    merged_dir, "--outfile", f16_gguf, "--outtype", "f16"
], check=True)

# Step 2: Convert mmproj (vision encoder + projector) to GGUF F16
print("\n" + "=" * 50)
print("Step 2/3: Converting vision encoder (mmproj) to GGUF F16...")
print("=" * 50)
subprocess.run([
    "python", "/kaggle/working/llama.cpp/convert_hf_to_gguf.py",
    merged_dir, "--mmproj", "--outfile", mmproj_gguf, "--outtype", "f16"
], check=True)

# Step 3: Quantize text model to Q4_K_M for CPU efficiency
print("\n" + "=" * 50)
print("Step 3/3: Quantizing text model to Q4_K_M...")
print("=" * 50)
subprocess.run([
    "/kaggle/working/llama.cpp/build/bin/llama-quantize",
    f16_gguf, q4_gguf, "Q4_K_M"
], check=True)

# Delete the large F16 text model (no longer needed)
os.remove(f16_gguf)
print(f"\nDeleted intermediate F16 model to save space.")

# Print final sizes
print("\n" + "=" * 50)
print("GGUF FILES READY")
print("=" * 50)
for f in [q4_gguf, mmproj_gguf]:
    size_mb = os.path.getsize(f) / (1024 * 1024)
    print(f"  {os.path.basename(f)}: {size_mb:.0f} MB")

## 15. Package GGUF for Download & Ollama Setup

In [None]:
import subprocess

gguf_dir = os.path.join(OUTPUT_DIR, "gguf")
q4_gguf = os.path.join(gguf_dir, "medgemma-isic-Q4_K_M.gguf")
mmproj_gguf = os.path.join(gguf_dir, "mmproj-medgemma-isic-f16.gguf")

# Create Ollama Modelfile
modelfile_content = """FROM ./medgemma-isic-Q4_K_M.gguf
PROJECTOR ./mmproj-medgemma-isic-f16.gguf

SYSTEM \"\"\"You are a dermatology AI assistant specialized in skin lesion analysis. Analyze dermoscopic images and classify lesions as benign or malignant with a confidence score.\"\"\"

PARAMETER temperature 0.1
PARAMETER top_p 0.9
PARAMETER num_predict 200
PARAMETER stop <end_of_turn>
"""

modelfile_path = os.path.join(gguf_dir, "Modelfile")
with open(modelfile_path, "w") as f:
    f.write(modelfile_content)
print(f"Modelfile written to {modelfile_path}")

# Zip GGUF files + Modelfile for download
zip_path = os.path.join(OUTPUT_DIR, "medgemma-isic-gguf.zip")
subprocess.run([
    "zip", "-j", zip_path,
    q4_gguf,
    mmproj_gguf,
    modelfile_path,
], check=True)

zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
print(f"\n{'=' * 50}")
print(f"GGUF PACKAGE READY FOR DOWNLOAD")
print(f"{'=' * 50}")
print(f"  Archive: {zip_path} ({zip_size_mb:.0f} MB)")
print(f"\n  Contents:")
print(f"    - medgemma-isic-Q4_K_M.gguf  (quantized text model)")
print(f"    - mmproj-medgemma-isic-f16.gguf  (vision encoder)")
print(f"    - Modelfile  (Ollama model definition)")
print(f"\n  LOCAL SETUP INSTRUCTIONS:")
print(f"  1. Download medgemma-isic-gguf.zip from Kaggle output")
print(f"  2. unzip medgemma-isic-gguf.zip -d ~/models/isic-medgemma/")
print(f"  3. cd ~/models/isic-medgemma/")
print(f"  4. ollama create isic-medgemma -f Modelfile")
print(f"  5. ollama run isic-medgemma 'Analyze this skin lesion'")
print(f"\n  Python API:")
print(f"    python isic_ollama_serve.py --image lesion.jpg")