# AL-FEP: Complete Tutorial

This notebook demonstrates the complete AL-FEP workflow for molecular discovery using active learning and reinforcement learning with Free Energy Perturbation (FEP) and docking oracles.

## Overview

The AL-FEP framework combines:
1. **Multiple Oracles**: Docking, FEP, and ML-FEP for molecular evaluation
2. **Active Learning**: Smart selection of molecules to evaluate
3. **Reinforcement Learning**: Agent-based molecular generation
4. **Target Optimization**: Focus on specific protein targets (e.g., 7JVR SARS-CoV-2 Main Protease)

## Setup and Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# AL-FEP imports
from al_fep.utils.config import load_config
from al_fep.molecular import MolecularDataset, MolecularFeaturizer
from al_fep.oracles import DockingOracle, MLFEPOracle
from al_fep.active_learning import (
    UncertaintySampling, 
    QueryByCommittee, 
    ExpectedImprovement,
    ActiveLearningPipeline
)
from al_fep.reinforcement import PPOAgent, MolecularEnvironment

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("AL-FEP Tutorial - All imports successful!")

## 1. Configuration and Target Setup

In [None]:
# Load configuration for 7JVR target
config = load_config('config/targets/7jvr.yaml')
print("Target Configuration:")
print(f"Target: {config['target']['name']}")
print(f"PDB ID: {config['target']['pdb_id']}")
print(f"Binding site center: {config['target']['binding_site']['center']}")
print(f"Binding site size: {config['target']['binding_site']['size']}")

# Display known active compounds
print(f"\nKnown active compounds: {len(config['target']['known_actives'])}")
for i, compound in enumerate(config['target']['known_actives'][:3]):
    print(f"  {i+1}. {compound['name']}: {compound['smiles']}")
print("  ...")

## 2. Molecular Dataset Creation and Featurization

In [None]:
# Create sample molecular dataset
sample_smiles = [
    "CCO",  # Ethanol
    "CC(=O)O",  # Acetic acid
    "c1ccccc1",  # Benzene
    "CC(C)O",  # Isopropanol
    "CCN(CC)CC",  # Triethylamine
    "CC(=O)Nc1ccc(O)cc1",  # Acetaminophen
    "CC(C)(C)c1ccc(O)cc1",  # 4-tert-butylphenol
    "COc1ccc(CC(=O)O)cc1",  # 4-methoxyphenylacetic acid
    "Nc1ccc(C(=O)O)cc1",  # 4-aminobenzoic acid
    "CC1=CC(=O)C=CC1=O",  # 2-methyl-1,4-benzoquinone
    # Add some drug-like molecules
    "CC(C)Cc1ccc(C(C)C(=O)O)cc1",  # Ibuprofen
    "CC(=O)Oc1ccccc1C(=O)O",  # Aspirin
    "CN1CCC[C@H]1c2cccnc2",  # Nicotine
    "CCN(CC)CCNC(=O)c1cc(Cl)ccc1N",  # Procainamide
    "Cc1ccc(C)c(S(=O)(=O)Nc2nccs2)c1"  # Sulfamethiazole
]

# Generate synthetic binding affinities for demonstration
np.random.seed(42)
synthetic_targets = np.random.normal(6.0, 1.5, len(sample_smiles))  # pIC50 values

# Create dataset
dataset = MolecularDataset()
for smiles, target in zip(sample_smiles, synthetic_targets):
    dataset.add_molecule(smiles, target_value=target)

print(f"Created dataset with {len(dataset)} molecules")
print(f"Dataset shape: {dataset.data.shape}")
print("\nDataset preview:")
print(dataset.data.head())

In [None]:
# Initialize molecular featurizer
featurizer = MolecularFeaturizer(
    fingerprint_type="morgan",
    fingerprint_radius=2,
    fingerprint_bits=1024,
    include_descriptors=True,
    include_fragments=False
)

print(f"Featurizer configuration:")
print(f"  Fingerprint: {featurizer.fingerprint_type}")
print(f"  Bits: {featurizer.fingerprint_bits}")
print(f"  Include descriptors: {featurizer.include_descriptors}")
print(f"  Feature dimension: {featurizer.get_feature_dim()}")

