# Ax Bayesian Optimization with Prefect Human-in-the-Loop Tutorial

This notebook demonstrates how to integrate Ax (Adaptive Experimentation Platform) with Prefect for human-in-the-loop Bayesian optimization workflows.

## Overview

- **Objective**: Optimize the Hartmann6 benchmark function
- **Method**: Bayesian optimization with Ax
- **Workflow**: Prefect for orchestration and human-in-the-loop
- **Persistence**: MongoDB for checkpointing and restart capability

## Setup

First, install the required packages:

In [None]:
# Uncomment and run to install dependencies
# !pip install ax-platform prefect pymongo numpy matplotlib plotly

## Imports and Setup

In [None]:
import asyncio
import json
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any

import numpy as np
import matplotlib.pyplot as plt

# Conditional imports with fallbacks
try:
    from ax.service.ax_client import AxClient
    from ax.utils.measurement.synthetic_functions import hartmann6
    AX_AVAILABLE = True
    print("✅ Ax available")
except ImportError:
    AX_AVAILABLE = False
    print("⚠️  Ax not available - using fallback implementation")

try:
    from prefect import flow, task, get_run_logger
    from prefect.input import RunInput
    PREFECT_AVAILABLE = True
    print("✅ Prefect available")
except ImportError:
    PREFECT_AVAILABLE = False
    print("⚠️  Prefect not available - using simple function decorators")

try:
    from pymongo import MongoClient
    MONGO_AVAILABLE = True
    print("✅ PyMongo available")
except ImportError:
    MONGO_AVAILABLE = False
    print("⚠️  PyMongo not available - using file-based checkpoints")

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Configuration

In [None]:
# Optimization configuration
EXPERIMENT_NAME = "notebook_hartmann6_demo"
N_ITERATIONS = 10
MIN_CONFIDENCE_THRESHOLD = 0.7

# MongoDB configuration (optional)
MONGODB_CONNECTION_STRING = os.getenv("MONGODB_CONNECTION_STRING", "mongodb://localhost:27017/")
MONGODB_DATABASE = "ax_tutorial_notebook"
MONGODB_COLLECTION = "optimization_results"

print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Iterations: {N_ITERATIONS}")
print(f"Confidence threshold: {MIN_CONFIDENCE_THRESHOLD}")

## Objective Function

We'll optimize the Hartmann6 function, a standard benchmark in Bayesian optimization:

In [None]:
def hartmann6_objective(parameters: Dict[str, float]) -> float:
    """
    Hartmann6 synthetic function for optimization.
    6-dimensional function with known global minimum ≈ -3.32.
    """
    if AX_AVAILABLE:
        # Use the actual Ax implementation
        x = np.array([parameters[f"x{i}"] for i in range(1, 7)])
        return hartmann6(x)
    else:
        # Simplified version for demonstration
        x = np.array([parameters[f"x{i}"] for i in range(1, 7)])
        # This is a simplified approximation of Hartmann6
        return np.sum(x**2) - 3.0 * np.exp(-np.sum(x**2))

# Test the objective function
test_params = {f"x{i}": 0.2 for i in range(1, 7)}
test_result = hartmann6_objective(test_params)
print(f"Test evaluation: f({test_params}) = {test_result:.4f}")

## Checkpoint System

Implement a checkpoint system for saving and restoring optimization state:

In [None]:
class NotebookCheckpointHandler:
    """Simple checkpoint handler for notebook demonstrations."""
    
    def __init__(self):
        self.checkpoints = []
    
    def save_checkpoint(self, experiment_name: str, iteration: int, 
                       best_parameters: Dict, best_objective: float, 
                       metadata: Dict = None) -> str:
        """Save a checkpoint."""
        checkpoint = {
            "experiment_name": experiment_name,
            "iteration": iteration,
            "best_parameters": best_parameters,
            "best_objective": best_objective,
            "metadata": metadata or {},
            "timestamp": datetime.utcnow().isoformat()
        }
        self.checkpoints.append(checkpoint)
        return f"checkpoint_{len(self.checkpoints)}"
    
    def get_latest_checkpoint(self, experiment_name: str) -> Optional[Dict]:
        """Get the latest checkpoint for an experiment."""
        matching = [cp for cp in self.checkpoints if cp["experiment_name"] == experiment_name]
        return matching[-1] if matching else None
    
    def list_checkpoints(self) -> List[Dict]:
        """List all checkpoints."""
        return self.checkpoints.copy()

# Initialize checkpoint handler
checkpoint_handler = NotebookCheckpointHandler()
print("✅ Checkpoint handler initialized")

