# 01 - Comprehensive EDA: ISIC 2018 Task 3 / HAM10000

**Objective:** Perform a thorough exploratory data analysis of the HAM10000 dermoscopic image dataset to inform modeling decisions for skin lesion classification.

---

## Table of Contents

1. **Setup & Data Loading**
2. **Dataset Overview & Schema Inspection**
3. **Class Distribution & Imbalance Analysis**
4. **Demographic Analysis** (Age, Sex)
5. **Lesion Localization Analysis**
6. **Missing Value Analysis**
7. **Lesion Grouping & Data Leakage Risk**
8. **Feature Correlations & Statistical Tests**
9. **Image Quality & Properties Analysis**
10. **Per-Class Sample Image Grid**
11. **Co-occurrence & Interaction Analysis**
12. **Class Imbalance Strategy Recommendations**
13. **Key Findings & Modeling Implications**

---

**Dataset:** ISIC 2018 Challenge Task 3 (HAM10000)
**Classes:** MEL (Melanoma), NV (Nevus), BCC (Basal Cell Carcinoma), AKIEC (Actinic Keratosis), BKL (Benign Keratosis), DF (Dermatofibroma), VASC (Vascular Lesion)

In [None]:
# =============================================================================
# 1. Setup & Data Loading
# =============================================================================
from __future__ import annotations

import warnings
from pathlib import Path
from collections import Counter

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
from scipy import stats

warnings.filterwarnings("ignore", category=FutureWarning)

# -- Plotting defaults --
sns.set_theme(style="whitegrid", font_scale=1.1)
plt.rcParams.update({
    "figure.dpi": 120,
    "savefig.dpi": 150,
    "axes.titlesize": 14,
    "axes.labelsize": 12,
    "figure.figsize": (12, 5),
})

# Clinical label mapping for readable names
LABEL_MAP = {
    "mel": "Melanoma",
    "nv": "Melanocytic Nevus",
    "bcc": "Basal Cell Carcinoma",
    "akiec": "Actinic Keratosis",
    "bkl": "Benign Keratosis",
    "df": "Dermatofibroma",
    "vasc": "Vascular Lesion",
}
CLASS_ORDER = ["nv", "mel", "bkl", "bcc", "akiec", "df", "vasc"]
PALETTE = sns.color_palette("Set2", n_colors=7)
CLASS_PALETTE = dict(zip(CLASS_ORDER, PALETTE))

# -- Resolve paths --
if Path("/content/DermaFusion").exists():
    PROJECT_ROOT = Path("/content/DermaFusion")
else:
    PROJECT_ROOT = Path.cwd().resolve().parents[0] if (Path.cwd() / "src").exists() is False else Path.cwd()

RAW_DIR = PROJECT_ROOT / "data" / "raw"
META_DIR = RAW_DIR / "metadata"
MERGED_CSV = META_DIR / "metadata_merged.csv"
HAM_META_CSV = META_DIR / "HAM10000_metadata.csv"
TRAIN_GT = META_DIR / "ISIC2018_Task3_Training_GroundTruth.csv"
GROUPINGS_CSV = META_DIR / "ISIC2018_Task3_Training_LesionGroupings.csv"

# -- Load metadata --
if MERGED_CSV.exists():
    df = pd.read_csv(MERGED_CSV)
    print(f"Loaded merged metadata: {MERGED_CSV}")
elif HAM_META_CSV.exists():
    df = pd.read_csv(HAM_META_CSV)
    print(f"Loaded HAM10000 metadata: {HAM_META_CSV}")
elif TRAIN_GT.exists() and GROUPINGS_CSV.exists():
    gt = pd.read_csv(TRAIN_GT)
    grp = pd.read_csv(GROUPINGS_CSV)
    df = gt.merge(grp, on="image", how="left").rename(columns={"image": "image_id"})
    print(f"Built metadata from ISIC 2018 ground truth + lesion groupings")
else:
    raise FileNotFoundError(
        "No metadata found. Place metadata_merged.csv, HAM10000_metadata.csv, or "
        "ISIC2018_Task3_Training_GroundTruth.csv + ISIC2018_Task3_Training_LesionGroupings.csv "
        "in data/raw/metadata/."
    )

# -- Derive 'dx' column if only one-hot ground truth is available --
if "dx" not in df.columns:
    class_cols = [c for c in ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"] if c in df.columns]
    if class_cols:
        df["dx"] = df[class_cols].idxmax(axis=1).str.lower()

# -- Resolve image directory --
raw_images = RAW_DIR / "images"
preprocessed = PROJECT_ROOT / "data" / "preprocessed_hair_removed" / "images"
if (raw_images / "train").exists():
    IMAGE_DIR = raw_images / "train"
elif raw_images.exists() and list(raw_images.glob("*.jpg")):
    IMAGE_DIR = raw_images
elif preprocessed.exists():
    IMAGE_DIR = preprocessed
else:
    IMAGE_DIR = raw_images

print(f"Project root : {PROJECT_ROOT}")
print(f"Image dir    : {IMAGE_DIR}")
print(f"Total samples: {len(df):,}")
print(f"Columns      : {list(df.columns)}")

In [None]:
# =============================================================================
# 2. Dataset Overview & Schema Inspection
# =============================================================================
print("=" * 60)
print("DATASET SCHEMA")
print("=" * 60)
print(f"\nShape: {df.shape[0]:,} rows x {df.shape[1]} columns\n")
print(df.dtypes.to_string())
print("\n" + "=" * 60)
print("FIRST 5 ROWS")
print("=" * 60)
display(df.head())

print("\n" + "=" * 60)
print("DESCRIPTIVE STATISTICS (Numeric)")
print("=" * 60)
display(df.describe().round(2))

print("\n" + "=" * 60)
print("DESCRIPTIVE STATISTICS (Categorical)")
print("=" * 60)
cat_cols = df.select_dtypes(include=["object"]).columns.tolist()
if cat_cols:
    display(df[cat_cols].describe())

# Unique value counts per column
print("\n" + "=" * 60)
print("UNIQUE VALUES PER COLUMN")
print("=" * 60)
unique_counts = pd.DataFrame({
    "column": df.columns,
    "dtype": df.dtypes.values,
    "unique": [df[c].nunique() for c in df.columns],
    "missing": df.isna().sum().values,
    "missing_pct": (df.isna().sum() / len(df) * 100).round(2).values,
    "sample_values": [str(df[c].dropna().unique()[:5].tolist()) for c in df.columns],
})
display(unique_counts)

In [None]:
# =============================================================================
# 3. Class Distribution & Imbalance Analysis
# =============================================================================
class_counts = df["dx"].value_counts().reindex(CLASS_ORDER)
class_pct = (class_counts / class_counts.sum() * 100).round(2)
majority_class = class_counts.max()
imbalance_ratio = (majority_class / class_counts).round(1)

summary = pd.DataFrame({
    "Class": CLASS_ORDER,
    "Full Name": [LABEL_MAP[c] for c in CLASS_ORDER],
    "Count": class_counts.values,
    "Percent (%)": class_pct.values,
    "Imbalance Ratio (vs majority)": imbalance_ratio.values,
})
display(summary)

fig, axes = plt.subplots(1, 3, figsize=(20, 5))

# Bar plot with counts and percentages
bars = axes[0].bar(CLASS_ORDER, class_counts.values, color=[CLASS_PALETTE[c] for c in CLASS_ORDER],
                   edgecolor="black", linewidth=0.5)
for bar, count, pct in zip(bars, class_counts.values, class_pct.values):
    axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 30,
                 f"{count:,}\n({pct}%)", ha="center", va="bottom", fontsize=9, fontweight="bold")
