In [1]:
# Cell 1: Install & Imports
# Install any missing libraries if needed (run in terminal: pip install stable-baselines3 xgboost pandas numpy gymnasium)
import pandas as pd
import numpy as np
import json
import zipfile
import os
from stable_baselines3 import DQN
from sklearn.multioutput import MultiOutputRegressor
import xgboost as xgb
import gymnasium as gym
from gymnasium import spaces

print("Libraries imported successfully.")

Libraries imported successfully.


In [12]:
# Cell 2: Load Trained Artifacts (DQN Agent + World Model)
import joblib

model_path = 'dqn_mimic_patient_model_v1.zip'
world_model_path = 'world_model.pkl'

# Load the trained DQN agent
try:
    dqn_model = DQN.load(model_path)
    print("Trained DQN agent loaded successfully from zip.")
except FileNotFoundError:
    print(f"Error: File '{model_path}' not found. Ensure the zip file exists in the current directory.")
    raise
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# Load the world model separately
try:
    world_model = joblib.load(world_model_path)
    print("World model loaded successfully from pickle.")
except FileNotFoundError:
    print(f"Error: File '{world_model_path}' not found. Ensure the pickle exists in the current directory.")
    world_model = None
except Exception as e:
    print(f"Error loading world model: {e}")
    world_model = None

Trained DQN agent loaded successfully from zip.
World model loaded successfully from pickle.


In [13]:
# Cell 3: Load Synthetic Patient JSON (Synthea)
# Load the FHIR Bundle JSON
json_file_path = 'output/fhir/Bethel526_Gerlach374_49578d6f-b690-6615-598e-3ee3719d1c69.json'
with open(json_file_path, 'r') as f:
    synthea_data = json.load(f)

print("Synthetic patient JSON loaded.")
print(f"Bundle type: {synthea_data.get('type')}")
print(f"Number of entries: {len(synthea_data.get('entry', []))}")

Synthetic patient JSON loaded.
Bundle type: transaction
Number of entries: 943


In [15]:
# Cell 4: Inspect Synthea Structure & Quick Validation
print("Inspecting FHIR Bundle structure...")
for i, entry in enumerate(synthea_data['entry'][:5]):  # Show first 5 entries
    resource = entry.get('resource', {})
    resource_type = resource.get('resourceType')
    print(f"Entry {i}: Resource Type = {resource_type}")
    if resource_type == 'Observation':
        print(f"  - Code: {resource.get('code', {}).get('coding', [{}])[0].get('code')}")
        print(f"  - Value: {resource.get('valueQuantity', {}).get('value')}")
    elif resource_type == 'Patient':
        name_obj = resource.get('name', [{}])[0]
        given_names = " ".join(name_obj.get('given', []))
        family_name = name_obj.get('family', '')
        print(f"  - Name: {given_names} {family_name}")

# Validation: Check for key resource types
resource_types = [entry['resource'].get('resourceType') for entry in synthea_data['entry']]
print(f"\nUnique resource types: {set(resource_types)}")
if 'Observation' not in resource_types:
    print("Warning: No Observations found in JSON. Model features may be missing.")

Inspecting FHIR Bundle structure...
Entry 0: Resource Type = Patient
  - Name: Bethel526 Gerlach374
Entry 1: Resource Type = Encounter
Entry 2: Resource Type = MedicationRequest
Entry 3: Resource Type = Claim
Entry 4: Resource Type = ExplanationOfBenefit

Unique resource types: {'Encounter', 'Immunization', 'Provenance', 'Condition', 'MedicationAdministration', 'SupplyDelivery', 'MedicationRequest', 'Device', 'CarePlan', 'Observation', 'DocumentReference', 'Procedure', 'ExplanationOfBenefit', 'ImagingStudy', 'CareTeam', 'DiagnosticReport', 'Patient', 'AllergyIntolerance', 'Claim', 'Medication'}


In [16]:
# Cell 5: Map Synthea Fields → Model Features (States & Actions)
# Define mappings from FHIR codes to model features
# FHIR uses LOINC codes for observations. Map common ones to your model's state columns.
# Note: Synthea may use specific codes; adjust based on your JSON. Actions are not in JSON (they're outputs), so skip.
fhir_to_state_map = {
    '8867-4': 'heart_rate',      # LOINC for Heart rate
    '8480-6': 'systolic_bp',     # Systolic blood pressure
    '8462-4': 'diastolic_bp',    # Diastolic blood pressure
    '8310-5': 'temperature',     # Body temperature
    '5902-2': 'lactate'          # Lactate (may not be present in Synthea)
}

