In [1]:
import sys

sys.path.append("..")

In [None]:
import ast
import json
import numpy as np
import pandas as pd
import torch
import librosa
import matplotlib.pyplot as plt
import seaborn as sns
import pytorch_lightning as pl
from torch.utils.data import TensorDataset, DataLoader
from typing import List
from collections import defaultdict
from tqdm import tqdm

from src.tcav.model import MusicGenreClassifier
from src.tcav.tcav import TCAV
from src.constants import GTZAN_PATH, METADATA_CSV_PATH, AUDIO_DATA_PATH, TCAV_RESULTS_PATH, GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH

pl.seed_everything(42)

sns.set_style('whitegrid')
plt.rcParams.update({
    'figure.figsize': (18, 8),
    'font.size': 32,
    'axes.titlesize': 28,
    'axes.labelsize': 25,
    'xtick.labelsize': 23,
    'ytick.labelsize': 23,
    'legend.fontsize': 21
})

TRAIN_MODEL = False
RUN_TCAV_ANALYSIS = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Seed set to 42


Device: cuda


## Load Custom Concept Dataset

In [4]:
# Load concept-to-tags mapping
CONCEPTS = json.load(open("../data/concepts_to_tags.json", "r"))

print("Available concept categories:")
for cat, tags in CONCEPTS.items():
    print(f"  {cat}: {len(tags)} tags (e.g., {tags[:3]})")

Available concept categories:
  tempo: 50 tags (e.g., ['medium tempo', 'slow tempo', 'fast tempo'])
  genre: 50 tags (e.g., ['rock', 'pop', 'electronic music'])
  mood: 50 tags (e.g., ['emotional', 'passionate', 'energetic'])
  instrument: 50 tags (e.g., ['acoustic drums', 'electric guitar', 'bass guitar'])


In [5]:
# Create reverse mapping
TAG_TO_CATEGORY = {}
for cat, tags in CONCEPTS.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat

In [6]:
def preprocess_audio(audio_array: np.ndarray, sr: int, target_sr: int = 16000, duration: float = 3.0) -> torch.Tensor:
    """Preprocess audio to fixed length and sample rate."""
    if sr != target_sr:
        audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=target_sr)
    
    target_length = int(target_sr * duration)
    if len(audio_array) > target_length:
        audio_array = audio_array[:target_length]
    else:
        audio_array = np.pad(audio_array, (0, target_length - len(audio_array)))
    
    return torch.from_numpy(audio_array).float()

In [7]:
print("Loading GTZAN dataset from local files...")

GENRE_MAP = {
    'blues': 0, 'classical': 1, 'country': 2, 'disco': 3, 'hiphop': 4,
    'jazz': 5, 'metal': 6, 'pop': 7, 'reggae': 8, 'rock': 9
}
TARGET_GENRES = list(GENRE_MAP.keys())

print("Scanning audio files...")
audio_files = []
for genre in TARGET_GENRES:
    genre_path = GTZAN_PATH / genre
    wav_files = sorted(genre_path.glob("*.wav"))
    for wav_file in wav_files:
        audio_files.append({
            'path': wav_file,
            'genre': genre,
            'label': GENRE_MAP[genre]
        })

print(f"Loaded {len(audio_files)} audio files")
print(f"Genres: {TARGET_GENRES}")
print(f"Files per genre: ~{len(audio_files) // len(TARGET_GENRES)}")

Loading GTZAN dataset from local files...
Scanning audio files...
Loaded 1000 audio files
Genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
Files per genre: ~100


In [8]:
print("Loading and preprocessing GTZAN audio files...")
X_train, y_train = [], []

for file_info in tqdm(audio_files, desc="Loading audio"):
    try:
        audio, sr = librosa.load(file_info['path'], sr=None, mono=True)
    except Exception as e:
        print(f"Error loading {file_info['path']}: {e}")
        continue
    
    audio_tensor = preprocess_audio(audio, sr)
    X_train.append(audio_tensor)
    y_train.append(file_info['label'])

X_train = torch.stack(X_train)
y_train = torch.tensor(y_train)

print(f"Training data shape: {X_train.shape}")
print(f"Labels shape: {y_train.shape}")

