# TCAV Analysis on CLAP Model for Genre Concepts
## Using MusicCaps Dataset to Interpret Genre-Specific Concept Importance

This notebook performs Testing with Concept Activation Vectors (TCAV) analysis on the pre-trained CLAP model to understand how the model's internal representations respond to different music genres (pop, rock, classical). We use the MusicCaps dataset with genre descriptions to create text-based concepts and analyze their influence across model layers.

## 1. Import Required Libraries

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.cm import get_cmap

# Model and data loading
from transformers import ClapModel, ClapProcessor, AutoTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score
from scipy import stats

# Audio processing
import librosa
import soundfile as sf
from tqdm import tqdm

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
random_state = 42
np.random.seed(random_state)
torch.manual_seed(random_state)

# Plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

## 2. Load and Explore MusicCaps Dataset

In [None]:
# Load MusicCaps dataset with descriptions
musiccaps_path = Path("../data/musiccaps_tags_to_description_dataset.csv")
df = pd.read_csv(musiccaps_path)

print(f"MusicCaps dataset shape: {df.shape}")
print(f"\nDataset columns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
print(df.head())
print(f"\nDataset info:")
print(df.info())

In [None]:
# Filter samples by genre
genres_of_interest = ['pop', 'rock', 'classical']

def extract_genre_from_tags(tags_str):
    """Extract genre from genre_tags column"""
    if pd.isna(tags_str):
        return None
    tags = str(tags_str).lower().split(', ')
    for tag in tags:
        for genre in genres_of_interest:
            if genre in tag.lower():
                return genre
    return None

df['primary_genre'] = df['genre_tags'].apply(extract_genre_from_tags)

# Create separate dataframes for each genre
genre_dfs = {}
for genre in genres_of_interest:
    genre_dfs[genre] = df[df['primary_genre'] == genre].head(50)  # Limit to 50 samples per genre
    print(f"\n{genre.upper()} samples: {len(genre_dfs[genre])}")
    print(f"Sample descriptions:")
    for idx, desc in enumerate(genre_dfs[genre]['caption'].head(3)):
        print(f"  {idx+1}. {desc[:100]}...")

# Combine all samples
all_genre_samples = pd.concat([genre_dfs[genre] for genre in genres_of_interest], ignore_index=True)
print(f"\nTotal samples for analysis: {len(all_genre_samples)}")

In [None]:
# Visualize dataset distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Genre distribution
genre_counts = all_genre_samples['primary_genre'].value_counts()
axes[0].bar(genre_counts.index, genre_counts.values, color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
axes[0].set_title('Distribution of Genres in Analysis Dataset')
axes[0].set_ylabel('Number of Samples')
axes[0].set_xlabel('Genre')

# Caption length distribution
all_genre_samples['caption_length'] = all_genre_samples['caption'].str.len()
for genre in genres_of_interest:
    genre_data = genre_dfs[genre]['caption'].str.len()
    axes[1].hist(genre_data, alpha=0.5, label=genre, bins=20)
axes[1].set_title('Distribution of Caption Lengths by Genre')
axes[1].set_xlabel('Caption Length')
axes[1].set_ylabel('Frequency')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"\nCaption length statistics:")
print(all_genre_samples.groupby('primary_genre')['caption_length'].describe())

## 3. Prepare Genre-Specific Text Concepts

In [None]:
# Define genre concepts with descriptive keywords
genre_concepts = {
    'pop': [
        'pop music with catchy melodies',
        'upbeat pop song with pop production',
        'pop genre with pop rhythms',
        'commercial pop music',
        'pop style music',
        'contemporary pop',
        'popular pop music',
    ],
    'rock': [
        'rock music with electric guitars',
        'rock genre with rock drums',
        'heavy rock sound',
        'rock and roll style',
        'rock guitar riffs',
        'loud rock song',
        'rock music style',
    ],
    'classical': [
        'classical music composition',
        'classical orchestra arrangement',
        'classical music with strings',
        'classical symphony style',
        'classical genre music',
        'classical instrumental',
        'classical orchestral piece',
    ]
}

print("Genre Concepts Prepared:")
for genre, concepts in genre_concepts.items():
    print(f"\n{genre.upper()}:")
    for i, concept in enumerate(concepts, 1):
        print(f"  {i}. {concept}")

## 4. Load Pre-trained CLAP Model

In [None]:
# Load CLAP model and processor
print("Loading CLAP model...")
model_name = "laion/clap-htsat-unfused"
model = ClapModel.from_pretrained(model_name).to(device)
processor = ClapProcessor.from_pretrained(model_name)
model.eval()

print(f"Model loaded: {model_name}")
print(f"\nModel architecture:")
print(model)

In [None]:
# Create a hook to extract intermediate layer activations
layer_activations = {}

def get_activation_hook(name):
    def hook(model, input, output):
        layer_activations[name] = output.detach()
    return hook

# Select layers for analysis (using text encoder layers as TCAV concepts are text-based)
selected_layers = [
    'text_model.encoder.layers.0',
    'text_model.encoder.layers.2',
    'text_model.encoder.layers.4',
    'text_model.encoder.layers.6',
    'text_model.encoder.layers.8',
    'text_model.encoder.layers.10',
    'text_model.encoder.layers.11',  # Final layer
]

# Register hooks for selected layers
handles = []
for layer_name in selected_layers:
    layer = dict(model.named_modules()).get(layer_name)
    if layer is not None:
        handle = layer.register_forward_hook(get_activation_hook(layer_name))
        handles.append(handle)
        print(f"Hook registered for layer: {layer_name}")

print(f"\nTotal hooks registered: {len(handles)}")

## 5. Extract Text Embeddings for Genre Concepts

In [None]:
# Extract text embeddings for genre concepts
genre_embeddings = {}
concept_activations = {}  # Store activations at each layer

print("Extracting text embeddings for genre concepts...\n")

with torch.no_grad():
    for genre, concepts in genre_concepts.items():
        print(f"Processing {genre.upper()} concepts...")
        genre_embeddings[genre] = []
        concept_activations[genre] = {layer: [] for layer in selected_layers}
        
        for concept in concepts:
            # Tokenize and process text
            inputs = processor(text=[concept], return_tensors="pt", padding=True).to(device)
            
            # Clear previous activations
            layer_activations.clear()
            
            # Forward pass to extract embeddings
            outputs = model.get_text_features(**inputs)
            
            # Store embeddings
            genre_embeddings[genre].append(outputs.cpu())
            
            # Store activations from each layer
            for layer_name in selected_layers:
                if layer_name in layer_activations:
                    activation = layer_activations[layer_name]
                    # Take the CLS token (first token) activations
                    cls_activation = activation[:, 0, :]  # Shape: (batch_size, hidden_dim)
                    concept_activations[genre][layer_name].append(cls_activation.cpu())
        
        # Average embeddings for each genre
        genre_embeddings[genre] = torch.cat(genre_embeddings[genre], dim=0).mean(dim=0, keepdim=True)
        
        # Average activations for each layer and genre
        for layer_name in selected_layers:
            if concept_activations[genre][layer_name]:
                concept_activations[genre][layer_name] = torch.cat(
                    concept_activations[genre][layer_name], dim=0
                ).mean(dim=0, keepdim=True)
        
        print(f"  ✓ Extracted embeddings, shape: {genre_embeddings[genre].shape}\n")

print("Text embeddings extraction complete!")

## 6. Extract Audio Activations from Sample Descriptions

In [None]:
# Extract activations from MusicCaps descriptions for each genre
sample_activations = {}  # {genre: {layer: activations}}

print("Extracting activations from MusicCaps descriptions...\n")

with torch.no_grad():
    for genre in genres_of_interest:
        print(f"Processing {genre.upper()} samples...")
        sample_activations[genre] = {layer: [] for layer in selected_layers}
        
        descriptions = genre_dfs[genre]['caption'].values
        
        for desc in tqdm(descriptions[:30], desc=f"  {genre}"):  # Limit to 30 samples per genre for efficiency
            try:
                # Tokenize and process text description
                inputs = processor(text=[str(desc)], return_tensors="pt", padding=True).to(device)
                
                # Clear previous activations
                layer_activations.clear()
                
                # Forward pass
                outputs = model.get_text_features(**inputs)
                
                # Store activations from each layer
                for layer_name in selected_layers:
                    if layer_name in layer_activations:
                        activation = layer_activations[layer_name]
                        # Take the CLS token activations
                        cls_activation = activation[:, 0, :]  # Shape: (1, hidden_dim)
                        sample_activations[genre][layer_name].append(cls_activation.cpu())
            except Exception as e:
                print(f"    Warning: Failed to process sample: {str(e)[:50]}")
                continue
        
        # Convert to tensors and stack
        for layer_name in selected_layers:
            if sample_activations[genre][layer_name]:
                sample_activations[genre][layer_name] = torch.cat(
                    sample_activations[genre][layer_name], dim=0
                )
        
        print(f"  ✓ Processed {len(descriptions[:30])} samples\n")

print("Sample activations extraction complete!")

## 7. Compute TCAV Vectors for Genre Concepts

In [None]:
# TCAV computation: Learn linear decision boundaries between genre concepts
tcav_vectors = {}  # {genre: {layer: direction_vector}}
tcav_accuracies = {}  # {genre: {layer: accuracy}}

print("Computing TCAV vectors...\n")

for genre in genres_of_interest:
    print(f"Computing TCAV for {genre.upper()}...")
    tcav_vectors[genre] = {}
    tcav_accuracies[genre] = {}
    
    for layer_name in selected_layers:
        # Prepare data for this genre vs. others
        # Positive class: this genre (concept activations)
        positive_samples = concept_activations[genre][layer_name].numpy()
        
        # Negative class: other genres (sampled from other genre descriptions)
        negative_samples_list = []
        for other_genre in genres_of_interest:
            if other_genre != genre:
                if layer_name in sample_activations[other_genre] and len(sample_activations[other_genre][layer_name]) > 0:
                    negative_samples_list.append(sample_activations[other_genre][layer_name].numpy())
        
        if not negative_samples_list:
            print(f"  ⚠ Skipping {layer_name}: insufficient negative samples")
            continue
        
        negative_samples = np.vstack(negative_samples_list)
        
        # Create binary classification dataset
        X = np.vstack([positive_samples, negative_samples])
        y = np.hstack([np.ones(len(positive_samples)), np.zeros(len(negative_samples))])
        
        # Normalize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Train linear classifier (TCAV)
        clf = LogisticRegression(max_iter=1000, random_state=random_state)
        clf.fit(X_scaled, y)
        
        # Store TCAV vector (direction of concept)
        tcav_vectors[genre][layer_name] = torch.tensor(clf.coef_[0], dtype=torch.float32)
        
        # Compute accuracy
        y_pred = clf.predict(X_scaled)
        accuracy = accuracy_score(y, y_pred)
        tcav_accuracies[genre][layer_name] = accuracy
    
    print(f"  ✓ TCAV vectors computed for {len(tcav_vectors[genre])} layers\n")

print("TCAV vector computation complete!")

## 8. Calculate TCAV Scores

In [None]:
# Calculate TCAV scores: measure of how much the concept matters at each layer
# TCAV score = how well the direction of the concept aligns with sample activations

tcav_scores = {}  # {genre: {layer: score}}

print("Computing TCAV scores...\n")

with torch.no_grad():
    for genre in genres_of_interest:
        print(f"Computing scores for {genre.upper()}...")
        tcav_scores[genre] = {}
        
        for layer_name in selected_layers:
            if layer_name not in tcav_vectors[genre]:
                continue
            
            # Get TCAV vector (direction of concept)
            tcav_direction = tcav_vectors[genre][layer_name].to(device)
            
            # Get sample activations for this genre
            if layer_name not in sample_activations[genre] or len(sample_activations[genre][layer_name]) == 0:
                tcav_scores[genre][layer_name] = 0.0
                continue
            
            sample_acts = sample_activations[genre][layer_name].to(device)
            
            # Normalize both vectors
            tcav_direction_norm = F.normalize(tcav_direction.unsqueeze(0), p=2, dim=1)[0]
            sample_acts_norm = F.normalize(sample_acts, p=2, dim=1)
            
            # Compute cosine similarity (dot product after normalization)
            # This measures how much samples align with the concept direction
            cosine_sims = torch.matmul(sample_acts_norm, tcav_direction_norm)
            
            # TCAV score is the mean of positive activations (ReLU)
            tcav_score = torch.relu(cosine_sims).mean().item()
            tcav_scores[genre][layer_name] = tcav_score
        
        print(f"  ✓ Scores computed\n")

print("TCAV score computation complete!")

# Create summary dataframe
tcav_df = pd.DataFrame(tcav_scores).T
print("\nTCAV Scores Summary:")
print(tcav_df.round(4))

## 9. Visualize TCAV Results

In [None]:
# Prepare data for visualization
tcav_data_for_plot = []
for genre in genres_of_interest:
    for layer_name in selected_layers:
        if layer_name in tcav_scores[genre]:
            layer_idx = int(layer_name.split('.')[-1])
            tcav_data_for_plot.append({
                'genre': genre,
                'layer': layer_idx,
                'layer_name': layer_name,
                'tcav_score': tcav_scores[genre][layer_name],
                'accuracy': tcav_accuracies[genre].get(layer_name, 0.0)
            })

tcav_plot_df = pd.DataFrame(tcav_data_for_plot)

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. TCAV Scores Heatmap
heatmap_data = tcav_plot_df.pivot_table(
    values='tcav_score', 
    index='genre', 
    columns='layer'
)
sns.heatmap(heatmap_data, annot=True, fmt='.3f', cmap='RdYlGn', ax=axes[0, 0], cbar_kws={'label': 'TCAV Score'})
axes[0, 0].set_title('TCAV Scores Across Model Layers', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Layer Index')
axes[0, 0].set_ylabel('Genre')

# 2. TCAV Scores Line Plot
for genre in genres_of_interest:
    genre_data = tcav_plot_df[tcav_plot_df['genre'] == genre].sort_values('layer')
    axes[0, 1].plot(genre_data['layer'], genre_data['tcav_score'], marker='o', label=genre, linewidth=2)

axes[0, 1].set_xlabel('Layer Index')
axes[0, 1].set_ylabel('TCAV Score')
axes[0, 1].set_title('Genre Sensitivity Across Layers', fontsize=12, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Classifier Accuracy
accuracy_data = tcav_plot_df.pivot_table(
    values='accuracy', 
    index='genre', 
    columns='layer'
)
sns.heatmap(accuracy_data, annot=True, fmt='.3f', cmap='Blues', ax=axes[1, 0], cbar_kws={'label': 'Accuracy'})
axes[1, 0].set_title('Binary Classification Accuracy by Genre and Layer', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Layer Index')
axes[1, 0].set_ylabel('Genre')

# 4. Average TCAV Score by Genre
avg_scores = tcav_plot_df.groupby('genre')['tcav_score'].mean().sort_values(ascending=False)
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
axes[1, 1].bar(avg_scores.index, avg_scores.values, color=colors[:len(avg_scores)])
axes[1, 1].set_title('Average TCAV Score by Genre', fontsize=12, fontweight='bold')
axes[1, 1].set_ylabel('Average TCAV Score')
axes[1, 1].set_xlabel('Genre')

for i, v in enumerate(avg_scores.values):
    axes[1, 1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print("\n" + "="*80)
print("TCAV ANALYSIS SUMMARY")
print("="*80)
print(f"\nAverage TCAV Scores by Genre:")
print(avg_scores.round(4))

## 10. Layer-wise Analysis and Interpretation

In [None]:
# Analyze which layers are most important for each genre
print("="*80)
print("LAYER-WISE IMPORTANCE ANALYSIS")
print("="*80)

layer_importance = {}

for genre in genres_of_interest:
    print(f"\n{genre.upper()} GENRE:")
    print("-" * 40)
    
    genre_scores = [(layer, score) for layer, score in tcav_scores[genre].items()]
    genre_scores.sort(key=lambda x: x[1], reverse=True)
    
    print("Top layers by TCAV score:")
    for i, (layer, score) in enumerate(genre_scores[:3], 1):
        layer_idx = int(layer.split('.')[-1])
        accuracy = tcav_accuracies[genre].get(layer, 0.0)
        print(f"  {i}. Layer {layer_idx}: TCAV Score = {score:.4f}, Accuracy = {accuracy:.4f}")
    
    # Store for later analysis
    if genre_scores:
        top_layer = genre_scores[0][0]
        layer_importance[genre] = (top_layer, genre_scores[0][1])

# Identify critical layers for each genre
print("\n" + "="*80)
print("CRITICAL LAYERS PER GENRE")
print("="*80)

for genre, (layer, score) in layer_importance.items():
    layer_idx = int(layer.split('.')[-1])
    print(f"{genre.upper():10} -> Layer {layer_idx} (TCAV Score: {score:.4f})")

## 11. Concept Vector Analysis

In [None]:
# Analyze the similarity between genre concept vectors
print("\n" + "="*80)
print("GENRE CONCEPT VECTOR SIMILARITY")
print("="*80)

# Select a middle layer for analysis
analysis_layer = selected_layers[len(selected_layers) // 2]
layer_idx = int(analysis_layer.split('.')[-1])

print(f"\nAnalyzing concept vectors at Layer {layer_idx}:\n")

# Compute pairwise similarities between genre TCAV vectors
from scipy.spatial.distance import cosine

similarity_matrix = np.zeros((len(genres_of_interest), len(genres_of_interest)))

for i, genre1 in enumerate(genres_of_interest):
    for j, genre2 in enumerate(genres_of_interest):
        if analysis_layer in tcav_vectors[genre1] and analysis_layer in tcav_vectors[genre2]:
            vec1 = tcav_vectors[genre1][analysis_layer].numpy()
            vec2 = tcav_vectors[genre2][analysis_layer].numpy()
            
            # Normalize vectors
            vec1_norm = vec1 / (np.linalg.norm(vec1) + 1e-8)
            vec2_norm = vec2 / (np.linalg.norm(vec2) + 1e-8)
            
            # Compute cosine similarity
            similarity = 1 - cosine(vec1_norm, vec2_norm)
            similarity_matrix[i, j] = similarity

# Visualize similarity matrix
fig, ax = plt.subplots(figsize=(8, 7))
sns.heatmap(similarity_matrix, annot=True, fmt='.3f', cmap='coolwarm', 
            xticklabels=genres_of_interest, yticklabels=genres_of_interest,
            ax=ax, cbar_kws={'label': 'Cosine Similarity'}, vmin=-1, vmax=1)
ax.set_title(f'Genre TCAV Vector Similarity at Layer {layer_idx}', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("Pairwise Similarity Matrix:")
for i, genre1 in enumerate(genres_of_interest):
    for j, genre2 in enumerate(genres_of_interest):
        print(f"  {genre1} <-> {genre2}: {similarity_matrix[i, j]:.4f}")

## 12. Key Findings and Interpretation

In [None]:
# Generate comprehensive summary report
report = """
================================================================================
                    TCAV ANALYSIS REPORT: CLAP MODEL GENRES
================================================================================

METHODOLOGY:
-----------
Testing with Concept Activation Vectors (TCAV) is used to interpret how neural 
networks respond to human-understandable concepts (genres). For each genre, we:

1. Create text-based concept embeddings using genre-specific descriptions
2. Extract intermediate layer activations from sample descriptions
3. Train linear classifiers to learn the direction of each concept
4. Compute TCAV scores measuring layer sensitivity to genre concepts

RESULTS SUMMARY:
----------------
"""

# Add genre-level statistics
report += "\nGenre-Level Statistics:\n"
for genre in genres_of_interest:
    genre_scores = [tcav_scores[genre].get(layer, 0) for layer in selected_layers]
    report += f"\n  {genre.upper()}:\n"
    report += f"    - Average TCAV Score: {np.mean(genre_scores):.4f}\n"
    report += f"    - Max TCAV Score: {np.max(genre_scores):.4f}\n"
    report += f"    - Min TCAV Score: {np.min(genre_scores):.4f}\n"

# Add layer-level analysis
report += "\n\nLayer-Level Analysis:\n"
report += "-" * 40 + "\n"

for layer_idx, layer_name in enumerate(selected_layers):
    layer_num = int(layer_name.split('.')[-1])
    report += f"\n  Layer {layer_num}:\n"
    for genre in genres_of_interest:
        score = tcav_scores[genre].get(layer_name, 0)
        accuracy = tcav_accuracies[genre].get(layer_name, 0)
        report += f"    {genre:10} - TCAV: {score:.4f}, Accuracy: {accuracy:.4f}\n"

# Add key insights
report += "\n\nKEY INSIGHTS:\n"
report += "=" * 80 + "\n"

# Most discriminative layers
report += "\n1. Most Discriminative Layers per Genre:\n"
for genre, (layer, score) in layer_importance.items():
    layer_idx = int(layer.split('.')[-1])
    report += f"   - {genre.upper():10} most responsive at Layer {layer_idx} (Score: {score:.4f})\n"

# Genre distinctiveness
report += "\n2. Genre Concept Distinctiveness:\n"
report += "   (Higher similarity indicates more overlapping representations)\n"
for i, genre1 in enumerate(genres_of_interest):
    for j, genre2 in enumerate(genres_of_interest):
        if i < j:
            sim = similarity_matrix[i, j]
            distinctiveness = "overlapping" if sim > 0.5 else "distinct"
            report += f"   - {genre1.upper()} vs {genre2.upper()}: {sim:.4f} ({distinctiveness})\n"

report += "\n" + "="*80

print(report)

## 13. Detailed Concept Vector Magnitude Analysis

In [None]:
# Analyze the magnitude and direction of concept vectors
fig, axes = plt.subplots(1, len(genres_of_interest), figsize=(16, 5))

concept_magnitudes = {}

for idx, genre in enumerate(genres_of_interest):
    magnitudes = []
    layer_indices = []
    
    for layer_name in selected_layers:
        if layer_name in tcav_vectors[genre]:
            vec = tcav_vectors[genre][layer_name].numpy()
            magnitude = np.linalg.norm(vec)
            magnitudes.append(magnitude)
            layer_idx = int(layer_name.split('.')[-1])
            layer_indices.append(layer_idx)
    
    concept_magnitudes[genre] = magnitudes
    
    # Plot magnitude across layers
    axes[idx].plot(layer_indices, magnitudes, marker='o', linewidth=2, markersize=8, color=['#FF6B6B', '#4ECDC4', '#45B7D1'][idx])
    axes[idx].fill_between(layer_indices, magnitudes, alpha=0.3, color=['#FF6B6B', '#4ECDC4', '#45B7D1'][idx])
    axes[idx].set_title(f'{genre.upper()} Concept Vector Magnitude', fontsize=11, fontweight='bold')
    axes[idx].set_xlabel('Layer Index')
    axes[idx].set_ylabel('Vector Magnitude')
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nConcept Vector Magnitudes by Genre and Layer:")
print("=" * 60)
for genre in genres_of_interest:
    print(f"\n{genre.upper()}:")
    mags = concept_magnitudes[genre]
    print(f"  Mean magnitude: {np.mean(mags):.4f}")
    print(f"  Max magnitude:  {np.max(mags):.4f} (Layer {selected_layers[np.argmax(mags)].split('.')[-1]})")
    print(f"  Min magnitude:  {np.min(mags):.4f} (Layer {selected_layers[np.argmin(mags)].split('.')[-1]})")

## 14. Export Results

In [None]:
# Export results to CSV and JSON
output_dir = Path("../outputs/tcav_analysis")
output_dir.mkdir(parents=True, exist_ok=True)

# Export TCAV scores
tcav_export_df = tcav_plot_df.copy()
tcav_scores_path = output_dir / "tcav_scores.csv"
tcav_export_df.to_csv(tcav_scores_path, index=False)
print(f"✓ TCAV scores exported to: {tcav_scores_path}")

# Export summary statistics
summary_stats = {
    'genre_average_scores': {genre: float(np.mean([tcav_scores[genre].get(layer, 0) for layer in selected_layers])) 
                             for genre in genres_of_interest},
    'genre_max_scores': {genre: float(np.max([tcav_scores[genre].get(layer, 0) for layer in selected_layers])) 
                         for genre in genres_of_interest},
    'critical_layers': {genre: {'layer': layer, 'score': float(score)} 
                       for genre, (layer, score) in layer_importance.items()},
    'concept_similarity': {
        f"{g1}_vs_{g2}": float(similarity_matrix[i, j])
        for i, g1 in enumerate(genres_of_interest)
        for j, g2 in enumerate(genres_of_interest)
    }
}

import json
summary_path = output_dir / "tcav_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary_stats, f, indent=2)
print(f"✓ Summary statistics exported to: {summary_path}")

# Create a detailed report file
report_path = output_dir / "tcav_analysis_report.txt"
with open(report_path, 'w') as f:
    f.write(report)
print(f"✓ Detailed report exported to: {report_path}")

print(f"\nAll results saved to: {output_dir}")
print(f"\nFiles created:")
print(f"  - {tcav_scores_path.name}")
print(f"  - {summary_path.name}")
print(f"  - {report_path.name}")

## 15. Cleanup and Final Notes

In [None]:
# Clean up hooks to free memory
for handle in handles:
    handle.remove()

print("✓ Hooks removed and memory cleaned up")

print("\n" + "="*80)
print("TCAV ANALYSIS COMPLETED SUCCESSFULLY")
print("="*80)

print(f"""
NEXT STEPS:
-----------
1. Review the generated visualizations and heatmaps
2. Examine the exported CSV and JSON files in outputs/tcav_analysis/
3. Consider the following questions:
   - Which layers are most important for genre distinction?
   - Are genre concepts well-separated or do they overlap?
   - How do different genres activate different parts of the model?
   
INTERPRETATION GUIDE:
--------------------
- TCAV Score: Higher scores indicate stronger concept presence at that layer
- Accuracy: Higher accuracy means the concept is more linearly separable
- Similarity: Higher similarity between genres suggests overlapping representations
- Vector Magnitude: Larger magnitudes indicate stronger concept expression

For more details, see the full report in outputs/tcav_analysis/tcav_analysis_report.txt
""")