# Grad-CAM Analysis for Music Genre Classification

This notebook performs explainability analysis using Grad-CAM (Gradient-weighted Class Activation Mapping) to understand:

1. **Baseline (Correct Predictions)**: What frequency/time regions the model focuses on for correct classifications
2. **Natural Failures**: Where the model looks when it makes mistakes (without adversarial attacks)
3. **Adversarial Failures**: How model attention shifts after adversarial perturbations

## 1. Setup and Imports

In [None]:
import sys
from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F

# Add src to path
PROJECT_ROOT = Path('../').resolve()
sys.path.insert(0, str(PROJECT_ROOT / 'src'))

# Import custom modules
from models.cnn import GenreClassifierCNN
from attacks.adversarial import prepare_classifier, generate_fgsm_attack, generate_pgd_attack
from explainability.gradcam import GradCAMExplainer, visualize_gradcam, compare_gradcam_side_by_side

# Define paths
DATA_PATH = PROJECT_ROOT / 'data'
PROCESSED_PATH = DATA_PATH / 'processed'
RESULTS_PATH = PROJECT_ROOT / 'results'
MODEL_PATH = RESULTS_PATH / 'models'
FIGURES_PATH = RESULTS_PATH / 'figures'

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Project root: {PROJECT_ROOT}")
print(f"Using device: {device}")

## 2. Load Metadata and Test Data

In [None]:
# Load metadata
metadata_path = PROCESSED_PATH / 'metadata.json'
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

genre_to_id = metadata['genre_to_id']
id_to_genre = {int(k): v for k, v in metadata['id_to_genre'].items()}
genres = metadata['genres']

print(f"Number of classes: {metadata['n_classes']}")
print(f"Genres: {genres}")
print(f"Input shape: {metadata['input_shape']}")

In [None]:
# Load test data
X_test = np.load(PROCESSED_PATH / 'test' / 'X.npy')
y_test = np.load(PROCESSED_PATH / 'test' / 'y.npy')
track_ids_test = np.load(PROCESSED_PATH / 'test' / 'track_ids.npy')

print(f"\nTest set shape: {X_test.shape}")
print(f"Test labels shape: {y_test.shape}")
print(f"Track IDs shape: {track_ids_test.shape}")
print(f"Input range: [{X_test.min():.3f}, {X_test.max():.3f}]")

## 3. Load Trained Model

In [None]:
# Load trained model
model = GenreClassifierCNN(num_classes=10)
model.load_state_dict(torch.load(MODEL_PATH / 'genre_cnn_pytorch_best.pth', map_location=device))
model.to(device)
model.eval()