# Featurize molecules
features, valid_smiles = featurizer.featurize_molecules(sample_smiles)
print(f"\nFeaturized {len(valid_smiles)} molecules")
print(f"Feature matrix shape: {features.shape}")

# Show feature distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.hist(features.mean(axis=0), bins=30, alpha=0.7)
plt.title('Feature Means Distribution')
plt.xlabel('Mean Value')
plt.ylabel('Count')

plt.subplot(1, 3, 2)
plt.hist(features.std(axis=0), bins=30, alpha=0.7)
plt.title('Feature Std Distribution')
plt.xlabel('Standard Deviation')
plt.ylabel('Count')

plt.subplot(1, 3, 3)
plt.hist(synthetic_targets, bins=10, alpha=0.7)
plt.title('Target Values Distribution')
plt.xlabel('pIC50')
plt.ylabel('Count')

plt.tight_layout()
plt.show()

## 3. Oracle Setup and Evaluation

In [None]:
# Initialize ML-FEP Oracle (faster for demonstration)
ml_fep_oracle = MLFEPOracle()

# Train oracle with initial data
training_indices = np.random.choice(len(valid_smiles), size=10, replace=False)
training_smiles = [valid_smiles[i] for i in training_indices]
training_features = features[training_indices]
training_targets = synthetic_targets[training_indices]

ml_fep_oracle.train(training_features, training_targets)
print(f"Trained ML-FEP oracle with {len(training_smiles)} molecules")

# Evaluate on test set
test_indices = [i for i in range(len(valid_smiles)) if i not in training_indices]
test_smiles = [valid_smiles[i] for i in test_indices]
test_features = features[test_indices]
test_targets = synthetic_targets[test_indices]

# Get predictions with uncertainty
predictions = []
uncertainties = []

for smiles in test_smiles:
    pred, unc = ml_fep_oracle.evaluate_with_uncertainty(smiles)
    predictions.append(pred)
    uncertainties.append(unc)

predictions = np.array(predictions)
uncertainties = np.array(uncertainties)

print(f"\nOracle evaluation statistics:")
print(f"  Mean prediction: {predictions.mean():.2f} ± {predictions.std():.2f}")
print(f"  Mean uncertainty: {uncertainties.mean():.3f} ± {uncertainties.std():.3f}")
print(f"  Correlation with true values: {np.corrcoef(predictions, test_targets)[0,1]:.3f}")

In [None]:
# Visualize oracle performance
plt.figure(figsize=(15, 5))

# Prediction vs True values
plt.subplot(1, 3, 1)
plt.scatter(test_targets, predictions, alpha=0.7, s=60)
plt.plot([test_targets.min(), test_targets.max()], [test_targets.min(), test_targets.max()], 'r--', alpha=0.8)
plt.xlabel('True pIC50')
plt.ylabel('Predicted pIC50')
plt.title('Oracle Predictions vs True Values')
plt.grid(True, alpha=0.3)

# Uncertainty distribution
plt.subplot(1, 3, 2)
plt.hist(uncertainties, bins=15, alpha=0.7, edgecolor='black')
plt.xlabel('Prediction Uncertainty')
plt.ylabel('Count')
plt.title('Uncertainty Distribution')
plt.grid(True, alpha=0.3)

# Uncertainty vs Error
plt.subplot(1, 3, 3)
errors = np.abs(predictions - test_targets)
plt.scatter(uncertainties, errors, alpha=0.7, s=60)
plt.xlabel('Prediction Uncertainty')
plt.ylabel('Absolute Error')
plt.title('Uncertainty vs Prediction Error')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print oracle statistics
print(f"Oracle call statistics:")
print(f"  Total calls: {ml_fep_oracle.call_count}")
print(f"  Cache hits: {ml_fep_oracle.cache_hits}")
print(f"  Failed evaluations: {ml_fep_oracle.failed_evaluations}")

## 4. Active Learning Strategies Comparison

In [None]:
# Initialize different active learning strategies
strategies = {
    'Uncertainty Sampling': UncertaintySampling(uncertainty_method='variance'),
    'Query by Committee': QueryByCommittee(ensemble_size=5, diversity_weight=0.1),
    'Expected Improvement': ExpectedImprovement(xi=0.01, use_gp=False)  # Use RF for speed
}

