In [None]:
"""
train_multilabel_efficientnetv2.py

Requirements:
- Python 3.8+
- torch
- torchvision
- timm
- pandas
- scikit-learn
- tqdm
Optional (but recommended):
- tensorboard (for logging)
Install example:
 pip install torch torchvision timm pandas scikit-learn tqdm tensorboard
"""

import os
import math
import time
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import timm  # for EfficientNetV2

from sklearn.metrics import roc_auc_score, f1_score

# ---------------------------
# USER CONFIG: modify these
# ---------------------------
TRAIN_CSV ="/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/Adi_test/output_90.csv"   # CSV must contain: filename, and label columns as 0/1 for each pathology
VAL_CSV   = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/Adi_test/output_10.csv" 
IMAGE_FOLDER = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/OCT_layerwise_classification_dataset_15k/oct_data_15k/data"
IMAGE_SIZE = 384                   # typical EfficientNetV2-S input 384 or 384x384
BATCH_SIZE = 16
NUM_EPOCHS = 50
NUM_WORKERS = 6
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
PATIENCE_ES = 8    # early stopping patience
PATIENCE_LR = 3    # ReduceLROnPlateau patience
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "tf_efficientnet_b3_ns"
NUM_GPUS = torch.cuda.device_count()
USE_WEIGHTED_SAMPLER = True   # set False to disable oversampling
MIXED_PRECISION = True       # use torch.cuda.amp when True and CUDA available

# CSV label columns: must match the names in your CSV that correspond to each class.
# Order matters: this will be the order for model outputs.
LABEL_COLUMNS = [
    # vitreomacular_layer (5)
    'Vitreomacular Traction(#D95030)',
    'Epiretinal Membrane(ERM)(#EA899A)',
    'Full Thickness Macular Hole(FTMH)(#F54021)',
    'Lamellar Macular Hole(LMH)(#F3A505)',
    'Pseudo Macular Hole(#79553D)',

    # intraretinal_layer (4)
    'Intraretinal Fluid/Spongiform Edema(#EA899A)',
    'Subretinal Fluid(IRL)(#B44C43)',
    'Cystoid Macular Edema(CME)(#00BB2D)',
    'Hyperreflective Intraretinal Foci(#EFA94A)',

    # subretinal_layer (5)
    'Subretinal Fluid(SRL)(#8673A1)',
    'Subretinal Hyperreflective Material(SHRM)(#6A5D4D)',
    'Drusen(#FAD201)',
    'CNVM(#316650)',
    'PED(#0E294B)',
]
# ---------------------------


# ---------------------------
# Dataset
# ---------------------------
class OCTMultiLabelDataset(Dataset):
    def __init__(self, df: pd.DataFrame, image_folder: str, label_cols: List[str],
                 transform=None):
        """
        df: DataFrame must contain a column 'filename' with file names relative to image_folder,
            and label columns with 0/1 values.
        """
        self.df = df.reset_index(drop=True)
        self.image_folder = image_folder
        self.label_cols = label_cols
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_folder, str(row['filename']))
        # load image with PIL
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        labels = torch.tensor(row[self.label_cols].values.astype(np.float32))
        return img, labels


# ---------------------------
# Utilities for balancing
# ---------------------------
def compute_pos_weight(df: pd.DataFrame, label_cols: List[str]) -> torch.Tensor:
    """
    Compute pos_weight per class for torch.nn.BCEWithLogitsLoss.
    pos_weight = #neg / #pos
    """
    pos_weights = []
    for c in label_cols:
        pos = df[c].sum()
        neg = len(df) - pos
        # avoid division by zero; clamp pos to at least 1
        pos = max(pos, 1.0)
        pos_weights.append(neg / pos)
    return torch.tensor(pos_weights, dtype=torch.float32)