# Extract observations into a dict
patient_features = {}
for entry in synthea_data['entry']:
    resource = entry['resource']
    if resource.get('resourceType') == 'Observation':
        code = resource.get('code', {}).get('coding', [{}])[0].get('code')
        value = resource.get('valueQuantity', {}).get('value')
        if code in fhir_to_state_map and value is not None:
            patient_features[fhir_to_state_map[code]] = float(value)

# Check extracted features
state_cols = ['heart_rate', 'systolic_bp', 'diastolic_bp', 'temperature', 'lactate']
print("Extracted patient features:")
for col in state_cols:
    if col in patient_features:
        print(f"  {col}: {patient_features[col]}")
    else:
        print(f"  {col}: Missing (will impute to 0)")

# Impute missing features with 0 (as in training)
for col in state_cols:
    if col not in patient_features:
        patient_features[col] = 0.0

print("Mapping complete. Patient features ready.")

Extracted patient features:
  heart_rate: 85.0
  systolic_bp: Missing (will impute to 0)
  diastolic_bp: Missing (will impute to 0)
  temperature: 37.841
  lactate: Missing (will impute to 0)
Mapping complete. Patient features ready.


In [17]:
# Cell 6: Preprocess / Build Time-Bucketed Features or Single Initial State
# For testing, we use a single initial state (not time-bucketed, as the model simulates sequentially).
# Convert to numpy array for the env.
initial_state = np.array([patient_features[col] for col in state_cols], dtype=np.float32)
print(f"Initial state vector: {initial_state}")

# Optional: If you want to simulate multiple time steps, you could create a sequence, but for now, start with one.
# Note: The model expects states in the same units as training (e.g., HR in bpm, BP in mmHg, Temp in C, Lactate in mmol/L).
# If Synthea uses different units, convert here (e.g., if Temp is in F, convert to C).

Initial state vector: [85.     0.     0.    37.841  0.   ]


In [18]:
# Cell 7: Impute Missing Features / Scaling / Unit Checks
# Already handled imputation in Cell 5. Add unit checks/scaling if needed.
# Example: Ensure Temp is in Celsius (Synthea often uses F; convert if necessary).
if 'temperature' in patient_features and patient_features['temperature'] > 50:  # Likely F
    patient_features['temperature'] = (patient_features['temperature'] - 32) * 5/9
    initial_state[3] = patient_features['temperature']  # Update array
    print("Temperature converted from F to C.")

# Scaling: If your training data was scaled, apply the same scaler here (e.g., StandardScaler).
# For simplicity, assume no scaling in training; if yes, load and apply.
print("Unit checks and imputation complete.")

Unit checks and imputation complete.


In [19]:
# Cell 8: Create Initial State Matrix for the Simulator
# The env expects a DataFrame-like initial_states for reset. Create a dummy DF with this state.
initial_states_df = pd.DataFrame([initial_state], columns=state_cols)
print(f"Initial states DataFrame:\n{initial_states_df}")

# Recreate the env with this initial state (modify PatientSimulatorEnv to accept custom initial).
# For testing, we'll override the reset to use our state.
class TestPatientSimulatorEnv(gym.Env):
    def __init__(self, world_model, custom_initial_state, state_cols, action_cols):
        super().__init__()
        self.world_model = world_model
        self.custom_initial_state = custom_initial_state
        self.state_cols = state_cols
        self.action_cols = action_cols
        self.action_space = spaces.Discrete(25)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(len(state_cols),), dtype=np.float32)
        self.current_state = None
        self.episode_length = 0
    
    def reset(self, seed=None, options=None):
        self.current_state = self.custom_initial_state.copy()
        self.episode_length = 0
        return self.current_state, {}
    
    def step(self, action):
        # Same as original, but use custom initial
        action_vector = np.zeros(len(self.action_cols))  # Placeholder; map action to dosages if needed
        model_input = np.concatenate([self.current_state, action_vector]).reshape(1, -1)
        predicted_next_state = self.world_model.predict(model_input)[0]
        previous_lactate = self.current_state[self.state_cols.index('lactate')]
        self.current_state = predicted_next_state
        self.episode_length += 1
        new_lactate = self.current_state[self.state_cols.index('lactate')]
        reward = (previous_lactate - new_lactate) * 10
        terminated = new_lactate > 4.0 or (new_lactate < 1.0 and self.episode_length > 5) or self.episode_length >= 50
        if terminated and new_lactate > 4.0:
            reward -= 100
        elif terminated and new_lactate < 1.0:
            reward += 50
        return self.current_state, reward, terminated, False, {}

test_env = TestPatientSimulatorEnv(world_model, initial_state, state_cols, ['norepinephrine', 'fluid_bolus'])
print("Test environment created with synthetic patient initial state.")

Initial states DataFrame:
   heart_rate  systolic_bp  diastolic_bp  temperature  lactate
