<a href="https://colab.research.google.com/github/ISE-CS4445-AI/challenge-7-tomjoyce1/blob/main/challenge-7.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Challenge #7: Interpretability and explainable ML

In this assignment, you will explore explainability using LIME on image classification. Your tasks are:

1. **get_lime_explanation: (3 points)**  
   Given a pretrained image classifier and an input image, generate a LIME explanation object.

2. **display_lime_explanation: (1 point)**  
   From the explanation object, extract and return a visualization (image and mask overlay) that highlights important superpixels.

3. **extract_feature_importance: (1 point)**  
   Extract a list of important feature (superpixel) contributions from the explanation object, sorted by importance.

4. **Task 4:** Apply a pretrained ResNet model on an input image and use LIME to generate and display explanations for the top 5 predicted classes.

After coding, answer three brief reflection questions on explainability methods.

*Total points: 9 (6 points for code tasks and 3 points for reflection questions).*

---

## Background on LIME

LIME (Local Interpretable Model-agnostic Explanations) is a popular method for explaining predictions of any classifier. For image classification, LIME works by:
- Perturbing the input image by turning superpixels on/off.
- Evaluating how these perturbations affect the prediction.
- Fitting a local, interpretable linear model to approximate the classifier's behavior near the instance.

The result is an explanation object that can produce:
- A list of feature contributions.
- A visualization (image with a mask overlay) highlighting which superpixels had the greatest influence on the prediction.

