In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm

from src.algorithms.cql import CQL, QNetwork
from src.data.replay_buffer import OfflineReplayBuffer
from src.environments.icu_sepsis_wrapper import create_sepsis_env
from src.utils.evaluation import evaluate_policy

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Offline Dataset

In [None]:
# Load dataset
dataset_path = Path('../data/offline_datasets/behavior_policy.pkl')

if dataset_path.exists():
    buffer = OfflineReplayBuffer.load(dataset_path)
    print(f"Loaded dataset with {len(buffer)} transitions")
else:
    print("Dataset not found. Please run 02_data_collection.ipynb first.")
    print("Or run: python scripts/02_collect_offline_data.py")

## 2. CQL Algorithm Overview

CQL adds a conservative regularization term to standard Q-learning:

$$\mathcal{L}_{CQL} = \alpha \left( \log \sum_a \exp(Q(s,a)) - Q(s, a_{data}) \right) + \mathcal{L}_{TD}$$

Where:
- $\alpha$ controls the conservatism level
- First term penalizes high Q-values for all actions
- Second term rewards high Q-values for dataset actions

In [None]:
# Create CQL agent
cql_config = {
    'state_dim': 1,  # Discrete state index
    'action_dim': 25,
    'hidden_dims': [256, 256],
    'learning_rate': 3e-4,
    'gamma': 0.99,
    'tau': 0.005,
    'alpha': 1.0,  # CQL conservatism coefficient
    'target_update_frequency': 1,
}

cql_agent = CQL(**cql_config)
print("CQL Agent Created")
print(f"  Alpha (conservatism): {cql_config['alpha']}")
print(f"  Network architecture: {cql_config['hidden_dims']}")

## 3. Training Loop

In [None]:
# Training configuration
n_iterations = 10000  # Increase for better results
batch_size = 256
eval_frequency = 1000
n_eval_episodes = 50

# Create evaluation environment
eval_env = create_sepsis_env(use_action_masking=True)

# Training history
history = {
    'q_loss': [],
    'cql_loss': [],
    'td_loss': [],
    'eval_survival_rate': [],
    'eval_return': [],
}

In [None]:
# Training loop
print(f"Training CQL for {n_iterations} iterations...")
print(f"Evaluating every {eval_frequency} iterations\n")

for iteration in tqdm(range(n_iterations)):
    # Sample batch
    batch = buffer.sample(batch_size)
    
    # Update CQL
    metrics = cql_agent.update(batch)
    
    # Log metrics
    history['q_loss'].append(metrics['q_loss'])
    history['cql_loss'].append(metrics['cql_loss'])
    history['td_loss'].append(metrics['td_loss'])
    
    # Evaluation
    if (iteration + 1) % eval_frequency == 0:
        eval_results = evaluate_policy(
            env=eval_env,
            policy=cql_agent,
            n_episodes=n_eval_episodes,
        )
        
        history['eval_survival_rate'].append(eval_results['survival_rate'])
        history['eval_return'].append(eval_results['mean_return'])
        
        print(f"\nIteration {iteration + 1}:")
        print(f"  Survival Rate: {eval_results['survival_rate']:.1%}")
        print(f"  Mean Return: {eval_results['mean_return']:.3f}")
        print(f"  Q-Loss: {np.mean(history['q_loss'][-100:]):.4f}")

## 4. Visualise Training

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Q-Loss
axes[0, 0].plot(history['q_loss'], alpha=0.3)
# Smoothed
window = 100
smoothed = np.convolve(history['q_loss'], np.ones(window)/window, mode='valid')
axes[0, 0].plot(range(window-1, len(history['q_loss'])), smoothed, linewidth=2)
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Q-Loss')
axes[0, 0].set_title('Total Q-Loss')

# CQL vs TD Loss
axes[0, 1].plot(history['cql_loss'], label='CQL Loss', alpha=0.5)
axes[0, 1].plot(history['td_loss'], label='TD Loss', alpha=0.5)
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('CQL vs TD Loss Components')
axes[0, 1].legend()

# Survival Rate
eval_iters = np.arange(eval_frequency, n_iterations + 1, eval_frequency)
axes[1, 0].plot(eval_iters, history['eval_survival_rate'], 'o-', linewidth=2, markersize=8)
axes[1, 0].axhline(y=0.8, color='g', linestyle='--', alpha=0.7, label='Target (80%)')
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Survival Rate')
axes[1, 0].set_title('Evaluation Survival Rate')
axes[1, 0].set_ylim([0, 1])
axes[1, 0].legend()

# Mean Return
axes[1, 1].plot(eval_iters, history['eval_return'], 'o-', linewidth=2, markersize=8, color='orange')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Mean Return')
axes[1, 1].set_title('Evaluation Mean Return')

plt.tight_layout()
plt.show()

## 5. Analyse Learned Q-Values

In [None]:
# Sample states and compute Q-values
sample_batch = buffer.sample(1000)
states = torch.FloatTensor(sample_batch['states']).to(cql_agent.device)

with torch.no_grad():
    q_values = cql_agent.q_network(states).cpu().numpy()

print(f"Q-value statistics:")
print(f"  Mean: {np.mean(q_values):.4f}")
print(f"  Std: {np.std(q_values):.4f}")
print(f"  Min: {np.min(q_values):.4f}")
print(f"  Max: {np.max(q_values):.4f}")

In [None]:
# Q-value distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Overall distribution
axes[0].hist(q_values.flatten(), bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Q-Value')
axes[0].set_ylabel('Count')
axes[0].set_title('Q-Value Distribution')
axes[0].axvline(x=0, color='red', linestyle='--', alpha=0.7)

# Mean Q per action
mean_q_per_action = np.mean(q_values, axis=0)
axes[1].bar(range(25), mean_q_per_action, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Action Index')
axes[1].set_ylabel('Mean Q-Value')
axes[1].set_title('Mean Q-Value per Action')

plt.tight_layout()
plt.show()

## 6. Save Model

In [None]:
# Save the trained model
checkpoint_dir = Path('../results/cql_notebook')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

cql_agent.save(checkpoint_dir / 'cql_model.pt')
print(f"Model saved to: {checkpoint_dir / 'cql_model.pt'}")

## Key Takeaways

1. **CQL Regularization**: The conservative penalty prevents overestimation
2. **Alpha Parameter**: Higher Î± = more conservative (closer to behavior policy)
3. **Training Stability**: CQL is generally stable but may need tuning
4. **Q-Value Analysis**: Conservative Q-values indicate successful regularization