def make_weighted_sampler(df: pd.DataFrame, label_cols: List[str]) -> WeightedRandomSampler:
    """
    Build sample weights by averaging inverse label frequencies of labels present in each sample.
    This gives higher weight to samples containing rare labels.
    """
    # label frequencies
    freqs = df[label_cols].sum(axis=0).values.astype(float)
    # avoid zero freq
    freqs = np.clip(freqs, 1.0, None)
    inv_freq = 1.0 / freqs  # inverse frequency per label
    # for each sample, weight = mean(inv_freq for labels present), or sum if you prefer
    sample_weights = []
    labels = df[label_cols].values.astype(int)
    for row in labels:
        present = row.astype(bool)
        if present.sum() == 0:
            # no labels: give it small base weight (so we still sample some)
            sample_weights.append(0.1)
        else:
            w = inv_freq[present].mean()
            sample_weights.append(w)
    sample_weights = np.array(sample_weights, dtype=np.float32)
    sample_weights = sample_weights / sample_weights.mean()  # normalize
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    return sampler


# ---------------------------
# Model builder
# ---------------------------
def build_model(model_name: str, num_classes: int, pretrained=True):
    """
    Use timm to build EfficientNetV2-S and adapt final head for multi-label outputs.
    """
    model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
    # model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)
    return model


# ---------------------------
# Metrics
# ---------------------------
def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold=0.5) -> Dict:
    """
    y_true: (N, C) {0,1}
    y_prob: (N, C) probabilities (sigmoid applied)
    Returns per-class AUC, micro/macro F1, and overall average AUC.
    """
    n_classes = y_true.shape[1]
    aucs = []
    for c in range(n_classes):
        try:
            auc = roc_auc_score(y_true[:, c], y_prob[:, c])
        except ValueError:
            auc = float('nan')  # if only one class present in y_true
        aucs.append(auc)
    avg_auc = np.nanmean(aucs)
    y_pred = (y_prob >= threshold).astype(int)
    micro_f1 = f1_score(y_true.flatten(), y_pred.flatten(), average='micro', zero_division=0)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    return {
        "per_class_auc": aucs,
        "avg_auc": avg_auc,
        "micro_f1": micro_f1,
        "macro_f1": macro_f1
    }


# ---------------------------
# Early stopping
# ---------------------------
class EarlyStopping:
    def __init__(self, patience=10, mode='max', delta=0.0):
        """
        patience: epochs to wait after last improvement
        mode: 'max' if higher metric is better, 'min' if lower is better
        delta: minimum change to qualify as improvement
        """
        self.patience = patience
        self.mode = mode
        self.delta = delta
        self.best = None
        self.num_bad_epochs = 0
        self.is_done = False

    def step(self, metric):
        if self.best is None:
            self.best = metric
            self.num_bad_epochs = 0
            return False

        if self.mode == 'max':
            improved = metric > (self.best + self.delta)
        else:
            improved = metric < (self.best - self.delta)

        if improved:
            self.best = metric
            self.num_bad_epochs = 0
            return False
        else:
            self.num_bad_epochs += 1
            if self.num_bad_epochs >= self.patience:
                self.is_done = True
                return True
            return False


# ---------------------------
# Training / Validation loops
# ---------------------------
# def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
#     model.train()
#     running_loss = 0.0
#     all_targets = []
#     all_outputs = []
#     pbar = tqdm(loader, desc="Train", leave=False)
#     for images, targets in pbar:
#         images = images.to(device)
#         targets = targets.to(device)
#         optimizer.zero_grad()

#         if scaler is not None:
#             with torch.cuda.amp.autocast():
#                 outputs = model(images)
#                 loss = criterion(outputs, targets)
#             scaler.scale(loss).backward()
#             scaler.step(optimizer)
#             scaler.update()
#         else:
#             outputs = model(images)
#             loss = criterion(outputs, targets)
#             loss.backward()
#             optimizer.step()

#         running_loss += loss.item() * images.size(0)
#         all_targets.append(targets.detach().cpu().numpy())
#         # store probabilities (sigmoid)
#         all_targets.append(targets.detach().cpu().numpy())
#         all_outputs.append(torch.sigmoid(outputs.detach()).cpu().numpy())