**Additional resources:**  
[Official LIME blog post](https://www.oreilly.com/content/introduction-to-local-interpretable-model-agnostic-explanations-lime/)  
[Tutorial notebooks on their official GitHub repository](https://github.com/marcotcr/lime/tree/master/doc/notebooks)  
[General article reading about explainable machine learning | Medium](https://medium.com/michelle-and-ryan-explain-ml/explainable-and-interpretable-machine-learning-7e7c28bba4f2)

## Imports & Setup

In [None]:
!pip install lime

In [None]:
from lime import lime_image
import torch
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
import cv2
import json

print("Torch version:", torch.__version__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

## Task 1: Get LIME Explanation <font color='green'>(3 points)</font>

Generate and return a LIME explanation object for the given image and model.

In [None]:
def get_lime_explanation(model: torch.nn.Module, image: np.ndarray, top_labels: int = 5, num_samples: int = 1000):
    """
    Generate and return a LIME explanation object for the given image and model.

    Parameters:
        model (torch.nn.Module): A pretrained image classification model.
        image (np.ndarray): Input image in numpy array format (H x W x C).
        top_labels (int): Number of top labels to consider.
        num_samples (int): Number of perturbed samples to generate.

    Returns:
        explanation: A LIME explanation object (as returned by lime_image.LimeImageExplainer.explain_instance).
    """
    # TODO: Create an instance of LimeImageExplainer.
    explainer = None  # Replace with your code
    
    # TODO: Define a prediction function that accepts a batch of images as numpy arrays and returns prediction probabilities.
    def predict(images):
        # Convert images to torch tensors, preprocess them, and obtain predictions.
        # TODO: Implement the necessary transformation and model inference.
        return None  # Replace with your code.
    
    # TODO: Use the explainer to generate an explanation for the image.
    explanation = None  # Replace with your code.
    
    return explanation

## Task 2: Display LIME Explanation <font color='green'>(1 point)</font>

Implement a function that extracts and returns the visualization of the LIME explanation. This function should use the explanation object's method (such as `get_image_and_mask`) to generate an image with an overlay mask that highlights the most important superpixels.

The output should be a tuple: (explanation_image, mask), which you can then display using matplotlib.

In [None]:
# Utility function to visualise the LIME explanation
def visualiseExplanation(explanation_image, mask):
    # show the image boundary
    img_boundry = mark_boundaries(explanation_image/255.0, mask)
    plt.axis('off')
    plt.imshow(img_boundry)

def display_lime_explanation(explanation, image: np.ndarray, positive_only: bool = True, num_features: int = 5):
    """
    From the LIME explanation object, generate and return the visualization image and mask overlay.

    Parameters:
        explanation: The LIME explanation object.
        image (np.ndarray): The original image (H x W x C).
        positive_only (bool): Whether to show only features that positively influence the prediction.
        num_features (int): Number of superpixels to display.

    Returns:
        tuple: (explanation_image, mask) as produced by explanation.get_image_and_mask.
    """
    # TODO: Call explanation.get_image_and_mask with appropriate parameters.
    explanation_image, mask = None, None  # Replace with your code.
    
    return explanation_image, mask


## Task 3: Extract Feature Importance <font color='green'>(1 point)</font>

Implement a function that extracts a sorted list of feature (superpixel) contributions from the LIME explanation object. The function should return a list of tuples (feature_index, importance) sorted in descending order by importance.

In [None]:
def extract_feature_importance(explanation) -> list:
    """
    Extract and return a sorted list of feature contributions from the LIME explanation object.

    Returns:
        List[tuple]: Each tuple contains (feature_index, importance) sorted by importance (descending).
    """
    # TODO: Process the output to return a sorted list of tuples.
    feature_importances = None  # Replace with your code.
    
    return feature_importances

## Task 4: Top 5 explanations from a pretrained ResNet model <font color='green'>(2 points)</font>

Use a pretrained ResNet model to predict the image classes and generate LIME explanations for the top 5 classes.

In [None]:
# Download ImageNet class label mappings
!wget https://github.com/marcotcr/lime/blob/master/doc/notebooks/data/imagenet_class_index.json?raw=true -O imagenet_class_index.json

In [None]:
from torchvision import models
import cv2

def get_resnet_lime_explanations(image: np.ndarray):
    """
    Use a pretrained ResNet model to predict the image classes and generate LIME explanations for the top 5 classes.
    
    This function should:
      1. Load a pretrained ResNet (e.g., ResNet18) from torchvision.
      2. Apply necessary preprocessing to the image.
      3. Obtain predictions and determine the top 5 classes.
      4. For each of the top 5 classes, generate a LIME explanation using get_lime_explanation().
      5. Return a dictionary mapping each top class label to its corresponding explanation visualization (image and mask).
    
    Boilerplate (e.g., label list) is provided below.
    
    Returns:
        dict: {class_label: (explanation_image, mask), ...} for top 5 classes.
    """
    # Predefined label list for demonstration (first 5 ImageNet classes).
    imagenet_labels = json.load(open("imagenet_class_index.json"))
    imagenet_labels = {int(key): value[1] for key, value in imagenet_labels.items()}
    
    # TODO: Load a pretrained ResNet18 model and set it to evaluation mode.
    model = None
    model.eval()
    
    # TODO: Generate LIME explanation for the image with top_labels=5.
    explanation = None  # Replace with your code.
    
    explanations = {}

    # TODO: Extract the top 5 labels from the explanation.
    top5_labels = None # Replace with your code.
    for label in top5_labels:
        class_name = imagenet_labels.get(label, str(label))
        # TODO: Extract visualization (image and mask).
        expl_img, mask = None, None  # Replace with your code.
        explanations[class_name] = (expl_img, mask)
    
    return explanations


## Reflection Questions (answer in brief)  <font color='green'>(1 point each)</font>

**Question 1:**  
What are the main advantages of using LIME for explaining image classification models?  
*Your Answer:*  

> *(Type your answer here)*

**Question 2:**  
How does LIME generate local explanations for a model's prediction, and why is this approach considered model-agnostic?  
*Your Answer:*  

> *(Type your answer here)*

**Question 3:**  
Discuss the trade-offs between model complexity and interpretability. How do these trade-offs impact both the performance of a model and its deployment in real-world, sensitive applications?
*Your Answer:*

> *(Type your answer here)*

---
### Autograder

Run this code cell at the end and do not change any code here.

In [None]:
!mkdir datasets
!wget -p https://raw.githubusercontent.com/sprince0031/CS4445-AI-Practice/refs/heads/main/datasets/dog.jpg -O datasets/dog.jpg

# ================================
# Pytest Tests for Challenge #7: LIME Explainability
# ================================

import pytest
import numpy as np
import torch.nn as nn

def run_tests_and_accumulate_score():
    total_code_points = 6  # Tasks 1-4: 2 + 1 + 1 + 2 = 6 points
    score = 0

    # Prepare dummy setup for testing:
    # Create a dummy image (224 x 224 x 3) as a numpy array.
    dummy_image = Image.open("datasets/dog.jpg")
    dummy_image = np.array(dummy_image)
    
    # Dummy model for LIME explanation tests (for Task 1, 2, 3)
    class DummyModel(nn.Module):
        def __init__(self):
            super(DummyModel, self).__init__()
            self.flatten = nn.Flatten()
            self.fc = nn.Linear(224*224*3, 2)  # Assume 2 classes for simplicity.
        def forward(self, x):
            x = self.flatten(x)
            return torch.softmax(self.fc(x), dim=1)
    
    dummy_model = DummyModel()
    
    # ------------------------------
    # Task 1: get_lime_explanation (2 points)
    # ------------------------------
    try:
        explanation = get_lime_explanation(dummy_model, dummy_image, top_labels=2, num_samples=100)
        assert explanation is not None, "get_lime_explanation() returned None."
        # Check for at least one expected method (as_list or get_image_and_mask).
        has_method = hasattr(explanation, "as_list") or hasattr(explanation, "get_image_and_mask")
        assert has_method, "Explanation object lacks expected methods."
        score += 2
        print("Task 1 (get_lime_explanation): Passed (2 points)")
    except AssertionError as e:
        print("Task 1 (get_lime_explanation): Failed -", e)
    
    # ------------------------------
    # Task 2: display_lime_explanation (1 point)
    # ------------------------------
    try:
        expl_img, mask = display_lime_explanation(explanation, dummy_image, positive_only=True, num_features=5)
        assert expl_img is not None, "display_lime_explanation() returned None for explanation image."
        assert mask is not None, "display_lime_explanation() returned None for mask."
        assert isinstance(expl_img, np.ndarray), "Explanation image should be a numpy array."
        assert isinstance(mask, np.ndarray), "Mask should be a numpy array."
        score += 1
        print("Task 2 (display_lime_explanation): Passed (1 point)")
    except AssertionError as e:
        print("Task 2 (display_lime_explanation): Failed -", e)
    
    # ------------------------------
    # Task 3: extract_feature_importance (1 point)
    # ------------------------------
    try:
        feature_list = extract_feature_importance(explanation)
        assert isinstance(feature_list, list), "extract_feature_importance() should return a list."
        if feature_list:
            first_item = feature_list[0]
            assert isinstance(first_item, tuple) and len(first_item) == 2, "Each item should be a tuple (feature, importance)."
        score += 1
        print("Task 3 (extract_feature_importance): Passed (1 point)")
    except AssertionError as e:
        print("Task 3 (extract_feature_importance): Failed -", e)
    
    # ------------------------------
    # Task 4: Pretrained ResNet LIME Explanations (2 points)
    # ------------------------------
    try:
        explanations = get_resnet_lime_explanations(dummy_image)
        assert isinstance(explanations, dict), "get_resnet_lime_explanations() should return a dictionary."
        # Expect 5 keys corresponding to top 5 classes.
        assert len(explanations) == 5, "Expected explanations for top 5 classes."
        print('top 5 label predictions:')
        for i, (label, (expl_img, mask)) in enumerate(explanations.items()):
            assert isinstance(expl_img, np.ndarray), "Explanation image should be a numpy array."
            assert isinstance(mask, np.ndarray), "Mask should be a numpy array."
            print(f'{i+1}. {label}')
            # plt.subplot(1, 5, i+1)
        expl_img, mask = next(iter(explanations.values()))
        visualiseExplanation(expl_img, mask)
        score += 2
        print("Task 4 (ResNet LIME Explanations): Passed (2 points)")
    except AssertionError as e:
        print("Task 4 (ResNet LIME Explanations): Failed -", e)
    
    print(f"Total Code Score: {score} / {total_code_points}")
    
    # Reflection questions are graded manually.
    print("Reflection Questions: 3 points (graded manually)")
    
# Run the autograder.
run_tests_and_accumulate_score()