# Simulate active learning selection
n_queries = 3
candidate_features = test_features
candidate_ids = [f"mol_{i}" for i in range(len(test_smiles))]

print(f"Comparing active learning strategies for selecting {n_queries} molecules:")
print(f"Candidate pool size: {len(candidate_features)}")
print("\n" + "="*60)

strategy_results = {}

for strategy_name, strategy in strategies.items():
    print(f"\n{strategy_name}:")
    
    # Prepare strategy-specific requirements
    if isinstance(strategy, QueryByCommittee):
        strategy.train_committee(training_features, training_targets)
    elif isinstance(strategy, ExpectedImprovement):
        strategy.train_model(training_features, training_targets)
    
    # Select queries
    selected_features, selected_ids = strategy.select_queries(
        candidate_features,
        candidate_ids,
        n_queries=n_queries,
        labeled_features=training_features,
        uncertainties=uncertainties
    )
    
    # Get corresponding SMILES and true values
    selected_indices = [int(id.split('_')[1]) for id in selected_ids]
    selected_smiles = [test_smiles[i] for i in selected_indices]
    selected_true_values = [test_targets[i] for i in selected_indices]
    
    strategy_results[strategy_name] = {
        'smiles': selected_smiles,
        'true_values': selected_true_values,
        'indices': selected_indices
    }
    
    print(f"  Selected molecules:")
    for i, (smiles, value) in enumerate(zip(selected_smiles, selected_true_values)):
        print(f"    {i+1}. {smiles} (pIC50: {value:.2f})")
    
    print(f"  Mean true value: {np.mean(selected_true_values):.2f}")
    print(f"  Max true value: {np.max(selected_true_values):.2f}")

In [None]:
# Visualize strategy comparison
plt.figure(figsize=(12, 8))

# Plot 1: Selected molecules by strategy
plt.subplot(2, 2, 1)
strategy_names = list(strategy_results.keys())
mean_values = [np.mean(strategy_results[name]['true_values']) for name in strategy_names]
max_values = [np.max(strategy_results[name]['true_values']) for name in strategy_names]

x = np.arange(len(strategy_names))
width = 0.35

plt.bar(x - width/2, mean_values, width, label='Mean pIC50', alpha=0.7)
plt.bar(x + width/2, max_values, width, label='Max pIC50', alpha=0.7)

plt.xlabel('Strategy')
plt.ylabel('pIC50')
plt.title('Selected Molecules Quality by Strategy')
plt.xticks(x, strategy_names, rotation=45)
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: All candidates with selections highlighted
plt.subplot(2, 2, 2)
plt.scatter(range(len(test_targets)), test_targets, alpha=0.5, s=30, label='All candidates')

colors = ['red', 'blue', 'green']
for i, (strategy_name, color) in enumerate(zip(strategy_names, colors)):
    indices = strategy_results[strategy_name]['indices']
    values = strategy_results[strategy_name]['true_values']
    plt.scatter(indices, values, s=100, alpha=0.8, color=color, 
               label=f'{strategy_name}', marker='s')

plt.xlabel('Molecule Index')
plt.ylabel('True pIC50')
plt.title('Selected Molecules Highlighted')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 3: Uncertainty vs True Value
plt.subplot(2, 2, 3)
plt.scatter(test_targets, uncertainties, alpha=0.6, s=40)
plt.xlabel('True pIC50')
plt.ylabel('Prediction Uncertainty')
plt.title('Uncertainty vs True Value')
plt.grid(True, alpha=0.3)

# Plot 4: Strategy selection overlap
plt.subplot(2, 2, 4)
all_selected = []
for strategy_name in strategy_names:
    all_selected.extend(strategy_results[strategy_name]['indices'])

unique_selected = list(set(all_selected))
selection_counts = [all_selected.count(idx) for idx in unique_selected]

plt.bar(range(len(unique_selected)), selection_counts, alpha=0.7)
plt.xlabel('Selected Molecule Index')
plt.ylabel('Number of Strategies Selecting')
plt.title('Strategy Selection Overlap')
plt.xticks(range(len(unique_selected)), unique_selected)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Active Learning Pipeline

