# Explaining WavLM Speech Models with GradientSHAP

This notebook demonstrates how to implement and use GradientSHAP for explaining predictions from WavLM speech models using audio data from librosa. We'll walk through the entire process step-by-step:

1. Setting up the environment
2. Understanding how GradientSHAP works
3. Loading and preparing WavLM models
4. Implementing a dataset class for audio files
5. Building the GradientSHAP explainer
6. Generating and visualizing attributions

By the end of this tutorial, you'll be able to explain which parts of an audio input contribute most to a model's prediction for tasks like speech emotion recognition.

## 1. Setup Environment and Import Libraries

Let's start by installing the necessary libraries for our implementation. We'll need:

- PyTorch: For deep learning and gradient computation
- Transformers: To access the WavLM model and feature extractors
- Librosa: For audio processing and manipulation
- Matplotlib/Seaborn: For visualization
- NumPy: For numerical operations
- tqdm: For progress bars

In [None]:
# Install required packages
!pip install torch torchaudio transformers librosa matplotlib seaborn numpy tqdm

# Import libraries
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import librosa.display
from typing import Dict, List, Tuple, Union, Optional, Any, Callable
from tqdm import tqdm
import random
from transformers import AutoModel, AutoFeatureExtractor
from torch.utils.data import Dataset, DataLoader

# Set up seaborn style for better visualizations
sns.set_style("whitegrid")

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

## 2. Understanding GradientSHAP

GradientSHAP is an efficient approximation method that combines ideas from Integrated Gradients and SHAP (SHapley Additive exPlanations). It helps us understand which parts of an input most influence a model's predictions.

### How GradientSHAP Works:

1. **Reference Distribution**: Create a reference distribution by generating samples that interpolate between random noise and the actual input.

2. **Gradient Computation**: Calculate gradients of the model output with respect to the inputs for each sample.

3. **Weighted Integration**: Combine these gradients using weights according to SHAP principles to get attribution scores.

For speech models like WavLM, GradientSHAP can tell us:
- Which frames (time segments) of audio most influence predictions
- Which neurons in the model are most responsive to important audio features

### Key Advantages:

- More computationally efficient than kernel SHAP methods
- Provides local explanations for specific predictions
- Works well with deep neural networks that have gradient information
- Can handle high-dimensional inputs like speech spectrograms

Let's see how to implement this for WavLM speech models.

## 3. Load WavLM Model and Feature Extractor

Now let's load a pre-trained WavLM model and its feature extractor from HuggingFace. WavLM (Wave-Language Model) is a powerful speech model that can be fine-tuned for various downstream tasks like speech recognition, speaker verification, or emotion detection.

For our example, we'll use WavLM-Base-Plus, which has 12 transformer layers and is a good balance between performance and resource usage. We'll also create a simple classifier head that can be used for emotion recognition (with 8 emotion classes).

In [None]:
# Load the WavLM model and feature extractor
model_name = 'microsoft/wavlm-base-plus'
print(f"Loading {model_name}...")

# Load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Load model
wavlm_model = AutoModel.from_pretrained(model_name)

# Move model to device (GPU if available)
wavlm_model.to(device)

# Put model in evaluation mode
wavlm_model.eval()

print(f"Model loaded successfully!")
print(f"Model architecture: {type(wavlm_model).__name__}")
print(f"Hidden size: {wavlm_model.config.hidden_size}")
print(f"Number of layers: {len(wavlm_model.encoder.layers)}")

# Create a simple classifier head for emotion recognition (8 classes)
# In a real scenario, you would train this classifier on labeled data
hidden_size = wavlm_model.config.hidden_size
num_classes = 8  # 8 emotion classes (e.g., neutral, happy, sad, angry, etc.)

classifier = torch.nn.Sequential(
    torch.nn.Linear(hidden_size, 256),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(256, num_classes)
).to(device)

print(f"Classifier architecture:")
print(classifier)

## 4. Create LibrosaAudioDataset Class

