In [None]:
# ================================================================
# üìå PHASE 5 ‚Äî Gradio Interface (LIME + Integrated Gradients)
# ================================================================

!pip install -q gradio lime

import os
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import gradio as gr

from lime import lime_image
from skimage.segmentation import mark_boundaries
from sklearn.preprocessing import LabelEncoder
from google.colab import drive

# ------------------------------------------------
# 1Ô∏è‚É£ Mount Drive & load model / metadata
# ------------------------------------------------
drive.mount("/content/drive")

BASE_DIR   = "/content/drive/MyDrive"
MODEL_PATH = os.path.join(BASE_DIR, "AML2_Project_Models",
                          "efficientnetB0_phase3_final.keras")
META_PATH  = os.path.join(BASE_DIR, "Skin Cancer Images", "metadata.csv")

IMG_SIZE = 224

print("Loading model from:", MODEL_PATH)
model = tf.keras.models.load_model(MODEL_PATH, compile=False)

meta_df = pd.read_csv(META_PATH)
le = LabelEncoder()
le.fit(meta_df["diagnostic"])
class_names = list(le.classes_)
num_classes = len(class_names)
print("Classes:", class_names)

# Global LIME explainer
lime_explainer = lime_image.LimeImageExplainer()


# ------------------------------------------------
# 2Ô∏è‚É£ Preprocessing utilities
# ------------------------------------------------
def preprocess_image(image: np.ndarray) -> np.ndarray:
    """
    Gradio gives RGB uint8 (H, W, 3).
    We resize to 224x224 and scale to [0,1] float32.
    """
    if image is None:
        raise ValueError("No image supplied.")

    # Ensure 3-channel
    if image.ndim == 2:  # grayscale ‚Üí RGB
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    if image.shape[2] == 4:  # RGBA ‚Üí RGB
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

    img_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    img_norm = img_resized.astype("float32") / 255.0
    return img_norm


def denormalise_to_uint8(img_norm: np.ndarray) -> np.ndarray:
    """Convert [0,1] float image back to uint8 RGB."""
    img = np.clip(img_norm * 255.0, 0, 255).astype("uint8")
    return img


# ------------------------------------------------
# 3Ô∏è‚É£ LIME explanation
# ------------------------------------------------
def generate_lime(img_norm: np.ndarray) -> np.ndarray:
    """
    img_norm: (224, 224, 3) in [0,1]
    Returns RGB uint8 overlay with boundaries.
    """
    explanation = lime_explainer.explain_instance(
        image=img_norm,
        classifier_fn=lambda imgs: model.predict(np.array(imgs), verbose=0),
        top_labels=1,
        hide_color=0,
        num_samples=400,      # balanced between speed and quality
    )

    top_label = explanation.top_labels[0]
    temp, mask = explanation.get_image_and_mask(
        label=top_label,
        positive_only=True,
        num_features=8,
        hide_rest=False,
    )

    # temp is in original float range used by LIME; normalise and overlay
    temp_norm = temp - temp.min()
    if temp_norm.max() > 0:
        temp_norm = temp_norm / temp_norm.max()

    lime_vis = mark_boundaries(temp_norm, mask)
    lime_vis = (lime_vis * 255).astype("uint8")
    return lime_vis


# ------------------------------------------------
# 4Ô∏è‚É£ Integrated Gradients
# ------------------------------------------------
def integrated_gradients(
    model,
    img_norm: np.ndarray,
    target_index: int,
    baseline: np.ndarray | None = None,
    steps: int = 40,
) -> np.ndarray:
    """
    img_norm: (224, 224, 3) float32 in [0,1]
    Returns IG attribution (same shape).
    """

    img = tf.convert_to_tensor(img_norm[None, ...], dtype=tf.float32)

    if baseline is None:
        baseline = tf.zeros_like(img)
    else:
        baseline = tf.convert_to_tensor(baseline[None, ...], dtype=tf.float32)

    alphas = tf.linspace(0.0, 1.0, steps + 1)
    alphas_x = alphas[:, None, None, None]

    interpolated = baseline + alphas_x * (img - baseline)

    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        preds = model(interpolated, training=False)
        target = preds[:, target_index]

    grads = tape.gradient(target, interpolated)
    if grads is None:
        # Fallback to zeros if something goes wrong
        return np.zeros_like(img_norm, dtype="float32")

    avg_grads = tf.reduce_mean(grads, axis=0)  # (H, W, C)
    ig = (img[0] - baseline[0]) * avg_grads
    return ig.numpy()


def compute_ig_heatmap(img_norm: np.ndarray, target_index: int) -> np.ndarray:
    """
    Compute IG and convert to coloured heatmap (uint8 RGB).
    """

    # Guard against pathological inputs
    mean_val = img_norm.mean()
    if mean_val < 0.01 or mean_val > 0.99:
        return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype="uint8")

    ig = integrated_gradients(model, img_norm, target_index=target_index)
    ig_abs = np.abs(ig).mean(axis=-1)  # (H, W)

    ig_abs -= ig_abs.min()
    if ig_abs.max() > 0:
        ig_abs /= ig_abs.max() + 1e-8

    heatmap = (ig_abs * 255).astype("uint8")
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    return heatmap


