# Laboratorio 3: SHAP MedMNIST

## ========= 1) Setup =========

In [None]:
# Install required libraries if needed
# !pip install medmnist torch torchvision shap matplotlib

import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import shap
import numpy as np
import random
import os

import medmnist
from medmnist import INFO


import torch.nn as nn
import torch.nn.functional as F

## ========= 2) Dataset Info =========

In [None]:
data_flag = 'pathmnist'
download = True

info = INFO[data_flag]
n_classes = len(info['label'])
id2label = {int(k): v for k, v in info['label'].items()}

print(f"Dataset: {info['description']}")
print(f"Task: {info['task']}, Classes: {n_classes}")
print("Classes:", id2label)

## ========= 3) Load dataset =========

In [None]:
DataClass = getattr(medmnist, info['python_class'])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

train_dataset = DataClass(split='train', transform=transform, download=download)
test_dataset  = DataClass(split='test', transform=transform, download=download)

print("Train size:", len(train_dataset), " Test size:", len(test_dataset))

# Show 5 sample images
fig, axs = plt.subplots(1, 5, figsize=(10, 2))
for i in range(5):
    img, label = train_dataset[i]

    # Ensure img is a numpy array
    if isinstance(label, torch.Tensor):
        lbl = label.item()
    elif isinstance(label, np.ndarray):
        lbl = label.item()
    else:
        lbl = label

    img = img.numpy().transpose(1, 2, 0).squeeze()
    axs[i].imshow(img, cmap="gray")
    axs[i].set_title(f"{id2label[int(lbl)]}")  # ✅ cast label to int
    axs[i].axis("off")
plt.show()

## ========= 4) Define Simple CNN Model (already provided) =========

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(n_channels, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*14*14, n_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        return x

model = SimpleCNN(info['n_channels'], n_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Model class defined and ready.")

## ========= 5) Load Pretrained Model =========

In [None]:
MODEL_PATH = "pathmnist_simplecnn.pth"

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

print("Pretrained model loaded successfully!")

## ========= 6) Predictions + SHAP Explanations =========

**In this cell, you will:**
1. Write a helper function so SHAP can call the model
2. Pick one test image and predict its class
3. Compare prediction vs. true label
4. Use SHAP to explain WHY the model made that prediction
5. Visualize the results

## ========= 6) Predictions + SHAP Explanations =========

In [None]:
# In this cell, you will:
#   1. Write a helper function so SHAP can call the model
#   2. Pick one test image and predict its class
#   3. Compare prediction vs. true label
#   4. Use SHAP to explain WHY the model made that prediction
#   5. Visualize the results

### --- Step 1: Helper function for SHAP ---

In [None]:
# HINT: x will come as a numpy array with shape (N, H, W, C).
#       Convert it to a torch tensor (N, C, H, W), run through the model,
#       return probabilities as numpy.
def model_forward(x):
    # Ensure batch np.array
    if isinstance(x, list):
        x = np.stack(x, axis=0)

    x = x.astype(np.float32)

    # If images are 0–255, scale to 0–1
    if x.max() > 1.0:
        x = x / 255.0

    # Match training normalization: Normalize(mean=0.5, std=0.5) per channel
    x = (x - 0.5) / 0.5  # -> [-1, 1]

    # To torch tensor with shape (N, C, H, W)
    xt = torch.from_numpy(x).permute(0, 3, 1, 2).to(device)

    # Forward pass -> probabilities
    with torch.no_grad():
        logits = model(xt)
        probs = torch.softmax(logits, dim=1).cpu().numpy()

    return probs

### --- Step 2: Pick one test image ---

In [None]:
# TODO: Select an image from test_dataset
# HINT: take sample_img, sample_label = test_dataset[0] (or a random index)
# TODO: Predict class probabilities using the model
# TODO: Print predicted class (with probability) and true label
# (use id2label to show class names)

### --- Step 3: Prepare image for SHAP ---

In [None]:
# TODO: Convert the image into numpy format (H, W, C)
# HINT: remember test_dataset gives (C, H, W), so you might need np.transpose

### --- Step 4: Create SHAP explainer ---

In [None]:
# TODO: Create a masker for images
# HINT: shap.maskers.Image("blur(28,28)", img_np.shape)
# TODO: Create an Explainer with (model_forward, masker)
# TODO: Run explainer on your selected image

### --- Step 5: Visualize ---

In [None]:
# TODO: Plot the original image and the SHAP heatmap side by side
# HINT: use matplotlib subplots

## ========= 7) Extension: Multiple Images =========

In [None]:
# TODO: Loop over 5 random test images
# For each:
#   - Show original image with true label
#   - Predict with the model and show predicted label + probability
#   - Plot SHAP heatmap for predicted class
# HINT: Use matplotlib with 2 rows and 5 columns

## ========= 8) Reflection =========

In [None]:
# Answer in text (Markdown or comments):
# 1. Why did the model predict this class?
# 2. Are the SHAP heatmaps focusing on meaningful regions?
# 3. What differences do you see between correct and incorrect predictions?
# 4. How could interpretability help improve this model?