# DRL vs EA Quantum Architecture Search: Classification Comparison

This notebook provides scaffolding for comparing the DRL approach from arXiv:2407.20147
with the coevolutionary (EA) agents on **classification tasks**.

## Prerequisites

1. Run classification experiments using configs in `comparison/experiments/configs/`:
   - `drl_classification.yaml` for DRL method
   - `ea_classification.yaml` for EA method
2. Place log files in `comparison/logs/` following naming convention
3. Install dependencies: `pip install -r comparison/requirements.txt`

## Reference

- **Paper**: "Quantum Machine Learning Architecture Search via Deep Reinforcement Learning"
- **arXiv**: [2407.20147](https://arxiv.org/abs/2407.20147)
- **Venue**: IEEE QCE 2024

## 1. Setup and Imports

In [None]:
import sys
import json
from pathlib import Path

# Add repository root to path
repo_root = Path().resolve().parent.parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Import comparison modules
from comparison.analysis.compute_classif_metrics import (
    load_logs,
    validate_classification_logs,
    aggregate_classification_metrics,
    save_classification_summary,
)

# Optional: Import visualization libraries
try:
    import matplotlib.pyplot as plt
    import numpy as np
    HAS_PLOTTING = True
except ImportError:
    HAS_PLOTTING = False
    print("matplotlib/numpy not available. Install with: pip install matplotlib numpy")

# Set plotting style
if HAS_PLOTTING:
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams['figure.figsize'] = (10, 6)
    plt.rcParams['font.size'] = 12

## 2. Load Paper Metadata

Load the extracted hyperparameters and experimental details from the paper.

In [None]:
# Load paper metadata
METADATA_PATH = Path("../paper_metadata/quantum_ml_arch_search_2407.20147.json")

with open(METADATA_PATH, 'r') as f:
    paper_metadata = json.load(f)

print(f"Paper: {paper_metadata['paper_title']}")
print(f"arXiv: {paper_metadata['arxiv_id']}")
print(f"Authors: {', '.join(paper_metadata['authors'])}")
print(f"\nDRL Algorithm: {paper_metadata['drl_algorithm']['name']}")
print(f"Gate Set: {paper_metadata['action_space']['gate_set']}")
print(f"Max Gates: {paper_metadata['max_depth_termination']['max_gates']['L_values_used']}")
print(f"\nDatasets:")
for task in paper_metadata['tasks']:
    print(f"  - {task['name']}: {task['description']}")

## 3. Load Experiment Logs

In [None]:
# Configure log paths
LOGS_DIR = Path("../logs")

# Load DRL classification logs
drl_log_pattern = str(LOGS_DIR / "drl_classification" / "*.jsonl")
drl_logs = load_logs(drl_log_pattern)
print(f"Loaded {len(drl_logs)} DRL classification log entries")

# Load EA classification logs
ea_log_pattern = str(LOGS_DIR / "ea_classification" / "*.jsonl")
ea_logs = load_logs(ea_log_pattern)
print(f"Loaded {len(ea_logs)} EA classification log entries")

## 4. Validate Logs

In [None]:
# Validate DRL logs
valid_drl, drl_errors = validate_classification_logs(drl_logs)
print(f"DRL: {len(valid_drl)} valid, {len(drl_errors)} errors")

# Validate EA logs
valid_ea, ea_errors = validate_classification_logs(ea_logs)
print(f"EA: {len(valid_ea)} valid, {len(ea_errors)} errors")

# Show errors if any
if drl_errors:
    print("\nDRL validation errors (first 3):")
    for idx, err in drl_errors[:3]:
        print(f"  Entry {idx}: {err}")

if ea_errors:
    print("\nEA validation errors (first 3):")
    for idx, err in ea_errors[:3]:
        print(f"  Entry {idx}: {err}")

## 5. Compute Classification Metrics

In [None]:
# Combine logs for comparison
all_logs = valid_drl + valid_ea

# Compute aggregated metrics
metrics = aggregate_classification_metrics(all_logs)

# Display summary
print(f"Total runs: {metrics['total_runs']}")
print(f"Total log entries: {metrics['total_logs']}")

print("\n" + "="*60)
print("CLASSIFICATION METRICS BY METHOD")
print("="*60)

for method, stats in metrics.get('by_method', {}).items():
    print(f"\n--- {method.upper()} ---")
    print(f"  Runs: {stats['n_runs']}")
    
    if stats.get('mean_best_test_accuracy') is not None:
        std = stats.get('std_best_test_accuracy') or 0
        print(f"  Best test accuracy: {stats['mean_best_test_accuracy']:.4f} ± {std:.4f}")
    
    if stats.get('mean_final_test_accuracy') is not None:
        std = stats.get('std_final_test_accuracy') or 0
        print(f"  Final test accuracy: {stats['mean_final_test_accuracy']:.4f} ± {std:.4f}")
    
    if stats.get('success_rate_90') is not None:
        print(f"  Success rate (≥90%): {stats['success_rate_90']*100:.1f}%")
    
    if stats.get('mean_final_gate_count') is not None:
        std = stats.get('std_final_gate_count') or 0
        print(f"  Final gate count: {stats['mean_final_gate_count']:.1f} ± {std:.1f}")

## 6. Plot: Accuracy vs Evaluations

Compare learning curves showing classification accuracy improvement over training.

In [None]:
def plot_accuracy_vs_evals(logs_by_method, accuracy_field='test_accuracy',
                           title="Classification Accuracy vs Evaluations"):
    """
    Plot accuracy learning curves for each method.
    
    Args:
        logs_by_method: dict mapping method name to list of log entries
        accuracy_field: which accuracy field to plot
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(12, 7))
    
    colors = {'drl': '#1f77b4', 'ea': '#ff7f0e'}
    
    for method, logs in logs_by_method.items():
        # Group by seed
        seeds = {}
        for log in logs:
            seed = log.get('seed', 0)
            if seed not in seeds:
                seeds[seed] = {'evals': [], 'accuracies': []}
            
            # Get evaluation index
            eval_idx = log.get('eval_id') or log.get('episode') or log.get('generation', 0)
            seeds[seed]['evals'].append(eval_idx)
            
            # Get accuracy
            acc = log.get(accuracy_field) or log.get('train_accuracy') or log.get('accuracy', 0)
            seeds[seed]['accuracies'].append(acc)
        
        color = colors.get(method, 'gray')
        
        # Plot each seed with light color
        for seed, data in seeds.items():
            # Sort by evals
            sorted_pairs = sorted(zip(data['evals'], data['accuracies']))
            evals, accs = zip(*sorted_pairs) if sorted_pairs else ([], [])
            ax.plot(evals, accs, color=color, alpha=0.2, linewidth=1)
        
        # Compute and plot mean curve
        if seeds:
            # Interpolate to common x values for averaging
            all_evals = sorted(set(e for seed_data in seeds.values() for e in seed_data['evals']))
            if all_evals:
                mean_accs = []
                for eval_pt in all_evals:
                    accs_at_pt = []
                    for seed_data in seeds.values():
                        for e, a in zip(seed_data['evals'], seed_data['accuracies']):
                            if e == eval_pt:
                                accs_at_pt.append(a)
                    mean_accs.append(np.mean(accs_at_pt) if accs_at_pt else np.nan)
                
                ax.plot(all_evals, mean_accs, color=color, linewidth=2.5,
                       label=f"{method.upper()} (n={len(seeds)})")
    
    # Add threshold lines
    for thresh in [0.7, 0.8, 0.9]:
        ax.axhline(y=thresh, color='gray', linestyle='--', alpha=0.3, linewidth=1)
        ax.text(ax.get_xlim()[1] * 0.98, thresh + 0.01, f'{int(thresh*100)}%',
               ha='right', va='bottom', color='gray', fontsize=10)
    
    ax.set_xlabel('Evaluations (episodes/generations)', fontsize=12)
    ax.set_ylabel('Classification Accuracy', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc='lower right', fontsize=11)
    ax.set_ylim(0.4, 1.02)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Example usage (uncomment when logs are available)
# fig = plot_accuracy_vs_evals({'drl': valid_drl, 'ea': valid_ea})
# plt.show()
print("TODO: Uncomment plotting code when logs are available")

## 7. Plot: ECDF of Final Accuracies

Empirical cumulative distribution function comparing final accuracies across seeds.

In [None]:
def plot_accuracy_ecdf(metrics, title="ECDF of Final Classification Accuracies"):
    """
    Plot empirical CDF of final accuracies for each method.
    
    Args:
        metrics: Aggregated metrics from aggregate_classification_metrics
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(10, 7))
    
    colors = {'drl': '#1f77b4', 'ea': '#ff7f0e'}
    
    # Extract final accuracies from per_run metrics
    for method in ['drl', 'ea']:
        accuracies = []
        for run_key, run_data in metrics.get('per_run', {}).items():
            if run_data['method'] == method:
                acc = run_data.get('final_test_accuracy') or run_data.get('best_test_accuracy')
                if acc is not None:
                    accuracies.append(acc)
        
        if accuracies:
            sorted_accs = np.sort(accuracies)
            ecdf = np.arange(1, len(sorted_accs) + 1) / len(sorted_accs)
            
            color = colors.get(method, 'gray')
            ax.step(sorted_accs, ecdf, where='post', color=color, linewidth=2.5,
                   label=f"{method.upper()} (n={len(accuracies)})")
            
            # Add scatter points
            ax.scatter(sorted_accs, ecdf, color=color, s=50, zorder=5)
    
    # Add threshold lines
    for thresh in [0.7, 0.8, 0.9, 0.95]:
        ax.axvline(x=thresh, color='red', linestyle='--', alpha=0.3, linewidth=1)
    
    ax.set_xlabel('Final Classification Accuracy', fontsize=12)
    ax.set_ylabel('Cumulative Probability', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc='lower right', fontsize=11)
    ax.set_xlim(0.5, 1.02)
    ax.set_ylim(0, 1.05)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Example usage (uncomment when metrics are computed)
