# Module 12: Model Interpretation and Visualization

**Difficulty**: ⭐⭐⭐ (Advanced)

**Estimated Time**: 60-75 minutes

**Prerequisites**: 
- [Module 05: Feed-Forward Neural Networks with Keras](05_feedforward_neural_networks_keras.ipynb)
- [Module 10: Transfer Learning Concepts](10_transfer_learning_concepts.ipynb)

## Learning Objectives

By the end of this notebook, you will be able to:
1. Understand what different layers in a neural network learn
2. Visualize layer activations to see how networks process inputs
3. Generate feature visualizations to understand learned representations
4. Create saliency maps to identify important input regions
5. Interpret model predictions and assess confidence
6. Apply basic interpretability techniques for explainable AI

## 1. Why Model Interpretation Matters

Neural networks are often called "black boxes" - they make predictions, but we don't know why.

**Why we need interpretability**:
- **Trust**: Understand if model reasons are valid
- **Debugging**: Find what model learned wrong
- **Compliance**: Legal requirements in healthcare, finance
- **Improvement**: Identify areas for model enhancement
- **Fairness**: Detect biases and discriminatory patterns

**Levels of Interpretation**:
1. **Global**: What has the model learned overall?
2. **Local**: Why did the model make this specific prediction?
3. **Layer-wise**: What does each layer represent?

## 2. Setup and Imports

In [None]:
# Core libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

# Deep learning libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import VGG16
from tensorflow.keras.datasets import mnist, cifar10

# Image processing
from scipy.ndimage import zoom

# For reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Plotting configuration
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

## 3. Prepare Data and Train a Model

In [None]:
# Load MNIST for clear visualization
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Normalize and reshape
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")

In [None]:
# Create a CNN for MNIST
def create_mnist_cnn():
    """Create CNN with multiple conv layers for visualization."""
    model = keras.Sequential([
        # First conv block
        layers.InputLayer(input_shape=(28, 28, 1)),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1'),
        layers.MaxPooling2D((2, 2), name='pool1'),
        
        # Second conv block
        layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2'),
        layers.MaxPooling2D((2, 2), name='pool2'),
        
        # Third conv block
        layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv3'),
        
        # Classification head
        layers.Flatten(name='flatten'),
        layers.Dense(128, activation='relu', name='dense1'),
        layers.Dropout(0.5, name='dropout'),
        layers.Dense(10, activation='softmax', name='output')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Create and train model
model = create_mnist_cnn()
print("Model Architecture:")
model.summary()

In [None]:
# Train the model
print("\nTraining model...")
history = model.fit(
    X_train, y_train,
    validation_split=0.1,
    epochs=5,
    batch_size=128,
    verbose=1
)

# Evaluate
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest accuracy: {test_acc:.4f}")

## 4. Visualizing Layer Activations

**Layer activations** show how the network transforms input through successive layers.

**What we'll see**:
- **Early layers**: Detect simple features (edges, corners)
- **Middle layers**: Combine features into patterns
- **Late layers**: Abstract, high-level representations

In [None]:
def visualize_activations(model, image, layer_names=None):
    """
    Visualize activations from specified layers.
    
    Args:
        model: Trained Keras model
        image: Input image (single sample)
        layer_names: List of layer names to visualize (default: all conv layers)
    """
    # Get all convolutional layers if not specified
    if layer_names is None:
        layer_names = [layer.name for layer in model.layers 
                      if isinstance(layer, layers.Conv2D)]
    
    # Create models to extract layer outputs
    layer_outputs = [model.get_layer(name).output for name in layer_names]
    activation_model = keras.Model(inputs=model.input, outputs=layer_outputs)
    
    # Get activations
    activations = activation_model.predict(image[np.newaxis, ...], verbose=0)
    
    # Visualize
    for layer_name, layer_activation in zip(layer_names, activations):
        n_features = layer_activation.shape[-1]  # Number of filters
        size = layer_activation.shape[1]  # Feature map size
        
        # Display up to 16 filters per layer
        n_cols = 8
        n_rows = min(2, (n_features + n_cols - 1) // n_cols)
        n_display = min(16, n_features)
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 1.5, n_rows * 1.5))
        fig.suptitle(f'Layer: {layer_name} ({n_features} filters, {size}x{size})', 
                     fontsize=12, fontweight='bold')
        
        axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes
        
        for i in range(n_cols * n_rows):
            if i < n_display:
                # Display feature map
                axes[i].imshow(layer_activation[0, :, :, i], cmap='viridis')
                axes[i].set_title(f'Filter {i}', fontsize=9)
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()

# Select a test image
test_image = X_test[0]
test_label = y_test[0]

# Display the input image
plt.figure(figsize=(4, 4))
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Input Image (Label: {test_label})', fontsize=12, fontweight='bold')
plt.axis('off')
plt.show()

# Visualize activations
print("Visualizing layer activations...\n")
visualize_activations(model, test_image)

## 5. Filter Visualization

**Filter Visualization** shows what patterns each filter is designed to detect.

We'll visualize the actual learned weights of convolutional filters.

In [None]:
def visualize_conv_filters(model, layer_name, max_filters=16):
    """
    Visualize convolutional filter weights.
    
    Args:
        model: Trained Keras model
        layer_name: Name of convolutional layer
        max_filters: Maximum number of filters to display
    """
    # Get layer weights
    layer = model.get_layer(layer_name)
    filters = layer.get_weights()[0]  # Shape: (height, width, input_channels, output_channels)
    
    n_filters = min(max_filters, filters.shape[-1])
    n_cols = 8
    n_rows = (n_filters + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 1.2, n_rows * 1.2))
    fig.suptitle(f'Filters from {layer_name}', fontsize=12, fontweight='bold')
    
    axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes
    
    for i in range(n_cols * n_rows):
        if i < n_filters:
            # Get filter (average across input channels if multiple)
            filt = filters[:, :, :, i]
            if filt.shape[-1] > 1:
                filt = filt.mean(axis=-1)
            else:
                filt = filt.squeeze()
            
            # Normalize for visualization
            filt = (filt - filt.min()) / (filt.max() - filt.min() + 1e-8)
            
            axes[i].imshow(filt, cmap='gray')
            axes[i].set_title(f'{i}', fontsize=8)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize filters from first convolutional layer
