In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from src.environments.icu_sepsis_wrapper import create_sepsis_env
from src.data.replay_buffer import OfflineReplayBuffer
from src.data.data_collection import DataCollector, MixedBehaviorPolicy, RandomBehaviorPolicy

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

## 1. Behavior Policies

In offline RL, we learn from data collected by a "behavior policy". We'll explore different behavior policies.

In [None]:
# Create environment
env = create_sepsis_env(use_action_masking=True)

# Define behavior policies
random_policy = RandomBehaviorPolicy(n_actions=25)
mixed_policy = MixedBehaviorPolicy(n_actions=25, epsilon=0.3)

print("Behavior Policies:")
print("  1. Random: Uniform random action selection")
print("  2. Mixed: Expert-like with Îµ-greedy exploration")

## 2. Collect Offline Dataset

In [None]:
# Collect data with mixed policy
collector = DataCollector(
    env=env,
    behavior_policy=mixed_policy,
    buffer_size=100000,
)

print("Collecting offline dataset...")
buffer = collector.collect(n_episodes=1000, verbose=True)

print(f"\nDataset size: {len(buffer)} transitions")

## 3. Dataset Statistics

In [None]:
# Compute statistics
stats = buffer.compute_statistics()

print("Dataset Statistics:")
print(f"  Total transitions: {stats['n_transitions']}")
print(f"  Number of episodes: {stats['n_episodes']}")
print(f"  Mean episode length: {stats['mean_episode_length']:.1f}")
print(f"  Mean return: {stats['mean_return']:.3f}")
print(f"  Survival rate: {stats['survival_rate']:.1%}")

In [None]:
# Analyze state distribution
states = buffer.states[:len(buffer)]
actions = buffer.actions[:len(buffer)]
rewards = buffer.rewards[:len(buffer)]

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# State distribution
axes[0, 0].hist(states.flatten(), bins=50, edgecolor='black', alpha=0.7)
axes[0, 0].set_xlabel('State Index')
axes[0, 0].set_ylabel('Count')
axes[0, 0].set_title('State Distribution')

# Action distribution
action_counts = np.bincount(actions.flatten().astype(int), minlength=25)
axes[0, 1].bar(range(25), action_counts, edgecolor='black', alpha=0.7)
axes[0, 1].set_xlabel('Action Index')
axes[0, 1].set_ylabel('Count')
axes[0, 1].set_title('Action Distribution')

# Reward distribution
reward_counts = np.bincount((rewards.flatten() + 1).astype(int), minlength=3)
axes[1, 0].bar(['-1 (Death)', '0 (Ongoing)', '+1 (Survival)'], reward_counts,
               color=['red', 'gray', 'green'], edgecolor='black', alpha=0.7)
axes[1, 0].set_xlabel('Reward')
axes[1, 0].set_ylabel('Count')
axes[1, 0].set_title('Reward Distribution')

# Action heatmap (5x5 grid)
action_grid = action_counts.reshape(5, 5)
sns.heatmap(action_grid, annot=True, fmt='d', cmap='Blues', ax=axes[1, 1],
            xticklabels=['0', '1', '2', '3', '4'],
            yticklabels=['0', '1', '2', '3', '4'])
axes[1, 1].set_xlabel('IV Fluid Level')
axes[1, 1].set_ylabel('Vasopressor Level')
axes[1, 1].set_title('Action Frequency Grid')

plt.tight_layout()
plt.show()

## 4. Save Dataset

In [None]:
# Save the dataset
output_dir = Path('../data/offline_datasets')
output_dir.mkdir(parents=True, exist_ok=True)

buffer.save(output_dir / 'mixed_policy_1k.pkl')
print(f"Dataset saved to: {output_dir / 'mixed_policy_1k.pkl'}")

## 5. Load and Verify Dataset

In [None]:
# Load the dataset
loaded_buffer = OfflineReplayBuffer.load(output_dir / 'mixed_policy_1k.pkl')

print(f"Loaded dataset size: {len(loaded_buffer)}")

# Sample a batch
batch = loaded_buffer.sample(batch_size=32)
print(f"\nSample batch shapes:")
print(f"  States: {batch['states'].shape}")
print(f"  Actions: {batch['actions'].shape}")
print(f"  Rewards: {batch['rewards'].shape}")
print(f"  Next states: {batch['next_states'].shape}")
print(f"  Dones: {batch['dones'].shape}")

## Key Takeaways

1. **Behavior Policy Quality**: The dataset quality affects CQL performance
2. **Dataset Size**: More data generally helps, but CQL can work with limited data
3. **Action Coverage**: Important for offline RL - need diverse actions
4. **Survival Rate**: Behavior policy survival rate sets a baseline