#         pbar.set_postfix(loss=loss.item())

#     epoch_loss = running_loss / len(loader.dataset)
#     y_true = np.vstack(all_targets)
#     y_prob = np.vstack(all_outputs)
#     metrics = compute_metrics(y_true, y_prob)
#     return epoch_loss, metrics


# def validate_one_epoch(model, loader, criterion, device):
#     model.eval()
#     running_loss = 0.0
#     all_targets = []
#     all_outputs = []
#     with torch.no_grad():
#         pbar = tqdm(loader, desc="Valid", leave=False)
#         for images, targets in pbar:
#             images = images.to(device)
#             targets = targets.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, targets)
#             running_loss += loss.item() * images.size(0)
#             all_targets.append(targets.cpu().numpy())
#             all_targets.append(targets.detach().cpu().numpy())
#             all_outputs.append(torch.sigmoid(outputs.detach()).cpu().numpy())


        
#     epoch_loss = running_loss / len(loader.dataset)
#     # y_true = np.vstack(all_targets)
#     # y_prob = np.vstack(all_outputs)
#     # Safely concatenate and ensure same shape
#     y_true = np.concatenate(all_targets, axis=0)
#     y_prob = np.concatenate(all_outputs, axis=0)

# # Defensive check
#     assert y_true.shape == y_prob.shape, f"Shape mismatch: {y_true.shape} vs {y_prob.shape}" 
#     metrics = compute_metrics(y_true, y_prob)
#     return epoch_loss, metrics
def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss = 0.0
    all_targets, all_outputs = [], []
    pbar = tqdm(loader, desc="Train", leave=False)

    for images, targets in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * images.size(0)

        # Always detach before moving to CPU
        outputs = torch.sigmoid(outputs.detach()).cpu().numpy()
        targets = targets.detach().cpu().numpy()

        # Defensive shape check
        if outputs.shape != targets.shape:
            print(f"[Shape Mismatch] outputs: {outputs.shape}, targets: {targets.shape}")
            # Align only up to min batch length if needed
            min_len = min(outputs.shape[0], targets.shape[0])
            outputs, targets = outputs[:min_len], targets[:min_len]

        all_outputs.append(outputs)
        all_targets.append(targets)
        pbar.set_postfix(loss=loss.item())

    y_true = np.concatenate(all_targets, axis=0)
    y_prob = np.concatenate(all_outputs, axis=0)

    # Final shape assertion
    assert y_true.shape == y_prob.shape, f"Final shape mismatch: {y_true.shape} vs {y_prob.shape}"

    epoch_loss = running_loss / len(loader.dataset)
    metrics = compute_metrics(y_true, y_prob)
    return epoch_loss, metrics


def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_targets, all_outputs = [], []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Valid", leave=False)
        for images, targets in pbar:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * images.size(0)

            outputs = torch.sigmoid(outputs.detach()).cpu().numpy()
            targets = targets.detach().cpu().numpy()

            if outputs.shape != targets.shape:
                print(f"[Shape Mismatch] outputs: {outputs.shape}, targets: {targets.shape}")
                min_len = min(outputs.shape[0], targets.shape[0])
                outputs, targets = outputs[:min_len], targets[:min_len]

            all_outputs.append(outputs)
            all_targets.append(targets)

    y_true = np.concatenate(all_targets, axis=0)
    y_prob = np.concatenate(all_outputs, axis=0)

    assert y_true.shape == y_prob.shape, f"Final shape mismatch: {y_true.shape} vs {y_prob.shape}"

    epoch_loss = running_loss / len(loader.dataset)
    metrics = compute_metrics(y_true, y_prob)
    return epoch_loss, metrics