Next, we'll implement a PyTorch dataset class for loading audio files using librosa. This class will:

1. Load audio files from a directory
2. Resample them to the required sample rate for WavLM (16kHz)
3. Process them with the WavLM feature extractor
4. Return tensors ready for model input

This dataset class will help us process multiple audio files efficiently, whether for training a classifier or for batch explanation.

In [None]:
class LibrosaAudioDataset(Dataset):
    """
    Dataset class for loading audio files using librosa.
    
    This dataset can load audio files from a directory and prepare them for
    use with WavLM and similar speech models.
    """
    
    def __init__(self, data_path, feature_extractor, labels=None, sample_rate=16000, max_samples=None, transform=None):
        """
        Initialize the dataset.
        
        Args:
            data_path (str): Path to audio files directory
            feature_extractor: WavLM feature extractor for audio preprocessing
            labels (dict, optional): Dictionary mapping filenames to labels
            sample_rate (int): Target sample rate for audio
            max_samples (int, optional): Maximum number of samples to use
            transform (callable, optional): Optional transform to apply to audio
        """
        self.data_path = data_path
        self.feature_extractor = feature_extractor
        self.sample_rate = sample_rate
        self.transform = transform
        self.labels = labels or {}
        
        # List all audio files
        self.audio_files = []
        
        # Walk through the directory
        for root, _, files in os.walk(data_path):
            for file in files:
                if file.endswith(('.wav', '.mp3', '.flac')):
                    self.audio_files.append(os.path.join(root, file))
        
        # Limit dataset size if specified
        if max_samples and max_samples < len(self.audio_files):
            self.audio_files = random.sample(self.audio_files, max_samples)
            
        print(f"Loaded {len(self.audio_files)} audio files from {data_path}")
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        """
        Get a sample from the dataset.
        
        Args:
            idx (int): Index
            
        Returns:
            dict: A dictionary with input_values, label (if available), and filename
        """
        audio_path = self.audio_files[idx]
        filename = os.path.basename(audio_path)
        
        # Load and resample audio
        waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
        
        # Apply transform if specified
        if self.transform:
            waveform = self.transform(waveform)
        
        # Convert to float32 tensor
        waveform = torch.tensor(waveform, dtype=torch.float32)
        
        # Process with feature extractor
        inputs = self.feature_extractor(waveform, sampling_rate=self.sample_rate, return_tensors="pt")
        
        # Get label if available
        label = self.labels.get(filename, -1)
        
        return {
            "input_values": inputs.input_values.squeeze(0),
            "label": label,
            "filename": filename
        }

# Test the dataset class with a small example
# Uncomment this if you have audio files available
"""
# Create a small test dataset (replace with your audio directory)
test_audio_path = "./sample_audio"
if os.path.exists(test_audio_path):
    # Create dataset
    test_dataset = LibrosaAudioDataset(
        data_path=test_audio_path,
        feature_extractor=feature_extractor,
        max_samples=2
    )
    
    # Get a sample
    if len(test_dataset) > 0:
        sample = test_dataset[0]
        print(f"Sample filename: {sample['filename']}")
        print(f"Input shape: {sample['input_values'].shape}")
else:
    print("No test audio directory available. Skipping dataset test.")
"""

## 5. Implement GradientSHAP Explainer

Now we'll implement the core of our explainer: the GradientSHAP class. This class will:

1. Register hooks to capture activations from different layers
2. Generate reference samples by interpolating between random noise and the input
3. Compute gradients for each sample
4. Calculate SHAP values and aggregate them
5. Provide visualization utilities

This is a comprehensive implementation that supports both frame-level and neuron-level explanations.

