# Spectral Analysis of SD3 Prompt Switching

This notebook provides detailed spectral analysis tools for understanding frequency-domain artifacts during prompt switching.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from scipy import signal
from IPython.display import display, Image as IPImage
import ipywidgets as widgets
from src.analysis.spectral_analyzer import SpectralAnalyzer
import seaborn as sns

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")

## 1. Load Experiment and Initialize Analyzer

In [None]:
# Find most recent experiment
artifact_dirs = sorted(Path('../data/artifacts').glob('*'))
if not artifact_dirs:
    print("No experiments found. Run some experiments first!")
else:
    artifact_dir = artifact_dirs[-1]  # Most recent
    print(f"Loading experiment: {artifact_dir.name}")
    
    # Initialize spectral analyzer
    analyzer = SpectralAnalyzer(str(artifact_dir))
    
    # Display metadata
    print("\nExperiment details:")
    for k, v in analyzer.metadata.items():
        if k not in ['timestamp']:
            print(f"  {k}: {v}")

## 2. Spectral Evolution Analysis

In [None]:
# Analyze spectral evolution
evolution = analyzer.analyze_spectral_evolution()

# Create detailed plot
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. Stacked area plot of frequency bands
ax = axes[0, 0]
steps = evolution['steps']
y_low = np.array(evolution['low'])
y_mid = np.array(evolution['mid'])
y_high = np.array(evolution['high'])

ax.fill_between(steps, 0, y_low, alpha=0.7, label='Low (0-10%)')
ax.fill_between(steps, y_low, y_low + y_mid, alpha=0.7, label='Mid (10-30%)')
ax.fill_between(steps, y_low + y_mid, y_low + y_mid + y_high, alpha=0.7, label='High (30-50%)')
ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2, label='Switch')
ax.set_xlabel('Step')
ax.set_ylabel('Relative Power')
ax.set_title('Frequency Band Distribution Over Time')
ax.legend()
ax.set_ylim(0, 1)

# 2. Individual frequency bands
ax = axes[0, 1]
ax.plot(steps, evolution['low'], 'b-', marker='o', markersize=4, label='Low freq')
ax.plot(steps, evolution['mid'], 'g-', marker='s', markersize=4, label='Mid freq')
ax.plot(steps, evolution['high'], 'r-', marker='^', markersize=4, label='High freq')
ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Relative Power')
ax.set_title('Frequency Bands Evolution')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Total spectral power
ax = axes[1, 0]
total_power = evolution['total_power']
ax.semilogy(steps, total_power, 'k-', marker='o', markersize=4)
ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Total Power (log scale)')
ax.set_title('Total Spectral Power Evolution')
ax.grid(True, alpha=0.3)

# 4. Rate of change
ax = axes[1, 1]
if len(total_power) > 1:
    power_change = np.abs(np.diff(total_power)) / np.array(total_power[:-1])
    ax.plot(steps[1:], power_change, 'm-', marker='o', markersize=4)
    ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Step')
    ax.set_ylabel('Relative Change')
    ax.set_title('Rate of Spectral Power Change')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("\nFrequency distribution statistics:")
for band in ['low', 'mid', 'high']:
    values = evolution[band]
    pre_switch = [v for s, v in zip(steps, values) if s < analyzer.switch_step]
    post_switch = [v for s, v in zip(steps, values) if s >= analyzer.switch_step]
    
    if pre_switch and post_switch:
        print(f"\n{band.title()} frequency band:")
        print(f"  Pre-switch mean: {np.mean(pre_switch):.4f} ± {np.std(pre_switch):.4f}")
        print(f"  Post-switch mean: {np.mean(post_switch):.4f} ± {np.std(post_switch):.4f}")
        print(f"  Change: {np.mean(post_switch) - np.mean(pre_switch):+.4f}")

## 3. Spectral Coherence Analysis

In [None]:
# Compute and visualize spectral coherence
coherence = analyzer.compute_spectral_coherence()

