# Guitar Transcription Pipeline - Component-by-Component Debugging

This notebook breaks down the 4-stage embedding pipeline into individual components for detailed inspection and debugging.

**Pipeline Overview:**
1. **Basic Pitch** - Spotify pretrained model (440 dims)
2. **Meta Encodec** - Audio compression (128 dims)
3. **VQ-VAE** - Discrete representation (64 dims + tokens)
4. **CLAP** - Semantic understanding (768 dims)
5. **Fusion** - Combined embeddings (768 dims)
6. **Audio Decoder** - Transcription (88 piano keys)

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Import our models
import sys
sys.path.append('..')

from models.basic_pitch_wrapper import BasicPitchFeatureExtractor
from models.huggingface_encodec import HuggingFaceEncodec
from models.vq_vae import KenaVQVAE
from models.huggingface_clap import HuggingFaceCLAP
from models.embedding_validation_decoder import EmbeddingValidationDecoder

## Generate Synthetic Input Audio

Create a synthetic guitar-like audio signal for testing.

In [None]:
def create_synthetic_guitar_audio(duration=3.0, sample_rate=22050, seed=42):
    """
    Create synthetic guitar audio with multiple harmonics and envelope.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    t = np.linspace(0, duration, int(duration * sample_rate))
    
    # Guitar chord: Em (E2, G3, B3, E4)
    frequencies = [82.41, 196.0, 246.94, 329.63]  # E2, G3, B3, E4
    audio = np.zeros_like(t)
    
    for i, freq in enumerate(frequencies):
        # Fundamental + harmonics
        fundamental = np.sin(2 * np.pi * freq * t)
        harmonic2 = 0.3 * np.sin(2 * np.pi * 2 * freq * t)
        harmonic3 = 0.1 * np.sin(2 * np.pi * 3 * freq * t)
        
        # String-specific envelope (different decay rates)
        decay_rate = 0.5 + i * 0.2  # Different decay for each string
        envelope = np.exp(-t * decay_rate)
        
        # Add some vibrato
        vibrato = 1 + 0.02 * np.sin(2 * np.pi * 5 * t)
        
        string_signal = (fundamental + harmonic2 + harmonic3) * envelope * vibrato
        audio += string_signal * (0.8 - i * 0.15)  # Different volumes
    
    # Add some noise for realism
    noise = 0.02 * np.random.randn(len(audio))
    audio += noise
    
    # Normalize
    audio = audio / np.max(np.abs(audio)) * 0.7
    
    return torch.tensor(audio, dtype=torch.float32)

# Create synthetic audio
audio_length = 3.0  # seconds
sample_rate = 22050
synthetic_audio = create_synthetic_guitar_audio(duration=audio_length, sample_rate=sample_rate)

print(f"Generated synthetic audio:")
print(f"  Shape: {synthetic_audio.shape}")
print(f"  Duration: {len(synthetic_audio) / sample_rate:.2f} seconds")
print(f"  Sample rate: {sample_rate} Hz")
print(f"  Range: [{synthetic_audio.min():.3f}, {synthetic_audio.max():.3f}]")

# Plot the audio waveform
plt.figure(figsize=(12, 4))
time_axis = np.linspace(0, audio_length, len(synthetic_audio))
plt.plot(time_axis, synthetic_audio.numpy())
plt.title('Synthetic Guitar Audio (Em Chord)')
plt.xlabel('Time (seconds)')
plt.ylabel('Amplitude')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Add batch dimension for processing
audio_batch = synthetic_audio.unsqueeze(0)  # [1, time]
print(f"\nBatched audio shape: {audio_batch.shape}")

## Stage 1: Basic Pitch Feature Extraction

Extract features using the real Spotify Basic Pitch model.

In [None]:
print("=" * 60)
print("STAGE 1: BASIC PITCH FEATURE EXTRACTION")
print("=" * 60)

# Initialize Basic Pitch
try:
    basic_pitch = BasicPitchFeatureExtractor(
        sample_rate=sample_rate,
        device='cpu'  # Use CPU for debugging
    )
    print("‚úÖ Basic Pitch initialized successfully")
    print(f"   Output dimension: {basic_pitch.output_dim}")
except Exception as e:
    print(f"‚ùå Basic Pitch initialization failed: {e}")
    basic_pitch = None

if basic_pitch is not None:
    # Extract features
    print("\nüîÑ Extracting Basic Pitch features...")
    try:
        pitch_features = basic_pitch(audio_batch)
        print(f"‚úÖ Features extracted successfully")
        print(f"   Shape: {pitch_features.shape}")
        print(f"   Data type: {pitch_features.dtype}")
        print(f"   Range: [{pitch_features.min():.3f}, {pitch_features.max():.3f}]")
        print(f"   Mean: {pitch_features.mean():.3f}")
        print(f"   Std: {pitch_features.std():.3f}")
        
        # Check for reasonable values
        has_nan = torch.isnan(pitch_features).any()
        has_inf = torch.isinf(pitch_features).any()
        print(f"   Contains NaN: {has_nan}")
        print(f"   Contains Inf: {has_inf}")
        
        # Visualize features
        plt.figure(figsize=(15, 8))
        
        # Plot feature heatmap (first 100 features for visibility)
        plt.subplot(2, 2, 1)
        features_to_show = pitch_features[0, :, :100].detach().numpy()
        plt.imshow(features_to_show.T, aspect='auto', origin='lower', cmap='viridis')
        plt.title('Basic Pitch Features (First 100 dims)')
        plt.xlabel('Time Frames')
        plt.ylabel('Feature Dimension')
        plt.colorbar()
        
        # Plot feature statistics over time
        plt.subplot(2, 2, 2)
        frame_means = pitch_features[0].mean(dim=1).detach().numpy()
        frame_stds = pitch_features[0].std(dim=1).detach().numpy()
        frames = np.arange(len(frame_means))
        plt.plot(frames, frame_means, label='Mean', alpha=0.7)
        plt.plot(frames, frame_stds, label='Std', alpha=0.7)
        plt.title('Feature Statistics Over Time')
        plt.xlabel('Time Frame')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot feature distribution
        plt.subplot(2, 2, 3)
        all_features = pitch_features[0].flatten().detach().numpy()
        plt.hist(all_features, bins=50, alpha=0.7, density=True)
        plt.title('Feature Value Distribution')
        plt.xlabel('Feature Value')
        plt.ylabel('Density')
        plt.grid(True, alpha=0.3)
        
        # Plot activation patterns
        plt.subplot(2, 2, 4)
        activation_sum = pitch_features[0].sum(dim=0).detach().numpy()
        plt.plot(activation_sum[:88], label='Onset (0-87)', alpha=0.7)
        plt.plot(activation_sum[88:88+88], label='Note (88-175)', alpha=0.7)
        plt.plot(activation_sum[88+88:88+88+64], label='Contour (176-239)', alpha=0.7)
        plt.title('Activation Patterns by Output Type')
        plt.xlabel('Pitch/Feature Index')
        plt.ylabel('Total Activation')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"‚ùå Feature extraction failed: {e}")
        pitch_features = None
else:
    pitch_features = None
    print("‚ö†Ô∏è Skipping Basic Pitch - using random features for testing")
    pitch_features = torch.randn(1, 129, 440)  # Approximate expected shape
    print(f"   Random features shape: {pitch_features.shape}")

## Stage 2: Meta Encodec Audio Compression

Compress audio using Meta's Encodec model.

In [None]:
print("=" * 60)
print("STAGE 2: META ENCODEC COMPRESSION")
print("=" * 60)

# Initialize Encodec
try:
    encodec = MetaEncodecWrapper(
        bandwidth=6.0,  # kbps
        sample_rate=sample_rate
    )
    print("‚úÖ Encodec initialized successfully")
    print(f"   Output dimension: {encodec.output_dim}")
    print(f"   Bandwidth: 6.0 kbps")
except Exception as e:
    print(f"‚ùå Encodec initialization failed: {e}")
    encodec = None

if encodec is not None:
    # Compress audio
    print("\nüîÑ Compressing audio with Encodec...")
    try:
        encodec_codes, encodec_embeddings = encodec(audio_batch)
        print(f"‚úÖ Audio compressed successfully")
        print(f"   Codes shape: {encodec_codes.shape}")
        print(f"   Embeddings shape: {encodec_embeddings.shape}")
        print(f"   Embeddings range: [{encodec_embeddings.min():.3f}, {encodec_embeddings.max():.3f}]")
        print(f"   Embeddings mean: {encodec_embeddings.mean():.3f}")
        print(f"   Embeddings std: {encodec_embeddings.std():.3f}")
        
        # Check codes statistics
        print(f"   Codes range: [{encodec_codes.min()}, {encodec_codes.max()}]")
        print(f"   Unique codes: {torch.unique(encodec_codes).numel()}")
        
        # Visualize compression results
        plt.figure(figsize=(15, 10))
        
        # Plot embedding heatmap
        plt.subplot(3, 2, 1)
        embeddings_viz = encodec_embeddings[0].detach().numpy()
        plt.imshow(embeddings_viz.T, aspect='auto', origin='lower', cmap='viridis')
        plt.title('Encodec Embeddings')
        plt.xlabel('Time Frames')
        plt.ylabel('Embedding Dimension')
        plt.colorbar()
        
        # Plot codes for each RVQ stage
        plt.subplot(3, 2, 2)
        codes_viz = encodec_codes[0].detach().numpy()
        plt.imshow(codes_viz.T, aspect='auto', origin='lower', cmap='tab10')
        plt.title('Encodec Codes (8 RVQ Stages)')
        plt.xlabel('Time Frames')
        plt.ylabel('RVQ Stage')
        plt.colorbar()
        
        # Plot embedding statistics over time
        plt.subplot(3, 2, 3)
        emb_means = encodec_embeddings[0].mean(dim=1).detach().numpy()
        emb_stds = encodec_embeddings[0].std(dim=1).detach().numpy()
        frames = np.arange(len(emb_means))
        plt.plot(frames, emb_means, label='Mean', alpha=0.7)
        plt.plot(frames, emb_stds, label='Std', alpha=0.7)
        plt.title('Embedding Statistics Over Time')
        plt.xlabel('Time Frame')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot code usage histogram
        plt.subplot(3, 2, 4)
        all_codes = encodec_codes.flatten().detach().numpy()
        plt.hist(all_codes, bins=50, alpha=0.7)
        plt.title('Codebook Usage Distribution')
        plt.xlabel('Code Index')
        plt.ylabel('Frequency')
        plt.grid(True, alpha=0.3)
        
        # Plot embedding distribution
        plt.subplot(3, 2, 5)
        all_embeddings = encodec_embeddings.flatten().detach().numpy()
        plt.hist(all_embeddings, bins=50, alpha=0.7, density=True)
        plt.title('Embedding Value Distribution')
        plt.xlabel('Embedding Value')
        plt.ylabel('Density')
        plt.grid(True, alpha=0.3)
        
        # Plot compression ratio info
        plt.subplot(3, 2, 6)
        original_size = audio_batch.numel() * 4  # 4 bytes per float32
        compressed_size = encodec_codes.numel() * 2  # Assume 2 bytes per code
        compression_ratio = original_size / compressed_size
        
        plt.bar(['Original', 'Compressed'], [original_size/1024, compressed_size/1024])
        plt.title(f'Compression Ratio: {compression_ratio:.1f}x')
        plt.ylabel('Size (KB)')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"‚ùå Encodec compression failed: {e}")
        encodec_codes = None
        encodec_embeddings = None
else:
    encodec_codes = None
    encodec_embeddings = None
    print("‚ö†Ô∏è Skipping Encodec - using random codes/embeddings for testing")
    # Approximate expected shapes
    encodec_codes = torch.randint(0, 1024, (1, 129, 8))  # [batch, time, 8_stages]
    encodec_embeddings = torch.randn(1, 129, 128)  # [batch, time, 128]
    print(f"   Random codes shape: {encodec_codes.shape}")
    print(f"   Random embeddings shape: {encodec_embeddings.shape}")

## Stage 3: Kena VQ-VAE Discrete Representation

Process features through VQ-VAE to get discrete tokens and embeddings.

In [None]:
print("=" * 60)
print("STAGE 3: KENA VQ-VAE DISCRETE REPRESENTATION")
print("=" * 60)

# Initialize VQ-VAE
try:
    vq_vae = KenaVQVAE(
        input_dim=440,  # From Basic Pitch
        codebook_size=512,
        codebook_dim=64,
        hidden_dims=[512, 512, 512, 512]
    )
    print("‚úÖ VQ-VAE initialized successfully")
    print(f"   Input dimension: 440")
    print(f"   Codebook size: 512")
    print(f"   Codebook dimension: 64")
    print(f"   Hidden dimensions: [512, 512, 512, 512]")
except Exception as e:
    print(f"‚ùå VQ-VAE initialization failed: {e}")
    vq_vae = None

if vq_vae is not None and pitch_features is not None:
    # Process through VQ-VAE
    print("\nüîÑ Processing through VQ-VAE...")
    try:
        vq_vae.eval()  # Set to eval mode for debugging
        with torch.no_grad():
            vq_outputs = vq_vae(pitch_features)
        
        vq_embeddings = vq_outputs['z_q']
        vq_indices = vq_outputs['indices']
        commitment_loss = vq_outputs['commitment_loss']
        
        print(f"‚úÖ VQ-VAE processing successful")
        print(f"   Embeddings shape: {vq_embeddings.shape}")
        print(f"   Indices shape: {vq_indices.shape}")
        print(f"   Commitment loss: {commitment_loss.item():.4f}")
        print(f"   Embeddings range: [{vq_embeddings.min():.3f}, {vq_embeddings.max():.3f}]")
        print(f"   Embeddings mean: {vq_embeddings.mean():.3f}")
        print(f"   Embeddings std: {vq_embeddings.std():.3f}")
        print(f"   Unique tokens: {torch.unique(vq_indices).numel()}/{512}")
        print(f"   Token range: [{vq_indices.min()}, {vq_indices.max()}]")
        
        # Visualize VQ-VAE results
        plt.figure(figsize=(15, 12))
        
        # Plot VQ embeddings heatmap
        plt.subplot(3, 3, 1)
        vq_viz = vq_embeddings[0].detach().numpy()
        plt.imshow(vq_viz.T, aspect='auto', origin='lower', cmap='viridis')
        plt.title('VQ-VAE Embeddings (64 dims)')
        plt.xlabel('Time Frames')
        plt.ylabel('Embedding Dimension')
        plt.colorbar()
        
        # Plot discrete tokens
        plt.subplot(3, 3, 2)
        tokens_viz = vq_indices[0].detach().numpy()
        plt.plot(tokens_viz, 'o-', alpha=0.7, markersize=3)
        plt.title('Discrete Token Sequence')
        plt.xlabel('Time Frame')
        plt.ylabel('Token Index')
        plt.grid(True, alpha=0.3)
        
        # Plot token usage histogram
        plt.subplot(3, 3, 3)
        unique_tokens, token_counts = torch.unique(vq_indices, return_counts=True)
        plt.bar(unique_tokens.numpy(), token_counts.numpy(), alpha=0.7)
        plt.title('Token Usage Distribution')
        plt.xlabel('Token Index')
        plt.ylabel('Frequency')
        plt.grid(True, alpha=0.3)
        
        # Plot embedding statistics over time
        plt.subplot(3, 3, 4)
        vq_means = vq_embeddings[0].mean(dim=1).detach().numpy()
        vq_stds = vq_embeddings[0].std(dim=1).detach().numpy()
        frames = np.arange(len(vq_means))
        plt.plot(frames, vq_means, label='Mean', alpha=0.7)
        plt.plot(frames, vq_stds, label='Std', alpha=0.7)
        plt.title('VQ Embedding Stats Over Time')
        plt.xlabel('Time Frame')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot embedding value distribution
        plt.subplot(3, 3, 5)
        all_vq_embeddings = vq_embeddings.flatten().detach().numpy()
        plt.hist(all_vq_embeddings, bins=50, alpha=0.7, density=True)
        plt.title('VQ Embedding Distribution')
        plt.xlabel('Embedding Value')
        plt.ylabel('Density')
        plt.grid(True, alpha=0.3)
        
        # Plot codebook utilization
        plt.subplot(3, 3, 6)
        codebook_usage = torch.zeros(512)
        for token in torch.unique(vq_indices):
            codebook_usage[token] = (vq_indices == token).sum()
        used_codes = (codebook_usage > 0).sum().item()
        plt.bar(range(512), codebook_usage.numpy(), alpha=0.7)
        plt.title(f'Codebook Utilization ({used_codes}/512 codes used)')
        plt.xlabel('Codebook Index')
        plt.ylabel('Usage Count')
        plt.grid(True, alpha=0.3)
        
        # Plot commitment loss info
        plt.subplot(3, 3, 7)
        plt.bar(['Commitment Loss'], [commitment_loss.item()], alpha=0.7)
        plt.title('VQ-VAE Training Signal')
        plt.ylabel('Loss Value')
        plt.grid(True, alpha=0.3)
        
        # Plot embedding dimension correlations (sample)
        plt.subplot(3, 3, 8)
        # Compute correlation matrix for first 16 dimensions
        sample_dims = vq_embeddings[0, :, :16].detach().numpy().T
        corr_matrix = np.corrcoef(sample_dims)
        plt.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
        plt.title('Embedding Dim Correlations (First 16)')
        plt.xlabel('Dimension')
        plt.ylabel('Dimension')
        plt.colorbar()
        
        # Plot token transition patterns
        plt.subplot(3, 3, 9)
        token_seq = vq_indices[0].detach().numpy()
        transitions = [(token_seq[i], token_seq[i+1]) for i in range(len(token_seq)-1)]
        from collections import Counter
        transition_counts = Counter(transitions)
        most_common = transition_counts.most_common(10)
        if most_common:
            trans_labels = [f"{t[0][0]}‚Üí{t[0][1]}" for t in most_common]
            trans_counts = [t[1] for t in most_common]
            plt.barh(range(len(trans_labels)), trans_counts, alpha=0.7)
            plt.yticks(range(len(trans_labels)), trans_labels)
            plt.title('Most Common Token Transitions')
            plt.xlabel('Frequency')
        else:
            plt.text(0.5, 0.5, 'No transitions\n(single frame?)', 
                    ha='center', va='center', transform=plt.gca().transAxes)
            plt.title('Token Transitions')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"‚ùå VQ-VAE processing failed: {e}")
        vq_embeddings = None
        vq_indices = None
        commitment_loss = None
else:
    vq_embeddings = None
    vq_indices = None
    commitment_loss = None
    print("‚ö†Ô∏è Skipping VQ-VAE - using random embeddings/indices for testing")
    vq_embeddings = torch.randn(1, 129, 64)
    vq_indices = torch.randint(0, 512, (1, 129))
    commitment_loss = torch.tensor(0.1)
    print(f"   Random VQ embeddings shape: {vq_embeddings.shape}")
    print(f"   Random VQ indices shape: {vq_indices.shape}")

## Stage 4: CLAP Semantic Understanding

Extract semantic embeddings using HuggingFace CLAP music model.

In [None]:
print("=" * 60)
print("STAGE 4: CLAP SEMANTIC UNDERSTANDING")
print("=" * 60)

# Initialize CLAP
try:
    clap_encoder = HuggingFaceCLAP(
        model_name="laion/larger_clap_music",
        output_dim=768,
        freeze_model=True
    )
    print("‚úÖ CLAP initialized successfully")
    print(f"   Model: laion/larger_clap_music")
    print(f"   Output dimension: 768")
    print(f"   Model frozen: True")
except Exception as e:
    print(f"‚ùå CLAP initialization failed: {e}")
    clap_encoder = None

if clap_encoder is not None:
    # Extract semantic embeddings
    print("\nüîÑ Extracting CLAP semantic embeddings...")
    try:
        with torch.no_grad():
            clap_embeddings = clap_encoder(audio_batch)
        
        print(f"‚úÖ CLAP embeddings extracted successfully")
        print(f"   Shape: {clap_embeddings.shape}")
        print(f"   Range: [{clap_embeddings.min():.3f}, {clap_embeddings.max():.3f}]")
        print(f"   Mean: {clap_embeddings.mean():.3f}")
        print(f"   Std: {clap_embeddings.std():.3f}")
        print(f"   L2 norm: {torch.norm(clap_embeddings, p=2, dim=-1).item():.3f}")
        
        # Test zero-shot classification
        print("\nüéµ Testing zero-shot audio classification...")
        guitar_labels = [
            "guitar chord",
            "piano music",
            "electric guitar",
            "acoustic guitar",
            "violin music",
            "guitar slide technique",
            "clean guitar playing",
            "distorted guitar"
        ]
        
        try:
            classification_results = clap_encoder.classify_audio(audio_batch[0], guitar_labels)
            print("   Classification results:")
            sorted_results = sorted(classification_results.items(), key=lambda x: x[1], reverse=True)
            for label, score in sorted_results:
                print(f"     {label}: {score:.3f}")
        except Exception as e:
            print(f"   ‚ö†Ô∏è Classification failed: {e}")
        
        # Visualize CLAP embeddings
        plt.figure(figsize=(15, 10))
        
        # Plot embedding values
        plt.subplot(2, 3, 1)
        emb_values = clap_embeddings[0].detach().numpy()
        plt.plot(emb_values, alpha=0.7)
        plt.title('CLAP Embedding Values')
        plt.xlabel('Dimension')
        plt.ylabel('Value')
        plt.grid(True, alpha=0.3)
        
        # Plot embedding distribution
        plt.subplot(2, 3, 2)
        plt.hist(emb_values, bins=50, alpha=0.7, density=True)
        plt.title('CLAP Embedding Distribution')
        plt.xlabel('Embedding Value')
        plt.ylabel('Density')
        plt.grid(True, alpha=0.3)
        
        # Plot top activated dimensions
        plt.subplot(2, 3, 3)
        abs_values = np.abs(emb_values)
        top_dims = np.argsort(abs_values)[-20:]  # Top 20 dimensions
        plt.bar(range(len(top_dims)), abs_values[top_dims], alpha=0.7)
        plt.title('Top 20 Activated Dimensions')
        plt.xlabel('Dimension Rank')
        plt.ylabel('Absolute Value')
        plt.grid(True, alpha=0.3)
        
        # Plot L2 norm info
        plt.subplot(2, 3, 4)
        l2_norm = torch.norm(clap_embeddings, p=2, dim=-1).item()
        plt.bar(['L2 Norm'], [l2_norm], alpha=0.7)
        plt.title('Embedding L2 Norm (Should be ~1.0)')
        plt.ylabel('Norm Value')
        plt.grid(True, alpha=0.3)
        
        # Plot classification scores if available
        plt.subplot(2, 3, 5)
        if 'classification_results' in locals() and classification_results:
            labels, scores = zip(*sorted_results[:6])  # Top 6 results
            plt.barh(range(len(labels)), scores, alpha=0.7)
            plt.yticks(range(len(labels)), [l[:15] + '...' if len(l) > 15 else l for l in labels])
            plt.title('Top Classification Scores')
            plt.xlabel('Confidence')
        else:
            plt.text(0.5, 0.5, 'Classification\nNot Available', 
                    ha='center', va='center', transform=plt.gca().transAxes)
            plt.title('Classification Scores')
        plt.grid(True, alpha=0.3)
        
        # Plot embedding sparsity
        plt.subplot(2, 3, 6)
        sparsity_levels = [0.01, 0.05, 0.1, 0.2]
        sparsity_counts = [(np.abs(emb_values) < level).sum() for level in sparsity_levels]
        sparsity_ratios = [count / len(emb_values) for count in sparsity_counts]
        plt.bar([f'<{level}' for level in sparsity_levels], sparsity_ratios, alpha=0.7)
        plt.title('Embedding Sparsity Analysis')
        plt.xlabel('Threshold')
        plt.ylabel('Fraction of Dimensions')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"‚ùå CLAP processing failed: {e}")
        clap_embeddings = None
else:
    clap_embeddings = None
    print("‚ö†Ô∏è Skipping CLAP - using random embeddings for testing")
    clap_embeddings = torch.randn(1, 768)
    # Normalize to simulate CLAP's normalized output
    clap_embeddings = torch.nn.functional.normalize(clap_embeddings, p=2, dim=-1)
    print(f"   Random CLAP embeddings shape: {clap_embeddings.shape}")
    print(f"   Random embeddings L2 norm: {torch.norm(clap_embeddings, p=2, dim=-1).item():.3f}")

## Stage 5: Embedding Fusion

Combine all embeddings into unified 768-dimensional representation.

In [None]:
print("=" * 60)
print("STAGE 5: EMBEDDING FUSION")
print("=" * 60)

# Fusion parameters (from pipeline)
embedding_dim = 768
fusion_weights = {
    'pitch': 0.4,
    'encodec': 0.2,
    'vq': 0.3,
    'semantic': 0.1
}

print(f"Target embedding dimension: {embedding_dim}")
print(f"Fusion weights: {fusion_weights}")

# Create projection layers (simplified version of pipeline logic)
pitch_proj = nn.Linear(440, embedding_dim) if pitch_features is not None else None
encodec_proj = nn.Linear(128, embedding_dim) if encodec_embeddings is not None else None
vq_proj = nn.Linear(64, embedding_dim) if vq_embeddings is not None else None
layer_norm = nn.LayerNorm(embedding_dim)

print("\nüîÑ Fusing embeddings...")

try:
    # Get the expected time dimension (use Basic Pitch as reference)
    if pitch_features is not None:
        target_time_frames = pitch_features.shape[1]
    elif vq_embeddings is not None:
        target_time_frames = vq_embeddings.shape[1]
    else:
        target_time_frames = 129  # Default
    
    print(f"   Target time frames: {target_time_frames}")
    
    # Project each component to embedding dimension
    projected_components = {}
    
    # 1. Basic Pitch projection
    if pitch_features is not None and pitch_proj is not None:
        pitch_embedded = pitch_proj(pitch_features)
        projected_components['pitch'] = pitch_embedded
        print(f"   ‚úÖ Pitch projected: {pitch_features.shape} ‚Üí {pitch_embedded.shape}")
    else:
        pitch_embedded = torch.zeros(1, target_time_frames, embedding_dim)
        projected_components['pitch'] = pitch_embedded
        print(f"   ‚ö†Ô∏è Pitch using zeros: {pitch_embedded.shape}")
    
    # 2. Encodec projection (with temporal alignment)
    if encodec_embeddings is not None and encodec_proj is not None:
        # Temporal alignment if needed
        if encodec_embeddings.shape[1] != target_time_frames:
            encodec_aligned = torch.nn.functional.interpolate(
                encodec_embeddings.transpose(1, 2),
                size=target_time_frames,
                mode='linear',
                align_corners=False
            ).transpose(1, 2)
            print(f"   üîÑ Encodec aligned: {encodec_embeddings.shape[1]} ‚Üí {target_time_frames} frames")
        else:
            encodec_aligned = encodec_embeddings
        
        encodec_embedded = encodec_proj(encodec_aligned)
        projected_components['encodec'] = encodec_embedded
        print(f"   ‚úÖ Encodec projected: {encodec_aligned.shape} ‚Üí {encodec_embedded.shape}")
    else:
        encodec_embedded = torch.zeros(1, target_time_frames, embedding_dim)
        projected_components['encodec'] = encodec_embedded
        print(f"   ‚ö†Ô∏è Encodec using zeros: {encodec_embedded.shape}")
    
    # 3. VQ-VAE projection (with temporal alignment)
    if vq_embeddings is not None and vq_proj is not None:
        # Temporal alignment if needed
        if vq_embeddings.shape[1] != target_time_frames:
            vq_aligned = torch.nn.functional.interpolate(
                vq_embeddings.transpose(1, 2),
                size=target_time_frames,
                mode='linear',
                align_corners=False
            ).transpose(1, 2)
            print(f"   üîÑ VQ-VAE aligned: {vq_embeddings.shape[1]} ‚Üí {target_time_frames} frames")
        else:
            vq_aligned = vq_embeddings
        
        vq_embedded = vq_proj(vq_aligned)
        projected_components['vq'] = vq_embedded
        print(f"   ‚úÖ VQ-VAE projected: {vq_aligned.shape} ‚Üí {vq_embedded.shape}")
    else:
        vq_embedded = torch.zeros(1, target_time_frames, embedding_dim)
        projected_components['vq'] = vq_embedded
        print(f"   ‚ö†Ô∏è VQ-VAE using zeros: {vq_embedded.shape}")
    
    # 4. CLAP expansion (broadcast to time dimension)
    if clap_embeddings is not None:
        clap_expanded = clap_embeddings.unsqueeze(1).repeat(1, target_time_frames, 1)
        projected_components['semantic'] = clap_expanded
        print(f"   ‚úÖ CLAP expanded: {clap_embeddings.shape} ‚Üí {clap_expanded.shape}")
    else:
        clap_expanded = torch.zeros(1, target_time_frames, embedding_dim)
        projected_components['semantic'] = clap_expanded
        print(f"   ‚ö†Ô∏è CLAP using zeros: {clap_expanded.shape}")
    
    # Weighted fusion
    fused_embeddings = (
        fusion_weights['pitch'] * projected_components['pitch'] +
        fusion_weights['encodec'] * projected_components['encodec'] +
        fusion_weights['vq'] * projected_components['vq'] +
        fusion_weights['semantic'] * projected_components['semantic']
    )
    
    # Apply layer normalization
    fused_embeddings = layer_norm(fused_embeddings)
    
    print(f"\n‚úÖ Fusion completed successfully")
    print(f"   Fused embeddings shape: {fused_embeddings.shape}")
    print(f"   Range: [{fused_embeddings.min():.3f}, {fused_embeddings.max():.3f}]")
    print(f"   Mean: {fused_embeddings.mean():.3f}")
    print(f"   Std: {fused_embeddings.std():.3f}")
    
    # Visualize fusion results
    plt.figure(figsize=(15, 12))
    
    # Plot individual component contributions
    plt.subplot(3, 3, 1)
    components_viz = {
        'Pitch': projected_components['pitch'][0].detach().numpy(),
        'Encodec': projected_components['encodec'][0].detach().numpy(),
        'VQ-VAE': projected_components['vq'][0].detach().numpy(),
        'CLAP': projected_components['semantic'][0].detach().numpy()
    }
    
    for i, (name, component) in enumerate(components_viz.items()):
        plt.subplot(3, 3, i+1)
        plt.imshow(component.T, aspect='auto', origin='lower', cmap='viridis')
        plt.title(f'{name} Component (Weight: {list(fusion_weights.values())[i]})')
        plt.xlabel('Time Frames')
        plt.ylabel('Embedding Dimension')
        plt.colorbar()
    
    # Plot fused embeddings
    plt.subplot(3, 3, 5)
    fused_viz = fused_embeddings[0].detach().numpy()
    plt.imshow(fused_viz.T, aspect='auto', origin='lower', cmap='viridis')
    plt.title('Fused Embeddings (768 dims)')
    plt.xlabel('Time Frames')
    plt.ylabel('Embedding Dimension')
    plt.colorbar()
    
    # Plot fusion statistics over time
    plt.subplot(3, 3, 6)
    fused_means = fused_embeddings[0].mean(dim=1).detach().numpy()
    fused_stds = fused_embeddings[0].std(dim=1).detach().numpy()
    frames = np.arange(len(fused_means))
    plt.plot(frames, fused_means, label='Mean', alpha=0.7)
    plt.plot(frames, fused_stds, label='Std', alpha=0.7)
    plt.title('Fused Embedding Stats Over Time')
    plt.xlabel('Time Frame')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot component contribution magnitudes
    plt.subplot(3, 3, 7)
    component_norms = {}
    for name, component in components_viz.items():
        component_norms[name] = np.linalg.norm(component, axis=1).mean()
    
    plt.bar(component_norms.keys(), component_norms.values(), alpha=0.7)
    plt.title('Average Component Magnitudes')
    plt.ylabel('L2 Norm')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Plot embedding distribution
    plt.subplot(3, 3, 8)
    all_fused = fused_embeddings.flatten().detach().numpy()
    plt.hist(all_fused, bins=50, alpha=0.7, density=True)
    plt.title('Fused Embedding Distribution')
    plt.xlabel('Embedding Value')
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)
    
    # Plot fusion weights
    plt.subplot(3, 3, 9)
    plt.pie(fusion_weights.values(), labels=fusion_weights.keys(), autopct='%1.1f%%')
    plt.title('Fusion Weight Distribution')
    
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"‚ùå Fusion failed: {e}")
    fused_embeddings = torch.zeros(1, 129, 768)
    print(f"   Using zero embeddings: {fused_embeddings.shape}")

## Stage 6: Audio Transcription Decoder

Process fused embeddings through CRNN decoder to get pitch predictions.

In [None]:
print("=" * 60)
print("STAGE 6: AUDIO TRANSCRIPTION DECODER")
print("=" * 60)

# Initialize decoder
try:
    decoder = EmbeddingValidationDecoder(
        embedding_dim=768,
        hidden_dim=256,
        gru_hidden=128,
        n_pitches=88,  # Piano range
        dropout=0.2,
        bidirectional=True
    )
    print("‚úÖ Audio decoder initialized successfully")
    print(f"   Input dimension: 768")
    print(f"   Hidden dimension: 256")
    print(f"   GRU hidden: 128 (bidirectional)")
    print(f"   Output pitches: 88 (A0-C8)")
except Exception as e:
    print(f"‚ùå Decoder initialization failed: {e}")
    decoder = None

if decoder is not None and fused_embeddings is not None:
    print("\nüîÑ Processing through audio decoder...")
    try:
        decoder.eval()  # Set to eval mode
        with torch.no_grad():
            decoder_outputs = decoder(fused_embeddings)
        
        onset_logits = decoder_outputs['onset_logits']
        frame_logits = decoder_outputs['frame_logits']
        confidence_logits = decoder_outputs['confidence_logits']
        onset_probs = decoder_outputs['onset_probs']
        frame_probs = decoder_outputs['frame_probs']
        confidence = decoder_outputs['confidence']
        
        print(f"‚úÖ Decoder processing successful")
        print(f"   Onset probabilities shape: {onset_probs.shape}")
        print(f"   Frame probabilities shape: {frame_probs.shape}")
        print(f"   Confidence shape: {confidence.shape}")
        print(f"   Onset range: [{onset_probs.min():.3f}, {onset_probs.max():.3f}]")
        print(f"   Frame range: [{frame_probs.min():.3f}, {frame_probs.max():.3f}]")
        print(f"   Confidence range: [{confidence.min():.3f}, {confidence.max():.3f}]")
        print(f"   Average onset activation: {onset_probs.mean():.3f}")
        print(f"   Average frame activation: {frame_probs.mean():.3f}")
        print(f"   Average confidence: {confidence.mean():.3f}")
        
        # Analyze predictions
        onset_threshold = 0.5
        frame_threshold = 0.5
        
        onset_detections = (onset_probs > onset_threshold).sum()
        frame_detections = (frame_probs > frame_threshold).sum()
        high_confidence = (confidence > 0.7).sum()
        
        print(f"   Onset detections (>{onset_threshold}): {onset_detections}")
        print(f"   Frame detections (>{frame_threshold}): {frame_detections}")
        print(f"   High confidence frames (>0.7): {high_confidence}")
        
        # Guitar pitch range analysis (E2 to E6: MIDI 40-88, piano keys 19-67)
        guitar_range_start = 19  # E2 in piano key index
        guitar_range_end = 67    # E6 in piano key index
        
        guitar_onset_activity = onset_probs[0, :, guitar_range_start:guitar_range_end].mean()
        guitar_frame_activity = frame_probs[0, :, guitar_range_start:guitar_range_end].mean()
        
        print(f"   Guitar range onset activity: {guitar_onset_activity:.3f}")
        print(f"   Guitar range frame activity: {guitar_frame_activity:.3f}")
        
        # Visualize decoder outputs
        plt.figure(figsize=(16, 14))
        
        # Plot onset probabilities piano roll
        plt.subplot(3, 3, 1)
        onset_viz = onset_probs[0].detach().numpy().T
        plt.imshow(onset_viz, aspect='auto', origin='lower', cmap='hot', vmin=0, vmax=1)
        plt.title('Onset Probabilities (88 Piano Keys)')
        plt.xlabel('Time Frames')
        plt.ylabel('Piano Key (A0-C8)')
        # Mark guitar range
        plt.axhline(guitar_range_start, color='cyan', linestyle='--', alpha=0.7, linewidth=1)
        plt.axhline(guitar_range_end, color='cyan', linestyle='--', alpha=0.7, linewidth=1)
        plt.colorbar()
        
        # Plot frame probabilities piano roll
        plt.subplot(3, 3, 2)
        frame_viz = frame_probs[0].detach().numpy().T
        plt.imshow(frame_viz, aspect='auto', origin='lower', cmap='hot', vmin=0, vmax=1)
        plt.title('Frame Probabilities (88 Piano Keys)')
        plt.xlabel('Time Frames')
        plt.ylabel('Piano Key (A0-C8)')
        # Mark guitar range
        plt.axhline(guitar_range_start, color='cyan', linestyle='--', alpha=0.7, linewidth=1)
        plt.axhline(guitar_range_end, color='cyan', linestyle='--', alpha=0.7, linewidth=1)
        plt.colorbar()
        
        # Plot confidence over time
        plt.subplot(3, 3, 3)
        confidence_viz = confidence[0].detach().numpy()
        frames = np.arange(len(confidence_viz))
        plt.plot(frames, confidence_viz, alpha=0.7)
        plt.axhline(0.7, color='red', linestyle='--', alpha=0.5, label='High Confidence')
        plt.title('Prediction Confidence Over Time')
        plt.xlabel('Time Frame')
        plt.ylabel('Confidence')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot guitar range activity
        plt.subplot(3, 3, 4)
        guitar_onset = onset_probs[0, :, guitar_range_start:guitar_range_end].mean(dim=1).detach().numpy()
        guitar_frame = frame_probs[0, :, guitar_range_start:guitar_range_end].mean(dim=1).detach().numpy()
        plt.plot(frames, guitar_onset, label='Onset', alpha=0.7)
        plt.plot(frames, guitar_frame, label='Frame', alpha=0.7)
        plt.title('Guitar Range Activity (E2-E6)')
        plt.xlabel('Time Frame')
        plt.ylabel('Average Probability')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot pitch activation distribution
        plt.subplot(3, 3, 5)
        pitch_activation = onset_probs[0].mean(dim=0).detach().numpy()
        piano_keys = np.arange(88)
        plt.plot(piano_keys, pitch_activation, alpha=0.7)
        plt.axvspan(guitar_range_start, guitar_range_end, alpha=0.2, color='cyan', label='Guitar Range')
        plt.title('Average Pitch Activation')
        plt.xlabel('Piano Key Index')
        plt.ylabel('Average Onset Probability')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot detection statistics
        plt.subplot(3, 3, 6)
        detection_stats = {
            'Onset Det.': onset_detections.item(),
            'Frame Det.': frame_detections.item(),
            'High Conf.': high_confidence.item()
        }
        plt.bar(detection_stats.keys(), detection_stats.values(), alpha=0.7)
        plt.title('Detection Statistics')
        plt.ylabel('Count')
        plt.grid(True, alpha=0.3)
        
        # Plot onset vs frame correlation
        plt.subplot(3, 3, 7)
        onset_flat = onset_probs.flatten().detach().numpy()
        frame_flat = frame_probs.flatten().detach().numpy()
        plt.scatter(onset_flat, frame_flat, alpha=0.1, s=1)
        plt.plot([0, 1], [0, 1], 'r--', alpha=0.5)
        plt.title('Onset vs Frame Predictions')
        plt.xlabel('Onset Probability')
        plt.ylabel('Frame Probability')
        plt.grid(True, alpha=0.3)
        
        # Plot prediction distributions
        plt.subplot(3, 3, 8)
        plt.hist(onset_flat, bins=50, alpha=0.5, label='Onset', density=True)
        plt.hist(frame_flat, bins=50, alpha=0.5, label='Frame', density=True)
        plt.axvline(onset_threshold, color='blue', linestyle='--', alpha=0.7, label='Thresholds')
        plt.axvline(frame_threshold, color='orange', linestyle='--', alpha=0.7)
        plt.title('Prediction Distributions')
        plt.xlabel('Probability')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot strongest activations
        plt.subplot(3, 3, 9)
        max_onset_per_frame = onset_probs[0].max(dim=1)[0].detach().numpy()
        max_frame_per_frame = frame_probs[0].max(dim=1)[0].detach().numpy()
        plt.plot(frames, max_onset_per_frame, label='Max Onset', alpha=0.7)
        plt.plot(frames, max_frame_per_frame, label='Max Frame', alpha=0.7)
        plt.title('Strongest Activations Over Time')
        plt.xlabel('Time Frame')
        plt.ylabel('Max Probability')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"‚ùå Decoder processing failed: {e}")
        onset_probs = None
        frame_probs = None
        confidence = None
else:
    onset_probs = None
    frame_probs = None
    confidence = None
    print("‚ö†Ô∏è Skipping decoder - using random predictions for testing")
    onset_probs = torch.sigmoid(torch.randn(1, 129, 88))
    frame_probs = torch.sigmoid(torch.randn(1, 129, 88))
    confidence = torch.sigmoid(torch.randn(1, 129))
    print(f"   Random onset probs shape: {onset_probs.shape}")
    print(f"   Random frame probs shape: {frame_probs.shape}")
    print(f"   Random confidence shape: {confidence.shape}")

## Pipeline Summary and Analysis

Complete analysis of the entire pipeline performance.

In [None]:
print("=" * 60)
print("PIPELINE SUMMARY AND ANALYSIS")
print("=" * 60)

# Collect pipeline statistics
pipeline_stats = {
    'input_audio': {
        'shape': audio_batch.shape,
        'duration_sec': audio_length,
        'sample_rate': sample_rate,
        'status': '‚úÖ'
    },
    'basic_pitch': {
        'shape': pitch_features.shape if pitch_features is not None else 'N/A',
        'output_dim': 440 if pitch_features is not None else 'N/A',
        'status': '‚úÖ' if pitch_features is not None else '‚ùå'
    },
    'encodec': {
        'embeddings_shape': encodec_embeddings.shape if encodec_embeddings is not None else 'N/A',
        'codes_shape': encodec_codes.shape if encodec_codes is not None else 'N/A',
        'status': '‚úÖ' if encodec_embeddings is not None else '‚ùå'
    },
    'vq_vae': {
        'embeddings_shape': vq_embeddings.shape if vq_embeddings is not None else 'N/A',
        'indices_shape': vq_indices.shape if vq_indices is not None else 'N/A',
        'commitment_loss': commitment_loss.item() if commitment_loss is not None else 'N/A',
        'status': '‚úÖ' if vq_embeddings is not None else '‚ùå'
    },
    'clap': {
        'shape': clap_embeddings.shape if clap_embeddings is not None else 'N/A',
        'model': 'laion/larger_clap_music' if clap_embeddings is not None else 'N/A',
        'status': '‚úÖ' if clap_embeddings is not None else '‚ùå'
    },
    'fusion': {
        'shape': fused_embeddings.shape if fused_embeddings is not None else 'N/A',
        'target_dim': 768,
        'status': '‚úÖ' if fused_embeddings is not None else '‚ùå'
    },
    'decoder': {
        'onset_shape': onset_probs.shape if onset_probs is not None else 'N/A',
        'frame_shape': frame_probs.shape if frame_probs is not None else 'N/A',
        'status': '‚úÖ' if onset_probs is not None else '‚ùå'
    }
}

print("\nüìä COMPONENT STATUS:")
for component, stats in pipeline_stats.items():
    status = stats.pop('status')
    print(f"\n{status} {component.upper()}:")
    for key, value in stats.items():
        print(f"   {key}: {value}")

# Overall pipeline health
successful_components = sum(1 for stats in pipeline_stats.values() if '‚úÖ' in str(stats))
total_components = len(pipeline_stats)
pipeline_health = successful_components / total_components

print(f"\nüè• PIPELINE HEALTH: {pipeline_health:.1%} ({successful_components}/{total_components} components working)")

# Performance analysis
if onset_probs is not None and frame_probs is not None:
    print("\nüéµ TRANSCRIPTION ANALYSIS:")
    
    # Note detection with thresholds
    onset_threshold = 0.3
    frame_threshold = 0.3
    
    onset_detections = (onset_probs > onset_threshold)
    frame_detections = (frame_probs > frame_threshold)
    
    # Simple peak picking for notes
    detected_notes = []
    for frame_idx in range(onset_probs.shape[1]):
        frame_onsets = onset_detections[0, frame_idx]
        if frame_onsets.any():
            active_pitches = torch.where(frame_onsets)[0]
            for pitch in active_pitches:
                # Convert piano key to MIDI note
                midi_note = pitch.item() + 21  # A0 = 21
                onset_prob = onset_probs[0, frame_idx, pitch].item()
                frame_prob = frame_probs[0, frame_idx, pitch].item()
                detected_notes.append({
                    'frame': frame_idx,
                    'time_sec': frame_idx * 0.023,  # ~23ms per frame
                    'piano_key': pitch.item(),
                    'midi_note': midi_note,
                    'onset_prob': onset_prob,
                    'frame_prob': frame_prob
                })
    
    print(f"   Notes detected (threshold {onset_threshold}): {len(detected_notes)}")
    
    if detected_notes:
        print(f"   Time range: {detected_notes[0]['time_sec']:.2f}s - {detected_notes[-1]['time_sec']:.2f}s")
        
        # Show top 10 strongest detections
        top_notes = sorted(detected_notes, key=lambda x: x['onset_prob'], reverse=True)[:10]
        print(f"\n   Top 10 strongest detections:")
        for i, note in enumerate(top_notes):
            print(f"     {i+1}. MIDI {note['midi_note']} at {note['time_sec']:.2f}s (onset: {note['onset_prob']:.3f})")
        
        # Guitar range analysis
        guitar_notes = [n for n in detected_notes if 40 <= n['midi_note'] <= 88]  # E2 to E6
        print(f"\n   Notes in guitar range (E2-E6): {len(guitar_notes)}/{len(detected_notes)}")
    else:
        print(f"   ‚ö†Ô∏è No notes detected - try lowering threshold or check embeddings")

# Memory usage estimation
print("\nüíæ MEMORY USAGE ESTIMATION:")
total_memory_kb = 0

if audio_batch is not None:
    audio_kb = audio_batch.numel() * 4 / 1024  # 4 bytes per float32
    total_memory_kb += audio_kb
    print(f"   Input audio: {audio_kb:.1f} KB")

if pitch_features is not None:
    pitch_kb = pitch_features.numel() * 4 / 1024
    total_memory_kb += pitch_kb
    print(f"   Basic Pitch features: {pitch_kb:.1f} KB")

if encodec_embeddings is not None:
    encodec_kb = encodec_embeddings.numel() * 4 / 1024
    total_memory_kb += encodec_kb
    print(f"   Encodec embeddings: {encodec_kb:.1f} KB")

if vq_embeddings is not None:
    vq_kb = vq_embeddings.numel() * 4 / 1024
    total_memory_kb += vq_kb
    print(f"   VQ-VAE embeddings: {vq_kb:.1f} KB")

if clap_embeddings is not None:
    clap_kb = clap_embeddings.numel() * 4 / 1024
    total_memory_kb += clap_kb
    print(f"   CLAP embeddings: {clap_kb:.1f} KB")

if fused_embeddings is not None:
    fused_kb = fused_embeddings.numel() * 4 / 1024
    total_memory_kb += fused_kb
    print(f"   Fused embeddings: {fused_kb:.1f} KB")

if onset_probs is not None:
    predictions_kb = (onset_probs.numel() + frame_probs.numel() + confidence.numel()) * 4 / 1024
    total_memory_kb += predictions_kb
    print(f"   Decoder predictions: {predictions_kb:.1f} KB")

print(f"\n   Total memory usage: {total_memory_kb:.1f} KB ({total_memory_kb/1024:.2f} MB)")
print(f"   Memory per second: {total_memory_kb/audio_length:.1f} KB/sec")

# Next steps recommendations
print("\nüéØ NEXT STEPS RECOMMENDATIONS:")

if pipeline_health == 1.0:
    print("   ‚úÖ All components working! Ready for training or real audio testing.")
    print("   ‚Ä¢ Try with real guitar audio files")
    print("   ‚Ä¢ Train the decoder on GuitarSet dataset")
    print("   ‚Ä¢ Implement note post-processing (frame-to-note)")
    print("   ‚Ä¢ Add tab assignment and technique detection")
elif pipeline_health >= 0.7:
    print("   üü° Most components working. Fix remaining issues:")
    failed_components = [comp for comp, stats in pipeline_stats.items() if '‚ùå' in str(stats)]
    for comp in failed_components:
        print(f"   ‚Ä¢ Fix {comp} component")
else:
    print("   üî¥ Multiple components failing. Priority fixes:")
    print("   ‚Ä¢ Check pre-trained model downloads")
    print("   ‚Ä¢ Verify dependencies (transformers, basic-pitch, encodec)")
    print("   ‚Ä¢ Test components individually")

if onset_probs is not None and len(detected_notes) == 0:
    print("\n   üéµ Transcription improvements:")
    print("   ‚Ä¢ Lower detection thresholds")
    print("   ‚Ä¢ Train decoder on labeled data")
    print("   ‚Ä¢ Check if embeddings contain musical information")

print("\n‚úÖ Pipeline debugging complete!")

## Export Results for Further Analysis

Save intermediate results for later analysis or debugging.

In [None]:
# Optionally save results
save_results = False  # Set to True if you want to save

if save_results:
    print("üíæ Saving pipeline results...")
    
    results_dict = {
        'pipeline_stats': pipeline_stats,
        'audio_shape': audio_batch.shape,
        'detected_notes': detected_notes if 'detected_notes' in locals() else [],
        'pipeline_health': pipeline_health
    }
    
    # Save embeddings if available
    if fused_embeddings is not None:
        results_dict['fused_embeddings'] = fused_embeddings.detach().numpy()
    
    if onset_probs is not None:
        results_dict['onset_probs'] = onset_probs.detach().numpy()
        results_dict['frame_probs'] = frame_probs.detach().numpy()
    
    # Save to numpy file
    import os
    os.makedirs('../results', exist_ok=True)
    np.savez('../results/pipeline_debug_results.npz', **results_dict)
    print("   Results saved to ../results/pipeline_debug_results.npz")
else:
    print("üíæ Results not saved (set save_results=True to save)")

print("\nüéâ Debugging notebook complete!")
print("\nUse this notebook to:")
print("‚Ä¢ Test individual pipeline components")
print("‚Ä¢ Analyze intermediate representations")
print("‚Ä¢ Debug dimension mismatches")
print("‚Ä¢ Visualize embedding quality")
print("‚Ä¢ Monitor transcription performance")