In [None]:
# Set up active learning pipeline
al_pipeline = ActiveLearningPipeline(
    oracle=ml_fep_oracle,
    strategy=UncertaintySampling(uncertainty_method='variance'),
    featurizer=featurizer,
    batch_size=2,
    max_iterations=3
)

print("Active Learning Pipeline Configuration:")
print(f"  Oracle: {type(al_pipeline.oracle).__name__}")
print(f"  Strategy: {type(al_pipeline.strategy).__name__}")
print(f"  Batch size: {al_pipeline.batch_size}")
print(f"  Max iterations: {al_pipeline.max_iterations}")

# Initialize with training data
initial_dataset = MolecularDataset()
for smiles, target in zip(training_smiles, training_targets):
    initial_dataset.add_molecule(smiles, target_value=target)

# Define candidate pool (remaining molecules)
candidate_pool = test_smiles.copy()

print(f"\nStarting active learning with:")
print(f"  Initial dataset: {len(initial_dataset)} molecules")
print(f"  Candidate pool: {len(candidate_pool)} molecules")

In [None]:
# Run active learning iterations
history = al_pipeline.run(initial_dataset, candidate_pool)

print(f"\nActive Learning completed in {len(history)} iterations")
print("\nIteration Summary:")
for i, iteration_data in enumerate(history):
    print(f"  Iteration {i+1}:")
    print(f"    Selected: {iteration_data['selected_molecules']}")
    print(f"    Values: {[f'{v:.2f}' for v in iteration_data['oracle_values']]}")
    print(f"    Dataset size: {iteration_data['dataset_size']}")
    if 'model_performance' in iteration_data:
        perf = iteration_data['model_performance']
        print(f"    Model R²: {perf.get('r2', 'N/A'):.3f}" if isinstance(perf.get('r2'), (int, float)) else "    Model R²: N/A")
    print()

In [None]:
# Visualize active learning progress
plt.figure(figsize=(15, 5))

# Plot 1: Dataset growth
plt.subplot(1, 3, 1)
dataset_sizes = [len(initial_dataset)] + [item['dataset_size'] for item in history]
iterations = range(len(dataset_sizes))
plt.plot(iterations, dataset_sizes, 'o-', linewidth=2, markersize=8)
plt.xlabel('Iteration')
plt.ylabel('Dataset Size')
plt.title('Dataset Growth')
plt.grid(True, alpha=0.3)

# Plot 2: Selected molecule values
plt.subplot(1, 3, 2)
all_selected_values = []
iteration_labels = []
for i, item in enumerate(history):
    values = item['oracle_values']
    all_selected_values.extend(values)
    iteration_labels.extend([f'Iter {i+1}'] * len(values))

# Create box plot by iteration
iteration_nums = []
values_by_iter = []
for i, item in enumerate(history):
    iteration_nums.append(i+1)
    values_by_iter.append(item['oracle_values'])

plt.boxplot(values_by_iter, positions=iteration_nums)
plt.xlabel('Iteration')
plt.ylabel('Selected Molecule pIC50')
plt.title('Quality of Selected Molecules')
plt.grid(True, alpha=0.3)

# Plot 3: Model performance over time
plt.subplot(1, 3, 3)
r2_scores = []
for item in history:
    if 'model_performance' in item and 'r2' in item['model_performance']:
        r2 = item['model_performance']['r2']
        if isinstance(r2, (int, float)):
            r2_scores.append(r2)
        else:
            r2_scores.append(0)
    else:
        r2_scores.append(0)

if r2_scores:
    plt.plot(range(1, len(r2_scores)+1), r2_scores, 'o-', linewidth=2, markersize=8)
    plt.xlabel('Iteration')
    plt.ylabel('Model R² Score')
    plt.title('Model Performance Over Time')
    plt.grid(True, alpha=0.3)
else:
    plt.text(0.5, 0.5, 'No R² scores available', ha='center', va='center', transform=plt.gca().transAxes)
    plt.title('Model Performance Over Time')

plt.tight_layout()
plt.show()

# Summary statistics
print("\nFinal Active Learning Statistics:")
print(f"  Total oracle calls: {ml_fep_oracle.call_count}")
print(f"  Final dataset size: {dataset_sizes[-1]}")
print(f"  Best molecule found: {max(all_selected_values):.2f} pIC50")
print(f"  Mean selected value: {np.mean(all_selected_values):.2f} pIC50")