Loading and preprocessing GTZAN audio files...


  audio, sr = librosa.load(file_info['path'], sr=None, mono=True)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Loading audio:  56%|█████▋    | 564/1000 [00:18<00:06, 65.17it/s]

Error loading /media/bruno/B438-3BD6/datasets/GTZAN/Data/genres_original/jazz/jazz.00054.wav: 


Loading audio: 100%|██████████| 1000/1000 [00:23<00:00, 42.65it/s]


Training data shape: torch.Size([999, 48000])
Labels shape: torch.Size([999])


## Load GTZAN Dataset and Train Classifier

In [9]:
model = MusicGenreClassifier(num_genres=len(TARGET_GENRES))
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

Model params: 128,714


In [10]:
GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True)

print(f"Training mode: {'ENABLED' if TRAIN_MODEL else 'DISABLED'}")
print(f"Model checkpoint path: {GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH}")

Training mode: DISABLED
Model checkpoint path: ../models/best-genre-classifier.ckpt


In [11]:
if TRAIN_MODEL:
    print("Training genre classifier on GTZAN...")
    
    indices = torch.randperm(len(X_train))
    train_size = int(0.8 * len(X_train))
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    
    X_train_split = X_train[train_indices]
    y_train_split = y_train[train_indices]
    X_val = X_train[val_indices]
    y_val = y_train[val_indices]
    
    train_dataset = TensorDataset(X_train_split, y_train_split)
    val_dataset = TensorDataset(X_val, y_val)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=16, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=16, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH.parent,
        filename='genre_classifier_best',
        monitor='val_acc',
        mode='max',
        save_top_k=1,
        save_last=True
    )
    
    trainer = pl.Trainer(
        max_epochs=100,
        accelerator='auto',
        devices=1,
        callbacks=[
            checkpoint_callback,
            pl.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=20,
                mode='min'
            )
        ],
    )
    
    trainer.fit(model, train_loader, val_loader)
    
    print(f"Model training complete!")
    print(f"Best model saved to: {checkpoint_callback.best_model_path}")
    
    model = MusicGenreClassifier.load_from_checkpoint(
        checkpoint_callback.best_model_path,
        num_genres=len(TARGET_GENRES)
    )
    
else:
    print("Loading pre-trained model from checkpoint...")
    
    if GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH.exists():
        model = MusicGenreClassifier.load_from_checkpoint(
            GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH,
            num_genres=len(TARGET_GENRES)
        )
        print(f"Model loaded from: {GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH}")
    else:
        print(f"Checkpoint not found at {GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH}")
        print("Please set TRAIN_MODEL=True to train a new model first.")
        raise FileNotFoundError(f"Model checkpoint not found: {GENRE_CLASSIFIER_MODEL_CHECKPOINT_PATH}")
model = model.to(device)
model.eval()

print(f"Model ready on {device}")

Loading pre-trained model from checkpoint...
Model loaded from: ../models/best-genre-classifier.ckpt
Model ready on cuda


In [13]:
print("Evaluating model accuracy on GTZAN validation set...")
val_dataset = TensorDataset(X_train, y_train)
val_loader = DataLoader(
    val_dataset, 
    batch_size=16, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs, return_bottleneck=False)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f"Validation Accuracy on GTZAN: {accuracy * 100:.2f}%")

Evaluating model accuracy on GTZAN validation set...


Evaluating: 100%|██████████| 63/63 [00:00<00:00, 122.91it/s]

Validation Accuracy on GTZAN: 86.19%





## Load Real Audio from Dataset

In [None]:
print("Loading generated audio dataset from local files...")

metadata_df = pd.read_csv(METADATA_CSV_PATH)
metadata_df['aspect_list_parsed'] = metadata_df['aspect_list'].apply(
    lambda x: ast.literal_eval(x) if pd.notna(x) else []
)

print(f"Loaded {len(metadata_df)} audio samples")
print(f"Audio files location: {AUDIO_DATA_PATH}")
print(f"\nSample aspects: {metadata_df['aspect_list_parsed'].iloc[0][:5]}...")