# ------------------------------------------------
# 5Ô∏è‚É£ Prediction + outputs
# ------------------------------------------------
def analyse(image, explain_mode, theme):
    """
    Main Gradio callback.
    Returns:
      - summary text
      - confidence bar data (DataFrame)
      - LIME image (or None)
      - IG image (or None)
      - preprocessing image
      - explanation summary text
    """

    if image is None:
        return (
            "Please upload an image.",
            pd.DataFrame({"Class": [], "Probability": []}),
            None,
            None,
            None,
            "No explanation generated.",
        )

    # 1. Preprocess
    img_norm = preprocess_image(image)        # float32 [0,1], 224x224
    preproc_vis = denormalise_to_uint8(img_norm)

    # 2. Predict
    preds = model.predict(img_norm[None, ...], verbose=0)[0]
    top_idx = int(np.argmax(preds))
    top_class = class_names[top_idx]
    top_prob = float(preds[top_idx])

    # Confidence bars DataFrame
    prob_df = pd.DataFrame(
        {"Class": class_names, "Probability": preds.astype(float)}
    ).sort_values("Probability", ascending=False)

    # 3. Explanations (conditionally)
    lime_img = None
    ig_img = None

    if explain_mode in ["LIME", "LIME + IG"]:
        try:
            lime_img = generate_lime(img_norm)
        except Exception as e:
            print("LIME failed:", e)
            lime_img = None

    if explain_mode in ["Integrated Gradients", "LIME + IG"]:
        try:
            ig_img = compute_ig_heatmap(img_norm, target_index=top_idx)
        except Exception as e:
            print("IG failed:", e)
            ig_img = None

    # 4. Text summary
    summary_lines = [
        f"üîç **Predicted class:** `{top_class}`",
        f"üéØ **Confidence:** `{top_prob:.2%}`",
        "",
        "üìä Top probabilities:",
    ]
    for i in range(min(3, len(prob_df))):
        row = prob_df.iloc[i]
        summary_lines.append(
            f"- **{row['Class']}** ‚Äî {row['Probability']:.2%}"
        )

    summary_text = "\n".join(summary_lines)

    # 5. Explanation summary text
    if explain_mode == "None (fast prediction)":
        explain_text = (
            "Explanations are disabled. Select **LIME**, **Integrated Gradients**, "
            "or **LIME + IG** to see why the model made this decision."
        )
    else:
        explain_bits = []
        if "LIME" in explain_mode:
            explain_bits.append(
                "LIME highlights superpixels that *locally* support the prediction."
            )
        if "Integrated Gradients" in explain_mode:
            explain_bits.append(
                "Integrated Gradients shows pixels that most influence the model's "
                "output along a smooth path from a blank baseline."
            )
        explain_text = " ".join(explain_bits)

    # Theme currently only affects tone; could be extended for styling
    if theme == "Dark":
        summary_text = "üï∂Ô∏è **Dark mode** enabled.\n\n" + summary_text

    return summary_text, prob_df, lime_img, ig_img, preproc_vis, explain_text


# ------------------------------------------------
# 6Ô∏è‚É£ Build Gradio UI with tabs & confidence bars
# ------------------------------------------------
with gr.Blocks(title="AI-Powered Skin Lesion Classifier") as demo:
    gr.Markdown(
        "## ü©∫ AI-Powered Skin Lesion Classifier\n"
        "Deep learning‚Äìbased **EfficientNetB0** model for classifying dermatoscopic "
        "images into six diagnostic categories, with explainable AI overlays "
        "(**LIME** & **Integrated Gradients**)."
    )

    with gr.Tab("Classifier"):
        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(
                    type="numpy",
                    label="Upload Skin Lesion Image",
                )

                gr.Markdown("### Explanation method")
                explain_mode = gr.Radio(
                    choices=[
                        "LIME",
                        "Integrated Gradients",
                        "LIME + IG",
                        "None (fast prediction)",
                    ],
                    value="LIME + IG",
                    label="",
                )

                gr.Markdown("### Theme")
                theme = gr.Radio(
                    choices=["Light", "Dark"],
                    value="Light",
                    label="",
                )

                run_btn = gr.Button("üîé Run analysis", variant="primary")

            with gr.Column(scale=2):
                summary_out = gr.Markdown(label="Prediction summary")

                gr.Markdown("### Confidence by class")
                confidence_bar = gr.BarPlot(
                    value=None,
                    x="Class",
                    y="Probability",
                    x_title="Class",
                    y_title="Probability",
                    vertical=False,
                    interactive=False,
                    color="Class",
                    label="Class probabilities",
                )

                gr.Markdown("### Explanation visualisations")
                with gr.Row():
                    lime_out = gr.Image(label="LIME explanation", type="numpy")
                    ig_out = gr.Image(label="Integrated Gradients heatmap",
                                      type="numpy")

                gr.Markdown("### Preprocessing view")
                preproc_out = gr.Image(label="Model input (224√ó224 normalised)",
                                       type="numpy")

                explain_summary = gr.Markdown(
                    label="Explanation summary",
                )

        run_btn.click(
            fn=analyse,
            inputs=[image_input, explain_mode, theme],
            outputs=[
                summary_out,
                confidence_bar,
                lime_out,
                ig_out,
                preproc_out,
                explain_summary,
            ],
        )

    with gr.Tab("About & Disclaimer"):
        gr.Markdown(
            """
### ‚ÑπÔ∏è About this demo

- **Model:** EfficientNetB0 fine-tuned on dermatoscopic skin-lesion images.
- **Task:** Multi-class classification into six diagnostic categories.
- **Explainability:**
  - **LIME** ‚Äì local superpixel-based explanation.
  - **Integrated Gradients** ‚Äì pixel-level attribution along a path from a blank baseline.

### ‚ö†Ô∏è Medical disclaimer

This tool is a **research prototype** designed for educational purposes only.

It **must not** be used for clinical diagnosis, treatment decisions, or as a
replacement for professional medical advice. Always consult a qualified
dermatologist or medical professional for any concerns regarding skin lesions.

By using this demo you agree that the authors and maintainers are **not
responsible** for any decisions made based on its output.
"""
        )

# ------------------------------------------------
# 7Ô∏è‚É£ Launch
# ------------------------------------------------
demo.launch(debug=True)
