In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm
from scipy import stats

from src.algorithms.cql import CQL
from src.algorithms.bc import BehaviorCloning
from src.algorithms.dqn import DQN
from src.data.replay_buffer import OfflineReplayBuffer
from src.environments.icu_sepsis_wrapper import create_sepsis_env
from src.utils.evaluation import evaluate_policy, compute_confidence_intervals

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

## 1. Load Data and Setup

In [None]:
# Load dataset
buffer = OfflineReplayBuffer.load('../data/offline_datasets/behavior_policy.pkl')
print(f"Dataset size: {len(buffer)} transitions")

# Create evaluation environment
eval_env = create_sepsis_env(use_action_masking=True)

# Configuration
n_iterations = 10000  # Increase for better results
batch_size = 256
n_eval_episodes = 100
seeds = [42, 123, 456, 789, 1000]  # 5 seeds for significance

## 2. Define Algorithms

### Algorithm Overview:

1. **CQL**: Conservative Q-Learning with regularization
2. **DQN (Offline)**: Standard DQN without CQL penalty
3. **BC**: Behavior Cloning (supervised learning)

In [None]:
def train_cql(buffer, n_iterations, seed=42, alpha=1.0):
    """Train CQL agent."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    agent = CQL(
        state_dim=1,
        action_dim=25,
        hidden_dims=[256, 256],
        learning_rate=3e-4,
        alpha=alpha,
    )
    
    for _ in range(n_iterations):
        batch = buffer.sample(batch_size)
        agent.update(batch)
    
    return agent


def train_dqn(buffer, n_iterations, seed=42):
    """Train offline DQN agent (CQL with alpha=0)."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    agent = DQN(
        state_dim=1,
        action_dim=25,
        hidden_dims=[256, 256],
        learning_rate=3e-4,
    )
    
    for _ in range(n_iterations):
        batch = buffer.sample(batch_size)
        agent.update(batch)
    
    return agent