In [None]:
class GradientSHAP:
    """
    GradientSHAP implementation for WavLM speech models.
    
    This class implements GradientSHAP for speech models by:
    1. Creating reference samples by mixing the target input with random noise
    2. Computing gradients for each reference sample
    3. Weighting the gradients according to SHAP principles
    """
    
    def __init__(
        self,
        model,
        feature_extractor,
        device=None,
        num_samples=50,
        feature_layer=None,
        classifier=None,
        batch_size=8
    ):
        """
        Initialize the GradientSHAP explainer.
        
        Args:
            model: Pre-loaded WavLM model
            feature_extractor: WavLM feature extractor
            device: Device to use ('cuda' or 'cpu')
            num_samples: Number of reference samples for SHAP
            feature_layer: Layer to extract features from (None for last layer)
            classifier: Classification head
            batch_size: Batch size for processing reference samples
        """
        # Set device
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"GradientSHAP using device: {self.device}")
        
        # Store model and feature extractor
        self.model = model
        self.feature_extractor = feature_extractor
        
        # Ensure model is in evaluation mode
        self.model.eval()
        
        # Parameters
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.feature_layer = feature_layer
        
        # Set up classifier if not provided
        if classifier is None:
            hidden_size = self.model.config.hidden_size
            self.classifier = torch.nn.Linear(hidden_size, 8)  # Default: 8 emotions
        else:
            self.classifier = classifier
            
        self.classifier.to(self.device)
        
        # For storing activations
        self.activation_store = {}
        self.hooks = []
        
        # Register hooks for specified feature layer
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward hooks to capture activations from the model."""
        try:
            def get_activation(name):
                def hook(module, input, output):
                    # Store activations - supports both tuple and tensor outputs
                    if isinstance(output, tuple):
                        self.activation_store[name] = output[0].detach()
                    else:
                        self.activation_store[name] = output.detach()
                return hook
            
            # Register hooks for specific or all encoder layers
            if self.feature_layer is not None:
                # Register hook for just the specified layer
                if 0 <= self.feature_layer < len(self.model.encoder.layers):
                    layer = self.model.encoder.layers[self.feature_layer]
                    h = layer.register_forward_hook(get_activation(self.feature_layer))
                    self.hooks.append(h)
                    print(f"Registered hook for layer {self.feature_layer}")
            else:
                # Register hooks for all encoder layers
                for i, layer in enumerate(self.model.encoder.layers):
                    h = layer.register_forward_hook(get_activation(i))
                    self.hooks.append(h)
                print(f"Registered hooks for all {len(self.hooks)} encoder layers")
                
        except Exception as e:
            print(f"Error registering hooks: {e}")
            raise
    
    def remove_hooks(self):
        """Remove all registered hooks"""
        for h in self.hooks:
            h.remove()
        self.hooks = []
        print("Removed all hooks")
        
    def __del__(self):
        """Clean up by removing hooks when object is deleted"""
        if hasattr(self, 'hooks') and self.hooks:
            self.remove_hooks()
    
    def generate_reference_samples(self, input_values, num_samples=None):
        """
        Generate reference samples by interpolating between input and random noise.
        
        Args:
            input_values: Input audio values [T]
            num_samples: Number of samples to generate
            
        Returns:
            Reference samples [N, T]
            Interpolation coefficients [N, 1]
        """
        if num_samples is None:
            num_samples = self.num_samples
            
        # Create random reference noise with same shape as input
        # Use a normal distribution with same mean and std as input
        input_mean = input_values.mean()
        input_std = input_values.std()
        
        # Create reference distribution (random noise)
        reference = torch.normal(
            mean=input_mean,
            std=input_std,
            size=(num_samples, input_values.shape[0])
        ).to(self.device)
        
        # Create alphas for interpolation (0 = reference, 1 = input)
        alphas = torch.linspace(0, 1, num_samples).view(-1, 1).to(self.device)
        
        # Generate samples by interpolating between reference and input
        samples = reference * (1 - alphas) + input_values * alphas
        
        return samples, alphas
    
    def _compute_sample_gradients(self, samples, target_class=None):
        """
        Compute gradients for sample inputs with respect to the target class.
        
        Args:
            samples: Batch of input samples [B, T]
            target_class: Target class for explanation
            
        Returns:
            Gradients for each sample
        """
        batch_size = samples.shape[0]
        gradients = []
        
        for i in range(batch_size):
            sample = samples[i:i+1]  # Keep batch dimension [1, T]
            sample.requires_grad_(True)
            
            # Process with feature extractor
            inputs = self.feature_extractor(sample[0], sampling_rate=16000, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Forward pass
            self.model.zero_grad()
            if hasattr(self.classifier, 'zero_grad'):
                self.classifier.zero_grad()
            
            outputs = self.model(**inputs)
            
            # Get predictions from classifier
            if hasattr(outputs, 'last_hidden_state'):
                pooled = outputs.last_hidden_state.mean(dim=1)
                logits = self.classifier(pooled)
                
                # Get target class if not specified
                if target_class is None:
                    target_class = logits.argmax(dim=1).item()
                
                # Get target class logit
                target_logit = logits[0, target_class]
                
                # Backward pass to get gradients
                target_logit.backward()
                
                # Get gradients with respect to input
                grad = sample.grad.detach()
                gradients.append(grad)
            
            # Clean up
            if sample.grad is not None:
                sample.grad.zero_()
        
        # Combine gradients
        return torch.cat(gradients, dim=0)
    
    def _compute_shap_values(self, gradients, alphas, input_values, mode='frame', aggregate_frames='mean'):
        """
        Compute final SHAP values from gradients.
        
        Args:
            gradients: Gradients from all samples [N, T]
            alphas: Interpolation coefficients [N, 1]
            input_values: Original input values [T]
            mode: 'frame' or 'neuron' level explanations
            aggregate_frames: Method to aggregate frame attributions
            
        Returns:
            SHAP values for the input
        """
        # Apply trapezoidal rule for integration
        shap_values = gradients * input_values
        
        # For neuron-level explanations, we need to access the activation store
        if mode == 'neuron' and self.activation_store:
            # Get activation from the latest forward pass
            # We need to do one more forward pass to capture the activations
            with torch.no_grad():
                inputs = self.feature_extractor(input_values, sampling_rate=16000, return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                _ = self.model(**inputs)
            
            # Extract neuron-level explanations
            neuron_attributions = {}
            
            for layer_idx, activations in self.activation_store.items():
                # If we are only interested in specific layer
                if self.feature_layer is not None and layer_idx != self.feature_layer:
                    continue
                    
                # Get shape information
                batch_size, seq_len, hidden_dim = activations.shape
                
                # For each neuron, compute importance
                layer_attributions = torch.zeros((seq_len, hidden_dim)).to(self.device)
                
                # We use the integrated gradients to estimate neuron importance
                # This is a simplification - in practice we'd need to compute
                # gradients w.r.t each neuron activation
                for frame_idx in range(seq_len):
                    for neuron_idx in range(hidden_dim):
                        # Estimate neuron importance by corresponding gradient
                        neuron_attr = gradients[:, frame_idx].sum()
                        layer_attributions[frame_idx, neuron_idx] = neuron_attr
                
                neuron_attributions[f"layer_{layer_idx}"] = layer_attributions.cpu().numpy()
            
            return neuron_attributions
        
        # For frame-level explanations, we aggregate across frames
        else:
            # Convert to numpy for easier handling
            frame_attributions = shap_values.cpu().numpy()
            
            # Aggregate if requested
            if aggregate_frames == 'mean':
                return frame_attributions.mean(axis=0)
            elif aggregate_frames == 'sum':
                return frame_attributions.sum(axis=0)
            elif aggregate_frames == 'max':
                return frame_attributions.max(axis=0)
            else:
                return frame_attributions
    
    def explain(
        self, 
        input_values, 
        target_class=None,
        mode='frame',
        aggregate_frames='mean'
    ):
        """
        Generate GradientSHAP explanations for a given input.
        
        Args:
            input_values: Input audio values
            target_class: Target class for explanation
            mode: 'frame' or 'neuron' level explanations
            aggregate_frames: Method to aggregate frame attributions
            
        Returns:
            Dict: Explanation results including attributions
        """
        print(f"Generating GradientSHAP explanations in {mode} mode")
        
        # Move input to device
        if not isinstance(input_values, torch.Tensor):
            input_values = torch.tensor(input_values, dtype=torch.float32)
        
        input_values = input_values.to(self.device)
        
        # Generate reference samples
        samples, alphas = self.generate_reference_samples(input_values)
        num_samples = samples.shape[0]
        
        # Process in batches
        all_gradients = []
        
        for batch_idx in tqdm(range(0, num_samples, self.batch_size), desc="Computing gradients"):
            batch_end = min(batch_idx + self.batch_size, num_samples)
            batch_samples = samples[batch_idx:batch_end]
            batch_alphas = alphas[batch_idx:batch_end]
            
            # Get gradients for this batch
            batch_gradients = self._compute_sample_gradients(batch_samples, target_class)
            all_gradients.append(batch_gradients)
        
        # Combine gradients from all batches
        all_gradients = torch.cat(all_gradients, dim=0)
        
        # Compute integrated gradients by weighting according to SHAP formulation
        shap_values = self._compute_shap_values(all_gradients, alphas, input_values, mode, aggregate_frames)
        
        # Run forward pass to get predictions
        with torch.no_grad():
            inputs = self.feature_extractor(input_values, sampling_rate=16000, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            outputs = self.model(**inputs)
            
            # Get class predictions using classifier
            if hasattr(outputs, 'last_hidden_state'):
                pooled = outputs.last_hidden_state.mean(dim=1)
                logits = self.classifier(pooled)
                probs = torch.softmax(logits, dim=1)
                
                # Get predicted class if target not specified
                if target_class is None:
                    target_class = torch.argmax(probs, dim=1).item()
                
                prediction = {
                    'class': target_class,
                    'probability': probs[0, target_class].item()
                }
            else:
                prediction = {'class': -1, 'probability': 0.0}
            
        return {
            'attributions': shap_values,
            'prediction': prediction,
            'mode': mode,
            'aggregate_method': aggregate_frames,
        }

# Create an instance of GradientSHAP
explainer = GradientSHAP(
    model=wavlm_model,
    feature_extractor=feature_extractor,
    device=device,
    num_samples=20,  # Using a small number for demonstration
    classifier=classifier,
    batch_size=4
)

print("GradientSHAP explainer initialized successfully!")

## 6. Working with Audio Files

Now that we have our GradientSHAP explainer set up, let's load a sample audio file and prepare it for explanation. We'll use librosa to load and preprocess the audio, then feed it into our explainer.

For this demonstration, we'll create a synthetic audio sample if you don't have one available.

In [None]:
# Function to create a synthetic audio sample for demonstration
def create_synthetic_audio(duration=3, sr=16000):
    """Create a synthetic audio sample with some speech-like characteristics."""
    # Create time array
    t = np.linspace(0, duration, int(sr * duration), endpoint=False)
    
    # Generate a mixture of frequencies that roughly mimics speech formants
    freqs = [120, 240, 500, 1000, 2000]
    amplitudes = [1.0, 0.5, 0.3, 0.2, 0.1]
    
    # Create base signal
    signal = np.zeros_like(t)
    for freq, amp in zip(freqs, amplitudes):
        signal += amp * np.sin(2 * np.pi * freq * t)
    
    # Add some amplitude modulation to simulate speech syllables
    syllable_rate = 4  # 4 syllables per second
    modulation = 0.5 + 0.5 * np.sin(2 * np.pi * syllable_rate * t)
    signal *= modulation
    
    # Add some noise
    noise = np.random.normal(0, 0.05, size=len(t))
    signal += noise
    
    # Normalize
    signal = signal / np.max(np.abs(signal))
    
    return signal, sr

# Create a synthetic audio sample
audio_signal, sample_rate = create_synthetic_audio(duration=3, sr=16000)

# Visualize the audio
plt.figure(figsize=(12, 4))
plt.plot(np.arange(len(audio_signal)) / sample_rate, audio_signal)
plt.title("Synthetic Audio Sample")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.grid(alpha=0.3)
plt.show()

# Display spectrogram
plt.figure(figsize=(12, 4))
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_signal)), ref=np.max)
librosa.display.specshow(D, sr=sample_rate, x_axis='time', y_axis='hz')
plt.colorbar(format='%+2.0f dB')
plt.title("Spectrogram of Synthetic Audio")
plt.show()

# Convert to tensor for the model
audio_tensor = torch.tensor(audio_signal, dtype=torch.float32)

print(f"Audio shape: {audio_signal.shape}")
print(f"Sample rate: {sample_rate} Hz")
print(f"Duration: {len(audio_signal) / sample_rate:.2f} seconds")

## 7. Generating Explanations

Now let's use our GradientSHAP implementation to generate explanations for the audio sample. We'll demonstrate:

1. Frame-level explanations that show which parts of the audio most influence the model's prediction
2. How to interpret the attribution scores

We'll target emotion class 2 (happy) for this example, but in a real application, you might want to explain the predicted class or compare explanations across multiple classes.

In [None]:
# Generate frame-level explanations using GradientSHAP
target_class = 2  # Target "happy" emotion (adjust as needed)

# Generate explanations
explanation_results = explainer.explain(
    input_values=audio_tensor,
    target_class=target_class,
    mode='frame',
    aggregate_frames=None  # Don't aggregate, keep all attributions
)

# Extract attributions and prediction info
attributions = explanation_results['attributions']
prediction = explanation_results['prediction']

print(f"Prediction: Class {prediction['class']} with probability {prediction['probability']:.4f}")
print(f"Attribution shape: {attributions.shape}")

# Calculate summary statistics for attributions
if attributions.ndim > 1:
    # If we have multiple attribution values (e.g., for each sample)
    # Compute absolute values and mean across samples
    abs_attr = np.abs(attributions)
    mean_attr = abs_attr.mean(axis=0)
    std_attr = abs_attr.std(axis=0)
    
    print(f"Mean absolute attribution: {mean_attr.mean():.6f}")
    print(f"Max absolute attribution: {mean_attr.max():.6f}")
else:
    # If we have a single attribution vector
    abs_attr = np.abs(attributions)
    
    print(f"Mean absolute attribution: {abs_attr.mean():.6f}")
    print(f"Max absolute attribution: {abs_attr.max():.6f}")
    
    # Find top 5 frames with highest attribution
    top_indices = np.argsort(abs_attr)[-5:]
    top_times = top_indices / sample_rate
    
    print("\nTop 5 most influential frames:")
    for i, (idx, time) in enumerate(zip(top_indices, top_times), 1):
        print(f"{i}. Frame {idx} at time {time:.3f}s - Attribution: {abs_attr[idx]:.6f}")

# Generate neuron-level explanations for a specific layer
# This can be very resource-intensive, so we'll just show the code
"""
# Generate neuron-level explanations
neuron_explanations = explainer.explain(
    input_values=audio_tensor,
    target_class=target_class,
    mode='neuron',
)

