# DRL vs EA Classification Comparison

This notebook provides a scaffold for comparing the DRL approach from arXiv:2407.20147
with the coevolutionary (EA) agents on **classification tasks**.

## Prerequisites

1. Run experiments and generate logs using the configs in `comparison/experiments/configs/`
   - `drl_classification.yaml` for DRL experiments
   - `ea_classification.yaml` for EA experiments
2. Place log files in `comparison/logs/` following the naming convention:
   - `drl_classif_run_{seed}.jsonl` for DRL results
   - `ea_classif_run_{seed}.jsonl` for EA results
3. Install dependencies: `pip install -r comparison/requirements.txt`

## Alignment Checklist

Before running comparison, verify the following settings are aligned:

### Gate Set
- [ ] Both methods use same gate set: RX, RY, RZ, CNOT
- [ ] CNOT connectivity matches (nearest-neighbor cyclic)

### Inner-Loop Optimization
- [ ] Same loss function (binary cross-entropy)
- [ ] Same number of epochs per step (15 for make_classification, 25 for make_moons)
- [ ] Same data encoding (arctan embedding)

### Circuit Constraints
- [ ] Same max_depth/max_gates (20 for make_classification, 25 for make_moons)
- [ ] Same number of qubits (4 for make_classification, 2 for make_moons)

### Evaluation Budget
- [ ] Comparable total circuit evaluations (~1200 episodes/evaluations)
- [ ] Same seeds used for reproducibility

### Dataset
- [ ] Same dataset (make_classification or make_moons)
- [ ] Same train/test split (if specified)

## 1. Setup and Imports

In [None]:
import sys
import json
from pathlib import Path

# Add repository root to path
repo_root = Path().resolve().parent.parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Import classification metrics module
from comparison.analysis.compute_classif_metrics import (
    load_logs,
    aggregate_classification_metrics,
    save_summary,
)

# Optional: Import visualization libraries
try:
    import matplotlib.pyplot as plt
    import numpy as np
    HAS_PLOTTING = True
except ImportError:
    HAS_PLOTTING = False
    print("matplotlib/numpy not available. Install with: pip install matplotlib numpy")

## 2. Load Paper Metadata

In [None]:
# Load paper metadata
METADATA_PATH = Path("../paper_metadata/quantum_ml_arch_search_2407.20147.json")

if METADATA_PATH.exists():
    with open(METADATA_PATH) as f:
        paper_metadata = json.load(f)
    print(f"Paper: {paper_metadata['paper_title']}")
    print(f"Authors: {', '.join(paper_metadata['authors'])}")
    print(f"arXiv: {paper_metadata['arxiv_id']}")
    print(f"\nDatasets: {[d['name'] for d in paper_metadata['tasks']['datasets']]}")
else:
    print(f"Metadata file not found at {METADATA_PATH}")
    paper_metadata = None

## 3. Load Experiment Logs

In [None]:
# Configure log paths
LOGS_DIR = Path("../logs")

# Load DRL classification logs
drl_log_pattern = str(LOGS_DIR / "drl" / "*classif*.jsonl")
drl_logs = load_logs(drl_log_pattern)
print(f"Loaded {len(drl_logs)} DRL classification log entries")

# Load EA classification logs
ea_log_pattern = str(LOGS_DIR / "ea" / "*classif*.jsonl")
ea_logs = load_logs(ea_log_pattern)
print(f"Loaded {len(ea_logs)} EA classification log entries")

## 4. Compute Classification Metrics

In [None]:
# Combine logs for comparison
all_logs = drl_logs + ea_logs

# Compute classification metrics with custom thresholds
thresholds = [0.70, 0.80, 0.90]
metrics = aggregate_classification_metrics(all_logs, thresholds=thresholds)

# Display summary
print(f"Total runs: {metrics['total_runs']}")
print(f"Total log entries: {metrics['total_logs']}")
print(f"Thresholds: {metrics['thresholds_used']}")

print("\n--- By Method ---")
for method, stats in metrics.get('by_method', {}).items():
    print(f"\n{method.upper()}:")
    print(f"  Runs: {stats['n_runs']}")
    if stats['mean_max_val_accuracy'] is not None:
        std = stats.get('std_max_val_accuracy', 0) or 0
        print(f"  Max val accuracy: {stats['mean_max_val_accuracy']:.4f} ± {std:.4f}")
    if stats['mean_final_val_accuracy'] is not None:
        std = stats.get('std_final_val_accuracy', 0) or 0
        print(f"  Final val accuracy: {stats['mean_final_val_accuracy']:.4f} ± {std:.4f}")
    if stats['mean_final_test_accuracy'] is not None:
        std = stats.get('std_final_test_accuracy', 0) or 0
        print(f"  Final test accuracy: {stats['mean_final_test_accuracy']:.4f} ± {std:.4f}")