print("First layer filters (detecting basic patterns like edges):")
visualize_conv_filters(model, 'conv1', max_filters=16)

## 6. Saliency Maps (Gradient-based Visualization)

**Saliency Maps** highlight which input pixels are most important for a prediction.

**Method**: Compute gradient of output with respect to input:
$$\text{Saliency}(x) = \left|\frac{\partial y_c}{\partial x}\right|$$

where $y_c$ is the score for class $c$.

**Interpretation**: Pixels with large gradients have strong influence on the prediction.

In [None]:
def generate_saliency_map(model, image, class_idx):
    """
    Generate saliency map showing important regions for prediction.
    
    Args:
        model: Trained Keras model
        image: Input image
        class_idx: Target class index
    
    Returns:
        Saliency map (same size as input)
    """
    # Convert to tensor
    image_tensor = tf.Variable(image[np.newaxis, ...])
    
    # Compute gradients
    with tf.GradientTape() as tape:
        predictions = model(image_tensor)
        target_class = predictions[:, class_idx]
    
    # Get gradient with respect to input
    gradients = tape.gradient(target_class, image_tensor)
    
    # Take absolute value and reduce to 2D (if multiple channels)
    saliency = tf.abs(gradients).numpy().squeeze()
    if len(saliency.shape) == 3:
        saliency = saliency.max(axis=-1)
    
    return saliency

def visualize_saliency(model, image, true_label):
    """
    Visualize saliency map for an image.
    """
    # Get model prediction
    prediction = model.predict(image[np.newaxis, ...], verbose=0)
    predicted_class = np.argmax(prediction)
    confidence = prediction[0][predicted_class]
    
    # Generate saliency map
    saliency = generate_saliency_map(model, image, predicted_class)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    # Original image
    axes[0].imshow(image.squeeze(), cmap='gray')
    axes[0].set_title(f'Input Image\nTrue: {true_label}, Pred: {predicted_class}', 
                     fontweight='bold')
    axes[0].axis('off')
    
    # Saliency map
    axes[1].imshow(saliency, cmap='hot')
    axes[1].set_title('Saliency Map\n(Brighter = More Important)', fontweight='bold')
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(image.squeeze(), cmap='gray', alpha=0.6)
    axes[2].imshow(saliency, cmap='hot', alpha=0.4)
    axes[2].set_title(f'Overlay\nConfidence: {confidence:.2%}', fontweight='bold')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Generate saliency maps for several test images