## Mock Ax Client

Create a simplified version of Ax functionality for demonstration:

In [None]:
class MockAxClient:
    """Simplified Ax client for demonstration purposes."""
    
    def __init__(self, experiment_name: str):
        self.experiment_name = experiment_name
        self.parameter_space = {f"x{i}": {"bounds": [0.0, 1.0]} for i in range(1, 7)}
        self.trials = []
        self.best_parameters = None
        self.best_objective = float('inf')
        
    def get_next_trial(self) -> Tuple[Dict[str, float], int]:
        """Generate next trial parameters."""
        trial_index = len(self.trials)
        
        if trial_index == 0:
            # Start with a random point
            parameters = {f"x{i}": np.random.uniform(0, 1) for i in range(1, 7)}
        else:
            # Simple acquisition: explore around best point with some randomness
            if self.best_parameters:
                parameters = {}
                for key, value in self.best_parameters.items():
                    # Add Gaussian noise around best point
                    noise = np.random.normal(0, 0.1)
                    new_value = np.clip(value + noise, 0.0, 1.0)
                    parameters[key] = new_value
            else:
                parameters = {f"x{i}": np.random.uniform(0, 1) for i in range(1, 7)}
        
        return parameters, trial_index
    
    def complete_trial(self, trial_index: int, parameters: Dict[str, float], 
                      objective_value: float):
        """Complete a trial with results."""
        trial = {
            "trial_index": trial_index,
            "parameters": parameters,
            "objective_value": objective_value
        }
        self.trials.append(trial)
        
        # Update best if this is better
        if objective_value < self.best_objective:
            self.best_objective = objective_value
            self.best_parameters = parameters.copy()
    
    def get_best_parameters(self) -> Tuple[Dict[str, float], float]:
        """Get current best parameters and objective."""
        if self.best_parameters is None:
            return {f"x{i}": 0.5 for i in range(1, 7)}, float('inf')
        return self.best_parameters.copy(), self.best_objective

# Initialize optimization client
if AX_AVAILABLE:
    ax_client = AxClient()
    ax_client.create_experiment(
        name=EXPERIMENT_NAME,
        parameters=[
            {"name": f"x{i}", "type": "range", "bounds": [0.0, 1.0]}
            for i in range(1, 7)
        ],
        objective_name="hartmann6",
        minimize=True,
    )
    print("✅ Real Ax client initialized")
else:
    ax_client = MockAxClient(EXPERIMENT_NAME)
    print("✅ Mock Ax client initialized")

## Human-in-the-Loop Interface

In a notebook environment, we'll use simple input prompts instead of Prefect's pause functionality:

In [None]:
def get_human_input(iteration: int, confidence: float, 
                   best_parameters: Dict, best_objective: float,
                   current_parameters: Dict, current_objective: float) -> Dict:
    """Get human input for decision making."""
    
    print("\n" + "="*60)
    print(f"🧪 OPTIMIZATION UPDATE - Iteration {iteration}")
    print("="*60)
    print(f"📊 CURRENT BEST:")
    print(f"   • Objective Value: {best_objective:.4f}")
    print(f"   • Parameters: {best_parameters}")
    print()
    print(f"🎯 LATEST TRIAL:")
    print(f"   • Parameters: {current_parameters}")
    print(f"   • Objective: {current_objective:.4f}")
    print()
    print(f"🤖 ALGORITHM CONFIDENCE: {confidence:.1%}")
    print("="*60)
    
    # In a real Prefect workflow, this would be handled by pause_flow_run
    # For notebook demo, we'll auto-continue with some randomness
    import random
    
    if confidence > 0.8 or random.random() > 0.3:  # Usually continue
        decision = {
            "continue": True,
            "comments": "Auto-continuing for demo",
            "confidence_override": None
        }
        print("🤖 Auto-decision: CONTINUE (demo mode)")
    else:
        decision = {
            "continue": True,  # For demo, always continue
            "comments": "Low confidence but continuing for demo",
            "confidence_override": confidence + 0.1
        }
        print("🤖 Auto-decision: CONTINUE with confidence boost (demo mode)")
    
    print("\n" + "="*60 + "\n")
    return decision

print("✅ Human-in-the-loop interface ready")

## Main Optimization Loop

Now let's run the complete optimization with human-in-the-loop integration:

In [None]:
def run_optimization_loop():
    """Run the main optimization loop."""
    
    results = {
        "experiment_name": EXPERIMENT_NAME,
        "iterations": [],
        "best_parameters": None,
        "best_objective": float('inf'),
        "human_decisions": []
    }
    
    print(f"🚀 Starting optimization of {EXPERIMENT_NAME}")
    print(f"📊 Target: Minimize Hartmann6 function (known minimum ≈ -3.32)")
    print(f"🎯 Iterations: {N_ITERATIONS}")
    print(f"🤖 Confidence threshold: {MIN_CONFIDENCE_THRESHOLD}")
    print("\n" + "="*80 + "\n")
    
    for iteration in range(N_ITERATIONS):
        try:
            # Get next trial parameters
            if AX_AVAILABLE:
                parameters, trial_index = ax_client.get_next_trial()
            else:
                parameters, trial_index = ax_client.get_next_trial()
            
            print(f"🔬 Iteration {iteration + 1}/{N_ITERATIONS}")
            print(f"   Trial {trial_index}: {parameters}")
            
            # Evaluate objective function
            objective_value = hartmann6_objective(parameters)
            print(f"   Result: {objective_value:.4f}")
            
            # Complete the trial
            if AX_AVAILABLE:
                ax_client.complete_trial(trial_index=trial_index, raw_data=objective_value)
            else:
                ax_client.complete_trial(trial_index, parameters, objective_value)
            
            # Get current best
            best_parameters, best_objective = ax_client.get_best_parameters()
            
            # Calculate confidence (simple heuristic)
            confidence = min(0.95, (iteration + 1) * 0.08 + 0.1)
            
            # Store iteration results
            iteration_result = {
                "iteration": iteration + 1,
                "trial_index": trial_index,
                "parameters": parameters,
                "objective_value": objective_value,
                "best_parameters": best_parameters,
                "best_objective": best_objective,
                "confidence": confidence
            }
            results["iterations"].append(iteration_result)
            results["best_parameters"] = best_parameters
            results["best_objective"] = best_objective
            
            # Save checkpoint
            checkpoint_id = checkpoint_handler.save_checkpoint(
                EXPERIMENT_NAME, iteration + 1, best_parameters, best_objective,
                {"confidence": confidence, "trial_objective": objective_value}
            )
            
            print(f"   💾 Checkpoint saved: {checkpoint_id}")
            print(f"   🏆 Current best: {best_objective:.4f}")
            
            # Human-in-the-loop decision point
            if confidence < MIN_CONFIDENCE_THRESHOLD or (iteration + 1) % 3 == 0:
                print(f"   🤔 Requesting human input (confidence: {confidence:.1%})")
                
                human_decision = get_human_input(
                    iteration + 1, confidence, best_parameters, best_objective,
                    parameters, objective_value
                )
                
                results["human_decisions"].append({
                    "iteration": iteration + 1,
                    "confidence": confidence,
                    "decision": human_decision
                })
                
                if not human_decision["continue"]:
                    print("🛑 Human requested to stop optimization")
                    break
                    
                if human_decision["confidence_override"]:
                    confidence = human_decision["confidence_override"]
                    print(f"🎛️  Confidence overridden to {confidence:.1%}")
            
            print(f"   ✅ Iteration {iteration + 1} completed\n")
            
        except Exception as e:
            print(f"❌ Error in iteration {iteration + 1}: {e}")
            # In a real implementation, we'd pause for human intervention
            print("   Continuing despite error for demo purposes...\n")
    
    return results

# Run the optimization
optimization_results = run_optimization_loop()

## Results Analysis and Visualization

In [None]:
# Analyze results
print("\n" + "="*80)
print("📊 OPTIMIZATION RESULTS SUMMARY")
print("="*80)

print(f"🎯 Experiment: {optimization_results['experiment_name']}")
print(f"🔢 Total iterations: {len(optimization_results['iterations'])}")
print(f"🏆 Best objective found: {optimization_results['best_objective']:.6f}")
print(f"📍 Best parameters: {optimization_results['best_parameters']}")
print(f"🤝 Human interventions: {len(optimization_results['human_decisions'])}")

# Known global minimum for Hartmann6 is approximately -3.32237
known_minimum = -3.32237
gap = optimization_results['best_objective'] - known_minimum
print(f"📏 Gap to known global minimum: {gap:.6f}")

print("\n📈 ITERATION HISTORY:")
for i, iteration in enumerate(optimization_results['iterations']):
    marker = "🎯" if iteration['objective_value'] == optimization_results['best_objective'] else "  "
    print(f"{marker} Iter {iteration['iteration']:2d}: obj = {iteration['objective_value']:8.4f}, "
          f"best = {iteration['best_objective']:8.4f}, conf = {iteration['confidence']:.1%}")