0        85.0          0.0           0.0       37.841      0.0
Test environment created with synthetic patient initial state.


In [20]:
# Cell 9: Sanity-Check World Model Predictions on Sample Inputs
# Test the world model with the initial state + sample actions.
sample_action = np.array([0.0, 0.0])  # No meds
sample_input = np.concatenate([initial_state, sample_action]).reshape(1, -1)
predicted_next = world_model.predict(sample_input)[0]
print(f"Sample prediction (no action): Current state {initial_state} -> Next state {predicted_next}")

# Check shapes and ranges
print(f"Input shape: {sample_input.shape}, Output shape: {predicted_next.shape}")
print(f"Ranges: States {np.min(initial_state)}-{np.max(initial_state)}, Predicted {np.min(predicted_next)}-{np.max(predicted_next)}")

Sample prediction (no action): Current state [85.     0.     0.    37.841  0.   ] -> Next state [8.1925209e+01 1.4667047e+02 8.5988884e+01 1.0246951e+02 2.9216770e-02]
Input shape: (1, 7), Output shape: (5,)
Ranges: States 0.0-85.0, Predicted 0.029216770082712173-146.67047119140625


In [21]:
# Cell 10: Run the DQN Agent in PatientSimulatorEnv with Synthetic Patient
# Test the DQN agent in the test env.
episodes = 5  # Test with a few episodes
trajectories = []

for ep in range(episodes):
    obs, info = test_env.reset()
    done = False
    ep_trajectory = {'states': [obs], 'actions': [], 'rewards': []}
    while not done:
        action, _ = dqn_model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = test_env.step(action)
        ep_trajectory['states'].append(obs)
        ep_trajectory['actions'].append(action)
        ep_trajectory['rewards'].append(reward)
    trajectories.append(ep_trajectory)
    print(f"Episode {ep+1}: Length {len(ep_trajectory['states'])-1}, Total Reward {sum(ep_trajectory['rewards'])}")

print("Testing complete.")

Episode 1: Length 50, Total Reward -10.591890335083008
Episode 2: Length 50, Total Reward -10.591890335083008
Episode 3: Length 50, Total Reward -10.591890335083008
Episode 4: Length 50, Total Reward -10.591890335083008
Episode 5: Length 50, Total Reward -10.591890335083008
Testing complete.


In [23]:
# Cell 11: Record Trajectories, Metrics and Save Outputs
# Convert NumPy arrays to lists for JSON serialization
def convert_trajectory(traj):
    return {
        'states': [state.tolist() if hasattr(state, 'tolist') else state for state in traj['states']],
        'actions': [int(a) for a in traj['actions']],
        'rewards': [float(r) for r in traj['rewards']]
    }

trajectories_serializable = [convert_trajectory(traj) for traj in trajectories]

# Save trajectories to a JSON file.
with open('test_trajectories.json', 'w') as f:
    json.dump(trajectories_serializable, f, indent=4)

# Compute metrics (e.g., average reward, lactate changes).
avg_reward = np.mean([sum(traj['rewards']) for traj in trajectories])
lactate_changes = [traj['states'][-1][4] - traj['states'][0][4] for traj in trajectories]  # Lactate index 4
print(f"Average reward: {avg_reward}")
print(f"Average lactate change: {np.mean(lactate_changes)}")

print("Outputs saved to test_trajectories.json.")

Average reward: -10.591890335083008
Average lactate change: 1.0591890811920166
Outputs saved to test_trajectories.json.


In [24]:
# Cell 12: Debugging / Validation Utilities (Column Checks, Shapes, Ranges)
# Utility functions for debugging.
def check_shapes():
    print(f"Initial state shape: {initial_state.shape}")
    print(f"World model input shape: {sample_input.shape}")
    print(f"Predicted output shape: {predicted_next.shape}")

def check_ranges(df=None):
    if df is not None:
        print(f"DataFrame ranges:\n{df.describe()}")
    print(f"State ranges: {np.min(initial_state)} - {np.max(initial_state)}")

def validate_fhir_mapping():
    print("Mapped features:", patient_features)
    missing = [col for col in state_cols if patient_features.get(col, 0) == 0]
    if missing:
        print(f"Warning: Missing/imputed features: {missing}")

check_shapes()
check_ranges()
validate_fhir_mapping()
print("Debugging complete.")

Initial state shape: (5,)
World model input shape: (1, 7)
Predicted output shape: (5,)
State ranges: 0.0 - 85.0
Mapped features: {'heart_rate': 85.0, 'temperature': 37.841, 'systolic_bp': 0.0, 'diastolic_bp': 0.0, 'lactate': 0.0}
Debugging complete.