print("Saliency Maps - What does the model focus on?\n")
for i in range(3):
    visualize_saliency(model, X_test[i], y_test[i])

## 7. Class Activation Maps (CAM)

**Class Activation Maps** show which regions of the image are most relevant for a specific class.

**How it works**:
1. Take the last convolutional layer's output
2. Weight each feature map by its importance to the predicted class
3. Sum weighted feature maps to create heatmap

In [None]:
def generate_class_activation_map(model, image, last_conv_layer_name, pred_index=None):
    """
    Generate class activation map (CAM) for an image.
    
    Args:
        model: Trained model
        image: Input image
        last_conv_layer_name: Name of last convolutional layer
        pred_index: Class index (if None, use predicted class)
    
    Returns:
        CAM heatmap
    """
    # Create model that outputs both predictions and last conv layer
    last_conv_layer = model.get_layer(last_conv_layer_name)
    
    grad_model = keras.Model(
        inputs=model.input,
        outputs=[last_conv_layer.output, model.output]
    )
    
    # Get gradients and predictions
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(image[np.newaxis, ...])
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]
    
    # Gradient of class score with respect to feature maps
    grads = tape.gradient(class_channel, conv_outputs)
    
    # Global average pooling of gradients
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    # Weight feature maps by gradients
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    
    # Normalize heatmap
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    
    return heatmap.numpy()

def visualize_cam(model, image, true_label, last_conv_layer_name='conv3'):
    """
    Visualize Class Activation Map.
    """
    # Get prediction
    prediction = model.predict(image[np.newaxis, ...], verbose=0)
    predicted_class = np.argmax(prediction)
    confidence = prediction[0][predicted_class]
    
    # Generate CAM
    heatmap = generate_class_activation_map(model, image, last_conv_layer_name)
    
    # Resize heatmap to match input image size
    img_size = image.shape[0]
    heatmap_resized = zoom(heatmap, img_size / heatmap.shape[0])
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    # Original
    axes[0].imshow(image.squeeze(), cmap='gray')
    axes[0].set_title(f'Input Image\nTrue: {true_label}, Pred: {predicted_class}',
                     fontweight='bold')
    axes[0].axis('off')
    
    # CAM heatmap
    axes[1].imshow(heatmap_resized, cmap='jet')
    axes[1].set_title('Class Activation Map\n(Red = Most Important)', fontweight='bold')
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(image.squeeze(), cmap='gray', alpha=0.6)
    axes[2].imshow(heatmap_resized, cmap='jet', alpha=0.4)
    axes[2].set_title(f'Overlay\nConfidence: {confidence:.2%}', fontweight='bold')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize CAMs
print("Class Activation Maps - Where does the model look?\n")
for i in range(3):
    visualize_cam(model, X_test[i], y_test[i])

## 8. Prediction Confidence and Uncertainty

Understanding model confidence is crucial for:
- Knowing when to trust predictions
- Identifying ambiguous cases
- Detecting out-of-distribution samples

In [None]:
def analyze_prediction_confidence(model, X_samples, y_true, n_samples=10):
    """
    Analyze and visualize prediction confidence.
    """
    # Get predictions
    predictions = model.predict(X_samples[:n_samples], verbose=0)
    
    # Plot each prediction
    fig, axes = plt.subplots(n_samples, 2, figsize=(10, n_samples * 1.5))
    
    for i in range(n_samples):
        # Display image
        axes[i, 0].imshow(X_samples[i].squeeze(), cmap='gray')
        axes[i, 0].set_title(f'True Label: {y_true[i]}', fontsize=10)
        axes[i, 0].axis('off')
        
        # Display confidence distribution
        pred_probs = predictions[i]
        predicted_class = np.argmax(pred_probs)
        confidence = pred_probs[predicted_class]
        
        colors = ['green' if j == predicted_class else 'lightblue' for j in range(10)]
        axes[i, 1].bar(range(10), pred_probs, color=colors, alpha=0.7)
        axes[i, 1].set_ylim([0, 1])
        axes[i, 1].set_xlabel('Class', fontsize=9)
        axes[i, 1].set_ylabel('Probability', fontsize=9)
        
        # Title shows prediction and confidence
        correct = predicted_class == y_true[i]
        status = "✓" if correct else "✗"
        axes[i, 1].set_title(f'{status} Pred: {predicted_class} (Conf: {confidence:.2%})',
                            fontsize=10,
                            color='green' if correct else 'red')
        axes[i, 1].grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Prediction Confidence Analysis', fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.show()