# ---------------------------
# Main training function
# ---------------------------
def main():
    print(f"Device: {DEVICE}, GPUs available: {NUM_GPUS}")
    # --- load csvs
    df_train = pd.read_csv(TRAIN_CSV)
    df_val = pd.read_csv(VAL_CSV)

    # ensure label columns exist and are integers 0/1
    for df in (df_train, df_val):
        for c in LABEL_COLUMNS:
            if c not in df.columns:
                raise ValueError(f"Label column {c} not found in CSV.")
            df[c] = df[c].astype(int)

    num_classes = len(LABEL_COLUMNS)

    # --- transforms (augmentations)
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.02, hue=0.01),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                             std=[0.229, 0.224, 0.225]),
    ])
    valid_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # --- datasets
    train_dataset = OCTMultiLabelDataset(df_train, IMAGE_FOLDER, LABEL_COLUMNS, transform=train_transform)
    valid_dataset = OCTMultiLabelDataset(df_val, IMAGE_FOLDER, LABEL_COLUMNS, transform=valid_transform)

    # --- sampler or simple shuffle
    if USE_WEIGHTED_SAMPLER:
        sampler = make_weighted_sampler(df_train, LABEL_COLUMNS)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
                                  num_workers=NUM_WORKERS, pin_memory=True)
    else:
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                  num_workers=NUM_WORKERS, pin_memory=True)

    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=True)

    # --- model
    model = build_model(MODEL_NAME, num_classes=num_classes, pretrained=True)

    # If model outputs logits for classification, timm has set num_classes already.
    # Move to device(s)
    if DEVICE == "cuda":
        model = model.to(DEVICE)
        if NUM_GPUS > 1:
            print(f"Using DataParallel across {NUM_GPUS} GPUs")
            model = nn.DataParallel(model)

    # --- loss with pos_weight
    pos_weight = compute_pos_weight(df_train, LABEL_COLUMNS).to(DEVICE)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    # --- optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    # scheduler: Reduce LR on plateau - monitor val avg_auc (higher is better) so we pass mode='max'
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                           patience=PATIENCE_LR, verbose=True)

    # --- AMP scaler
    scaler = torch.cuda.amp.GradScaler() if (MIXED_PRECISION and DEVICE == "cuda") else None

    # --- early stopping
    early_stopper = EarlyStopping(patience=PATIENCE_ES, mode='max')

    # bookkeeping
    best_val_auc = -np.inf
    history = {"train_loss": [], "val_loss": [], "train_auc": [], "val_auc": []}

    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"Epoch {epoch}/{NUM_EPOCHS}")
        start_time = time.time()

        train_loss, train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE, scaler=scaler)
        val_loss, val_metrics = validate_one_epoch(model, valid_loader, criterion, DEVICE)

        # scheduler step with val avg AUC
        val_avg_auc = val_metrics["avg_auc"]
        # If avg_auc is nan (rare when single-class), fallback to -inf
        if np.isnan(val_avg_auc):
            val_avg_auc_for_sched = -np.inf
        else:
            val_avg_auc_for_sched = val_avg_auc
        scheduler.step(val_avg_auc_for_sched)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_auc"].append(train_metrics["avg_auc"])
        history["val_auc"].append(val_avg_auc)

        print(f"Train loss: {train_loss:.4f} | Train AUC: {train_metrics['avg_auc']:.4f} | "
              f"Val loss: {val_loss:.4f} | Val AUC: {val_avg_auc:.4f} | Time: {time.time()-start_time:.1f}s")

        # save best
        if not np.isnan(val_avg_auc) and val_avg_auc > best_val_auc:
            best_val_auc = val_avg_auc
            # un-wrap DataParallel if used
            model_to_save = model.module if isinstance(model, nn.DataParallel) else model
            save_path = f"best_{MODEL_NAME}_multilabel.pth"
            torch.save({
                "model_state_dict": model_to_save.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "best_val_auc": best_val_auc,
                "label_cols": LABEL_COLUMNS
            }, save_path)
            print(f"Saved best model to {save_path}")

        # early stopping
        if early_stopper.step(val_avg_auc):
            print("Early stopping triggered.")
            break

    print("Training finished.")
    # optionally return model and history
    return model, history


if __name__ == "__main__":
    main()