## 6. Reinforcement Learning Setup

In [None]:
# Create molecular environment for RL
molecular_env = MolecularEnvironment(
    oracle=ml_fep_oracle,
    max_length=50,
    reward_threshold=7.0  # Target pIC50 threshold
)

print("Molecular Environment Configuration:")
print(f"  Oracle: {type(molecular_env.oracle).__name__}")
print(f"  Max molecule length: {molecular_env.max_length}")
print(f"  Reward threshold: {molecular_env.reward_threshold}")
print(f"  Action space size: {molecular_env.action_space.n}")
print(f"  Observation space shape: {molecular_env.observation_space.shape}")

# Initialize PPO agent
state_dim = molecular_env.observation_space.shape[0]
action_dim = molecular_env.action_space.n

ppo_agent = PPOAgent(
    state_dim=state_dim,
    action_dim=action_dim,
    lr=3e-4,
    gamma=0.99,
    clip_epsilon=0.2,
    hidden_dims=[256, 128]
)

print(f"\nPPO Agent Configuration:")
print(f"  State dimension: {state_dim}")
print(f"  Action dimension: {action_dim}")
print(f"  Learning rate: {ppo_agent.optimizer.param_groups[0]['lr']}")
print(f"  Network parameters: {sum(p.numel() for p in ppo_agent.network.parameters())}")

In [None]:
# Demonstrate RL episode
print("Running sample RL episode...")

state = molecular_env.reset()
total_reward = 0
episode_length = 0
actions_taken = []

print(f"Initial state shape: {state.shape}")
print(f"Initial state (first 10 elements): {state[:10]}")

for step in range(20):  # Limit to 20 steps for demonstration
    action, log_prob, value = ppo_agent.get_action(state)
    next_state, reward, done, info = molecular_env.step(action)
    
    actions_taken.append(action)
    total_reward += reward
    episode_length += 1
    
    print(f"Step {step+1}: Action={action}, Reward={reward:.3f}, Done={done}")
    
    if done:
        break
    
    state = next_state

print(f"\nEpisode Summary:")
print(f"  Total reward: {total_reward:.3f}")
print(f"  Episode length: {episode_length}")
print(f"  Actions taken: {actions_taken}")
print(f"  Final molecule: {info.get('molecule', 'N/A')}")
print(f"  Oracle evaluation: {info.get('oracle_score', 'N/A')}")

## 7. Integration: AL + RL Workflow

In [None]:
# Demonstration of integrated AL-RL workflow
print("Integrated AL-RL Molecular Discovery Workflow")
print("=" * 50)

# Step 1: Use active learning to identify promising regions
print("\n1. Active Learning Phase:")
best_molecules = []
for item in history:
    for mol, val in zip(item['selected_molecules'], item['oracle_values']):
        if val > 6.0:  # Threshold for "good" molecules
            best_molecules.append((mol, val))

print(f"   Found {len(best_molecules)} promising molecules with pIC50 > 6.0")
for mol, val in best_molecules:
    print(f"   - {mol}: {val:.2f}")

# Step 2: Extract features from best molecules for RL guidance
print("\n2. Feature Analysis:")
if best_molecules:
    best_smiles = [mol for mol, _ in best_molecules]
    best_features, _ = featurizer.featurize_molecules(best_smiles)
    
    if len(best_features) > 0:
        feature_profile = np.mean(best_features, axis=0)
        print(f"   Computed average feature profile from {len(best_features)} molecules")
        print(f"   Profile stats: mean={feature_profile.mean():.3f}, std={feature_profile.std():.3f}")
    else:
        print("   No valid features extracted")
else:
    print("   No promising molecules found")

# Step 3: Summary of oracle usage
print("\n3. Oracle Usage Summary:")
print(f"   Total oracle calls: {ml_fep_oracle.call_count}")
print(f"   Cache hits: {ml_fep_oracle.cache_hits}")
print(f"   Cache efficiency: {ml_fep_oracle.cache_hits / max(ml_fep_oracle.call_count, 1) * 100:.1f}%")
print(f"   Failed evaluations: {ml_fep_oracle.failed_evaluations}")