if coherence['steps']:
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    # Coherence over time
    steps = coherence['steps']
    mean_coh = coherence['mean_coherence']
    
    ax1.plot(steps, mean_coh, 'b-', marker='o', markersize=6, linewidth=2)
    ax1.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2, label='Switch')
    ax1.axhline(y=np.mean(mean_coh), color='gray', linestyle=':', label=f'Mean: {np.mean(mean_coh):.3f}')
    ax1.set_xlabel('Step')
    ax1.set_ylabel('Mean Spectral Coherence')
    ax1.set_title('Step-to-Step Spectral Coherence Evolution')
    ax1.set_ylim(0, 1)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Highlight minimum coherence
    min_idx = np.argmin(mean_coh)
    ax1.scatter([steps[min_idx]], [mean_coh[min_idx]], color='red', s=100, zorder=5)
    ax1.annotate(f'Min: {mean_coh[min_idx]:.3f}\nStep {steps[min_idx]}', 
                xy=(steps[min_idx], mean_coh[min_idx]),
                xytext=(10, 20), textcoords='offset points',
                bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
    
    # Coherence change rate
    if len(mean_coh) > 1:
        coh_change = np.diff(mean_coh)
        ax2.bar(steps[1:], coh_change, alpha=0.7, 
               color=['red' if c < 0 else 'green' for c in coh_change])
        ax2.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2)
        ax2.set_xlabel('Step')
        ax2.set_ylabel('Coherence Change')
        ax2.set_title('Step-wise Coherence Changes')
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print coherence statistics
    print("\nCoherence Statistics:")
    print(f"Mean coherence: {np.mean(mean_coh):.4f}")
    print(f"Std deviation: {np.std(mean_coh):.4f}")
    print(f"Minimum at step {steps[min_idx]}: {mean_coh[min_idx]:.4f}")
    
    # Check coherence drop at switch
    switch_window = [(i, s) for i, s in enumerate(steps) if abs(s - analyzer.switch_step) <= 3]
    if switch_window:
        window_coh = [mean_coh[i] for i, _ in switch_window]
        print(f"\nCoherence around switch (±3 steps):")
        print(f"  Mean: {np.mean(window_coh):.4f}")
        print(f"  Drop: {max(window_coh) - min(window_coh):.4f}")

## 4. Spectral Entropy Analysis

In [None]:
# Compute spectral entropy
entropy = analyzer.compute_spectral_entropy()

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

# Entropy evolution
steps = entropy['steps']
norm_entropy = entropy['normalized_entropy']

ax1.plot(steps, norm_entropy, 'purple', marker='o', markersize=6, linewidth=2)
ax1.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2, label='Switch')
ax1.fill_between(steps, norm_entropy, alpha=0.3, color='purple')
ax1.set_xlabel('Step')
ax1.set_ylabel('Normalized Spectral Entropy')
ax1.set_title('Spectral Entropy Evolution (0=concentrated, 1=uniform)')
ax1.set_ylim(0, 1)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Entropy derivative (rate of change)
if len(norm_entropy) > 1:
    entropy_diff = np.diff(norm_entropy)
    ax2.plot(steps[1:], entropy_diff, 'darkviolet', marker='s', markersize=4)
    ax2.axvline(x=analyzer.switch_step, color='red', linestyle='--', linewidth=2)
    ax2.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
    ax2.set_xlabel('Step')
    ax2.set_ylabel('Entropy Change')
    ax2.set_title('Rate of Entropy Change')
    ax2.grid(True, alpha=0.3)
    
    # Highlight large changes
    threshold = 2 * np.std(np.abs(entropy_diff))
    large_changes = [(s, d) for s, d in zip(steps[1:], entropy_diff) if abs(d) > threshold]
    if large_changes:
        for s, d in large_changes:
            ax2.scatter([s], [d], color='red', s=100, zorder=5)

plt.tight_layout()
plt.show()