# fig = plot_accuracy_ecdf(metrics)
# plt.show()
print("TODO: Uncomment plotting code when metrics are available")

## 8. Plot: Pareto Front (Accuracy vs Circuit Depth)

Trade-off between classification accuracy and circuit complexity.

In [None]:
def plot_pareto_accuracy_vs_depth(metrics, title="Pareto: Accuracy vs Circuit Depth"):
    """
    Plot Pareto frontier of accuracy vs circuit depth/gate count.
    
    Args:
        metrics: Aggregated metrics from aggregate_classification_metrics
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, ax = plt.subplots(figsize=(10, 7))
    
    colors = {'drl': '#1f77b4', 'ea': '#ff7f0e'}
    markers = {'drl': 'o', 'ea': 's'}
    
    for method in ['drl', 'ea']:
        accuracies = []
        depths = []
        
        for run_key, run_data in metrics.get('per_run', {}).items():
            if run_data['method'] == method:
                acc = run_data.get('best_test_accuracy')
                depth = run_data.get('best_model_gate_count') or run_data.get('final_gate_count')
                
                if acc is not None and depth is not None:
                    accuracies.append(acc)
                    depths.append(depth)
        
        if accuracies:
            color = colors.get(method, 'gray')
            marker = markers.get(method, 'o')
            ax.scatter(depths, accuracies, c=color, marker=marker, s=120, alpha=0.7,
                      label=f"{method.upper()} (n={len(accuracies)})", edgecolors='white', linewidth=1)
    
    # Add accuracy threshold
    ax.axhline(y=0.9, color='green', linestyle='--', alpha=0.5, linewidth=1.5, label='90% threshold')
    
    ax.set_xlabel('Gate Count', fontsize=12)
    ax.set_ylabel('Best Classification Accuracy', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc='lower right', fontsize=11)
    ax.set_ylim(0.5, 1.02)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Example usage
# fig = plot_pareto_accuracy_vs_depth(metrics)
# plt.show()
print("TODO: Uncomment plotting code when metrics are available")

## 9. Plot: Box/Bar Plot of Final Accuracies

Statistical comparison of final accuracies across methods and seeds.

In [None]:
def plot_accuracy_boxplot(metrics, title="Final Classification Accuracies by Method"):
    """
    Create box plot comparing final accuracies across methods.
    
    Args:
        metrics: Aggregated metrics from aggregate_classification_metrics
        title: Plot title
    """
    if not HAS_PLOTTING:
        print("Plotting not available. Install matplotlib.")
        return None
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    colors = {'drl': '#1f77b4', 'ea': '#ff7f0e'}
    
    # Extract accuracies by method
    data_by_method = {'drl': [], 'ea': []}
    
    for run_key, run_data in metrics.get('per_run', {}).items():
        method = run_data['method']
        acc = run_data.get('final_test_accuracy') or run_data.get('best_test_accuracy')
        if acc is not None and method in data_by_method:
            data_by_method[method].append(acc)
    
    # Box plot
    bp = ax1.boxplot([data_by_method['drl'], data_by_method['ea']],
                    labels=['DRL', 'EA'], patch_artist=True)
    
    for patch, color in zip(bp['boxes'], [colors['drl'], colors['ea']]):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    
    ax1.set_ylabel('Final Classification Accuracy', fontsize=12)
    ax1.set_title('Box Plot', fontsize=14)
    ax1.set_ylim(0.5, 1.02)
    ax1.grid(True, alpha=0.3, axis='y')
    ax1.axhline(y=0.9, color='green', linestyle='--', alpha=0.5)
    
    # Bar plot with error bars
    methods = ['DRL', 'EA']
    means = [
        np.mean(data_by_method['drl']) if data_by_method['drl'] else 0,
        np.mean(data_by_method['ea']) if data_by_method['ea'] else 0
    ]
    stds = [
        np.std(data_by_method['drl'], ddof=1) if len(data_by_method['drl']) > 1 else 0,
        np.std(data_by_method['ea'], ddof=1) if len(data_by_method['ea']) > 1 else 0
    ]
    
    x = np.arange(len(methods))
    bars = ax2.bar(x, means, yerr=stds, capsize=5, color=[colors['drl'], colors['ea']], alpha=0.7)
    
    # Add value labels
    for bar, mean, std in zip(bars, means, stds):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.01,
                f'{mean:.3f}±{std:.3f}', ha='center', va='bottom', fontsize=11)
    
    ax2.set_xticks(x)
    ax2.set_xticklabels(methods)
    ax2.set_ylabel('Final Classification Accuracy', fontsize=12)
    ax2.set_title('Bar Plot (mean ± std)', fontsize=14)
    ax2.set_ylim(0.5, 1.1)
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.axhline(y=0.9, color='green', linestyle='--', alpha=0.5)
    
    fig.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    return fig

# Example usage
# fig = plot_accuracy_boxplot(metrics)
# plt.show()
print("TODO: Uncomment plotting code when metrics are available")

## 10. Comparison Checklist

Verify that DRL and EA experiments use matched settings for fair comparison.

In [None]:
# Checklist for matching DRL and EA settings
print("="*60)
print("FAIR COMPARISON CHECKLIST")
print("="*60)

checklist = [
    ("Gate set", "RX, RY, RZ, CNOT", "Match paper's allowed gates"),
    ("Max circuit depth/gates", "L=20", "Match paper's L parameter"),
    ("Evaluation budget", "~800 episodes/evaluations", "Match paper's episode count"),
    ("Inner-loop optimization", "15 epochs max, Adam", "Match paper's VQC training"),
    ("Dataset", "make_classification or make_moons", "Match paper's datasets"),
    ("Train/test split", "80/20 (assumed)", "Use consistent split"),
    ("Data encoding", "arctan embedding", "Match paper's encoding"),
    ("Number of seeds", "≥5", "Statistical validity"),
    ("Gate penalty coefficient", "0.01", "Match paper's complexity penalty"),
    ("Number of qubits", "4 (matches features)", "Match data dimensionality"),
]

for item, value, note in checklist:
    print(f"\n☐ {item}")
    print(f"   Value: {value}")
    print(f"   Note: {note}")

print("\n" + "="*60)
print("Verify configs in comparison/experiments/configs/ match these settings")
print("="*60)

## 11. Save Results

In [None]:
# Save metrics summary
output_dir = Path("../logs/classification_analysis")
output_dir.mkdir(parents=True, exist_ok=True)

# Uncomment when logs are available:
# json_path, csv_path = save_classification_summary(metrics, output_dir)
# print(f"Saved JSON: {json_path}")
# print(f"Saved CSV: {csv_path}")

# Save plots
# if HAS_PLOTTING:
#     fig = plot_accuracy_vs_evals({'drl': valid_drl, 'ea': valid_ea})
#     fig.savefig(output_dir / 'accuracy_vs_evals.png', dpi=150, bbox_inches='tight')
#     plt.close(fig)
#     
#     fig = plot_accuracy_ecdf(metrics)
#     fig.savefig(output_dir / 'accuracy_ecdf.png', dpi=150, bbox_inches='tight')
#     plt.close(fig)
#     
#     fig = plot_pareto_accuracy_vs_depth(metrics)
#     fig.savefig(output_dir / 'pareto_accuracy_depth.png', dpi=150, bbox_inches='tight')
#     plt.close(fig)
#     
#     fig = plot_accuracy_boxplot(metrics)
#     fig.savefig(output_dir / 'accuracy_boxplot.png', dpi=150, bbox_inches='tight')
#     plt.close(fig)
#     
#     print(f"Saved plots to: {output_dir}")

print("TODO: Uncomment save code when analysis is complete")

## 12. Key Insights from Paper Metadata

In [None]:
print("="*60)
print("KEY INSIGHTS FROM PAPER (arXiv:2407.20147)")
print("="*60)

print("\n1. REWARD FUNCTION:")
reward = paper_metadata.get('reward_function', {})
print(f"   Type: {reward.get('type')}")
print(f"   Gate penalty: {reward.get('gate_penalty_coefficient')}")
print(f"   Reward clipping: {reward.get('reward_clip_range')}")

print("\n2. ADAPTIVE SEARCH:")
adaptive = paper_metadata.get('adaptive_search', {})
print(f"   ytarget increment: {adaptive.get('ytarget_increment')}")
print(f"   Epsilon decay on increase: {adaptive.get('epsilon_decay_on_increase')}")

print("\n3. OMITTED HYPERPARAMETERS (need assumptions):")
omissions = paper_metadata.get('notes_and_omissions', {}).get('omitted_hyperparameters', [])
for item in omissions[:5]:
    print(f"   - {item}")

print("\n4. KEY INSIGHTS FOR REPRODUCTION:")
insights = paper_metadata.get('notes_and_omissions', {}).get('key_insights', [])
for item in insights:
    print(f"   - {item}")

## Next Steps

1. **Implement classification task** in EA pipeline (adapt existing VQE environment)
2. **Implement DRL agent** following the DDQN specification from the paper
3. **Run experiments** with matched settings using the YAML configs
4. **Generate logs** following the schema in `comparison/logs/schema.json`
5. **Analyze results** by uncommenting the code in this notebook

### Example Commands

```bash
# Run classification metrics analysis
python -m comparison.analysis.compute_classif_metrics \
    --input "comparison/logs/**/*.jsonl" \
    --out comparison/logs/classification_analysis

# Run tests
pytest comparison/tests/ -v
```