# Step 4: Recommendations for next steps
print("\n4. Next Steps Recommendations:")
print("   - Use identified molecular features to guide RL agent training")
print("   - Implement reward shaping based on similarity to best molecules")
print("   - Continue AL iterations with RL-generated candidates")
print("   - Validate top candidates with higher-fidelity oracles (FEP/Docking)")

## 8. Results Analysis and Visualization

In [None]:
# Create comprehensive results visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Molecular property space
ax = axes[0, 0]
valid_mols = [Chem.MolFromSmiles(smi) for smi in valid_smiles]
valid_mols = [mol for mol in valid_mols if mol is not None]

if valid_mols:
    from al_fep.molecular.featurizer import DescriptorCalculator
    
    mw_values = [DescriptorCalculator.lipinski_descriptors(mol)['MW'] for mol in valid_mols]
    logp_values = [DescriptorCalculator.lipinski_descriptors(mol)['LogP'] for mol in valid_mols]
    
    scatter = ax.scatter(mw_values, logp_values, c=synthetic_targets[:len(mw_values)], 
                        cmap='viridis', alpha=0.7, s=60)
    ax.set_xlabel('Molecular Weight')
    ax.set_ylabel('LogP')
    ax.set_title('Molecular Property Space')
    plt.colorbar(scatter, ax=ax, label='pIC50')
    ax.grid(True, alpha=0.3)

# 2. Active learning strategy comparison
ax = axes[0, 1]
strategy_names = list(strategy_results.keys())
mean_values = [np.mean(strategy_results[name]['true_values']) for name in strategy_names]
std_values = [np.std(strategy_results[name]['true_values']) for name in strategy_names]

bars = ax.bar(range(len(strategy_names)), mean_values, yerr=std_values, 
             capsize=5, alpha=0.7, color=['skyblue', 'lightcoral', 'lightgreen'])
ax.set_xlabel('Strategy')
ax.set_ylabel('Mean pIC50 of Selected Molecules')
ax.set_title('Active Learning Strategy Performance')
ax.set_xticks(range(len(strategy_names)))
ax.set_xticklabels([name.replace(' ', '\n') for name in strategy_names])
ax.grid(True, alpha=0.3)

# 3. Oracle performance over time
ax = axes[0, 2]
cumulative_calls = np.cumsum([len(item['selected_molecules']) for item in history])
cumulative_best = []
best_so_far = -np.inf
for item in history:
    current_best = max(item['oracle_values'])
    best_so_far = max(best_so_far, current_best)
    cumulative_best.append(best_so_far)

ax.plot([0] + list(cumulative_calls), [0] + cumulative_best, 'o-', linewidth=2, markersize=8)
ax.set_xlabel('Cumulative Oracle Calls')
ax.set_ylabel('Best pIC50 Found')
ax.set_title('Discovery Progress')
ax.grid(True, alpha=0.3)

# 4. Feature importance (if available)
ax = axes[1, 0]
if hasattr(ml_fep_oracle, 'model') and hasattr(ml_fep_oracle.model, 'feature_importances_'):
    importances = ml_fep_oracle.model.feature_importances_
    top_indices = np.argsort(importances)[-20:]  # Top 20 features
    
    ax.barh(range(len(top_indices)), importances[top_indices], alpha=0.7)
    ax.set_xlabel('Feature Importance')
    ax.set_ylabel('Feature Index')
    ax.set_title('Top 20 Feature Importances')
    ax.set_yticks(range(len(top_indices)))
    ax.set_yticklabels(top_indices)
else:
    ax.text(0.5, 0.5, 'Feature importance\nnot available', 
           ha='center', va='center', transform=ax.transAxes)
    ax.set_title('Feature Importances')