# Statistics
print("\nSpectral Entropy Statistics:")
print(f"Mean entropy: {np.mean(norm_entropy):.4f}")
print(f"Std deviation: {np.std(norm_entropy):.4f}")
print(f"Range: [{np.min(norm_entropy):.4f}, {np.max(norm_entropy):.4f}]")

# Entropy at switch
if analyzer.switch_step in steps:
    switch_idx = steps.index(analyzer.switch_step)
    print(f"\nEntropy at switch: {norm_entropy[switch_idx]:.4f}")
    
    # Compare to local average
    window = 3
    local_indices = [i for i in range(len(steps)) 
                    if abs(i - switch_idx) <= window and i != switch_idx]
    if local_indices:
        local_mean = np.mean([norm_entropy[i] for i in local_indices])
        print(f"Local mean (±{window} steps): {local_mean:.4f}")
        print(f"Deviation: {norm_entropy[switch_idx] - local_mean:+.4f}")

## 5. Detailed Frequency Analysis at Key Steps

In [None]:
# Analyze frequency content at specific steps
focus_steps = [
    analyzer.switch_step - 5,
    analyzer.switch_step,
    analyzer.switch_step + 5
]

freq_analysis = analyzer.analyze_frequency_domain_artifacts(focus_steps)

# Visualize radial power spectra
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for idx, (step_key, data) in enumerate(freq_analysis.items()):
    ax = axes[idx]
    step = int(step_key.split('_')[1])
    
    # Plot radial profiles for first few channels
    for c in range(min(4, len(data['spatial_frequencies']))):
        channel_data = data['spatial_frequencies'][f'channel_{c}']
        radial_prof = channel_data['radial_profile']
        radial_std = channel_data['radial_std']
        
        frequencies = np.arange(len(radial_prof))
        ax.semilogy(frequencies, radial_prof, label=f'Channel {c}', linewidth=2)
        
        # Add error bands
        ax.fill_between(frequencies, 
                       np.array(radial_prof) - np.array(radial_std),
                       np.array(radial_prof) + np.array(radial_std),
                       alpha=0.2)
    
    ax.set_xlabel('Spatial Frequency (cycles/image)')
    ax.set_ylabel('Power (log scale)')
    ax.set_title(f'Step {step}' + (' (SWITCH)' if step == analyzer.switch_step else ''))
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Mark Nyquist frequency
    nyquist = len(radial_prof) // 2
    ax.axvline(x=nyquist, color='red', linestyle=':', alpha=0.5, label='Nyquist')

plt.suptitle('Radial Power Spectra at Key Steps', fontsize=16)
plt.tight_layout()
plt.show()

# Analyze dominant frequencies
print("\nDominant Frequencies Analysis:")
for step_key, data in freq_analysis.items():
    step = int(step_key.split('_')[1])
    print(f"\nStep {step}:")
    
    # Get top frequencies across all channels
    all_freqs = []
    for c in range(min(4, len(data['spatial_frequencies']))):
        channel_data = data['spatial_frequencies'][f'channel_{c}']
        for freq_info in channel_data['dominant_freqs'][:3]:  # Top 3
            all_freqs.append((c, freq_info['freq_r'], freq_info['magnitude']))
    
    # Sort by magnitude
    all_freqs.sort(key=lambda x: x[2], reverse=True)
    
    print("  Top frequencies:")
    for c, freq_r, mag in all_freqs[:5]:
        print(f"    Channel {c}: r={freq_r:.3f}, magnitude={mag:.2e}")

## 6. Spectral Discontinuity Detection

In [None]:
# Detect spectral discontinuities
discontinuities = analyzer.detect_frequency_discontinuities(threshold_multiplier=2.0)

# Visualize discontinuities on timeline
fig, ax = plt.subplots(figsize=(14, 8))

# Create timeline
all_steps = set()
for disc_list in discontinuities.values():
    for step, _ in disc_list:
        all_steps.add(step)

