In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.environments.icu_sepsis_wrapper import create_sepsis_env, ICUSepsisWrapper

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

## 1. Environment Overview

The ICU-Sepsis environment simulates sepsis treatment in an ICU setting:
- **States**: 716 discrete states representing patient condition
- **Actions**: 25 discrete actions (5 vasopressor levels × 5 IV fluid levels)
- **Rewards**: Sparse (+1 survival, -1 death, 0 otherwise)

In [None]:
# Create the environment
env = create_sepsis_env(use_action_masking=True)

print("Environment Details:")
print(f"  Observation space: {env.observation_space}")
print(f"  Action space: {env.action_space}")
print(f"  Number of states: {env.n_states}")
print(f"  Number of actions: {env.n_actions}")

## 2. Action Space Analysis

Actions are combinations of:
- Vasopressor dose (0-4)
- IV fluid volume (0-4)

In [None]:
# Action decomposition
print("Action Space (Vasopressor × IV Fluid):")
print("\nAction ID | Vasopressor | IV Fluid")
print("-" * 35)

for action in range(25):
    vaso = action // 5
    iv = action % 5
    print(f"    {action:2d}    |      {vaso}      |    {iv}")

In [None]:
# Visualise action grid
fig, ax = plt.subplots(figsize=(8, 6))

action_grid = np.arange(25).reshape(5, 5)
im = ax.imshow(action_grid, cmap='Blues')

ax.set_xlabel('IV Fluid Level')
ax.set_ylabel('Vasopressor Level')
ax.set_title('Action Space (25 Discrete Actions)')
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.set_xticklabels(['None', 'Low', 'Med', 'High', 'V.High'])
ax.set_yticklabels(['None', 'Low', 'Med', 'High', 'V.High'])

plt.colorbar(im, label='Action Index')
plt.show()

## 3. Running Episodes

Let's run some episodes with a random policy to understand the environment dynamics.

In [None]:
def run_episode(env, policy='random', seed=None):
    """Run a single episode and collect trajectory."""
    if seed is not None:
        state, info = env.reset(seed=seed)
    else:
        state, info = env.reset()
    
    trajectory = {
        'states': [state],
        'actions': [],
        'rewards': [],
    }
    
    done = False
    total_reward = 0
    
    while not done:
        # Select action
        if policy == 'random':
            action = env.action_space.sample()
        elif policy == 'no_treatment':
            action = 0  # No vasopressor, no IV
        elif policy == 'max_treatment':
            action = 24  # Max vasopressor, max IV
        else:
            action = env.action_space.sample()
        
        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        
        trajectory['actions'].append(action)
        trajectory['rewards'].append(reward)
        trajectory['states'].append(next_state)
        
        total_reward += reward
        state = next_state
    
    return trajectory, total_reward

In [None]:
# Run multiple episodes with random policy
n_episodes = 100
returns = []
lengths = []
outcomes = []  # 1 for survival, 0 for death

for i in range(n_episodes):
    trajectory, total_reward = run_episode(env, policy='random', seed=i)
    returns.append(total_reward)
    lengths.append(len(trajectory['actions']))
    outcomes.append(1 if total_reward > 0 else 0)

print(f"Random Policy Statistics ({n_episodes} episodes):")
print(f"  Survival rate: {np.mean(outcomes):.1%}")
print(f"  Mean return: {np.mean(returns):.3f} ± {np.std(returns):.3f}")
print(f"  Mean episode length: {np.mean(lengths):.1f} ± {np.std(lengths):.1f}")

In [None]:
# Visualise episode statistics
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Return distribution
axes[0].hist(returns, bins=3, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Episode Return')
axes[0].set_ylabel('Count')
axes[0].set_title('Return Distribution (Random Policy)')

# Episode length distribution
axes[1].hist(lengths, bins=20, edgecolor='black', alpha=0.7, color='orange')
axes[1].set_xlabel('Episode Length')
axes[1].set_ylabel('Count')
axes[1].set_title('Episode Length Distribution')

# Survival rate pie chart
survival_counts = [sum(outcomes), len(outcomes) - sum(outcomes)]
axes[2].pie(survival_counts, labels=['Survived', 'Died'], autopct='%1.1f%%',
            colors=['green', 'red'], explode=[0.05, 0])
axes[2].set_title('Patient Outcomes')

plt.tight_layout()
plt.show()

## 4. Comparing Different Policies

Let's compare different simple policies to understand the environment better.

In [None]:
def evaluate_policy(env, policy, n_episodes=100):
    """Evaluate a policy over multiple episodes."""
    returns = []
    for i in range(n_episodes):
        _, total_reward = run_episode(env, policy=policy, seed=i)
        returns.append(total_reward)
    
    survival_rate = np.mean([r > 0 for r in returns])
    return {
        'mean_return': np.mean(returns),
        'std_return': np.std(returns),
        'survival_rate': survival_rate,
    }

policies = ['random', 'no_treatment', 'max_treatment']
results = {}

for policy in policies:
    results[policy] = evaluate_policy(env, policy, n_episodes=100)
    print(f"{policy}: Survival={results[policy]['survival_rate']:.1%}, "
          f"Return={results[policy]['mean_return']:.3f}")

In [None]:
# Visualise policy comparison
fig, ax = plt.subplots(figsize=(8, 5))

policy_names = list(results.keys())
survival_rates = [results[p]['survival_rate'] for p in policy_names]

bars = ax.bar(policy_names, survival_rates, color=['blue', 'red', 'green'], alpha=0.7, edgecolor='black')

ax.set_xlabel('Policy', fontsize=12)
ax.set_ylabel('Survival Rate', fontsize=12)
ax.set_title('Policy Comparison: Survival Rates', fontsize=14)
ax.set_ylim([0, 1])

# Add value labels
for bar, rate in zip(bars, survival_rates):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{rate:.1%}', ha='center', va='bottom', fontsize=11)

plt.tight_layout()
plt.show()

## 5. Key Takeaways

From this exploration, we learned:

1. **Environment Structure**: The ICU-Sepsis environment has 716 discrete states and 25 actions
2. **Sparse Rewards**: Only terminal rewards (+1 survival, -1 death)
3. **Baseline Performance**: Random policy achieves ~50-60% survival (varies by dataset)
4. **Challenge**: Need to learn effective treatment policies from offline data

Next steps:
- Collect offline data with behavior policies
- Train CQL agents on the offline data
- Compare against baselines