axes[0].set_title("Class Distribution (Counts)", fontweight="bold")
axes[0].set_ylabel("Number of Samples")
axes[0].set_xlabel("Diagnosis")

# Pie chart
axes[1].pie(class_counts.values, labels=[LABEL_MAP[c] for c in CLASS_ORDER],
            autopct="%1.1f%%", colors=[CLASS_PALETTE[c] for c in CLASS_ORDER],
            startangle=140, pctdistance=0.8, textprops={"fontsize": 8})
axes[1].set_title("Class Proportions", fontweight="bold")

# Imbalance ratio bar (log scale)
bars2 = axes[2].bar(CLASS_ORDER, imbalance_ratio.values, color=[CLASS_PALETTE[c] for c in CLASS_ORDER],
                    edgecolor="black", linewidth=0.5)
for bar, ratio in zip(bars2, imbalance_ratio.values):
    axes[2].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.2,
                 f"{ratio}x", ha="center", va="bottom", fontsize=9, fontweight="bold")
axes[2].set_yscale("log")
axes[2].set_title("Imbalance Ratio (Majority / Class)", fontweight="bold")
axes[2].set_ylabel("Ratio (log scale)")
axes[2].set_xlabel("Diagnosis")
axes[2].axhline(y=1, color="gray", linestyle="--", alpha=0.5)

plt.tight_layout()
plt.show()

# Effective number of samples (for class-balanced loss)
beta = 0.9999
effective_n = (1 - beta ** class_counts.values) / (1 - beta)
inv_freq_weights = (1.0 / class_counts.values)
inv_freq_weights = inv_freq_weights / inv_freq_weights.sum() * len(CLASS_ORDER)

weight_df = pd.DataFrame({
    "Class": CLASS_ORDER,
    "Count": class_counts.values,
    "Inverse-Freq Weight": inv_freq_weights.round(4),
    "Effective Samples (beta=0.9999)": effective_n.round(1),
})
print("\nClass weighting reference (for loss functions):")
display(weight_df)

In [None]:
# =============================================================================
# 4. Demographic Analysis (Age & Sex)
# =============================================================================

fig = plt.figure(figsize=(22, 18))
gs = gridspec.GridSpec(3, 3, hspace=0.35, wspace=0.3)

# --- 4a. Overall age distribution ---
if "age" in df.columns:
    ax1 = fig.add_subplot(gs[0, 0])
    df["age"].dropna().hist(bins=30, ax=ax1, color="steelblue", edgecolor="black", alpha=0.8)
    ax1.axvline(df["age"].median(), color="red", linestyle="--", linewidth=1.5, label=f"Median: {df['age'].median():.0f}")
    ax1.axvline(df["age"].mean(), color="orange", linestyle="--", linewidth=1.5, label=f"Mean: {df['age'].mean():.1f}")
    ax1.set_title("Overall Age Distribution", fontweight="bold")
    ax1.set_xlabel("Age")
    ax1.set_ylabel("Count")
    ax1.legend(fontsize=9)

    # --- 4b. Age by diagnosis (box + violin) ---
    ax2 = fig.add_subplot(gs[0, 1:])
    order = CLASS_ORDER
    sns.violinplot(data=df, x="dx", y="age", order=order, palette=CLASS_PALETTE,
                   inner=None, alpha=0.3, ax=ax2)
    sns.boxplot(data=df, x="dx", y="age", order=order, palette=CLASS_PALETTE,
                width=0.3, boxprops=dict(alpha=0.8), ax=ax2)
    ax2.set_title("Age Distribution by Diagnosis (Violin + Box)", fontweight="bold")
    ax2.set_xlabel("Diagnosis")
    ax2.set_ylabel("Age")

    # --- 4c. Age density per class (overlaid KDE) ---
    ax3 = fig.add_subplot(gs[1, 0:2])
    for cls in CLASS_ORDER:
        subset = df.loc[df["dx"] == cls, "age"].dropna()
        if len(subset) > 1:
            subset.plot.kde(ax=ax3, label=f"{cls.upper()} (n={len(subset)})",
                            color=CLASS_PALETTE[cls], linewidth=1.5)
    ax3.set_title("Age Density per Diagnosis (KDE)", fontweight="bold")
    ax3.set_xlabel("Age")
    ax3.set_ylabel("Density")
    ax3.legend(fontsize=8, ncol=2)
    ax3.set_xlim(0, 100)

    # --- 4d. Age statistics table ---
    ax4 = fig.add_subplot(gs[1, 2])
    ax4.axis("off")
    age_stats = df.groupby("dx")["age"].agg(["count", "mean", "median", "std", "min", "max"]).reindex(CLASS_ORDER).round(1)
    age_stats.columns = ["N", "Mean", "Median", "Std", "Min", "Max"]
    tbl = ax4.table(cellText=age_stats.values, rowLabels=age_stats.index,
                    colLabels=age_stats.columns, loc="center", cellLoc="center")
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.2, 1.4)
    ax4.set_title("Age Statistics by Class", fontweight="bold", pad=20)