## 5. Validation/Test Accuracy vs Evaluations

Plot learning curves showing accuracy improvement over training evaluations (median ± CI).

In [None]:
def plot_accuracy_vs_evals(logs_by_method, acc_key='best_val_accuracy', 
                           title="Validation Accuracy vs Evaluations"):
    """
    Plot accuracy learning curves for each method with median and confidence interval.
    
    Args:
        logs_by_method: dict mapping method name to list of log entries
        acc_key: key for accuracy value in log entries
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = {'drl': 'blue', 'ea': 'orange'}
    
    for method, logs in logs_by_method.items():
        # Group by seed
        seeds = {}
        for log in logs:
            seed = log.get('seed', 0)
            if seed not in seeds:
                seeds[seed] = {'evals': [], 'accs': []}
            
            eval_count = log.get('cum_eval_count') or log.get('eval_id', 0)
            acc = log.get(acc_key) or log.get('best_fidelity', 0)
            
            seeds[seed]['evals'].append(eval_count)
            seeds[seed]['accs'].append(acc)
        
        color = colors.get(method, 'gray')
        
        # Plot individual runs with light color
        for seed, data in seeds.items():
            ax.plot(data['evals'], data['accs'], 
                   color=color, alpha=0.2, linewidth=1)
        
        # TODO: Compute and plot median with CI
        # This requires interpolating accuracies to common eval points
    
    ax.set_xlabel('Cumulative Evaluations')
    ax.set_ylabel('Accuracy')
    ax.set_title(title)
    ax.legend(list(logs_by_method.keys()))
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1.05)
    
    plt.tight_layout()
    return fig

# Plot (uncomment when logs are available)
# fig = plot_accuracy_vs_evals({'drl': drl_logs, 'ea': ea_logs})
# plt.show()
print("TODO: Uncomment plotting code when logs are available")

## 6. ECDF of Final Accuracies

Plot the empirical cumulative distribution function of final validation/test accuracies.

In [None]:
def plot_ecdf(fidelities_by_method, title="ECDF of Final Accuracies"):
    """
    Plot empirical CDF of final accuracies for each method.
    
    Args:
        fidelities_by_method: dict mapping method name to list of final accuracies
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(8, 6))
    colors = {'drl': 'blue', 'ea': 'orange'}
    
    for method, accs in fidelities_by_method.items():
        if not accs:
            continue
        
        sorted_accs = np.sort(accs)
        ecdf = np.arange(1, len(sorted_accs) + 1) / len(sorted_accs)
        
        color = colors.get(method, 'gray')
        ax.step(sorted_accs, ecdf, where='post', 
               color=color, linewidth=2, label=f"{method} (n={len(accs)})")
    
    ax.set_xlabel('Final Accuracy')
    ax.set_ylabel('Cumulative Probability')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0.5, 1.05)
    ax.set_ylim(0, 1.05)
    
    # Add threshold lines
    for thresh in [0.70, 0.80, 0.90]:
        ax.axvline(x=thresh, color='red', linestyle='--', alpha=0.3)
    
    plt.tight_layout()
    return fig

# Extract final accuracies from per_run metrics (uncomment when available)
# drl_accs = [r['final_val_accuracy'] for r in metrics['per_run'].values() 
#             if r['method'] == 'drl' and r['final_val_accuracy'] is not None]
# ea_accs = [r['final_val_accuracy'] for r in metrics['per_run'].values() 
#            if r['method'] == 'ea' and r['final_val_accuracy'] is not None]
# fig = plot_ecdf({'drl': drl_accs, 'ea': ea_accs})
# plt.show()
print("TODO: Uncomment plotting code when metrics are available")

## 7. Pareto Plot: Accuracy vs Depth/Gate Count

Plot the Pareto frontier of accuracy vs circuit complexity.

