In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D
from tensorflow.keras.preprocessing import image

def load_and_preprocess_image(img_path, target_size=(224,224)):
    img = image.load_img(img_path, target_size=target_size)
    arr = image.img_to_array(img) / 255.0
    return np.expand_dims(arr, axis=0), img

def find_last_conv_layer_name(model):
    for layer in reversed(model.layers):
        if isinstance(layer, (Conv2D, DepthwiseConv2D)):
            return layer.name
    raise ValueError("No Conv2D/DepthwiseConv2D layer found.")

def make_gradcam_heatmap(img_tensor, model, class_index=None, last_conv_layer_name=None):
    if last_conv_layer_name is None:
        last_conv_layer_name = find_last_conv_layer_name(model)

    grad_model = tf.keras.models.Model(
        [model.inputs],
        [model.get_layer(last_conv_layer_name).output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_out, preds = grad_model(img_tensor, training=False)
        if class_index is None:
            class_index = int(tf.argmax(preds[0]))
        loss = preds[:, class_index]

    grads = tape.gradient(loss, conv_out)
    pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))
    conv_out = conv_out[0]
    heatmap = tf.reduce_sum(conv_out * pooled_grads, axis=-1)
    heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-8)
    return heatmap.numpy(), class_index, last_conv_layer_name

def overlay_heatmap_on_image(pil_img, heatmap, alpha=0.4):
    img = np.array(pil_img).astype(np.float32)
    heat = tf.image.resize(heatmap[..., np.newaxis], (img.shape[0], img.shape[1])).numpy().squeeze()
    heat = np.uint8(255 * heat)
    cmap = plt.get_cmap("jet")
    colored = cmap(heat)[:, :, :3] * 255.0
    out = (1 - alpha) * img + alpha * colored
    return np.clip(out, 0, 255).astype(np.uint8)

img_path = input().strip()
x, pil_img = load_and_preprocess_image(img_path, target_size=(img_size, img_size))
heatmap, pred_idx, layer_name = make_gradcam_heatmap(x, medinet_xg)

overlay = overlay_heatmap_on_image(pil_img, heatmap, alpha=0.4)
pred_label = class_names[pred_idx]

plt.figure(figsize=(6,6))
plt.imshow(overlay)
plt.axis("off")
plt.title(f"{pred_label} | {layer_name}")
plt.tight_layout()
out_path = f"{OUTPUT_DIR}/gradcam_overlay_MediNet_XG.png"
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
plt.close()