# Analyze predictions
analyze_prediction_confidence(model, X_test, y_test, n_samples=5)

In [None]:
# Identify high and low confidence predictions
def find_confidence_extremes(model, X_test, y_test, n_show=3):
    """
    Find and display highest and lowest confidence predictions.
    """
    # Get all predictions
    predictions = model.predict(X_test[:1000], verbose=0)
    
    # Calculate confidence (max probability)
    confidences = np.max(predictions, axis=1)
    predicted_classes = np.argmax(predictions, axis=1)
    
    # Check correctness
    correct = predicted_classes == y_test[:1000]
    
    # Find high confidence correct and incorrect
    high_conf_correct = np.where(correct & (confidences > 0.99))[0]
    high_conf_incorrect = np.where(~correct & (confidences > 0.90))[0]
    low_conf_correct = np.where(correct & (confidences < 0.80))[0]
    
    print(f"High confidence correct: {len(high_conf_correct)} samples")
    print(f"High confidence incorrect: {len(high_conf_incorrect)} samples (model is confidently wrong!)")
    print(f"Low confidence correct: {len(low_conf_correct)} samples (model unsure but right)")
    
    # Visualize examples
    fig, axes = plt.subplots(3, n_show, figsize=(n_show * 3, 9))
    
    categories = [
        ('High Confidence Correct', high_conf_correct),
        ('High Confidence WRONG', high_conf_incorrect),
        ('Low Confidence Correct', low_conf_correct)
    ]
    
    for row, (title, indices) in enumerate(categories):
        for col in range(n_show):
            if col < len(indices):
                idx = indices[col]
                axes[row, col].imshow(X_test[idx].squeeze(), cmap='gray')
                axes[row, col].set_title(
                    f'True: {y_test[idx]}\n'
                    f'Pred: {predicted_classes[idx]}\n'
                    f'Conf: {confidences[idx]:.2%}',
                    fontsize=9
                )
            axes[row, col].axis('off')
        
        axes[row, 0].set_ylabel(title, fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    plt.show()

find_confidence_extremes(model, X_test, y_test)

## 9. Feature Space Visualization (t-SNE)

**t-SNE** (t-Distributed Stochastic Neighbor Embedding) visualizes high-dimensional learned features in 2D.

This shows how well the model separates different classes in its internal representation.

In [None]:
from sklearn.manifold import TSNE

def visualize_feature_space(model, X_samples, y_samples, layer_name='dense1', n_samples=1000):
    """
    Visualize learned feature representations using t-SNE.
    
    Args:
        model: Trained model
        X_samples: Input samples
        y_samples: Labels
        layer_name: Layer to extract features from
        n_samples: Number of samples to visualize
    """
    # Extract features from specified layer
    feature_model = keras.Model(
        inputs=model.input,
        outputs=model.get_layer(layer_name).output
    )
    
    print(f"Extracting features from layer '{layer_name}'...")
    features = feature_model.predict(X_samples[:n_samples], verbose=0)
    
    # Apply t-SNE
    print("Applying t-SNE (this may take a minute)...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    features_2d = tsne.fit_transform(features)
    
    # Visualize
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(
        features_2d[:, 0], 
        features_2d[:, 1],
        c=y_samples[:n_samples],
        cmap='tab10',
        alpha=0.6,
        s=20
    )
    plt.colorbar(scatter, label='Class')
    plt.title(f'Feature Space Visualization (t-SNE)\nLayer: {layer_name}',
              fontsize=14, fontweight='bold')
    plt.xlabel('t-SNE Component 1', fontsize=12)
    plt.ylabel('t-SNE Component 2', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("\nInterpretation:")
    print("- Well-separated clusters = Model learned distinct class representations")
    print("- Overlapping clusters = Model struggles to distinguish these classes")
    print("- Outliers = Potentially mislabeled or difficult samples")

# Visualize feature space
visualize_feature_space(model, X_test, y_test, layer_name='dense1', n_samples=1000)

## 10. Exercise 1: Activation Maximization

**Task**: Generate images that maximally activate specific neurons or classes.

**Concept**: Start with random noise and use gradient ascent to modify the image to maximize a neuron's activation.

**Requirements**:
1. Implement gradient ascent to maximize class scores
2. Generate images for at least 3 different classes
3. Apply regularization (total variation, L2 norm) for better visualizations
4. Display the generated "ideal" images for each class

In [None]:
# YOUR CODE HERE
# Hint: Use tf.GradientTape to compute gradients w.r.t. input
# Update input image to maximize class score
# Apply regularization to get interpretable images

pass  # Replace with your implementation

## 11. Exercise 2: Adversarial Examples

**Task**: Create adversarial examples - images that look normal to humans but fool the network.

**Concept**: Add small perturbations to images to change predictions.

**Requirements**:
1. Implement Fast Gradient Sign Method (FGSM)
2. Generate adversarial examples with different perturbation strengths (epsilon)
3. Visualize original vs adversarial images and their predictions
4. Analyze how perturbation strength affects attack success
5. Discuss implications for model robustness

In [None]:
# YOUR CODE HERE
# Hint: FGSM formula: x_adv = x + epsilon * sign(gradient)
# Test with epsilon values: 0.01, 0.05, 0.1, 0.2

pass  # Replace with your implementation

## 12. Exercise 3: Layer-wise Relevance Propagation (LRP)

**Task**: Implement a simplified version of Layer-wise Relevance Propagation.

**Concept**: Backpropagate relevance scores from output to input to identify which features contributed to the prediction.

**Requirements**:
1. Implement relevance propagation for at least the final dense layers
2. Visualize relevance scores as heatmaps
3. Compare LRP with saliency maps
4. Analyze which method provides more interpretable results
5. Test on both correct and incorrect predictions

In [None]:
# YOUR CODE HERE
# Hint: LRP propagates relevance backward through layers
# R_i = sum_j (w_ij * a_i / sum_k w_kj * a_k) * R_j

pass  # Replace with your implementation

## 13. Summary

### Key Concepts Covered:

1. **Layer Activations**
   - Visualize how networks transform inputs through layers
   - Early layers detect simple features, late layers detect complex patterns
   - Helps understand what representations the model learns

2. **Filter Visualization**
   - Shows what patterns convolutional filters detect
   - First layer filters often resemble edge/blob detectors
   - Deeper layers learn more abstract patterns

3. **Saliency Maps**
   - Gradient-based method to find important input regions
   - Formula: $|\partial y_c / \partial x|$
   - Quick and simple but can be noisy

4. **Class Activation Maps (CAM)**
   - Show which regions activated for specific classes
   - More localized than saliency maps
   - Useful for understanding spatial attention

5. **Prediction Confidence**
   - Softmax outputs indicate model uncertainty
   - High confidence doesn't always mean correct!
   - Important for knowing when to trust predictions

6. **Feature Space Visualization**
   - t-SNE shows how model organizes classes internally
   - Well-separated clusters indicate good learning
   - Reveals which classes model confuses

### Interpretation Techniques Summary:

| Technique | Purpose | Pros | Cons |
|-----------|---------|------|------|
| **Activations** | See layer outputs | Direct view of processing | Hard to interpret deep layers |
| **Saliency Maps** | Important input pixels | Fast, simple | Noisy, not always interpretable |
| **CAM** | Class-specific attention | Localized, intuitive | Requires specific architecture |
| **t-SNE** | Class separation | Global view | Slow, parameters matter |
| **LIME/SHAP** | Local explanations | Model-agnostic | Computationally expensive |

### Best Practices:

- Use multiple interpretation methods (they complement each other)
- Always validate interpretations with domain knowledge
- Be skeptical of overconfident predictions
- Test model on edge cases and adversarial examples
- Document model limitations and failure modes

### Ethical Considerations:

- **Fairness**: Check if model focuses on biased features
- **Transparency**: Explain model decisions to stakeholders
- **Safety**: Understand failure modes before deployment
- **Privacy**: Ensure visualizations don't reveal sensitive data

### What's Next?

- [Module 13: PyTorch Introduction](13_pytorch_introduction.ipynb)
- Advanced interpretability: LIME, SHAP, Integrated Gradients
- Attention mechanisms and Transformers

### Additional Resources:

1. "Visualizing and Understanding Convolutional Networks" (Zeiler & Fergus, 2014)
2. "Deep Inside Convolutional Networks" (Simonyan et al., 2013)
3. LIME paper: "Why Should I Trust You?" (Ribeiro et al., 2016)
4. Distill.pub: https://distill.pub/ (excellent interactive visualizations)
5. TensorFlow Lucid: https://github.com/tensorflow/lucid