# --- 4e. Sex distribution (overall) ---
if "sex" in df.columns:
    ax5 = fig.add_subplot(gs[2, 0])
    sex_counts = df["sex"].value_counts()
    ax5.pie(sex_counts.values, labels=sex_counts.index, autopct="%1.1f%%",
            colors=sns.color_palette("Pastel1", n_colors=len(sex_counts)),
            startangle=90, textprops={"fontsize": 10})
    ax5.set_title("Overall Sex Distribution", fontweight="bold")

    # --- 4f. Sex by diagnosis (stacked %) ---
    ax6 = fig.add_subplot(gs[2, 1])
    sex_dx = pd.crosstab(df["dx"], df["sex"], normalize="index").reindex(CLASS_ORDER) * 100
    sex_dx.plot(kind="bar", stacked=True, ax=ax6, colormap="Pastel1", edgecolor="black", linewidth=0.5)
    ax6.set_title("Sex Distribution by Diagnosis (%)", fontweight="bold")
    ax6.set_ylabel("Percent")
    ax6.set_xlabel("Diagnosis")
    ax6.legend(title="Sex", fontsize=8)
    ax6.tick_params(axis="x", rotation=30)

    # --- 4g. Sex by diagnosis (grouped counts) ---
    ax7 = fig.add_subplot(gs[2, 2])
    sex_dx_counts = pd.crosstab(df["dx"], df["sex"]).reindex(CLASS_ORDER)
    sex_dx_counts.plot(kind="bar", ax=ax7, colormap="Set2", edgecolor="black", linewidth=0.5)
    ax7.set_title("Sex by Diagnosis (Absolute Counts)", fontweight="bold")
    ax7.set_ylabel("Count")
    ax7.set_xlabel("Diagnosis")
    ax7.legend(title="Sex", fontsize=8)
    ax7.tick_params(axis="x", rotation=30)

plt.suptitle("Section 4: Demographic Analysis", fontsize=16, fontweight="bold", y=1.01)
plt.show()

# --- 4h. Age distribution comparison: Male vs Female per class ---
if "age" in df.columns and "sex" in df.columns:
    fig, axes = plt.subplots(2, 4, figsize=(20, 8))
    axes = axes.flatten()
    for idx, cls in enumerate(CLASS_ORDER):
        ax = axes[idx]
        for sex_val in ["male", "female"]:
            subset = df[(df["dx"] == cls) & (df["sex"] == sex_val)]["age"].dropna()
            if len(subset) > 1:
                subset.plot.kde(ax=ax, label=sex_val.capitalize(), linewidth=1.5)
        ax.set_title(f"{cls.upper()}", fontweight="bold", fontsize=11)
        ax.set_xlabel("Age")
        ax.legend(fontsize=8)
        ax.set_xlim(0, 100)
    # Remove unused subplot
    if len(CLASS_ORDER) < len(axes):
        for i in range(len(CLASS_ORDER), len(axes)):
            axes[i].axis("off")
    plt.suptitle("Age Distribution: Male vs Female per Diagnosis", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()

In [None]:
# =============================================================================
# 5. Lesion Localization Analysis
# =============================================================================
if "localization" in df.columns:
    fig, axes = plt.subplots(2, 2, figsize=(22, 14))

    # --- 5a. Overall localization distribution ---
    loc_counts = df["localization"].value_counts().head(15)
    sns.barplot(x=loc_counts.values, y=loc_counts.index, palette="viridis", ax=axes[0, 0],
                edgecolor="black", linewidth=0.5)
    for i, (val, name) in enumerate(zip(loc_counts.values, loc_counts.index)):
        axes[0, 0].text(val + 10, i, f"{val:,} ({val / len(df) * 100:.1f}%)", va="center", fontsize=9)
    axes[0, 0].set_title("Top 15 Lesion Localizations", fontweight="bold")
    axes[0, 0].set_xlabel("Count")

    # --- 5b. Localization heatmap by diagnosis ---
    loc_dx = pd.crosstab(df["dx"], df["localization"]).reindex(CLASS_ORDER)
    sns.heatmap(loc_dx, cmap="YlOrRd", annot=True, fmt="d", linewidths=0.5,
                ax=axes[0, 1], cbar_kws={"shrink": 0.8})
    axes[0, 1].set_title("Diagnosis vs Localization (Counts)", fontweight="bold")
    axes[0, 1].tick_params(axis="x", rotation=45)

    # --- 5c. Normalized heatmap (row-normalized to show % within each class) ---
    loc_dx_pct = pd.crosstab(df["dx"], df["localization"], normalize="index").reindex(CLASS_ORDER) * 100
    sns.heatmap(loc_dx_pct, cmap="YlGnBu", annot=True, fmt=".1f", linewidths=0.5,
                ax=axes[1, 0], cbar_kws={"shrink": 0.8, "label": "% within class"})
    axes[1, 0].set_title("Localization Distribution within Each Class (%)", fontweight="bold")
    axes[1, 0].tick_params(axis="x", rotation=45)

    # --- 5d. Top localization per class ---
    top_loc_per_class = []
    for cls in CLASS_ORDER:
        cls_locs = df[df["dx"] == cls]["localization"].value_counts()
        if len(cls_locs) > 0:
            top_loc_per_class.append({
                "Class": cls.upper(),
                "Top 1": f"{cls_locs.index[0]} ({cls_locs.iloc[0]:,})",
                "Top 2": f"{cls_locs.index[1]} ({cls_locs.iloc[1]:,})" if len(cls_locs) > 1 else "-",
                "Top 3": f"{cls_locs.index[2]} ({cls_locs.iloc[2]:,})" if len(cls_locs) > 2 else "-",
                "Unique Sites": cls_locs.shape[0],
            })
    axes[1, 1].axis("off")
    tbl = axes[1, 1].table(
        cellText=[list(d.values()) for d in top_loc_per_class],
        colLabels=list(top_loc_per_class[0].keys()),
        loc="center", cellLoc="center",
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.3, 1.5)
    axes[1, 1].set_title("Top Localizations per Class", fontweight="bold", pad=20)

    plt.suptitle("Section 5: Localization Analysis", fontsize=16, fontweight="bold", y=1.01)
    plt.tight_layout()
    plt.show()
else:
    print("'localization' column not found - skipping localization analysis.")

In [None]:
# =============================================================================
# 6. Missing Value Analysis
# =============================================================================
fig, axes = plt.subplots(1, 3, figsize=(22, 6))

# --- 6a. Overall missing value bar chart ---
missing_abs = df.isna().sum().sort_values(ascending=False)
missing_pct = (missing_abs / len(df) * 100).round(2)
missing_df = pd.DataFrame({"count": missing_abs, "pct": missing_pct})
missing_present = missing_df[missing_df["count"] > 0]

if len(missing_present) > 0:
    bars = axes[0].barh(missing_present.index, missing_present["pct"], color="salmon", edgecolor="black")
    for bar, pct, cnt in zip(bars, missing_present["pct"], missing_present["count"]):
        axes[0].text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2,
                     f"{pct}% ({cnt:,})", va="center", fontsize=9)
    axes[0].set_title("Missing Values by Column", fontweight="bold")
    axes[0].set_xlabel("% Missing")