def get_audio_by_tags(tag: str, num_samples: int) -> List[torch.Tensor]:
    """Load real audio samples that match given tags."""
    samples = []
    
    matching_indices = []
    for idx, row in metadata_df.iterrows():
        aspect_list = row['aspect_list_parsed']
        if tag in aspect_list:
            matching_indices.append(idx)
            if len(matching_indices) >= num_samples:
                break
    
    loaded_count = 0
    for idx in matching_indices:
        filename = metadata_df.iloc[idx]['filename']
        audio_path = AUDIO_DATA_PATH / filename
        
        if audio_path.exists():
            try:
                audio, sr = librosa.load(audio_path, sr=None, mono=True)
                audio_tensor = preprocess_audio(audio, sr)
                samples.append(audio_tensor)
                loaded_count += 1
                if loaded_count >= num_samples:
                    break
            except Exception as e:
                print(f"Error loading {audio_path}: {e}")
    
    if num_samples > len(samples):
        print(f"Skipping concept {tag}: Found {len(samples)}/{num_samples} samples.")
        return None
    
    return samples[:num_samples]


def create_concept_dataset_from_audio(concept: str, num_samples: int = 20) -> torch.Tensor:
    """Create dataset from real audio matching concept tags."""
    samples = get_audio_by_tags(concept, num_samples)
    if samples is None:
        return None
    return torch.stack(samples)


def create_random_audio_dataset(num_samples: int = 30) -> torch.Tensor:
    """Create random audio samples from dataset."""
    samples = []
    indices = np.random.choice(len(metadata_df), min(num_samples, len(metadata_df)), replace=False)
    
    for idx in indices:
        filename = metadata_df.iloc[idx]['filename']
        audio_path = AUDIO_DATA_PATH / filename
        
        if audio_path.exists():
            audio, sr = librosa.load(audio_path, sr=None, mono=True)
            audio_tensor = preprocess_audio(audio, sr)
            samples.append(audio_tensor)
    
    return torch.stack(samples)

print("Audio loading functions defined.")

## Run TCAV Analysis

In [None]:
tcav = TCAV(model, device)
print("TCAV analyzer initialized with improved implementation.")

In [None]:
all_aspects = set()
for aspects in metadata_df['aspect_list_parsed']:
    all_aspects.update(aspects)

print(f"Total unique aspects in dataset: {len(all_aspects)}")
print(f"Sample aspects: {sorted(list(all_aspects))[:20]}")

ANALYSIS_CONCEPTS = {
    'tempo': CONCEPTS.get('tempo'),
    'instrument': CONCEPTS.get('instrument'),
}

In [None]:
if not RUN_TCAV_ANALYSIS:
    random_data = create_random_audio_dataset(num_samples=100)
    random_acts = tcav.get_activations(random_data)

    results = defaultdict(dict)
    for category, concept_list in ANALYSIS_CONCEPTS.items():
        print(f"Category: {category.upper()}")
        print("-" * 40)
        
        for concept in concept_list[:10]:
            print(f"  Loading audio for '{concept}'...")
            concept_data = create_concept_dataset_from_audio(concept, num_samples=50)

            if concept_data is None:
                print(f"  Skipping concept '{concept}' due to insufficient samples.")
                continue

            concept_acts = tcav.get_activations(concept_data)
            
            cav_result = tcav.train_cav(concept_acts, random_acts, num_runs=40)

            if cav_result['cav'] is None:
                print(f"  Skipping concept '{concept}' due to low CAV accuracy ({cav_result['accuracy']:.3f}).")
                continue
            
            genre_scores = {}
            for genre_name in TARGET_GENRES:
                genre_idx = GENRE_MAP[genre_name]
                genre_mask = y_train == genre_idx
                genre_samples = X_train[genre_mask]
                
                genre_acts = tcav.get_activations(genre_samples)
                genre_scores[genre_name] = tcav.compute_tcav_score(genre_acts, cav_result['cav'], method='cosine')
            
            results[category][concept] = {
                'cav_accuracy': cav_result['accuracy'],
                'genre_scores': genre_scores
            }
            
            print(f"  {concept}: CAV acc={cav_result['accuracy']:.3f}")

    with open(TCAV_RESULTS_PATH, "w") as f:
        json.dump(results, f, indent=4)
    print(f"TCAV results saved to: {TCAV_RESULTS_PATH}")