# Extract attributions for a specific layer
layer_idx = 6  # Middle layer
if f"layer_{layer_idx}" in neuron_explanations['attributions']:
    layer_attrs = neuron_explanations['attributions'][f"layer_{layer_idx}"]
    print(f"Layer {layer_idx} neuron attributions shape: {layer_attrs.shape}")
"""

## 8. Visualizing Attributions

Finally, let's create a comprehensive visualization that shows the attributions alongside the audio waveform and spectrogram. This will help us understand which parts of the audio are most important for the model's prediction.

We'll create a function to visualize frame-level attributions and highlight the most important frames.

In [None]:
def visualize_audio_attributions(audio, attributions, sample_rate, target_class, prediction_prob, top_k=20):
    """
    Visualize attributions for audio with waveform and spectrogram.
    
    Args:
        audio: Audio waveform
        attributions: Attribution scores from GradientSHAP
        sample_rate: Audio sample rate
        target_class: Target class for explanation
        prediction_prob: Prediction probability for the target class
        top_k: Number of top frames to highlight
    """
    # Create time array
    times = np.arange(len(audio)) / sample_rate
    
    # Process attributions
    if attributions.ndim > 1:
        # Take absolute value and mean across batch dimension
        attr_agg = np.abs(attributions).mean(axis=0)
    else:
        attr_agg = np.abs(attributions)
    
    # Normalize for visualization
    attr_norm = attr_agg / (attr_agg.max() + 1e-10)
    
    # Create figure with 3 subplots
    fig, axes = plt.subplots(3, 1, figsize=(15, 12), sharex=True)
    
    # Plot 1: Audio waveform
    axes[0].plot(times, audio, color='blue', alpha=0.7)
    axes[0].set_title(f"Audio Waveform - Target: Class {target_class} (Probability: {prediction_prob:.4f})")
    axes[0].set_ylabel("Amplitude")
    axes[0].grid(alpha=0.3)
    
    # Plot 2: Spectrogram
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
    img = librosa.display.specshow(D, sr=sample_rate, x_axis='time', y_axis='hz', ax=axes[1])
    fig.colorbar(img, ax=axes[1], format='%+2.0f dB')
    axes[1].set_title("Spectrogram")
    
    # Plot 3: Attributions
    axes[2].plot(times, attr_norm, color='red', alpha=0.8)
    axes[2].set_title("Frame-wise GradientSHAP Attribution")
    axes[2].set_ylabel("Normalized Attribution")
    axes[2].set_xlabel("Time (s)")
    axes[2].grid(alpha=0.3)
    
    # Highlight top k frames
    if top_k > 0 and top_k < len(attr_norm):
        top_indices = np.argsort(attr_norm)[-top_k:]
        axes[2].scatter(times[top_indices], attr_norm[top_indices], 
                     color='darkred', s=50, zorder=10, label=f"Top {top_k} frames")
        
        # Mark these points on the waveform and spectrogram too
        for idx in top_indices:
            t = times[idx]
            # Vertical lines across all subplots
            for ax in axes:
                ax.axvline(x=t, color='darkred', alpha=0.3, linestyle='--')
    
    # Add legend to the last plot
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()
    
    return fig

# Visualize the attributions
fig = visualize_audio_attributions(
    audio=audio_signal, 
    attributions=attributions,
    sample_rate=sample_rate,
    target_class=target_class,
    prediction_prob=prediction['probability'],
    top_k=10
)

# Optional: Save the figure
# fig.savefig('gradientshap_explanation.png', dpi=300, bbox_inches='tight')

## Conclusion

In this tutorial, we've implemented and demonstrated GradientSHAP for explaining WavLM speech model predictions. We've covered:

1. The theory behind GradientSHAP and how it combines ideas from Integrated Gradients and SHAP
2. How to load and prepare WavLM models and audio data
3. Implementation of a complete GradientSHAP explainer for speech models
4. How to generate and visualize frame-level and neuron-level attributions

### Next Steps

To further explore this topic, you might want to:

1. Apply this to real audio datasets like RAVDESS for emotion recognition
2. Compare attributions across different emotion classes to understand what makes certain emotions distinct
3. Extend the implementation to support other speech models like Wav2Vec2 or HuBERT
4. Use neuron-level attributions to understand which features the model learns at different layers
5. Compare GradientSHAP with other explainability methods like LIME or Integrated Gradients

### Additional Resources

- [SHAP Library](https://github.com/slundberg/shap) - Includes implementations of many SHAP variants
- [WavLM Paper](https://arxiv.org/abs/2110.13900) - Details on the WavLM model architecture
- [Captum](https://captum.ai/) - PyTorch model interpretability library with many explainability algorithms