else:
    axes[0].text(0.5, 0.5, "No missing values!", ha="center", va="center", fontsize=14)
    axes[0].set_title("Missing Values by Column", fontweight="bold")

# --- 6b. Missing value heatmap (sample of rows) ---
sample_idx = np.random.RandomState(42).choice(len(df), min(200, len(df)), replace=False)
cols_of_interest = [c for c in ["age", "sex", "localization", "dx", "lesion_id", "dx_type"] if c in df.columns]
if cols_of_interest:
    sns.heatmap(df.iloc[sample_idx][cols_of_interest].isna().astype(int).T,
                cmap="YlOrRd", cbar_kws={"label": "Missing (1) / Present (0)"},
                ax=axes[1], yticklabels=True)
    axes[1].set_title("Missing Pattern (200-row sample)", fontweight="bold")
    axes[1].set_xlabel("Sample Index")

# --- 6c. Missingness by class ---
if "age" in df.columns:
    miss_by_class = df.groupby("dx")[cols_of_interest].apply(lambda x: x.isna().mean() * 100).round(1)
    if hasattr(miss_by_class, "reindex"):
        miss_by_class = miss_by_class.reindex(CLASS_ORDER)
    sns.heatmap(miss_by_class, annot=True, fmt=".1f", cmap="Oranges", linewidths=0.5,
                ax=axes[2], cbar_kws={"label": "% Missing"})
    axes[2].set_title("Missing % by Class", fontweight="bold")

plt.suptitle("Section 6: Missing Value Analysis", fontsize=16, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

# Summary table
print("\nMissing Value Summary:")
display(missing_df[missing_df["count"] > 0].rename(columns={"count": "Missing Count", "pct": "Missing %"}))

# =============================================================================
# 7. Lesion Grouping & Data Leakage Risk Analysis
# =============================================================================
if "lesion_id" in df.columns:
    lesion_counts = df["lesion_id"].value_counts()
    n_unique_lesions = lesion_counts.shape[0]
    n_images = len(df)
    multi_image_lesions = (lesion_counts > 1).sum()
    max_images_per_lesion = lesion_counts.max()

    print(f"Total images          : {n_images:,}")
    print(f"Unique lesions        : {n_unique_lesions:,}")
    print(f"Images / lesion ratio : {n_images / n_unique_lesions:.2f}")
    print(f"Lesions with >1 image : {multi_image_lesions:,} ({multi_image_lesions / n_unique_lesions * 100:.1f}%)")
    print(f"Max images per lesion : {max_images_per_lesion}")

    fig, axes = plt.subplots(1, 3, figsize=(22, 5))

    # --- 7a. Images per lesion histogram ---
    sns.histplot(lesion_counts, bins=range(1, lesion_counts.max() + 2), ax=axes[0],
                 color="steelblue", edgecolor="black")
    axes[0].set_title("Images per Lesion Distribution", fontweight="bold")
    axes[0].set_xlabel("Images per Lesion")
    axes[0].set_ylabel("Number of Lesions")

    # --- 7b. Duplicate lesion breakdown by class ---
    lesion_class = df.groupby("lesion_id")["dx"].first()
    multi_lesion_ids = lesion_counts[lesion_counts > 1].index
    multi_lesion_classes = lesion_class.loc[multi_lesion_ids].value_counts().reindex(CLASS_ORDER, fill_value=0)
    single_lesion_ids = lesion_counts[lesion_counts == 1].index
    single_lesion_classes = lesion_class.loc[single_lesion_ids].value_counts().reindex(CLASS_ORDER, fill_value=0)

    x = np.arange(len(CLASS_ORDER))
    width = 0.35
    axes[1].bar(x - width / 2, single_lesion_classes.values, width, label="Single-image lesions",
                color="skyblue", edgecolor="black", linewidth=0.5)
    axes[1].bar(x + width / 2, multi_lesion_classes.values, width, label="Multi-image lesions",
                color="coral", edgecolor="black", linewidth=0.5)
    axes[1].set_xticks(x)
    axes[1].set_xticklabels([c.upper() for c in CLASS_ORDER])
    axes[1].set_title("Single vs Multi-Image Lesions by Class", fontweight="bold")
    axes[1].set_ylabel("Number of Lesions")
    axes[1].legend()

    # --- 7c. Data leakage risk: what % of images share a lesion with another ---
    images_in_multi = df[df["lesion_id"].isin(multi_lesion_ids)].shape[0]
    images_in_single = n_images - images_in_multi
    axes[2].pie([images_in_single, images_in_multi],
                labels=["Unique lesion images", "Shared lesion images"],
                autopct="%1.1f%%", colors=["#66b3ff", "#ff6666"],
                startangle=90, textprops={"fontsize": 11})
    axes[2].set_title("Data Leakage Risk: Shared Lesion Images", fontweight="bold")

    plt.suptitle("Section 7: Lesion Grouping & Leakage Risk", fontsize=16, fontweight="bold", y=1.02)
    plt.tight_layout()
    plt.show()

    print("\nKey insight: Lesion-level splitting is REQUIRED to prevent data leakage.")
    print(f"If images from the same lesion appear in both train and val/test,")
    print(f"the model will memorize lesion appearance rather than learning diagnostic features.")
    print(f"Affected images: {images_in_multi:,} ({images_in_multi / n_images * 100:.1f}% of dataset)")
else:
    print("'lesion_id' column not found - skipping lesion grouping analysis.")

In [None]:
# =============================================================================
# 8. Feature Correlations & Statistical Tests
# =============================================================================

# --- 8a. Numeric metadata correlation matrix ---
numeric_cols = [c for c in df.columns if df[c].dtype in ["float64", "float32", "int64", "int32"]
                and c not in ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]]

