# SD3 Prompt Switching Artifact Exploration

This notebook provides interactive tools for exploring artifacts in SD3 during mid-generation prompt switches.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from IPython.display import display, Image as IPImage
import ipywidgets as widgets
from src.analysis.artifact_detector import ArtifactDetector

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

## 1. Load and Explore Single Experiment

In [None]:
# List available experiments
artifact_dirs = sorted(Path('../data/artifacts').glob('*'))
print(f"Found {len(artifact_dirs)} experiments:")
for i, dir in enumerate(artifact_dirs[-10:]):  # Show last 10
    print(f"{i}: {dir.name}")

In [None]:
# Load a specific experiment
exp_idx = -1  # Use most recent
artifact_dir = artifact_dirs[exp_idx]

# Initialize detector
detector = ArtifactDetector(str(artifact_dir))

# Show metadata
print("Experiment metadata:")
for k, v in detector.metadata.items():
    print(f"  {k}: {v}")

# Display output image
display(IPImage(str(artifact_dir / 'output.png'), width=400))

## 2. Interactive Latent Visualization

In [None]:
def visualize_latent_evolution(detector, channel=0):
    """Visualize how a specific channel evolves over time."""
    steps = []
    latents_vis = []
    
    for key in sorted(detector.latents.keys()):
        if "latents_step_" in key:
            step = int(key.split("_")[-1])
            latent = detector.latents[key]
            steps.append(step)
            latents_vis.append(latent[0, channel].numpy())
    
    # Create grid visualization
    n_steps = len(steps)
    cols = min(6, n_steps)
    rows = (n_steps + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten() if n_steps > 1 else [axes]
    
    vmin = min(l.min() for l in latents_vis)
    vmax = max(l.max() for l in latents_vis)
    
    for idx, (step, latent) in enumerate(zip(steps, latents_vis)):
        if idx < len(axes):
            im = axes[idx].imshow(latent, cmap='RdBu_r', vmin=vmin, vmax=vmax)
            axes[idx].set_title(f'Step {step}')
            axes[idx].axis('off')
            
            # Mark switch step
            if step == detector.switch_step:
                axes[idx].set_title(f'Step {step} (SWITCH)', color='red')
    
    # Hide unused subplots
    for idx in range(len(steps), len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle(f'Latent Channel {channel} Evolution', fontsize=16)
    plt.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

# Interactive widget
channel_slider = widgets.IntSlider(
    value=0, min=0, max=15, step=1,
    description='Channel:'
)
widgets.interactive(visualize_latent_evolution, detector=widgets.fixed(detector), channel=channel_slider)

## 3. Artifact Detection Results

In [None]:
# Run artifact detection
discontinuities = detector.detect_discontinuities(threshold_multiplier=2.0)

print("Detected discontinuities:")
for artifact_type, detections in discontinuities.items():
    if detections:
        print(f"\n{artifact_type}:")
        for step, value in detections:
            distance = step - detector.switch_step
            print(f"  Step {step} (switch{distance:+d}): {value:.4f}")

In [None]:
# Visualize artifacts
plots = detector.visualize_artifacts()
print(f"Generated {len(plots)} visualizations")

# Display them
for name, path in plots.items():
    print(f"\n{name}:")
    display(IPImage(str(path), width=800))

## 4. Attention Map Analysis

In [None]:
# Check if attention maps were captured
attention_maps = torch.load(artifact_dir / 'attention_maps.pt') if (artifact_dir / 'attention_maps.pt').exists() else {}

if attention_maps:
    print(f"Found {len(attention_maps)} attention maps")
    
    # Analyze attention pattern changes
    def plot_attention_entropy():
        entropies = {}
        
        for key, attn in attention_maps.items():
            if "_step_" in key:
                step = int(key.split("_")[-1])
                # Compute entropy of attention distribution
                attn_probs = torch.softmax(attn.flatten(), dim=0)
                entropy = -torch.sum(attn_probs * torch.log(attn_probs + 1e-10))
                entropies[step] = float(entropy)
        
        if entropies:
            steps = sorted(entropies.keys())
            entropy_values = [entropies[s] for s in steps]
            
            plt.figure(figsize=(10, 6))
            plt.plot(steps, entropy_values, 'b-', marker='o')
            plt.axvline(x=detector.switch_step, color='r', linestyle='--', label='Prompt Switch')
            plt.xlabel('Step')
            plt.ylabel('Attention Entropy')
            plt.title('Attention Distribution Entropy Over Time')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()
    
    plot_attention_entropy()
else:
    print("No attention maps found in this experiment")

## 5. Cross-Experiment Analysis

In [None]:
# Compare multiple experiments
from src.analysis.artifact_detector import compare_experiments

# Select recent experiments to compare
compare_dirs = [str(d) for d in artifact_dirs[-5:]]
comparisons = compare_experiments(compare_dirs)

# Visualize comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Discontinuity counts
exp_names = [Path(d).name[-6:] for d in compare_dirs]
counts = [comparisons['discontinuity_counts'].get(d, 0) for d in compare_dirs]

ax1.bar(exp_names, counts, alpha=0.7)
ax1.set_xlabel('Experiment')
ax1.set_ylabel('Total Discontinuities')
ax1.set_title('Artifact Count Comparison')
ax1.tick_params(axis='x', rotation=45)

# Average discontinuity magnitude
avg_disc = [comparisons['avg_switch_discontinuity'].get(d, 0) for d in compare_dirs]

ax2.bar(exp_names, avg_disc, alpha=0.7, color='orange')
ax2.set_xlabel('Experiment')
ax2.set_ylabel('Avg Discontinuity')
ax2.set_title('Average Discontinuity Magnitude')
ax2.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

## 6. Pattern Identification

In [None]:
# Look for patterns in when artifacts occur
def analyze_temporal_patterns(detector):
    """Analyze when artifacts tend to occur relative to switch."""
    discontinuities = detector.detect_discontinuities()
    
    # Collect all artifact timings relative to switch
    relative_timings = []
    
    for artifact_type, detections in discontinuities.items():
        for step, value in detections:
            relative_timing = step - detector.switch_step
            relative_timings.append(relative_timing)
    
    if relative_timings:
        plt.figure(figsize=(10, 6))
        plt.hist(relative_timings, bins=20, alpha=0.7, edgecolor='black')
        plt.axvline(x=0, color='r', linestyle='--', label='Switch Point')
        plt.xlabel('Steps Relative to Switch')
        plt.ylabel('Artifact Count')
        plt.title('Temporal Distribution of Artifacts')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        # Statistics
        print(f"Mean timing: {np.mean(relative_timings):.2f} steps from switch")
        print(f"Std dev: {np.std(relative_timings):.2f} steps")
        print(f"Most artifacts within {np.percentile(np.abs(relative_timings), 90):.0f} steps of switch")

analyze_temporal_patterns(detector)

## 7. Generate Analysis Report

In [None]:
# Generate comprehensive report
report = detector.generate_report()
print(report)

# Save report
report_path = artifact_dir / "analysis_report.md"
with open(report_path, "w") as f:
    f.write(report)
print(f"\nReport saved to: {report_path}")