def train_bc(buffer, n_iterations, seed=42):
    """Train Behavior Cloning agent."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    agent = BehaviorCloning(
        state_dim=1,
        action_dim=25,
        hidden_dims=[256, 256],
        learning_rate=3e-4,
    )
    
    for _ in range(n_iterations):
        batch = buffer.sample(batch_size)
        agent.update(batch)
    
    return agent

## 3. Train All Algorithms

In [None]:
algorithms = {
    'CQL': train_cql,
    'DQN': train_dqn,
    'BC': train_bc,
}

results = {name: [] for name in algorithms}

for name, train_fn in algorithms.items():
    print(f"\nTraining {name}...")
    
    for seed in tqdm(seeds, desc=name):
        # Train
        agent = train_fn(buffer, n_iterations, seed=seed)
        
        # Evaluate
        eval_results = evaluate_policy(
            env=eval_env,
            policy=agent,
            n_episodes=n_eval_episodes,
            seed=seed,
        )
        
        results[name].append(eval_results)
        print(f"  Seed {seed}: Survival={eval_results['survival_rate']:.1%}")

In [None]:
# Add random baseline
print("\nEvaluating Random baseline...")

def random_policy(state, admissible_actions=None):
    return np.random.randint(0, 25)

results['Random'] = []
for seed in seeds:
    eval_results = evaluate_policy(
        env=eval_env,
        policy=random_policy,
        n_episodes=n_eval_episodes,
        seed=seed,
    )
    results['Random'].append(eval_results)

print(f"  Random: {np.mean([r['survival_rate'] for r in results['Random']]):.1%}")

## 4. Statistical Analysis

In [None]:
# Aggregate results
summary = {}

for name, algo_results in results.items():
    survival_rates = [r['survival_rate'] for r in algo_results]
    mean_returns = [r['mean_return'] for r in algo_results]
    
    mean, ci_low, ci_high = compute_confidence_intervals(np.array(survival_rates))
    
    summary[name] = {
        'survival_rates': survival_rates,
        'mean_survival': np.mean(survival_rates),
        'std_survival': np.std(survival_rates),
        'ci_low': ci_low,
        'ci_high': ci_high,
        'mean_return': np.mean(mean_returns),
    }

# Print summary table
print("\nResults Summary:")
print("=" * 65)
print(f"{'Algorithm':<12} {'Survival Rate':>15} {'95% CI':>20} {'Return':>12}")
print("-" * 65)

for name in ['CQL', 'DQN', 'BC', 'Random']:
    s = summary[name]
    print(f"{name:<12} {s['mean_survival']:>14.1%} "
          f"[{s['ci_low']:>7.1%}, {s['ci_high']:>7.1%}] "
          f"{s['mean_return']:>11.3f}")

In [None]:
# Statistical significance tests (t-tests)
print("\nStatistical Significance (t-tests vs CQL):")
print("=" * 50)

cql_survival = summary['CQL']['survival_rates']

for name in ['DQN', 'BC', 'Random']:
    other_survival = summary[name]['survival_rates']
    t_stat, p_value = stats.ttest_ind(cql_survival, other_survival)
    
    significance = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else "ns"
    print(f"CQL vs {name:<8}: t={t_stat:>6.2f}, p={p_value:.4f} {significance}")

## 5. Visualizations

In [None]:
# Bar plot with error bars
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

algo_names = ['CQL', 'DQN', 'BC', 'Random']
colors = ['steelblue', 'orange', 'green', 'gray']

# Survival rate
means = [summary[n]['mean_survival'] for n in algo_names]
stds = [summary[n]['std_survival'] for n in algo_names]

bars = axes[0].bar(algo_names, means, yerr=stds, capsize=5,
                   color=colors, edgecolor='black', alpha=0.8)
axes[0].axhline(y=0.8, color='red', linestyle='--', alpha=0.7, label='Target (80%)')
axes[0].set_ylabel('Survival Rate', fontsize=12)
axes[0].set_title('Algorithm Comparison: Survival Rate', fontsize=14)
axes[0].set_ylim([0, 1])
axes[0].legend()

# Add value labels
for bar, mean, std in zip(bars, means, stds):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.02,
                 f'{mean:.1%}', ha='center', va='bottom', fontsize=11, fontweight='bold')

# Box plot
survival_data = [summary[n]['survival_rates'] for n in algo_names]
bp = axes[1].boxplot(survival_data, labels=algo_names, patch_artist=True)

for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

axes[1].axhline(y=0.8, color='red', linestyle='--', alpha=0.7, label='Target (80%)')
axes[1].set_ylabel('Survival Rate', fontsize=12)
axes[1].set_title('Distribution of Results', fontsize=14)
axes[1].set_ylim([0, 1])
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Radar chart for multiple metrics
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))

metrics = ['Survival', 'Return', 'Stability']
n_metrics = len(metrics)

# Normalize metrics to [0, 1]
normalized = {}
for name in algo_names:
    s = summary[name]
    normalized[name] = [
        s['mean_survival'],  # Already 0-1
        (s['mean_return'] + 1) / 2,  # Normalize return from [-1, 1] to [0, 1]
        1 - s['std_survival'],  # Lower std = higher stability
    ]

angles = np.linspace(0, 2*np.pi, n_metrics, endpoint=False).tolist()
angles += angles[:1]  # Close the polygon

for name, color in zip(algo_names, colors):
    values = normalized[name] + [normalized[name][0]]
    ax.plot(angles, values, 'o-', linewidth=2, label=name, color=color)
    ax.fill(angles, values, alpha=0.1, color=color)

ax.set_xticks(angles[:-1])
ax.set_xticklabels(metrics)
ax.set_ylim([0, 1])
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
ax.set_title('Multi-Metric Comparison', fontsize=14)

plt.tight_layout()
plt.show()

## 6. Key Findings

### Summary:

1. **CQL** outperforms baselines due to conservative Q-value estimation
2. **DQN (Offline)** suffers from distribution shift and overestimation
3. **BC** is limited by the quality of the behavior policy
4. **Random** provides a lower bound baseline

### Why CQL Works:

- Penalizes Q-values for out-of-distribution actions
- Maintains conservative estimates, avoiding harmful extrapolation
- Better suited for offline RL where exploration is not possible