In [None]:
def plot_pareto_accuracy_vs_depth(runs_by_method, title="Pareto: Accuracy vs Circuit Depth"):
    """
    Plot Pareto frontier of accuracy vs circuit depth.
    
    Args:
        runs_by_method: dict mapping method name to list of run dicts with
                       'max_val_accuracy' and 'final_depth'/'min_depth' keys
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(8, 6))
    colors = {'drl': 'blue', 'ea': 'orange'}
    markers = {'drl': 'o', 'ea': 's'}
    
    for method, runs in runs_by_method.items():
        accs = []
        depths = []
        
        for run in runs:
            acc = run.get('max_val_accuracy') or run.get('final_val_accuracy')
            depth = run.get('final_depth') or run.get('min_depth') or run.get('final_gate_count')
            if acc is not None and depth is not None:
                accs.append(acc)
                depths.append(depth)
        
        if accs:
            color = colors.get(method, 'gray')
            marker = markers.get(method, 'o')
            ax.scatter(depths, accs, 
                      c=color, marker=marker, s=100, alpha=0.7,
                      label=f"{method} (n={len(accs)})")
    
    ax.set_xlabel('Circuit Depth / Gate Count')
    ax.set_ylabel('Max Validation Accuracy')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0.5, 1.05)
    
    # Add accuracy threshold line
    ax.axhline(y=0.90, color='green', linestyle='--', alpha=0.5, label='90% threshold')
    
    plt.tight_layout()
    return fig

# Plot (uncomment when metrics are available)
# drl_runs = [r for r in metrics['per_run'].values() if r['method'] == 'drl']
# ea_runs = [r for r in metrics['per_run'].values() if r['method'] == 'ea']
# fig = plot_pareto_accuracy_vs_depth({'drl': drl_runs, 'ea': ea_runs})
# plt.show()
print("TODO: Uncomment plotting code when metrics are available")

## 8. Boxplots of Final Accuracies

Compare distributions of final accuracies across methods.

In [None]:
def plot_boxplots(accs_by_method, title="Final Accuracy Distribution"):
    """
    Plot boxplots of final accuracies for each method.
    
    Args:
        accs_by_method: dict mapping method name to list of final accuracies
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(6, 6))
    
    methods = list(accs_by_method.keys())
    data = [accs_by_method[m] for m in methods]
    colors = ['blue', 'orange']
    
    bp = ax.boxplot(data, labels=methods, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors[:len(methods)]):
        patch.set_facecolor(color)
        patch.set_alpha(0.5)
    
    ax.set_ylabel('Final Accuracy')
    ax.set_title(title)
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0.5, 1.05)
    
    # Add threshold lines
    for thresh in [0.70, 0.80, 0.90]:
        ax.axhline(y=thresh, color='red', linestyle='--', alpha=0.3)
    
    plt.tight_layout()
    return fig

# Plot (uncomment when metrics are available)
# drl_accs = [r['final_val_accuracy'] for r in metrics['per_run'].values() 
#             if r['method'] == 'drl' and r['final_val_accuracy'] is not None]
# ea_accs = [r['final_val_accuracy'] for r in metrics['per_run'].values() 
#            if r['method'] == 'ea' and r['final_val_accuracy'] is not None]
# fig = plot_boxplots({'DRL': drl_accs, 'EA': ea_accs})
# plt.show()
print("TODO: Uncomment plotting code when metrics are available")

## 9. Save Results

In [None]:
# Save metrics summary
output_dir = Path("../logs/classif_analysis")
output_dir.mkdir(parents=True, exist_ok=True)

# Uncomment when logs are available:
# json_path, csv_path = save_summary(metrics, output_dir)
# print(f"Saved JSON: {json_path}")
# print(f"Saved CSV: {csv_path}")

# Save plots
# if HAS_PLOTTING:
#     fig = plot_accuracy_vs_evals({'drl': drl_logs, 'ea': ea_logs})
#     fig.savefig(output_dir / 'accuracy_vs_evals.png', dpi=150)
#     plt.close(fig)
#     print(f"Saved: {output_dir / 'accuracy_vs_evals.png'}")

print("TODO: Uncomment save code when analysis is complete")

## Next Steps

1. **Generate logs**: Run experiments using the configs in `comparison/experiments/configs/`
   - For DRL: Configure `drl_classification.yaml` and run your DRL agent
   - For EA: Configure `ea_classification.yaml` and run the repository's EA pipeline

2. **Update paths**: Modify the log paths in this notebook to point to your results

3. **Analyze results**: Uncomment the plotting and analysis code

4. **Compare methods**: Look at the metrics summary to compare DRL vs EA

### Example Commands

```bash
# Run EA classification experiments (from repo root)
# TODO: Replace with actual EA runner command
python run_experiments.py --preset quick --n-qubits 4 --seed 42

# Compute classification metrics from logs
python -m comparison.analysis.compute_classif_metrics \
    --input "comparison/logs/**/*classif*.jsonl" \
    --out comparison/logs/classif_analysis \
    --thresholds 0.70 0.80 0.90
```