# Scheduler Experiment: Early Prompt Switching at Step 10

This notebook tests all available schedulers with prompt switching at step 10 to identify which schedulers handle early transitions best.

In [None]:
import modal
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import subprocess
from typing import List, Dict
import datetime

## Define Experiment Parameters

In [None]:
# List of all schedulers to test
SCHEDULERS = [
    "FlowMatchEulerDiscreteScheduler",
    "FlowMatchHeunDiscreteScheduler", 
    "DEISMultistepScheduler",
    "DPMSolverMultistepScheduler",
    "DPMSolverSinglestepScheduler",
    "KDPM2DiscreteScheduler",
    "KDPM2AncestralDiscreteScheduler",
    "EulerDiscreteScheduler",
    "EulerAncestralDiscreteScheduler",
    "HeunDiscreteScheduler",
    "PNDMScheduler",
    "DDIMScheduler",
    "DDPMScheduler",
    "LCMScheduler",
]

# Experiment parameters
PROMPT_1 = "A serene mountain landscape with snow peaks"
PROMPT_2 = "A bustling city street at night with neon lights"
SWITCH_STEP = 10
SEED = 42
NUM_STEPS = 50

## Run Experiments

In [None]:
def run_single_experiment(scheduler: str) -> Dict:
    """Run a single experiment with the given scheduler."""
    cmd = [
        "modal", "run", "src/modal/sd3_modal.py",
        "--prompt-1", PROMPT_1,
        "--prompt-2", PROMPT_2,
        "--switch-step", str(SWITCH_STEP),
        "--seed", str(SEED),
        "--scheduler", scheduler,
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
        
        if result.returncode == 0:
            # Parse output to extract metrics
            output_lines = result.stdout.split('\n')
            pre_hf = post_hf = hf_change = None
            
            for line in output_lines:
                if "Pre-switch high freq:" in line:
                    pre_hf = float(line.split(":")[1].strip())
                elif "Post-switch high freq:" in line:
                    post_hf = float(line.split(":")[1].strip())
                elif "High freq change:" in line:
                    hf_change = float(line.split(":")[1].strip())
            
            return {
                "scheduler": scheduler,
                "status": "success",
                "pre_hf": pre_hf,
                "post_hf": post_hf,
                "hf_change": hf_change,
                "artifact_severity": categorize_artifacts(hf_change),
            }
        else:
            return {
                "scheduler": scheduler,
                "status": "failed",
                "error": result.stderr,
            }
    except Exception as e:
        return {
            "scheduler": scheduler,
            "status": "error",
            "error": str(e),
        }

def categorize_artifacts(hf_change: float) -> str:
    """Categorize artifact severity based on HF change."""
    if hf_change is None:
        return "unknown"
    abs_change = abs(hf_change)
    if abs_change > 0.1:
        return "strong"
    elif abs_change > 0.05:
        return "moderate"
    else:
        return "clean"

In [None]:
# Run experiments for all schedulers
results = []

for i, scheduler in enumerate(SCHEDULERS, 1):
    print(f"[{i}/{len(SCHEDULERS)}] Testing {scheduler}...")
    result = run_single_experiment(scheduler)
    results.append(result)
    
    if result["status"] == "success":
        print(f"  ✓ HF change: {result['hf_change']:+.4f} ({result['artifact_severity']})")
    else:
        print(f"  ✗ Failed: {result.get('error', 'Unknown error')}")

## Analyze Results

In [None]:
# Convert to DataFrame for analysis
df = pd.DataFrame(results)
successful_df = df[df['status'] == 'success'].copy()

# Sort by absolute HF change
successful_df['abs_hf_change'] = successful_df['hf_change'].abs()
successful_df = successful_df.sort_values('abs_hf_change')

print(f"\nExperiment Summary:")
print(f"Total schedulers tested: {len(SCHEDULERS)}")
print(f"Successful runs: {len(successful_df)}")
print(f"Failed runs: {len(df) - len(successful_df)}")

## Visualization

In [None]:
# Create visualizations
if len(successful_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Bar chart of HF changes by scheduler
    ax1 = axes[0, 0]
    colors = successful_df['artifact_severity'].map({
        'clean': 'green',
        'moderate': 'orange',
        'strong': 'red'
    })
    bars = ax1.bar(range(len(successful_df)), successful_df['hf_change'], color=colors)
    ax1.set_xticks(range(len(successful_df)))
    ax1.set_xticklabels(successful_df['scheduler'], rotation=45, ha='right')
    ax1.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax1.axhline(y=0.05, color='orange', linestyle='--', alpha=0.5)
    ax1.axhline(y=-0.05, color='orange', linestyle='--', alpha=0.5)
    ax1.axhline(y=0.1, color='red', linestyle='--', alpha=0.5)
    ax1.axhline(y=-0.1, color='red', linestyle='--', alpha=0.5)
    ax1.set_ylabel('High Frequency Change')
    ax1.set_title('HF Change by Scheduler (Step 10 Switch)')
    ax1.grid(True, alpha=0.3)
    
    # 2. Artifact severity distribution
    ax2 = axes[0, 1]
    severity_counts = successful_df['artifact_severity'].value_counts()
    colors_pie = ['green', 'orange', 'red']
    wedges, texts, autotexts = ax2.pie(severity_counts.values, 
                                         labels=severity_counts.index,
                                         autopct='%1.1f%%',
                                         colors=[colors_pie[i] for i in range(len(severity_counts))],
                                         startangle=90)
    ax2.set_title('Artifact Severity Distribution')
    
    # 3. Pre vs Post HF comparison
    ax3 = axes[1, 0]
    x = range(len(successful_df))
    width = 0.35
    ax3.bar([i - width/2 for i in x], successful_df['pre_hf'], width, label='Pre-switch', alpha=0.7)
    ax3.bar([i + width/2 for i in x], successful_df['post_hf'], width, label='Post-switch', alpha=0.7)
    ax3.set_xticks(x)
    ax3.set_xticklabels(successful_df['scheduler'], rotation=45, ha='right')
    ax3.set_ylabel('High Frequency Power')
    ax3.set_title('Pre vs Post Switch HF Power')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Top recommendations
    ax4 = axes[1, 1]
    ax4.axis('off')
    
    # Best performers
    best = successful_df.head(3)
    worst = successful_df.tail(3)
    
    recommendations = "\n".join([
        "Best Schedulers (Least Artifacts):",
        "",
    ] + [f"  {i+1}. {row['scheduler']}: {row['hf_change']:+.4f}" 
         for i, (_, row) in enumerate(best.iterrows())] + [
        "",
        "Worst Schedulers (Most Artifacts):",
        "",
    ] + [f"  {i+1}. {row['scheduler']}: {row['hf_change']:+.4f}" 
         for i, (_, row) in enumerate(worst.iterrows())])
    
    ax4.text(0.1, 0.9, recommendations, transform=ax4.transAxes,
            verticalalignment='top', fontfamily='monospace', fontsize=10)
    ax4.set_title('Recommendations')
    
    plt.tight_layout()
    plt.show()
    
    # Save results
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    successful_df.to_csv(f"scheduler_experiment_{timestamp}.csv", index=False)
    print(f"\nResults saved to: scheduler_experiment_{timestamp}.csv")

## Detailed Analysis

In [None]:
# Group by severity
if len(successful_df) > 0:
    print("\nDetailed Breakdown by Artifact Severity:\n")
    
    for severity in ['clean', 'moderate', 'strong']:
        group = successful_df[successful_df['artifact_severity'] == severity]
        if len(group) > 0:
            print(f"{severity.upper()} ({len(group)} schedulers):")
            for _, row in group.iterrows():
                print(f"  - {row['scheduler']}: {row['hf_change']:+.4f}")
            print()
    
    # Statistical summary
    print("\nStatistical Summary:")
    print(f"Mean HF change: {successful_df['hf_change'].mean():+.4f}")
    print(f"Std HF change: {successful_df['hf_change'].std():.4f}")
    print(f"Min HF change: {successful_df['hf_change'].min():+.4f} ({successful_df.iloc[successful_df['hf_change'].idxmin()]['scheduler']})")
    print(f"Max HF change: {successful_df['hf_change'].max():+.4f} ({successful_df.iloc[successful_df['hf_change'].idxmax()]['scheduler']})")

## Conclusions and Recommendations

In [None]:
print("\n" + "="*60)
print("CONCLUSIONS")
print("="*60)

if len(successful_df) > 0:
    clean_schedulers = successful_df[successful_df['artifact_severity'] == 'clean']
    
    print(f"\n1. Early prompt switching (step {SWITCH_STEP}) is challenging because:")
    print("   - High noise levels mean large structural changes")
    print("   - The image hasn't formed coherent features yet")
    print("   - Schedulers must handle dramatic latent space transitions")
    
    print(f"\n2. Out of {len(successful_df)} schedulers tested:")
    print(f"   - {len(clean_schedulers)} produced clean transitions")
    print(f"   - {len(successful_df[successful_df['artifact_severity'] == 'moderate'])} had moderate artifacts")
    print(f"   - {len(successful_df[successful_df['artifact_severity'] == 'strong'])} had strong artifacts")
    
    if len(clean_schedulers) > 0:
        print("\n3. RECOMMENDED schedulers for early switching:")
        for _, row in clean_schedulers.iterrows():
            print(f"   ✅ {row['scheduler']}")
    
    print("\n4. Key insights:")
    print("   - Flow-based schedulers tend to handle transitions better")
    print("   - Deterministic schedulers produce more consistent results")
    print("   - Ancestral sampling methods may introduce more artifacts")