In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm
import json

from src.algorithms.cql import CQL
from src.data.replay_buffer import OfflineReplayBuffer
from src.environments.icu_sepsis_wrapper import create_sepsis_env
from src.utils.evaluation import evaluate_policy

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

## 1. The Role of α in CQL

The parameter α controls the strength of conservative regularization:

- **α = 0**: Standard DQN (no conservative penalty)
- **Small α (0.1-0.5)**: Light regularization, may overestimate
- **Medium α (1.0-5.0)**: Balanced, typically best performance
- **Large α (10.0+)**: Very conservative, may underperform

The optimal α depends on:
1. Dataset quality
2. Coverage of the behavior policy
3. Task difficulty

In [None]:
# Load dataset
dataset_path = Path('../data/offline_datasets/behavior_policy.pkl')
buffer = OfflineReplayBuffer.load(dataset_path)
print(f"Loaded dataset with {len(buffer)} transitions")

# Create evaluation environment
eval_env = create_sepsis_env(use_action_masking=True)

## 2. Define Alpha Values to Test

In [None]:
# Alpha values to sweep
alpha_values = [0.0, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0]

# Training configuration (reduced for notebook)
n_iterations = 5000  # Increase for better results
batch_size = 256
n_eval_episodes = 50
seeds = [42, 123, 456]  # Multiple seeds for statistical significance

## 3. Run Alpha Sweep

In [None]:
def train_and_evaluate_cql(alpha, buffer, eval_env, n_iterations, seed=42):
    """Train CQL with given alpha and evaluate."""
    # Set seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Create agent
    agent = CQL(
        state_dim=1,
        action_dim=25,
        hidden_dims=[256, 256],
        learning_rate=3e-4,
        gamma=0.99,
        tau=0.005,
        alpha=alpha,
    )
    
    # Training
    for _ in range(n_iterations):
        batch = buffer.sample(batch_size)
        agent.update(batch)
    
    # Evaluation
    results = evaluate_policy(
        env=eval_env,
        policy=agent,
        n_episodes=n_eval_episodes,
        seed=seed,
    )
    
    return results

In [None]:
# Run sweep
sweep_results = {}

for alpha in tqdm(alpha_values, desc="Alpha sweep"):
    alpha_results = []
    
    for seed in seeds:
        print(f"  Training α={alpha}, seed={seed}")
        results = train_and_evaluate_cql(
            alpha=alpha,
            buffer=buffer,
            eval_env=eval_env,
            n_iterations=n_iterations,
            seed=seed,
        )
        alpha_results.append(results)
    
    # Aggregate results
    survival_rates = [r['survival_rate'] for r in alpha_results]
    mean_returns = [r['mean_return'] for r in alpha_results]
    
    sweep_results[alpha] = {
        'survival_rates': survival_rates,
        'mean_survival': np.mean(survival_rates),
        'std_survival': np.std(survival_rates),
        'mean_return': np.mean(mean_returns),
        'std_return': np.std(mean_returns),
    }
    
    print(f"  α={alpha}: Survival={np.mean(survival_rates):.1%} ± {np.std(survival_rates):.1%}")

## 4. Visualise Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

alphas = list(sweep_results.keys())
mean_survivals = [sweep_results[a]['mean_survival'] for a in alphas]
std_survivals = [sweep_results[a]['std_survival'] for a in alphas]
mean_returns = [sweep_results[a]['mean_return'] for a in alphas]
std_returns = [sweep_results[a]['std_return'] for a in alphas]

# Survival rate vs alpha
axes[0].errorbar(alphas, mean_survivals, yerr=std_survivals, 
                 fmt='o-', linewidth=2, markersize=10, capsize=5)
axes[0].axhline(y=0.8, color='g', linestyle='--', alpha=0.7, label='Target (80%)')
axes[0].set_xlabel('α (CQL Conservatism)', fontsize=12)
axes[0].set_ylabel('Survival Rate', fontsize=12)
axes[0].set_title('Effect of α on Survival Rate', fontsize=14)
axes[0].set_xscale('symlog', linthresh=0.1)
axes[0].set_ylim([0, 1])
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Return vs alpha
axes[1].errorbar(alphas, mean_returns, yerr=std_returns,
                 fmt='s-', linewidth=2, markersize=10, capsize=5, color='orange')
axes[1].set_xlabel('α (CQL Conservatism)', fontsize=12)
axes[1].set_ylabel('Mean Return', fontsize=12)
axes[1].set_title('Effect of α on Return', fontsize=14)
axes[1].set_xscale('symlog', linthresh=0.1)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Bar chart comparison
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(alphas))
width = 0.6

bars = ax.bar(x, mean_survivals, width, yerr=std_survivals, 
              capsize=5, color='steelblue', edgecolor='black', alpha=0.8)

# Highlight best performer
best_idx = np.argmax(mean_survivals)
bars[best_idx].set_color('green')
bars[best_idx].set_alpha(1.0)

ax.set_xlabel('α (CQL Conservatism)', fontsize=12)
ax.set_ylabel('Survival Rate', fontsize=12)
ax.set_title('Alpha Comparison (Best in Green)', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels([f'α={a}' for a in alphas])
ax.axhline(y=0.8, color='red', linestyle='--', alpha=0.7, label='Target (80%)')
ax.set_ylim([0, 1])
ax.legend()

# Add value labels
for bar, mean, std in zip(bars, mean_survivals, std_survivals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.02,
            f'{mean:.1%}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

## 5. Find Optimal Alpha

In [None]:
# Find best alpha
best_alpha = max(sweep_results.keys(), 
                 key=lambda a: sweep_results[a]['mean_survival'])

print("Results Summary:")
print("=" * 50)
for alpha in alphas:
    r = sweep_results[alpha]
    marker = " <-- BEST" if alpha == best_alpha else ""
    print(f"α={alpha:5.1f}: {r['mean_survival']:.1%} ± {r['std_survival']:.1%}{marker}")

print("\n" + "=" * 50)
print(f"Optimal α: {best_alpha}")
print(f"Best Survival Rate: {sweep_results[best_alpha]['mean_survival']:.1%}")

## 6. Save Results

In [None]:
# Save sweep results
output_dir = Path('../results/alpha_sweep')
output_dir.mkdir(parents=True, exist_ok=True)

# Convert to JSON-serializable format
json_results = {
    str(k): {
        'mean_survival': v['mean_survival'],
        'std_survival': v['std_survival'],
        'survival_rates': v['survival_rates'],
    }
    for k, v in sweep_results.items()
}

with open(output_dir / 'alpha_sweep_results.json', 'w') as f:
    json.dump(json_results, f, indent=2)

print(f"Results saved to: {output_dir / 'alpha_sweep_results.json'}")

## Key Takeaways

1. **α = 0 (No CQL)**: Often overestimates, leading to poor policies
2. **Optimal α**: Typically in the range [0.5, 5.0] for this environment
3. **High α**: Too conservative, may not improve over behavior policy
4. **Variance**: Multiple seeds are essential for reliable comparisons