fig, axes = plt.subplots(1, 2, figsize=(18, 6))

if len(numeric_cols) >= 2:
    corr = df[numeric_cols].corr()
    mask = np.triu(np.ones_like(corr, dtype=bool))
    sns.heatmap(corr, mask=mask, annot=True, fmt=".2f", cmap="RdBu_r", vmin=-1, vmax=1,
                linewidths=0.5, ax=axes[0], square=True)
    axes[0].set_title("Numeric Feature Correlations", fontweight="bold")
else:
    axes[0].text(0.5, 0.5, "Not enough numeric columns for correlation", ha="center", va="center")
    axes[0].set_title("Numeric Feature Correlations", fontweight="bold")

# --- 8b. Cramér's V for categorical associations ---
def cramers_v(x, y):
    """Cramér's V statistic for categorical-categorical association."""
    confusion_matrix = pd.crosstab(x, y)
    chi2 = stats.chi2_contingency(confusion_matrix)[0]
    n = confusion_matrix.sum().sum()
    min_dim = min(confusion_matrix.shape) - 1
    if min_dim == 0 or n == 0:
        return 0.0
    return np.sqrt(chi2 / (n * min_dim))

cat_features = [c for c in ["dx", "sex", "localization", "dx_type"] if c in df.columns]
if len(cat_features) >= 2:
    cv_matrix = pd.DataFrame(index=cat_features, columns=cat_features, dtype=float)
    for c1 in cat_features:
        for c2 in cat_features:
            valid = df[[c1, c2]].dropna()
            cv_matrix.loc[c1, c2] = cramers_v(valid[c1], valid[c2]) if len(valid) > 0 else 0.0
    cv_matrix = cv_matrix.astype(float)
    sns.heatmap(cv_matrix, annot=True, fmt=".3f", cmap="Purples", vmin=0, vmax=1,
                linewidths=0.5, ax=axes[1], square=True)
    axes[1].set_title("Cramér's V (Categorical Associations)", fontweight="bold")
else:
    axes[1].text(0.5, 0.5, "Not enough categorical columns", ha="center", va="center")
    axes[1].set_title("Cramér's V (Categorical Associations)", fontweight="bold")