if all_steps:
    # Plot discontinuities by type
    disc_types = list(discontinuities.keys())
    colors = plt.cm.tab10(np.linspace(0, 1, len(disc_types)))
    
    y_positions = {}
    for i, (disc_type, disc_list) in enumerate(discontinuities.items()):
        if disc_list:
            y_pos = i
            y_positions[disc_type] = y_pos
            
            steps = [d[0] for d in disc_list]
            values = [d[1] for d in disc_list]
            
            # Normalize values for visualization
            max_val = max(values) if values else 1
            norm_values = [v / max_val * 100 for v in values]
            
            ax.scatter(steps, [y_pos] * len(steps), 
                      s=norm_values, 
                      c=[colors[i]] * len(steps),
                      alpha=0.6,
                      label=disc_type.replace('_', ' ').title())
    
    # Add switch line
    ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', 
              linewidth=2, label='Prompt Switch')
    
    # Formatting
    ax.set_xlabel('Step', fontsize=12)
    ax.set_yticks(list(range(len(disc_types))))
    ax.set_yticklabels([d.replace('_', ' ').title() for d in disc_types])
    ax.set_title('Spectral Discontinuities Timeline', fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add background shading
    ax.axvspan(analyzer.switch_step - 2, analyzer.switch_step + 2, 
              alpha=0.1, color='red', label='Switch window')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\nSpectral Discontinuity Summary:")
    for disc_type, disc_list in discontinuities.items():
        if disc_list:
            print(f"\n{disc_type.replace('_', ' ').title()}:")
            print(f"  Count: {len(disc_list)}")
            
            # Distance from switch statistics
            distances = [abs(step - analyzer.switch_step) for step, _ in disc_list]
            print(f"  Mean distance from switch: {np.mean(distances):.1f} steps")
            print(f"  Closest to switch: {min(distances)} steps")
else:
    print("No spectral discontinuities detected.")

## 7. Channel-wise Spectral Analysis

In [None]:
# Interactive channel-wise spectral analysis
def plot_channel_spectrum(channel=0):
    """Plot spectral evolution for a specific channel."""
    evolution = analyzer.analyze_spectral_evolution(channel=channel)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    steps = evolution['steps']
    
    # 1. Frequency band evolution
    ax = axes[0, 0]
    ax.plot(steps, evolution['low'], 'b-', marker='o', label='Low')
    ax.plot(steps, evolution['mid'], 'g-', marker='s', label='Mid')
    ax.plot(steps, evolution['high'], 'r-', marker='^', label='High')
    ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', alpha=0.7)
    ax.set_title(f'Channel {channel}: Frequency Bands')
    ax.set_xlabel('Step')
    ax.set_ylabel('Relative Power')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. Total power
    ax = axes[0, 1]
    ax.semilogy(steps, evolution['total_power'], 'k-', marker='o')
    ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', alpha=0.7)
    ax.set_title(f'Channel {channel}: Total Power')
    ax.set_xlabel('Step')
    ax.set_ylabel('Power (log scale)')
    ax.grid(True, alpha=0.3)
    
    # 3. High/Low frequency ratio
    ax = axes[1, 0]
    high_low_ratio = np.array(evolution['high']) / (np.array(evolution['low']) + 1e-10)
    ax.plot(steps, high_low_ratio, 'm-', marker='o')
    ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', alpha=0.7)
    ax.set_title(f'Channel {channel}: High/Low Frequency Ratio')
    ax.set_xlabel('Step')
    ax.set_ylabel('Ratio')
    ax.grid(True, alpha=0.3)
    
    # 4. Spectral centroid
    ax = axes[1, 1]
    # Approximate spectral centroid
    centroid = (0.05 * np.array(evolution['low']) + 
                0.2 * np.array(evolution['mid']) + 
                0.4 * np.array(evolution['high']))
    ax.plot(steps, centroid, 'orange', marker='o')
    ax.axvline(x=analyzer.switch_step, color='red', linestyle='--', alpha=0.7)
    ax.set_title(f'Channel {channel}: Spectral Centroid')
    ax.set_xlabel('Step')
    ax.set_ylabel('Normalized Frequency')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Create interactive widget
channel_slider = widgets.IntSlider(
    value=0, 
    min=0, 
    max=15,  # SD3 has 16 channels
    step=1,
    description='Channel:',
    continuous_update=False
)

widgets.interactive(plot_channel_spectrum, channel=channel_slider)

## 8. Generate Spectral Analysis Report

In [None]:
# Generate comprehensive spectral report
report = analyzer.generate_spectral_report()
print(report)

# Save report
report_path = Path(artifact_dir) / "spectral_analysis_report.md"
with open(report_path, "w") as f:
    f.write(report)
print(f"\nReport saved to: {report_path}")

# Generate and save visualizations
print("\nGenerating spectral visualizations...")
plots = analyzer.visualize_spectral_analysis()
print(f"Generated {len(plots)} visualizations:")
for name, path in plots.items():
    print(f"  - {name}: {path.name}")

## 9. Cross-Experiment Spectral Comparison

In [None]:
# Compare spectral properties across multiple experiments
from src.analysis.spectral_analyzer import compare_spectral_properties

# Get recent experiments
recent_dirs = [str(d) for d in artifact_dirs[-5:]]  # Last 5 experiments

if len(recent_dirs) > 1:
    comparisons = compare_spectral_properties(recent_dirs)
    
    # Visualize comparisons
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    exp_names = [Path(d).name[-6:] for d in recent_dirs]
    
    # 1. Entropy at switch
    ax = axes[0, 0]
    entropy_values = [comparisons['entropy_at_switch'].get(d, 0) for d in recent_dirs]
    ax.bar(exp_names, entropy_values, alpha=0.7, color='purple')
    ax.set_ylabel('Normalized Entropy')
    ax.set_title('Spectral Entropy at Switch')
    ax.tick_params(axis='x', rotation=45)
    ax.set_ylim(0, 1)
    
    # 2. Coherence drop
    ax = axes[0, 1]
    coh_drops = [comparisons['coherence_drop'].get(d, 0) for d in recent_dirs]
    ax.bar(exp_names, coh_drops, alpha=0.7, color='green')
    ax.set_ylabel('Coherence Drop')
    ax.set_title('Maximum Coherence Drop')
    ax.tick_params(axis='x', rotation=45)
    
    # 3. High frequency increase
    ax = axes[1, 0]
    hf_changes = [comparisons['high_freq_increase'].get(d, 0) for d in recent_dirs]
    colors = ['red' if c < 0 else 'blue' for c in hf_changes]
    ax.bar(exp_names, hf_changes, alpha=0.7, color=colors)
    ax.set_ylabel('Relative Change')
    ax.set_title('High Frequency Content Change')
    ax.tick_params(axis='x', rotation=45)
    ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
    
    # 4. Discontinuity counts
    ax = axes[1, 1]
    disc_counts = [comparisons['spectral_discontinuities'].get(d, 0) for d in recent_dirs]
    ax.bar(exp_names, disc_counts, alpha=0.7, color='orange')
    ax.set_ylabel('Count')
    ax.set_title('Total Spectral Discontinuities')
    ax.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nCross-Experiment Statistics:")
    print(f"\nEntropy at switch:")
    print(f"  Mean: {np.mean(entropy_values):.4f}")
    print(f"  Std: {np.std(entropy_values):.4f}")
    
    print(f"\nCoherence drop:")
    print(f"  Mean: {np.mean(coh_drops):.4f}")
    print(f"  Std: {np.std(coh_drops):.4f}")
    
    print(f"\nHigh frequency change:")
    print(f"  Mean: {np.mean(hf_changes):.4f}")
    print(f"  Std: {np.std(hf_changes):.4f}")
else:
    print("Need at least 2 experiments for comparison.")