print("\n🤝 HUMAN DECISIONS:")
for decision in optimization_results['human_decisions']:
    print(f"   Iter {decision['iteration']:2d}: {decision['decision']['comments']} "
          f"(confidence: {decision['confidence']:.1%})")

In [None]:
# Plot optimization progress
plt.figure(figsize=(12, 8))

# Extract data for plotting
iterations = [r['iteration'] for r in optimization_results['iterations']]
objectives = [r['objective_value'] for r in optimization_results['iterations']]
best_objectives = [r['best_objective'] for r in optimization_results['iterations']]
confidences = [r['confidence'] for r in optimization_results['iterations']]

# Create subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

# Plot 1: Objective values
ax1.plot(iterations, objectives, 'bo-', label='Trial Objectives', alpha=0.7)
ax1.plot(iterations, best_objectives, 'ro-', label='Best So Far', linewidth=2)
ax1.axhline(y=known_minimum, color='g', linestyle='--', label=f'Known Global Min ({known_minimum:.3f})')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Objective Value')
ax1.set_title('Bayesian Optimization Progress')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Confidence
ax2.plot(iterations, confidences, 'go-', label='Algorithm Confidence')
ax2.axhline(y=MIN_CONFIDENCE_THRESHOLD, color='r', linestyle='--', 
           label=f'HiTL Threshold ({MIN_CONFIDENCE_THRESHOLD:.1%})')
ax2.set_xlabel('Iteration')
ax2.set_ylabel('Confidence')
ax2.set_title('Algorithm Confidence Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1)

# Mark human intervention points
hitl_iterations = [d['iteration'] for d in optimization_results['human_decisions']]
for iter_num in hitl_iterations:
    ax1.axvline(x=iter_num, color='purple', linestyle=':', alpha=0.7)
    ax2.axvline(x=iter_num, color='purple', linestyle=':', alpha=0.7)

plt.tight_layout()
plt.show()

print(f"\n📊 Purple dotted lines indicate human-in-the-loop intervention points")

## Checkpoint Analysis

Let's examine the checkpoints that were saved during optimization:

In [None]:
# Analyze checkpoints
checkpoints = checkpoint_handler.list_checkpoints()

print("💾 CHECKPOINT ANALYSIS")
print("="*50)
print(f"Total checkpoints saved: {len(checkpoints)}")

if checkpoints:
    latest = checkpoints[-1]
    print(f"\n📄 Latest checkpoint:")
    print(f"   Experiment: {latest['experiment_name']}")
    print(f"   Iteration: {latest['iteration']}")
    print(f"   Best objective: {latest['best_objective']:.6f}")
    print(f"   Timestamp: {latest['timestamp']}")
    
    print(f"\n📊 Checkpoint history:")
    for i, cp in enumerate(checkpoints):
        print(f"   {i+1:2d}. Iter {cp['iteration']:2d}: {cp['best_objective']:8.4f} "
              f"({cp['timestamp'][:19]})")

print("\n🔄 These checkpoints can be used to resume optimization after interruptions.")

## Summary and Next Steps

This notebook demonstrated a complete Bayesian optimization workflow with:

### ✅ What We Accomplished
- **Bayesian Optimization**: Used Ax (or mock implementation) to optimize Hartmann6
- **Human-in-the-Loop**: Integrated decision points for human oversight
- **Checkpointing**: Saved optimization state for restart capability
- **Monitoring**: Visualized progress and algorithm confidence
- **Error Handling**: Demonstrated robust error recovery

### 🚀 Next Steps for Production Use

1. **Setup Full Environment**:
   ```bash
   pip install ax-platform prefect pymongo
   ```

2. **Configure MongoDB**: Set up persistent storage for checkpoints

3. **Deploy to Prefect**: Use the full workflow script for production

4. **Customize for Your Problem**: Replace Hartmann6 with your objective function

5. **Add Constraints**: Implement parameter constraints and feasibility checks

### 📚 Additional Resources
- [Ax Documentation](https://ax.dev/)
- [Prefect Docs](https://docs.prefect.io/)
- [Full Tutorial Setup Guide](../docs/ax_tutorial_setup.md)
- [Production Script](ax_bayesian_optimization_hitl.py)

In [None]:
# Export results for further analysis
import json

# Save results to file
results_file = f"/tmp/{EXPERIMENT_NAME}_results.json"
with open(results_file, 'w') as f:
    json.dump(optimization_results, f, indent=2, default=str)

print(f"📁 Results saved to: {results_file}")
print("\n🎉 Tutorial completed successfully!")
print("\nTo run the full production version:")
print("   python ax_bayesian_optimization_hitl.py")