plt.suptitle("Section 8: Feature Correlations & Statistical Tests", fontsize=16, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

# --- 8c. ANOVA: Is age significantly different across diagnoses? ---
if "age" in df.columns:
    groups = [group["age"].dropna().values for _, group in df.groupby("dx")]
    groups = [g for g in groups if len(g) > 1]
    if len(groups) >= 2:
        f_stat, p_val = stats.f_oneway(*groups)
        print(f"\nOne-way ANOVA (Age ~ Diagnosis):")
        print(f"  F-statistic = {f_stat:.4f}")
        print(f"  p-value     = {p_val:.2e}")
        print(f"  Significant = {'YES' if p_val < 0.05 else 'NO'} (alpha=0.05)")

# --- 8d. Chi-Square test: Sex vs Diagnosis ---
if "sex" in df.columns:
    contingency = pd.crosstab(df["dx"], df["sex"])
    chi2, p_val, dof, expected = stats.chi2_contingency(contingency)
    print(f"\nChi-Square Test (Sex x Diagnosis):")
    print(f"  Chi2       = {chi2:.4f}")
    print(f"  p-value    = {p_val:.2e}")
    print(f"  DoF        = {dof}")
    print(f"  Significant = {'YES' if p_val < 0.05 else 'NO'} (alpha=0.05)")

# --- 8e. Chi-Square test: Localization vs Diagnosis ---
if "localization" in df.columns:
    contingency_loc = pd.crosstab(df["dx"], df["localization"])
    chi2_loc, p_val_loc, dof_loc, _ = stats.chi2_contingency(contingency_loc)
    print(f"\nChi-Square Test (Localization x Diagnosis):")
    print(f"  Chi2       = {chi2_loc:.4f}")
    print(f"  p-value    = {p_val_loc:.2e}")
    print(f"  DoF        = {dof_loc}")
    print(f"  Significant = {'YES' if p_val_loc < 0.05 else 'NO'} (alpha=0.05)")

print("\nInterpretation: Significant results confirm metadata features carry diagnostic signal,"
      "\nsupporting the multi-modal approach (image + metadata).")

In [None]:
# =============================================================================
# 9. Image Quality & Properties Analysis
# =============================================================================
image_files = sorted(IMAGE_DIR.glob("*.jpg"))
if not image_files:
    # Try subdirectories
    for sub in ["train", "val", "test"]:
        image_files.extend(sorted((IMAGE_DIR / sub).glob("*.jpg")) if (IMAGE_DIR / sub).exists() else [])
    if not image_files:
        image_files = sorted(IMAGE_DIR.glob("**/*.jpg"))

n_sample = min(500, len(image_files))
rng = np.random.RandomState(42)
sampled_files = rng.choice(image_files, n_sample, replace=False) if len(image_files) > n_sample else image_files

print(f"Analyzing {len(sampled_files)} images (sampled from {len(image_files)} total)...")

img_stats = []
for path in sampled_files:
    with Image.open(path) as img:
        w, h = img.size
        arr = np.asarray(img.convert("RGB"), dtype=np.float32)
        img_stats.append({
            "file": path.stem,
            "width": w,
            "height": h,
            "aspect_ratio": round(w / h, 3),
            "mean_r": arr[:, :, 0].mean(),
            "mean_g": arr[:, :, 1].mean(),
            "mean_b": arr[:, :, 2].mean(),
            "std_r": arr[:, :, 0].std(),
            "std_g": arr[:, :, 1].std(),
            "std_b": arr[:, :, 2].std(),
            "brightness": arr.mean(),
            "contrast": arr.std(),
            "variance": arr.var(),
        })

img_df = pd.DataFrame(img_stats)
print(f"\nImage statistics summary (n={len(img_df)}):")
display(img_df[["width", "height", "aspect_ratio", "brightness", "contrast"]].describe().round(2))

fig, axes = plt.subplots(2, 3, figsize=(22, 12))

# --- 9a. Width/Height scatter ---
axes[0, 0].scatter(img_df["width"], img_df["height"], alpha=0.3, s=10, color="steelblue")
axes[0, 0].set_title("Image Dimensions (Width x Height)", fontweight="bold")
axes[0, 0].set_xlabel("Width (px)")
axes[0, 0].set_ylabel("Height (px)")
axes[0, 0].set_aspect("equal")

# --- 9b. Aspect ratio distribution ---
sns.histplot(img_df["aspect_ratio"], bins=50, ax=axes[0, 1], color="teal", edgecolor="black")
axes[0, 1].axvline(1.0, color="red", linestyle="--", label="Square (1:1)")
axes[0, 1].set_title("Aspect Ratio Distribution", fontweight="bold")
axes[0, 1].set_xlabel("Aspect Ratio (W/H)")
axes[0, 1].legend()

# --- 9c. Mean color channel distributions ---
for ch, col, name in zip(["mean_r", "mean_g", "mean_b"], ["red", "green", "blue"], ["R", "G", "B"]):
    sns.kdeplot(img_df[ch], ax=axes[0, 2], color=col, label=name, linewidth=1.5, fill=True, alpha=0.2)
axes[0, 2].set_title("Mean Color Channel Distribution", fontweight="bold")
axes[0, 2].set_xlabel("Mean Pixel Value (0-255)")
axes[0, 2].legend()

# --- 9d. Brightness distribution ---
sns.histplot(img_df["brightness"], bins=40, ax=axes[1, 0], color="gold", edgecolor="black")
axes[1, 0].axvline(img_df["brightness"].mean(), color="red", linestyle="--",
                    label=f"Mean: {img_df['brightness'].mean():.1f}")
axes[1, 0].set_title("Image Brightness Distribution", fontweight="bold")
axes[1, 0].set_xlabel("Mean Pixel Intensity")
axes[1, 0].legend()

# --- 9e. Contrast distribution ---
sns.histplot(img_df["contrast"], bins=40, ax=axes[1, 1], color="mediumpurple", edgecolor="black")
axes[1, 1].axvline(img_df["contrast"].mean(), color="red", linestyle="--",
                    label=f"Mean: {img_df['contrast'].mean():.1f}")
axes[1, 1].set_title("Image Contrast (Std Dev) Distribution", fontweight="bold")
axes[1, 1].set_xlabel("Pixel Std Dev")
axes[1, 1].legend()

# --- 9f. Brightness vs Contrast scatter ---
axes[1, 2].scatter(img_df["brightness"], img_df["contrast"], alpha=0.3, s=10, color="darkorange")
axes[1, 2].set_title("Brightness vs Contrast", fontweight="bold")
axes[1, 2].set_xlabel("Brightness (Mean Intensity)")
axes[1, 2].set_ylabel("Contrast (Std Dev)")

plt.suptitle("Section 9: Image Quality & Properties", fontsize=16, fontweight="bold", y=1.01)
plt.tight_layout()
plt.show()

# --- 9g. Per-class brightness/contrast comparison ---
if "dx" in df.columns and "image_id" in df.columns:
    img_df_merged = img_df.merge(df[["image_id", "dx"]].rename(columns={"image_id": "file"}),
                                  on="file", how="inner")
    if len(img_df_merged) > 20:
        fig, axes = plt.subplots(1, 2, figsize=(18, 5))
        sns.boxplot(data=img_df_merged, x="dx", y="brightness", order=CLASS_ORDER,
                    palette=CLASS_PALETTE, ax=axes[0])
        axes[0].set_title("Brightness by Diagnosis", fontweight="bold")
        axes[0].set_xlabel("Diagnosis")
        axes[0].set_ylabel("Mean Pixel Intensity")

        sns.boxplot(data=img_df_merged, x="dx", y="contrast", order=CLASS_ORDER,
                    palette=CLASS_PALETTE, ax=axes[1])
        axes[1].set_title("Contrast by Diagnosis", fontweight="bold")
        axes[1].set_xlabel("Diagnosis")
        axes[1].set_ylabel("Pixel Std Dev")

        plt.suptitle("Image Properties by Diagnosis", fontsize=14, fontweight="bold", y=1.02)
        plt.tight_layout()
        plt.show()

In [None]:
# =============================================================================
# 10. Per-Class Sample Image Grid
# =============================================================================

def find_image(image_id, search_dirs):
    """Search for an image file across multiple directories."""
    for d in search_dirs:
        for ext in [".jpg", ".jpeg", ".png"]:
            candidate = d / f"{image_id}{ext}"
            if candidate.exists():
                return candidate
    return None

search_dirs = [IMAGE_DIR]
for sub in ["train", "val", "test"]:
    d = IMAGE_DIR / sub
    if d.exists():
        search_dirs.append(d)
    d2 = IMAGE_DIR.parent / sub
    if d2.exists():
        search_dirs.append(d2)

n_samples_per_class = 5

if "dx" in df.columns and "image_id" in df.columns:
    fig, axes = plt.subplots(len(CLASS_ORDER), n_samples_per_class, figsize=(3 * n_samples_per_class, 3 * len(CLASS_ORDER)))

    for row_idx, cls in enumerate(CLASS_ORDER):
        class_rows = df[df["dx"] == cls].sample(
            n=min(n_samples_per_class, len(df[df["dx"] == cls])),
            random_state=42
        )
        for col_idx in range(n_samples_per_class):
            ax = axes[row_idx, col_idx]
            if col_idx < len(class_rows):
                image_id = str(class_rows.iloc[col_idx]["image_id"])
                img_path = find_image(image_id, search_dirs)
                if img_path is not None:
                    ax.imshow(Image.open(img_path).convert("RGB"))
                    if col_idx == 0:
                        ax.set_ylabel(f"{cls.upper()}\n({LABEL_MAP[cls]})", fontsize=10, fontweight="bold")
                else:
                    ax.text(0.5, 0.5, "Not found", ha="center", va="center")
            ax.axis("off")
            if row_idx == 0:
                ax.set_title(f"Sample {col_idx + 1}", fontsize=10)

    plt.suptitle("Section 10: Sample Images per Diagnosis Class", fontsize=16, fontweight="bold", y=1.01)
    plt.tight_layout()
    plt.show()

    # --- 10b. Mean image per class ---
    print("\nComputing mean image per class (up to 50 images each)...")
    fig, axes = plt.subplots(1, len(CLASS_ORDER), figsize=(3 * len(CLASS_ORDER), 3))
    target_size = (224, 224)

    for idx, cls in enumerate(CLASS_ORDER):
        class_ids = df[df["dx"] == cls]["image_id"].values
        rng_cls = np.random.RandomState(42)
        sample_ids = rng_cls.choice(class_ids, min(50, len(class_ids)), replace=False)

        pixel_acc = np.zeros((*target_size, 3), dtype=np.float64)
        count = 0
        for img_id in sample_ids:
            img_path = find_image(str(img_id), search_dirs)
            if img_path is not None:
                with Image.open(img_path) as im:
                    arr = np.asarray(im.convert("RGB").resize(target_size), dtype=np.float64)
                    pixel_acc += arr
                    count += 1

        if count > 0:
            mean_img = (pixel_acc / count).astype(np.uint8)
            axes[idx].imshow(mean_img)
        axes[idx].set_title(f"{cls.upper()}\n(n={count})", fontsize=10, fontweight="bold")
        axes[idx].axis("off")

    plt.suptitle("Mean Image per Class (50-image average)", fontsize=14, fontweight="bold", y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
# =============================================================================
# 11. Co-occurrence & Interaction Analysis
# =============================================================================

fig = plt.figure(figsize=(22, 14))
gs = gridspec.GridSpec(2, 3, hspace=0.35, wspace=0.3)

# --- 11a. Age-Sex interaction per class (strip plot) ---
if "age" in df.columns and "sex" in df.columns:
    ax1 = fig.add_subplot(gs[0, :2])
    sns.stripplot(data=df.dropna(subset=["age", "sex"]), x="dx", y="age", hue="sex",
                  order=CLASS_ORDER, dodge=True, alpha=0.4, size=3, palette="Set1", ax=ax1)
    sns.boxplot(data=df.dropna(subset=["age", "sex"]), x="dx", y="age", hue="sex",
                order=CLASS_ORDER, dodge=True, palette="Set1", ax=ax1,
                boxprops=dict(alpha=0.5), showfliers=False, width=0.6)
    handles, labels = ax1.get_legend_handles_labels()
    # Keep only first set of labels (avoid duplication from box+strip)
    n_sex = df["sex"].nunique()
    ax1.legend(handles[:n_sex], labels[:n_sex], title="Sex", fontsize=9)
    ax1.set_title("Age x Sex Interaction by Diagnosis", fontweight="bold")
    ax1.set_xlabel("Diagnosis")
    ax1.set_ylabel("Age")

# --- 11b. Diagnosis type distribution ---
if "dx_type" in df.columns:
    ax2 = fig.add_subplot(gs[0, 2])
    dx_type_counts = df["dx_type"].value_counts()
    ax2.pie(dx_type_counts.values, labels=dx_type_counts.index, autopct="%1.1f%%",
            colors=sns.color_palette("Pastel2", n_colors=len(dx_type_counts)),
            startangle=90, textprops={"fontsize": 9})
    ax2.set_title("Diagnosis Confirmation Method", fontweight="bold")

# --- 11c. Diagnosis type breakdown per class ---
if "dx_type" in df.columns:
    ax3 = fig.add_subplot(gs[1, :2])
    dx_type_cross = pd.crosstab(df["dx"], df["dx_type"], normalize="index").reindex(CLASS_ORDER) * 100
    dx_type_cross.plot(kind="bar", stacked=True, ax=ax3, colormap="tab20", edgecolor="black", linewidth=0.3)
    ax3.set_title("Diagnosis Confirmation Method by Class (%)", fontweight="bold")
    ax3.set_ylabel("Percent")
    ax3.set_xlabel("Diagnosis")
    ax3.legend(title="dx_type", fontsize=8, bbox_to_anchor=(1.02, 1), loc="upper left")
    ax3.tick_params(axis="x", rotation=30)

# --- 11d. Age-Localization heatmap (top localizations) ---
if "age" in df.columns and "localization" in df.columns:
    ax4 = fig.add_subplot(gs[1, 2])
    top_locs = df["localization"].value_counts().head(8).index
    age_loc = df[df["localization"].isin(top_locs)].groupby("localization")["age"].agg(["mean", "std", "count"])
    age_loc = age_loc.sort_values("mean", ascending=True)
    ax4.barh(age_loc.index, age_loc["mean"], xerr=age_loc["std"], color="steelblue",
             edgecolor="black", linewidth=0.5, capsize=3)
    ax4.set_title("Mean Age by Localization (top 8)", fontweight="bold")
    ax4.set_xlabel("Mean Age (+/- Std)")

plt.suptitle("Section 11: Co-occurrence & Interaction Analysis", fontsize=16, fontweight="bold", y=1.01)
plt.tight_layout()
plt.show()

In [None]:
# =============================================================================
# 12. Class Imbalance Strategy Recommendations
# =============================================================================

class_counts_sorted = df["dx"].value_counts().reindex(CLASS_ORDER)
total = class_counts_sorted.sum()
n_classes = len(CLASS_ORDER)

# Compute various weighting schemes
inv_freq = total / (n_classes * class_counts_sorted)
inv_freq_norm = inv_freq / inv_freq.sum() * n_classes

sqrt_inv_freq = np.sqrt(total / (n_classes * class_counts_sorted))
sqrt_inv_freq_norm = sqrt_inv_freq / sqrt_inv_freq.sum() * n_classes

# Effective number of samples (Class-Balanced Loss, Cui et al. 2019)
beta_values = [0.99, 0.999, 0.9999]
effective_weights = {}
for beta in beta_values:
    eff = (1 - beta) / (1 - beta ** class_counts_sorted)
    eff_norm = eff / eff.sum() * n_classes
    effective_weights[f"CB (beta={beta})"] = eff_norm.values

weight_comparison = pd.DataFrame({
    "Class": CLASS_ORDER,
    "Count": class_counts_sorted.values,
    "Uniform": [1.0] * n_classes,
    "Inv Freq": inv_freq_norm.values.round(4),
    "Sqrt Inv Freq": sqrt_inv_freq_norm.values.round(4),
})
for k, v in effective_weights.items():
    weight_comparison[k] = v.round(4)

print("Class Weighting Strategies Comparison:")
display(weight_comparison)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# Weight comparison
weight_cols = ["Uniform", "Inv Freq", "Sqrt Inv Freq"] + list(effective_weights.keys())
x = np.arange(n_classes)
width = 0.12
for i, col in enumerate(weight_cols):
    axes[0].bar(x + i * width, weight_comparison[col], width, label=col, edgecolor="black", linewidth=0.3)
axes[0].set_xticks(x + width * len(weight_cols) / 2)
axes[0].set_xticklabels([c.upper() for c in CLASS_ORDER])
axes[0].set_title("Class Weight Comparison (Normalized)", fontweight="bold")
axes[0].set_ylabel("Weight")
axes[0].legend(fontsize=8)
axes[0].set_xlabel("Diagnosis")

# Effective samples per class at different betas
for beta in beta_values:
    eff_samples = (1 - beta ** class_counts_sorted) / (1 - beta)
    axes[1].plot(CLASS_ORDER, eff_samples.values, "o-", label=f"beta={beta}", linewidth=1.5, markersize=5)
axes[1].plot(CLASS_ORDER, class_counts_sorted.values, "s--", color="black", label="Actual count", linewidth=1)
axes[1].set_yscale("log")
axes[1].set_title("Effective Number of Samples (Class-Balanced Loss)", fontweight="bold")
axes[1].set_ylabel("Effective Samples (log)")
axes[1].set_xlabel("Diagnosis")
axes[1].legend(fontsize=9)

plt.suptitle("Section 12: Class Imbalance Strategies", fontsize=16, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

# Clinical cost analysis
print("\nClinical Cost Analysis (for cost-sensitive loss):")
print("-" * 55)
print(f"  {'Class':<8} {'Clinical Risk':<20} {'Suggested FN Cost'}")
print("-" * 55)
clinical_costs = {
    "mel": ("HIGH - Lethal", "10x"),
    "bcc": ("MODERATE - Invasive", "5x"),
    "akiec": ("MODERATE - Pre-cancer", "3x"),
    "bkl": ("LOW - Benign", "1x"),
    "nv": ("LOW - Benign", "1x"),
    "df": ("LOW - Benign", "1x"),
    "vasc": ("LOW - Benign", "1x"),
}
for cls in CLASS_ORDER:
    risk, cost = clinical_costs[cls]
    print(f"  {cls.upper():<8} {risk:<20} {cost}")
print("-" * 55)
print("\nRecommendation: Use a 2-stage loss strategy:")
print("  Stage 1: Focal Loss with class weights (learn general features)")
print("  Stage 2: Cost-sensitive loss (penalize dangerous misses heavily)")

## 13. Key Findings & Modeling Implications

### Dataset Characteristics
- **Severe class imbalance**: NV (melanocytic nevus) dominates the dataset; rare classes (DF, VASC) are extremely underrepresented
- **Multi-image lesions**: Multiple images can belong to the same lesion — **lesion-level splitting is mandatory** to prevent data leakage
- **Missing metadata**: Age, sex, and localization have varying levels of missingness — imputation with missingness flags is appropriate

### Statistical Insights
- **Age is diagnostically informative**: ANOVA confirms significant age differences across diagnoses — certain lesions (e.g., melanoma, BCC) trend older
- **Sex and localization carry signal**: Chi-square tests show significant associations with diagnosis — supports multi-modal fusion
- **Diagnosis confirmation method varies**: Not all labels are histopathologically confirmed — consider label noise in modeling

### Modeling Recommendations

| Aspect | Recommendation |
|--------|---------------|
| **Data splitting** | Group by `lesion_id` to prevent leakage |
| **Class imbalance** | Use focal loss + class-balanced sampling; 2-stage training with cost-sensitive loss |
| **Metadata** | Include age, sex, localization as auxiliary features via late fusion or FiLM |
| **Missing values** | Median imputation + explicit missingness flags (already implemented) |
| **Augmentation** | Aggressive augmentation for minority classes; hair removal preprocessing |
| **Clinical safety** | High false-negative penalty for MEL (10x) and BCC (5x) in cost-sensitive stage |
| **Evaluation** | Use balanced accuracy, per-class recall, and confusion matrix — not just overall accuracy |
| **Image preprocessing** | Color constancy (Shades of Gray) to reduce dermoscope variability |