print("Model loaded successfully!")
print(f"Model architecture: {model.__class__.__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Get Model Predictions on Test Set

In [None]:
# Get predictions on test set
X_test_tensor = torch.from_numpy(X_test).permute(0, 3, 1, 2).float().to(device)

with torch.no_grad():
    logits = model(X_test_tensor)
    predictions = torch.argmax(logits, dim=1).cpu().numpy()
    probabilities = F.softmax(logits, dim=1).cpu().numpy()

# Calculate accuracy
accuracy = np.mean(predictions == y_test)
print(f"Test accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

# Identify correct and incorrect predictions
correct_mask = predictions == y_test
incorrect_mask = ~correct_mask

print(f"\nCorrect predictions: {correct_mask.sum()}/{len(y_test)}")
print(f"Incorrect predictions: {incorrect_mask.sum()}/{len(y_test)}")

## 5. Initialize Grad-CAM Explainer

In [None]:
# Initialize Grad-CAM explainer with the last convolutional layer
target_layer = model.get_last_conv_layer()
explainer = GradCAMExplainer(model=model, target_layer=target_layer, device=device)

print(f"Grad-CAM explainer initialized")
print(f"Target layer: {target_layer.__class__.__name__}")
print(f"Target layer output channels: {target_layer.out_channels}")

## 6. Baseline Analysis - Correct Predictions

Select and analyze correctly classified samples to understand what frequency/time regions the model focuses on for correct decisions.

In [None]:
# Select diverse correctly classified samples (two high-confidence samples per genre)
correct_indices = np.where(correct_mask)[0]
correct_samples = []

print("Selecting correctly classified samples (two per genre with highest confidence)...")

# Get two high-confidence correct predictions per genre
for genre_id in range(10):
    genre_correct = [i for i in correct_indices if y_test[i] == genre_id]
    if genre_correct:
        # Sort by confidence and take the top 2
        confidences = [probabilities[i, genre_id] for i in genre_correct]
        sorted_indices = sorted(zip(genre_correct, confidences), key=lambda x: x[1], reverse=True)
        # Add top 2 samples for this genre
        for idx, conf in sorted_indices[:2]:
            correct_samples.append(idx)

print(f"\nSelected {len(correct_samples)} correctly classified samples:")
for idx in correct_samples:
    true_genre = id_to_genre[y_test[idx]]
    pred_genre = id_to_genre[predictions[idx]]
    conf = probabilities[idx, predictions[idx]]
    print(f"  Sample {idx}: {true_genre} (confidence: {conf:.4f})")

In [None]:
# Generate Grad-CAM visualizations for correct predictions
print("Generating Grad-CAM visualizations for correct predictions...\n")
print("=" * 80)

# Create subdirectory for baseline analysis
baseline_path = FIGURES_PATH / 'gradcam' / 'baseline_correct'
baseline_path.mkdir(parents=True, exist_ok=True)

for i, idx in enumerate(correct_samples):
    print(f"\nProcessing correct sample {i+1}/{len(correct_samples)}...")
    
    # Get sample data
    input_spec = X_test[idx]
    true_label = y_test[idx]
    
    # Analyze with Grad-CAM
    analysis = explainer.analyze_sample(
        spectrogram=input_spec,
        true_label=true_label,
        label_names=id_to_genre,
        get_prediction=True
    )
    
    # Create visualization
    title = f"Correct Prediction - Sample {i+1}"
    save_path = baseline_path / f'correct_{i+1:02d}_{analysis["true_label_name"]}.png'
    
    visualize_gradcam(
        spectrogram=input_spec,
        heatmap=analysis['heatmap'],
        title=title,
        prediction=analysis['predicted_label'],
        true_label=analysis['true_label_name'],
        confidence=analysis['confidence'],
        save_path=str(save_path),
        show_plot=True
    )
    print("-" * 80)

print(f"\n✓ Section 1 complete! Visualizations saved to {baseline_path}")
print("=" * 80)

**Observations: Attention Patterns in Correct Predictions**

- The model shows **genre-specific attention patterns** - different music genres activate different regions of the spectrograms
- Attention patterns are **consistent within the same genre** across different samples, indicating reliable learned features
- The model focuses on **localized regions** rather than the entire spectrogram uniformly
- The **consistency across the two samples per genre** confirms stable decision-making strategies

## 7. Natural Failures - Misclassifications Without Attacks

Analyze naturally misclassified examples to understand where the model fails without adversarial perturbations.

In [None]:
# Select diverse misclassified samples (5 high-confidence mistakes)
incorrect_indices = np.where(incorrect_mask)[0]
misclassified_samples = []
comparison_correct_samples = []

print("Selecting misclassified samples (5 high-confidence mistakes from different genres)...")

# Get diverse misclassifications - try one per true genre
for genre_id in range(10):
    genre_incorrect = [i for i in incorrect_indices if y_test[i] == genre_id]
    if genre_incorrect:
        # Sort by confidence in wrong prediction (most confident mistakes)
        confidences = [probabilities[i, predictions[i]] for i in genre_incorrect]
        if confidences:
            best_idx = genre_incorrect[np.argmax(confidences)]
            misclassified_samples.append(best_idx)

# Take only the first 5 samples
misclassified_samples = misclassified_samples[:5]

# For each misclassification, find a correctly classified sample of the predicted genre
print("\nFinding correctly classified samples of the predicted genres for comparison...")
for idx in misclassified_samples:
    predicted_genre_id = predictions[idx]
    # Find correct predictions for this genre
    genre_correct = [i for i in correct_indices if y_test[i] == predicted_genre_id]
    if genre_correct:
        # Take the highest confidence correct prediction
        confidences = [probabilities[i, predicted_genre_id] for i in genre_correct]
        best_correct_idx = genre_correct[np.argmax(confidences)]
        comparison_correct_samples.append(best_correct_idx)
    else:
        comparison_correct_samples.append(None)

print(f"\nSelected {len(misclassified_samples)} misclassified samples:")
for i, idx in enumerate(misclassified_samples):
    true_genre = id_to_genre[y_test[idx]]
    pred_genre = id_to_genre[predictions[idx]]
    conf = probabilities[idx, predictions[idx]]
    comp_idx = comparison_correct_samples[i]
    comp_info = f" | Comparison: sample {comp_idx}" if comp_idx is not None else " | No comparison available"
    print(f"  Sample {idx}: True={true_genre}, Predicted={pred_genre} (confidence: {conf:.4f}){comp_info}")

In [None]:
# Generate Grad-CAM visualizations for misclassifications with comparisons
print("\nGenerating Grad-CAM visualizations for misclassifications...\n")
print("=" * 80)

# Create subdirectory for misclassification analysis
misclass_path = FIGURES_PATH / 'gradcam' / 'natural_failures'
misclass_path.mkdir(parents=True, exist_ok=True)

for i, idx in enumerate(misclassified_samples):
    print(f"\nProcessing misclassification {i+1}/{len(misclassified_samples)}...")
    
    # Get misclassified sample data
    input_spec = X_test[idx]
    true_label = y_test[idx]
    
    # Analyze misclassified sample with Grad-CAM
    analysis = explainer.analyze_sample(
        spectrogram=input_spec,
        true_label=true_label,
        label_names=id_to_genre,
        get_prediction=True
    )
    
    # Create visualization for misclassified sample
    title = f"Misclassification - Sample {i+1}"
    save_path = misclass_path / f'misclass_{i+1:02d}_{analysis["true_label_name"]}_as_{analysis["predicted_label"]}.png'
    
    visualize_gradcam(
        spectrogram=input_spec,
        heatmap=analysis['heatmap'],
        title=title,
        prediction=analysis['predicted_label'],
        true_label=analysis['true_label_name'],
        confidence=analysis['confidence'],
        save_path=str(save_path),
        show_plot=True
    )
    
    # Generate comparison with correctly classified sample of the predicted genre
    comp_idx = comparison_correct_samples[i]
    if comp_idx is not None:
        print(f"  Generating comparison with correctly classified {id_to_genre[predictions[idx]]} sample...")
        
        comp_spec = X_test[comp_idx]
        comp_label = y_test[comp_idx]
        
        comp_analysis = explainer.analyze_sample(
            spectrogram=comp_spec,
            true_label=comp_label,
            label_names=id_to_genre,
            get_prediction=True
        )
        
        comp_title = f"Correct {id_to_genre[predictions[idx]]} - For Comparison"
        comp_save_path = misclass_path / f'compare_{i+1:02d}_correct_{comp_analysis["true_label_name"]}.png'
        
        visualize_gradcam(
            spectrogram=comp_spec,
            heatmap=comp_analysis['heatmap'],
            title=comp_title,
            prediction=comp_analysis['predicted_label'],
            true_label=comp_analysis['true_label_name'],
            confidence=comp_analysis['confidence'],
            save_path=str(comp_save_path),
            show_plot=True
        )
    
    print("-" * 80)

print(f"\n✓ Section 2 complete! Visualizations saved to {misclass_path}")
print("=" * 80)

**Observations: Attention Patterns in Natural Failures**

- **Partially similar attention patterns to the wrong class**: When the model misclassifies a sample, its attention sometimes focuses on regions similar to correctly classified samples of the predicted (wrong) genre, though not always as closely matched
- **Genre confusion due to overlapping features**: Misclassifications occur when samples contain patterns characteristic of multiple genres, leading the model to focus on features that resemble a different genre
- **Attention remains localized and confident**: Even in misclassifications, the model maintains focused attention on specific regions rather than showing uncertain patterns
- **Varying similarity in decision strategies**: While some misclassified samples show attention patterns that closely match the predicted genre, others show less alignment, suggesting different types of confusion

## 8. Adversarial Failures - Attack Analysis

Compare clean samples with their adversarial versions to understand how attacks shift model attention.

In [None]:
# Prepare adversarial attack framework
art_classifier = prepare_classifier(
    model=model,
    device=device,
    input_shape=(1, 128, 130),
    num_classes=10
)

print("ART classifier prepared for adversarial attacks")

### 8.1: FGSM Attack Comparison

In [None]:
# Select all 20 samples from Section 1 for FGSM attack
attack_samples_fgsm = correct_samples[:20]
epsilon = 0.1  # FGSM perturbation strength

print(f"Generating FGSM adversarial examples with epsilon={epsilon}...")
print(f"Selected {len(attack_samples_fgsm)} samples for FGSM attack\n")

# Prepare batch for attack
X_clean_batch = X_test[attack_samples_fgsm]
y_clean_batch = y_test[attack_samples_fgsm]

# Generate FGSM adversarial examples
X_fgsm_batch = generate_fgsm_attack(
    classifier=art_classifier,
    X=X_clean_batch,
    y=y_clean_batch,
    eps=epsilon
)

print(f"✓ Generated {len(X_fgsm_batch)} FGSM adversarial examples")

In [None]:
# Evaluate adversarial examples
X_fgsm_tensor = torch.from_numpy(X_fgsm_batch).permute(0, 3, 1, 2).float().to(device)

with torch.no_grad():
    logits_fgsm = model(X_fgsm_tensor)
    predictions_fgsm = torch.argmax(logits_fgsm, dim=1).cpu().numpy()
    probabilities_fgsm = F.softmax(logits_fgsm, dim=1).cpu().numpy()

# Compare clean vs adversarial predictions
print("\nFGSM Attack Results:")
print("=" * 80)
successful_attacks = 0

for i, idx in enumerate(attack_samples_fgsm):
    true_label = id_to_genre[y_clean_batch[i]]
    pred_clean = id_to_genre[predictions[idx]]
    pred_adv = id_to_genre[predictions_fgsm[i]]
    conf_clean = probabilities[idx, predictions[idx]]
    conf_adv = probabilities_fgsm[i, predictions_fgsm[i]]
    
    is_fooled = predictions[idx] != predictions_fgsm[i]
    if is_fooled:
        successful_attacks += 1
    status = "✗ FOOLED" if is_fooled else "✓ ROBUST"
    
    print(f"\nSample {i+1} ({status}):")
    print(f"  True label:  {true_label}")
    print(f"  Clean pred:  {pred_clean} (conf: {conf_clean:.4f})")
    print(f"  FGSM pred:   {pred_adv} (conf: {conf_adv:.4f})")
    print("-" * 80)

print(f"\nAttack success rate: {successful_attacks}/{len(attack_samples_fgsm)} ({successful_attacks/len(attack_samples_fgsm)*100:.1f}%)")
print("=" * 80)

In [None]:
# Generate side-by-side Grad-CAM comparisons for FGSM
print("\nGenerating FGSM Grad-CAM comparisons...\n")
print("=" * 80)

# Create subdirectory for FGSM attack analysis
fgsm_path = FIGURES_PATH / 'gradcam' / 'adversarial_fgsm'
fgsm_path.mkdir(parents=True, exist_ok=True)

for i, idx in enumerate(attack_samples_fgsm):
    print(f"\nProcessing FGSM comparison {i+1}/{len(attack_samples_fgsm)}...")
    
    # Clean sample
    clean_spec = X_test[idx]
    clean_analysis = explainer.analyze_sample(
        spectrogram=clean_spec,
        true_label=y_test[idx],
        label_names=id_to_genre,
        get_prediction=True
    )
    
    # Adversarial sample
    adv_spec = X_fgsm_batch[i]
    adv_analysis = explainer.analyze_sample(
        spectrogram=adv_spec,
        true_label=y_test[idx],
        label_names=id_to_genre,
        get_prediction=True
    )
    
    # Create side-by-side comparison
    title = f"FGSM Attack Comparison - Sample {i+1}"
    save_path = fgsm_path / f'fgsm_{i+1:02d}_{clean_analysis["true_label_name"]}.png'
    
    compare_gradcam_side_by_side(
        spec1=clean_spec,
        heatmap1=clean_analysis['heatmap'],
        spec2=adv_spec,
        heatmap2=adv_analysis['heatmap'],
        title1="Clean",
        title2="FGSM (ε=0.1)",
        overall_title=title,
        pred1=clean_analysis['predicted_label'],
        pred2=adv_analysis['predicted_label'],
        true_label=clean_analysis['true_label_name'],
        conf1=clean_analysis['confidence'],
        conf2=adv_analysis['confidence'],
        save_path=str(save_path),
        show_plot=True
    )
    print("-" * 80)

print(f"\n✓ FGSM comparisons complete! Visualizations saved to {fgsm_path}")
print("=" * 80)

### 8.2: PGD Attack Comparison

In [None]:
# Use the same 20 samples for PGD attack (for direct comparison with FGSM)
attack_samples_pgd = correct_samples[:20]
epsilon_pgd = 0.1
eps_step = 0.01
max_iter = 40

print(f"Generating PGD adversarial examples...")
print(f"Parameters: epsilon={epsilon_pgd}, eps_step={eps_step}, max_iter={max_iter}")
print(f"Selected {len(attack_samples_pgd)} samples for PGD attack (same as FGSM)\n")

# Prepare batch for attack
X_clean_batch_pgd = X_test[attack_samples_pgd]
y_clean_batch_pgd = y_test[attack_samples_pgd]

# Generate PGD adversarial examples
X_pgd_batch = generate_pgd_attack(
    classifier=art_classifier,
    X=X_clean_batch_pgd,
    y=y_clean_batch_pgd,
    eps=epsilon_pgd,
    eps_step=eps_step,
    max_iter=max_iter
)

print(f"✓ Generated {len(X_pgd_batch)} PGD adversarial examples")

In [None]:
# Evaluate PGD adversarial examples
X_pgd_tensor = torch.from_numpy(X_pgd_batch).permute(0, 3, 1, 2).float().to(device)

with torch.no_grad():
    logits_pgd = model(X_pgd_tensor)
    predictions_pgd = torch.argmax(logits_pgd, dim=1).cpu().numpy()
    probabilities_pgd = F.softmax(logits_pgd, dim=1).cpu().numpy()

# Compare clean vs adversarial predictions
print("\nPGD Attack Results:")
print("=" * 80)
successful_attacks_pgd = 0

for i, idx in enumerate(attack_samples_pgd):
    true_label = id_to_genre[y_clean_batch_pgd[i]]
    pred_clean = id_to_genre[predictions[idx]]
    pred_adv = id_to_genre[predictions_pgd[i]]
    conf_clean = probabilities[idx, predictions[idx]]
    conf_adv = probabilities_pgd[i, predictions_pgd[i]]
    
    is_fooled = predictions[idx] != predictions_pgd[i]
    if is_fooled:
        successful_attacks_pgd += 1
    status = "✗ FOOLED" if is_fooled else "✓ ROBUST"
    
    print(f"\nSample {i+1} ({status}):")
    print(f"  True label:  {true_label}")
    print(f"  Clean pred:  {pred_clean} (conf: {conf_clean:.4f})")
    print(f"  PGD pred:    {pred_adv} (conf: {conf_adv:.4f})")
    print("-" * 80)

print(f"\nAttack success rate: {successful_attacks_pgd}/{len(attack_samples_pgd)} ({successful_attacks_pgd/len(attack_samples_pgd)*100:.1f}%)")
print("=" * 80)

In [None]:
# Generate side-by-side Grad-CAM comparisons for PGD
print("\nGenerating PGD Grad-CAM comparisons...\n")
print("=" * 80)

# Create subdirectory for PGD attack analysis
pgd_path = FIGURES_PATH / 'gradcam' / 'adversarial_pgd'
pgd_path.mkdir(parents=True, exist_ok=True)

for i, idx in enumerate(attack_samples_pgd):
    print(f"\nProcessing PGD comparison {i+1}/{len(attack_samples_pgd)}...")
    
    # Clean sample
    clean_spec = X_test[idx]
    clean_analysis = explainer.analyze_sample(
        spectrogram=clean_spec,
        true_label=y_test[idx],
        label_names=id_to_genre,
        get_prediction=True
    )
    
    # Adversarial sample
    adv_spec = X_pgd_batch[i]
    adv_analysis = explainer.analyze_sample(
        spectrogram=adv_spec,
        true_label=y_test[idx],
        label_names=id_to_genre,
        get_prediction=True
    )
    
    # Create side-by-side comparison
    title = f"PGD Attack Comparison - Sample {i+1}"
    save_path = pgd_path / f'pgd_{i+1:02d}_{clean_analysis["true_label_name"]}.png'
    
    compare_gradcam_side_by_side(
        spec1=clean_spec,
        heatmap1=clean_analysis['heatmap'],
        spec2=adv_spec,
        heatmap2=adv_analysis['heatmap'],
        title1="Clean",
        title2=f"PGD (ε={epsilon_pgd}, iter={max_iter})",
        overall_title=title,
        pred1=clean_analysis['predicted_label'],
        pred2=adv_analysis['predicted_label'],
        true_label=clean_analysis['true_label_name'],
        conf1=clean_analysis['confidence'],
        conf2=adv_analysis['confidence'],
        save_path=str(save_path),
        show_plot=True
    )
    print("-" * 80)

print(f"\n✓ PGD comparisons complete! Visualizations saved to {pgd_path}")
print("=" * 80)

**Observations: Adversarial Attacks**

- **Adversarial perturbations shift model attention**: Both FGSM and PGD attacks cause the model to focus on different regions of the spectrograms compared to clean samples
- **PGD attacks cause stronger attention shifts than FGSM**: The iterative nature of PGD results in more dramatic changes in attention patterns, making it more effective at fooling the model
- **Attention patterns in adversarial failures differ from natural failures**: While natural misclassifications show attention similar to the confused genre, adversarial attacks create artificial attention patterns that don't necessarily match any specific genre
- **Some consistent confusion patterns across attack types**: For certain genres, both FGSM and PGD attacks lead to the same misclassification, though given the small sample size (20 samples), this observation is not strongly conclusive

## 9. Summary

**Grad-CAM Analysis Complete!**

This notebook performed comprehensive explainability analysis.

### Output Files
All visualizations saved to: `results/figures/gradcam/`
- `baseline_correct/` - Correct predictions analysis (20 samples - 2 per genre)
- `natural_failures/` - Misclassification analysis with comparisons (10 samples: 5 misclassified + 5 correct for comparison)
- `adversarial_fgsm/` - FGSM attack comparisons (20 samples)
- `adversarial_pgd/` - PGD attack comparisons (20 samples)

**Total: 70 visualizations organized by analysis type**

### Key Insights

**1. Correct Predictions - Stable Genre-Specific Patterns**
- The model has learned **distinct attention patterns for each genre**, focusing on specific frequency and time regions characteristic of each music style
- **High consistency** between samples of the same genre demonstrates that the model relies on stable, discriminative features rather than random patterns
- Attention is **localized and confident**, indicating the model has identified robust genre signatures in the spectrograms

**2. Natural Failures - Feature Ambiguity and Genre Overlap**
- Misclassifications occur when samples contain **overlapping features** from multiple genres, causing the model to focus on characteristics resembling the wrong genre
- The model's attention patterns in misclassified samples show **varying degrees of similarity** to the predicted genre - some closely match while others are less aligned
- Even when wrong, the model maintains **confident, focused attention**, suggesting it's not uncertain but rather genuinely confused by ambiguous spectral patterns
- These failures represent **legitimate classification challenges** where genre boundaries are naturally blurred in the audio data

**3. Adversarial Attacks - Artificial Manipulation of Decision Boundaries**
- Both FGSM and PGD attacks successfully **shift model attention** to different regions, but through fundamentally different mechanisms than natural failures
- Adversarial failures show **artificial attention patterns** that don't necessarily correspond to any specific genre's characteristics, unlike natural failures which align with confused genre features

**Overall Conclusion**: The model demonstrates strong performance with consistent genre-specific attention strategies, but remains vulnerable to both natural genre ambiguity and adversarial perturbations.