# ACBO Trajectory Visualization Example

This notebook demonstrates how to visualize learning trajectories from ACBO experiments.

In [None]:
# Setup
import sys
from pathlib import Path

# Add project root to path
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

# Import required modules
from examples.demo_scms import create_easy_scm
from examples.demo_learning import DemoConfig
from examples.demo_evaluation import extract_trajectory_metrics_from_demo
from examples.complete_workflow_demo import run_progressive_learning_demo_with_scm
from src.causal_bayes_opt.visualization.plots import (
    plot_convergence, plot_target_optimization, 
    plot_structure_learning_dashboard, plot_baseline_comparison
)

import matplotlib.pyplot as plt
%matplotlib inline

## 1. Run a Single Experiment

In [None]:
# Configure experiment
config = DemoConfig(
    n_observational_samples=20,
    n_intervention_steps=15,
    learning_rate=1e-3,
    random_seed=42
)

# Create SCM and run experiment
scm = create_easy_scm()
results = run_progressive_learning_demo_with_scm(scm, config)

# Extract trajectory metrics
trajectory_metrics = extract_trajectory_metrics_from_demo(results)

print(f"Experiment completed!")
print(f"Target: {trajectory_metrics['target_variable']}")
print(f"True parents: {trajectory_metrics['true_parents']}")
print(f"Converged: {trajectory_metrics['converged']}")
print(f"Final F1: {trajectory_metrics['f1_scores'][-1]:.3f}")

## 2. Convergence Visualization

In [None]:
# Plot convergence to true parent set
fig = plot_convergence(
    trajectory_metrics,
    title="Convergence to True Parent Set",
    show_f1=True,
    show_uncertainty=True
)
plt.show()

## 3. Target Optimization

In [None]:
# Plot target value optimization
fig = plot_target_optimization(
    trajectory_metrics,
    title="Target Variable Optimization"
)
plt.show()

## 4. Structure Learning Dashboard

In [None]:
# Create comprehensive dashboard
fig = plot_structure_learning_dashboard(
    trajectory_metrics,
    title="Structure Learning Dashboard"
)
plt.show()

## 5. Compare Multiple Methods

In [None]:
# Run experiments with different methods
from examples.complete_workflow_demo import run_progressive_learning_demo_with_oracle_interventions

# Random interventions
random_results = run_progressive_learning_demo_with_scm(scm, config)
random_metrics = extract_trajectory_metrics_from_demo(random_results)

# Oracle interventions
oracle_results = run_progressive_learning_demo_with_oracle_interventions(scm, config)
oracle_metrics = extract_trajectory_metrics_from_demo(oracle_results)

# Prepare data for comparison
results_by_method = {
    "Random Interventions": {
        'steps': random_metrics['steps'],
        'shd_mean': random_metrics['shd_values'],
        'f1_mean': random_metrics['f1_scores'],
        'target_mean': random_metrics['target_values'],
        'n_runs': 1
    },
    "Oracle Interventions": {
        'steps': oracle_metrics['steps'],
        'shd_mean': oracle_metrics['shd_values'],
        'f1_mean': oracle_metrics['f1_scores'],
        'target_mean': oracle_metrics['target_values'],
        'n_runs': 1
    }
}

# Create comparison plot
fig = plot_baseline_comparison(
    results_by_method,
    title="Random vs Oracle Intervention Comparison"
)
plt.show()

## 6. Analyze Learning Progress

In [None]:
# Analyze key points in the learning trajectory
import numpy as np

# Find when F1 score crosses thresholds
f1_scores = trajectory_metrics['f1_scores']
steps = trajectory_metrics['steps']

thresholds = [0.5, 0.7, 0.9]
for threshold in thresholds:
    crossing_idx = next((i for i, f1 in enumerate(f1_scores) if f1 >= threshold), None)
    if crossing_idx is not None:
        print(f"F1 score crossed {threshold} at step {steps[crossing_idx]}")
    else:
        print(f"F1 score never reached {threshold}")

# Analyze convergence rate
parent_likelihood = trajectory_metrics['true_parent_likelihood']
if len(parent_likelihood) > 1:
    improvement_rate = (parent_likelihood[-1] - parent_likelihood[0]) / (len(parent_likelihood) - 1)
    print(f"\nAverage improvement rate: {improvement_rate:.4f} per step")
    
# Find fastest learning period
if len(f1_scores) > 5:
    window_size = 5
    improvements = []
    for i in range(len(f1_scores) - window_size):
        improvement = f1_scores[i + window_size] - f1_scores[i]
        improvements.append((i, improvement))
    
    best_window = max(improvements, key=lambda x: x[1])
    print(f"\nFastest learning: steps {best_window[0]+1} to {best_window[0]+window_size+1}")
    print(f"F1 improvement: {best_window[1]:.3f}")

## 7. Save Plots

In [None]:
# Save all plots to a directory
from src.causal_bayes_opt.visualization.plots import save_all_plots

output_dir = "notebook_plots"
saved_files = save_all_plots(
    {'trajectory_metrics': trajectory_metrics},
    output_dir=output_dir,
    prefix="demo"
)

print(f"Saved {len(saved_files)} plots to {output_dir}/")
for file in saved_files:
    print(f"  - {Path(file).name}")