else:
    with open(TCAV_RESULTS_PATH, "r") as f:
        results = json.load(f)
    print(f"TCAV results loaded from: {TCAV_RESULTS_PATH}")

## 6. Visualize Results

In [None]:
all_concepts = [c for concepts in ANALYSIS_CONCEPTS.values() for c in concepts]

score_matrix = []
concept_names = []

for category, concept_dict in results.items():
    for concept, data in concept_dict.items():
        scores = [data['genre_scores'].get(g, 0.0) for g in TARGET_GENRES]
        score_matrix.append(scores)
        concept_names.append(f"{concept}\n({category})")

score_matrix = np.array(score_matrix)

fig, ax = plt.subplots(figsize=(22, 22))
sns.heatmap(
    score_matrix, 
    xticklabels=TARGET_GENRES,
    yticklabels=concept_names,
    annot=True, 
    fmt='.2f',
    vmin=0, vmax=1,
    ax=ax,
)
ax.set_title('TCAV Scores: Concept Importance per Genre', fontweight='bold')
ax.set_xlabel('Genre')
ax.set_ylabel('Concept')
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(18, 14))

concept_names = []
y_positions = []
y_pos = 0

for category, concept_dict in results.items():
    for concept, data in concept_dict.items():
        for i, (genre, score) in enumerate(data['genre_scores'].items()):
            size = score * 500  # Scale for visibility
            color = plt.cm.RdYlGn(score)  # Red to Green colormap
            
            ax.scatter(i, y_pos, s=size, c=[color], alpha=0.8, edgecolors='black', linewidth=0.5)
        
        concept_names.append(f"{concept} ({category})")
        y_positions.append(y_pos)
        y_pos += 1

ax.set_xticks(range(len(TARGET_GENRES)))
ax.set_xticklabels(TARGET_GENRES, rotation=45, ha='right')
ax.set_yticks(y_positions)
ax.set_yticklabels(concept_names)
ax.set_xlabel('Genre', fontsize=12, fontweight='bold')
ax.set_ylabel('Concept', fontsize=12, fontweight='bold')
ax.set_title('TCAV Scores: Concept Importance per Genre\n(size & color indicate strength)', 
             fontsize=14, fontweight='bold')
ax.grid(axis='x', alpha=0.3, linestyle='--')
ax.set_xlim(-0.5, len(TARGET_GENRES)-0.5)

sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlGn, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, pad=0.02)
cbar.set_label('TCAV Score', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

PLOT_GENRES = ['blues', 'country', 'jazz', 'pop', 'reggae', 'metal', 'rock']

PLOT_CONCEPTS = {
    'tempo': ['slow tempo', 'moderate tempo', 'fast tempo'],
    'instrument': ['bass guitar', 'piano', 'acoustic guitar', 'punchy kick']
}

concept_labels = []
all_scores_by_concept = []
all_categories = []

for category, concept_dict in results.items():
    if category not in PLOT_CONCEPTS:
        continue
    
    for concept, data in concept_dict.items():
        if concept not in PLOT_CONCEPTS[category]:
            continue
        
        concept_labels.append(concept)
        all_categories.append(category)
        scores = [data['genre_scores'].get(g, 0.0) for g in PLOT_GENRES]
        all_scores_by_concept.append(scores)

mean_scores = [np.mean(scores) for scores in all_scores_by_concept]
sorted_indices = np.argsort(mean_scores)[::-1]  # Sort descending

concept_labels_sorted = [concept_labels[i] for i in sorted_indices]
all_categories_sorted = [all_categories[i] for i in sorted_indices]
all_scores_sorted = [all_scores_by_concept[i] for i in sorted_indices]

fig, ax = plt.subplots(figsize=(17, 7))
x = np.arange(len(concept_labels_sorted))
width = 0.12

for concept_idx in range(len(concept_labels_sorted)):
    concept_scores = all_scores_sorted[concept_idx]
    
    genre_score_pairs = [(PLOT_GENRES[i], concept_scores[i], i) for i in range(len(PLOT_GENRES))]
    genre_score_pairs.sort(key=lambda x: x[1])
    
    for bar_idx, (genre_name, score, original_idx) in enumerate(genre_score_pairs):
        num_genres = len(PLOT_GENRES)
        position = x[concept_idx] + (bar_idx - num_genres/2) * width + width/2
        
        bar = ax.bar(position, score, width, 
                    color='coral', edgecolor='black', 
                    linewidth=1.2, alpha=0.7)
        
        ax.text(position, 0.02, genre_name.upper(), 
               rotation=90, ha='center', va='bottom', 
               fontsize=16, fontweight='bold',
               color='black')

ax.set_xlabel('', fontweight='bold')
ax.set_ylabel('TCAV Score', fontweight='bold')
ax.set_title('Concept Importance Across Genres\n(Higher → more important)', 
             fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels([c for c in concept_labels_sorted], 
                   rotation=45, ha='right')
ax.grid(axis='y', alpha=0.3)

plt.savefig("../docs/assets/concept_importance_across_genres.pdf", bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.patches as mpatches

PLOT_GENRES = ['blues', 'country', 'jazz', 'pop', 'reggae', 'metal', 'rock']

PLOT_CONCEPTS = {
    'tempo': ['slow tempo', 'moderate tempo', 'fast tempo'],
    'instrument': ['bass guitar', 'acoustic guitar', 'punchy kick']
}

all_plot_concepts = [c for cat_list in PLOT_CONCEPTS.values() for c in cat_list]
unique_concepts = sorted(list(set(all_plot_concepts)))
petroff_colors = ["#3f90da", "#ffa90e", "#bd1f01", "#94a4a2", "#832db6", "#a96b59", "#e76300", "#b9ac70", "#717581", "#92dadd"]
color_map = {concept: color for concept, color in zip(unique_concepts, petroff_colors)}

fig, ax = plt.subplots()
bar_width = 0.15
x_indices = np.arange(len(PLOT_GENRES))

for i, genre in enumerate(PLOT_GENRES):
    genre_data = []
    
    for category, concept_dict in results.items():
        if category not in PLOT_CONCEPTS: continue
        for concept, data in concept_dict.items():
            if concept not in PLOT_CONCEPTS[category]: continue
            
            score = data['genre_scores'].get(genre, 0.0)
            genre_data.append((concept, score))
    
    genre_data.sort(key=lambda x: x[1])
    n_bars = len(genre_data)
    total_group_width = n_bars * bar_width
    start_x = i - (total_group_width / 2) + (bar_width / 2)
    
    for j, (concept, score) in enumerate(genre_data):
        pos = start_x + (j * bar_width)
        ax.bar(pos, score, width=bar_width, edgecolor='black', linewidth=1.2, color=color_map[concept], alpha=0.8)

ax.set_xticks(x_indices)
ax.set_xticklabels([g.upper() for g in PLOT_GENRES], fontweight='bold')
ax.set_ylabel('TCAV Score', fontweight='bold')
ax.set_title('Concept Influence on Genres\n(Higher → more important)', fontweight='bold', pad=20)
ax.grid(axis='y', alpha=0.3, linestyle='--')

legend_patches = [mpatches.Patch(color=color_map[c], label=c) for c in unique_concepts]
ax.legend(handles=legend_patches, 
          title='Concepts', 
          bbox_to_anchor=(1.01, 1), 
          loc='upper left',
          frameon=True)

plt.tight_layout()
plt.savefig("../docs/assets/concept_influence_across_genres.pdf", bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(20, 15))

accuracies = []
labels = []
colors = []
color_map = {'tempo': 'steelblue', 'instrument': 'coral', 'mood': 'forestgreen'}

for category, concept_dict in results.items():
    for concept, data in concept_dict.items():
        accuracies.append(data['cav_accuracy'])
        labels.append(concept)
        colors.append(color_map.get(category, 'gray'))

bars = ax.bar(range(len(accuracies)), accuracies, color=colors, alpha=0.8)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha='right')
ax.set_ylabel('CAV Classifier Accuracy')
ax.set_title('CAV Training Accuracy per Concept', fontweight='bold')
ax.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random baseline')
ax.set_ylim([0, 1])
ax.legend()
ax.grid(axis='y', alpha=0.3)

from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=c, label=cat) for cat, c in color_map.items()]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()