# 5. Uncertainty calibration
ax = axes[1, 1]
if len(predictions) > 0 and len(uncertainties) > 0:
    errors = np.abs(predictions - test_targets)
    
    # Bin uncertainties and compute mean error in each bin
    n_bins = 5
    unc_bins = np.linspace(uncertainties.min(), uncertainties.max(), n_bins + 1)
    bin_centers = (unc_bins[:-1] + unc_bins[1:]) / 2
    bin_errors = []
    
    for i in range(n_bins):
        mask = (uncertainties >= unc_bins[i]) & (uncertainties < unc_bins[i+1])
        if np.any(mask):
            bin_errors.append(np.mean(errors[mask]))
        else:
            bin_errors.append(0)
    
    ax.plot(bin_centers, bin_errors, 'o-', linewidth=2, markersize=8)
    ax.set_xlabel('Prediction Uncertainty')
    ax.set_ylabel('Mean Absolute Error')
    ax.set_title('Uncertainty Calibration')
    ax.grid(True, alpha=0.3)
else:
    ax.text(0.5, 0.5, 'No prediction data\navailable', 
           ha='center', va='center', transform=ax.transAxes)
    ax.set_title('Uncertainty Calibration')

# 6. Summary statistics
ax = axes[1, 2]
ax.axis('off')

summary_text = f"""
AL-FEP Results Summary
=====================

Dataset:
• Total molecules: {len(valid_smiles)}
• Training set: {len(training_smiles)}
• Test set: {len(test_smiles)}

Active Learning:
• Iterations: {len(history)}
• Molecules selected: {sum(len(item['selected_molecules']) for item in history)}
• Best pIC50 found: {max(all_selected_values):.2f}

Oracle Usage:
• Total calls: {ml_fep_oracle.call_count}
• Cache hits: {ml_fep_oracle.cache_hits}
• Failed evaluations: {ml_fep_oracle.failed_evaluations}

Model Performance:
• Prediction correlation: {np.corrcoef(predictions, test_targets)[0,1]:.3f}
• Mean uncertainty: {uncertainties.mean():.3f}
"""

ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, 
        verticalalignment='top', fontfamily='monospace', fontsize=10)

plt.tight_layout()
plt.show()

## 9. Conclusions and Next Steps

This tutorial demonstrated the complete AL-FEP framework for molecular discovery:

### Key Achievements:
1. **Multi-Oracle Integration**: Successfully integrated ML-FEP oracle with uncertainty quantification
2. **Active Learning Strategies**: Compared multiple strategies for intelligent molecule selection
3. **Reinforcement Learning**: Set up RL environment for molecular generation
4. **Comprehensive Analysis**: Analyzed molecular properties, strategy performance, and discovery progress

### Next Steps for Production Use:
1. **Scale to Real Datasets**: Use ChEMBL, ZINC, or custom molecular libraries
2. **Integrate FEP Calculations**: Add OpenMM-based FEP calculations for accurate binding affinity prediction
3. **Docking Integration**: Set up AutoDock Vina or other docking software
4. **RL Training**: Train PPO agents on molecular generation tasks
5. **Target-Specific Optimization**: Focus on specific protein targets with known binding sites

### Framework Strengths:
- **Modular Design**: Easy to swap oracles, strategies, and algorithms
- **Uncertainty Quantification**: Built-in uncertainty estimation for active learning
- **Scalable**: Designed for large-scale molecular discovery campaigns
- **Configurable**: YAML-based configuration for different targets and experiments

In [None]:
# Final summary and recommendations
print("🎯 AL-FEP Tutorial Complete!")
print("\n" + "="*60)
print("WORKFLOW SUMMARY:")
print("="*60)
print(f"1. Dataset: Created and featurized {len(valid_smiles)} molecules")
print(f"2. Oracle: Trained ML-FEP with {len(training_smiles)} molecules")
print(f"3. Active Learning: Ran {len(history)} iterations, selected {sum(len(item['selected_molecules']) for item in history)} molecules")
print(f"4. Best Discovery: Found molecule with pIC50 = {max(all_selected_values):.2f}")
print(f"5. Efficiency: {ml_fep_oracle.cache_hits / max(ml_fep_oracle.call_count, 1) * 100:.1f}% cache hit rate")

print("\n" + "="*60)
print("NEXT STEPS:")
print("="*60)
print("• Run on larger molecular databases (ChEMBL, ZINC)")
print("• Integrate real FEP calculations with OpenMM")
print("• Set up AutoDock Vina for docking oracle")
print("• Train RL agents for molecular generation")
print("• Focus on specific targets (7JVR, etc.)")
print("• Scale to HPC environments for large campaigns")

print("\n🧬 Happy molecular discovery! 🧬")