# Dual-Mode Grad-CAM: Visualizing Positive & Negative Classes

This notebook allows you to toggle between visualizing the **Parasite** (Positive) and the **Healthy Cell** (Negative).

### How to use:
In the last cell, change the `VISUALIZE_NEGATIVE` setting:
* `False`: Shows why the model thinks it is **Parasitized** (highlights the dot/ring).
* `True`: Shows why the model thinks it is **Uninfected** (highlights the clean empty space).

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg19 import preprocess_input

print(f"TensorFlow Version: {tf.__version__}")

### 1. Robust Helper Functions (Updated for Inversion)
We added `invert_gradients` to `make_gradcam_heatmap` to handle the negative class.

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None, invert_gradients=False):
    """
    Generates Grad-CAM heatmap. 
    Args:
        invert_gradients (bool): If True, visualizes the 'Negative' class (pushes score to 0).
    """
    # 1. Find the target layer and its containing model (Split Execution Strategy)
    target_layer = None
    containing_model = None
    
    # Check top level first
    try:
        target_layer = model.get_layer(last_conv_layer_name)
        containing_model = model
    except ValueError:
        # Check nested models (e.g., vgg19 inside Sequential)
        for layer in model.layers:
            if isinstance(layer, tf.keras.Model):
                try:
                    target_layer = layer.get_layer(last_conv_layer_name)
                    containing_model = layer
                    # print(f"âœ“ Found target '{last_conv_layer_name}' inside nested layer '{layer.name}'")
                    break
                except ValueError:
                    continue
    
    if target_layer is None:
         raise ValueError(f"Layer {last_conv_layer_name} not found in model.")

    # 2. Define the gradient tape strategy
    with tf.GradientTape() as tape:
        # CASE A: Standard Flat Model
        if containing_model == model:
            grad_model = tf.keras.models.Model(
                [model.inputs], [target_layer.output, model.output]
            )
            conv_out, preds = grad_model(img_array)
            
        # CASE B: Nested Model (Split Execution)
        else:
            part1_model = tf.keras.models.Model(
                [containing_model.inputs], 
                [target_layer.output, containing_model.output]
            )
            conv_out, part1_out = part1_model(img_array)
            
            # Manually feed through the rest
            x = part1_out
            start_forwarding = False
            for layer in model.layers:
                if layer == containing_model:
                    start_forwarding = True
                    continue
                if start_forwarding:
                    x = layer(x)
            preds = x

        # 3. Compute Gradients
        if model.output_shape[-1] == 1:
            pred_index = 0
        elif pred_index is None:
            pred_index = tf.argmax(preds[0])
            
        class_channel = preds[:, pred_index]

    # 4. Get Gradients
    grads = tape.gradient(class_channel, conv_out)
    
    # --- INVERT GRADIENTS FOR NEGATIVE CLASS ---
    if invert_gradients:
        grads = -grads

    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    conv_out = conv_out[0]
    heatmap = conv_out @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # Normalize with NaN safety
    denom = tf.math.reduce_max(heatmap)
    heatmap = tf.maximum(heatmap, 0) / (denom + 1e-8)
    
    return heatmap.numpy()

def save_and_display_gradcam(img_path, heatmap, alpha=0.4, title="Grad-CAM"):
    img = image.load_img(img_path)
    img = image.img_to_array(img)
    
    heatmap = np.uint8(255 * heatmap)
    jet = cm.get_cmap("jet")
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    
    jet_heatmap = image.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = image.img_to_array(jet_heatmap)
    
    superimposed_img = jet_heatmap * alpha + img
    superimposed_img = image.array_to_img(superimposed_img)
    
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image.load_img(img_path))
    plt.title("Original")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(superimposed_img)
    plt.title(title)
    plt.axis('off')
    plt.show()

### 2. Load Model & Auto-Detect Settings

In [None]:
# Load Model
model_path = 'malaria_final.h5'
try:
    model = load_model(model_path, compile=False)
    print("Model loaded successfully!")
    
    input_shape = model.input_shape
    if input_shape and input_shape[1]:
        IMG_SIZE = (input_shape[1], input_shape[2])
    else:
        IMG_SIZE = (224, 224)
    print(f"Input Size: {IMG_SIZE}")
    
except Exception as e:
    print(f"Error: {e}")

### 3. Execution Settings (TOGGLE HERE)

In [None]:
# --- CONFIGURATION ---
img_path = "path/to/your/image.png"  # <--- UPDATE THIS
last_conv_layer_name = "block5_conv4"
USE_VGG_PREPROCESSING = True 

# --- TOGGLE CLASS ---
# False = Visualize Parasite (Positive)
# True  = Visualize Healthy/Uninfected (Negative)
VISUALIZE_NEGATIVE = False

try:
    # 1. Load & Preprocess
    img = image.load_img(img_path, target_size=IMG_SIZE)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array.astype(np.float32)

    if USE_VGG_PREPROCESSING:
        img_array = preprocess_input(img_array)
    else:
        img_array = img_array / 255.0

    # 2. Safety
    try:
        model.layers[-1].activation = None
    except:
        pass

    # 3. Generate
    heatmap = make_gradcam_heatmap(
        img_array, 
        model, 
        last_conv_layer_name, 
        pred_index=None, 
        invert_gradients=VISUALIZE_NEGATIVE
    )

    # 4. Display
    title = "Grad-CAM: Evidence for Uninfected" if VISUALIZE_NEGATIVE else "Grad-CAM: Evidence for Parasite"
    save_and_display_gradcam(img_path, heatmap, title=title)
    
    raw_score = model.predict(img_array)[0][0]
    print(f"Raw Model Score: {raw_score:.4f} (High=Parasite, Low=Uninfected)")

except Exception as e:
    print(f"Error: {e}")