In [1]:
import pandas as pd
import numpy as np
import xgboost as xgb
from functools import reduce
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error

# --- 1. Define Constants ---
# File paths based on your directory structure
MIMIC_ICU_DIR = 'data/mimic-iv-clinical-database-demo-2.2/icu/'
MIMIC_HOSP_DIR = 'data/mimic-iv-clinical-database-demo-2.2/hosp/'

# Define the time window for bucketing data (in hours)
TIME_BUCKET_SIZE = 4

# --- Define itemids for the features you want ---
# States (Vitals & Labs)
ID_HEART_RATE = 220045
ID_SYSTOLIC_BP = 220179
ID_DIASTOLIC_BP = 220180
ID_TEMPERATURE = 223761
ID_LACTATE = 50813 # from d_labitems

# Actions (Medications)
ID_NOREPINEPHRINE = 221906 # Vasopressor
ID_FLUID_BOLUS = 225158     # IV Fluid

print("Cell 1: Setup and constants defined successfully.")

Cell 1: Setup and constants defined successfully.


In [2]:
# --- 2. Load and Merge Data Upfront ---
print("Loading and filtering raw data...")

# ICU Stays (the source of truth for stay_id and intime)
df_icu_stays = pd.read_csv(f'{MIMIC_ICU_DIR}icustays.csv.gz', usecols=['stay_id', 'subject_id', 'intime'])
df_icu_stays['intime'] = pd.to_datetime(df_icu_stays['intime'])

# Vitals from chartevents
df_chartevents = pd.read_csv(f'{MIMIC_ICU_DIR}chartevents.csv.gz', usecols=['stay_id', 'charttime', 'itemid', 'valuenum'])
state_vitals_ids = [ID_HEART_RATE, ID_SYSTOLIC_BP, ID_DIASTOLIC_BP, ID_TEMPERATURE]
df_vitals_raw = df_chartevents[df_chartevents['itemid'].isin(state_vitals_ids)].copy()

# Labs from labevents
df_labevents = pd.read_csv(f'{MIMIC_HOSP_DIR}labevents.csv.gz', usecols=['subject_id', 'charttime', 'itemid', 'valuenum'])
state_labs_ids = [ID_LACTATE]
df_labs_raw = df_labevents[df_labevents['itemid'].isin(state_labs_ids)].copy()

# Actions from inputevents
df_inputevents = pd.read_csv(f'{MIMIC_ICU_DIR}inputevents.csv.gz', usecols=['stay_id', 'starttime', 'itemid', 'amount'])
action_ids = [ID_NOREPINEPHRINE, ID_FLUID_BOLUS]
df_actions_raw = df_inputevents[df_inputevents['itemid'].isin(action_ids)].copy()

# --- Centralize all merges here ---
# Add 'intime' to vitals and actions using their 'stay_id'
df_vitals = pd.merge(df_vitals_raw, df_icu_stays, on='stay_id', how='inner')
df_actions = pd.merge(df_actions_raw, df_icu_stays, on='stay_id', how='inner')

# Add 'stay_id' and 'intime' to labs using their 'subject_id'
df_labs = pd.merge(df_labs_raw, df_icu_stays, on='subject_id', how='inner')

print("Cell 2: Raw data loaded and merged successfully.")

Loading and filtering raw data...
Cell 2: Raw data loaded and merged successfully.


In [3]:
# --- 3. Simplified Time-Bucket and Aggregate ---
def simplified_process_and_aggregate(df, time_col, value_col, item_id, feature_name):
    """A simplified function that assumes df already has stay_id and intime."""
    df_filtered = df[df['itemid'] == item_id].copy()
    df_filtered[time_col] = pd.to_datetime(df_filtered[time_col])
    
    # Calculate hours from admission (intime is already present)
    df_filtered['hours_in'] = (df_filtered[time_col] - df_filtered['intime']).dt.total_seconds() / 3600
    df_filtered['time_bucket'] = (df_filtered['hours_in'] // TIME_BUCKET_SIZE).astype(int)
    
    # Aggregate by taking the mean value in each bucket
    df_agg = df_filtered.groupby(['stay_id', 'time_bucket'])[value_col].mean().reset_index()
    df_agg = df_agg.rename(columns={value_col: feature_name})
    
    return df_agg

print("Aggregating data into time buckets...")
# Process States
df_hr = simplified_process_and_aggregate(df_vitals, 'charttime', 'valuenum', ID_HEART_RATE, 'heart_rate')
df_sbp = simplified_process_and_aggregate(df_vitals, 'charttime', 'valuenum', ID_SYSTOLIC_BP, 'systolic_bp')
df_dbp = simplified_process_and_aggregate(df_vitals, 'charttime', 'valuenum', ID_DIASTOLIC_BP, 'diastolic_bp')
df_temp = simplified_process_and_aggregate(df_vitals, 'charttime', 'valuenum', ID_TEMPERATURE, 'temperature')
df_lactate = simplified_process_and_aggregate(df_labs, 'charttime', 'valuenum', ID_LACTATE, 'lactate')

# Process Actions
df_norepi = simplified_process_and_aggregate(df_actions, 'starttime', 'amount', ID_NOREPINEPHRINE, 'norepinephrine')
df_fluids = simplified_process_and_aggregate(df_actions, 'starttime', 'amount', ID_FLUID_BOLUS, 'fluid_bolus')

print("Cell 3: Data aggregation complete.")

Aggregating data into time buckets...
Cell 3: Data aggregation complete.


In [4]:
# --- 4. Merge into a Master DataFrame ---
print("Merging features into a master DataFrame...")
dfs_to_merge = [df_hr, df_sbp, df_dbp, df_temp, df_lactate, df_norepi, df_fluids]

# Start with the first DataFrame in the list as our base
df_master = dfs_to_merge[0]

# Loop through the rest of the DataFrames and merge them one by one
for df_to_merge in dfs_to_merge[1:]:
    df_master = pd.merge(df_master, df_to_merge, on=['stay_id', 'time_bucket'], how='outer')

# Sort for chronological order
df_master = df_master.sort_values(by=['stay_id', 'time_bucket']).reset_index(drop=True)

print("Cell 4: Master DataFrame created.")
df_master.info()

Merging features into a master DataFrame...
Cell 4: Master DataFrame created.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3970 entries, 0 to 3969
Data columns (total 9 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   stay_id         3970 non-null   int64  
 1   time_bucket     3970 non-null   int64  
 2   heart_rate      3122 non-null   float64
 3   systolic_bp     2095 non-null   float64
 4   diastolic_bp    2095 non-null   float64
 5   temperature     2725 non-null   float64
 6   lactate         1307 non-null   float64
 7   norepinephrine  376 non-null    float64
 8   fluid_bolus     1300 non-null   float64
dtypes: float64(7), int64(2)
memory usage: 279.3 KB


In [5]:
# --- FINAL CORRECTED Cell 5: Handle Missing Data ---
print("Handling missing data...")

# This .apply() method correctly forward-fills within each group...
# ...and .reset_index(drop=True) removes the ambiguous multi-index.
df_master = df_master.groupby('stay_id').apply(lambda group: group.ffill()).reset_index(drop=True)

# Fill any remaining NaNs (especially at the beginning of a stay) with 0
df_master = df_master.fillna(0)

print("Cell 5: Missing data handled, and index has been reset.")

Handling missing data...
Cell 5: Missing data handled, and index has been reset.


  df_master = df_master.groupby('stay_id').apply(lambda group: group.ffill()).reset_index(drop=True)


In [6]:
# --- DEBUGGING CELL ---
# Insert this cell right before the original Cell 6 to inspect the DataFrame

print("--- Inspecting df_master just before the error ---")
print("\nFirst 5 rows:")
print(df_master.head())

print("\n\nDataFrame Info:")
df_master.info()

print(f"\n\nIs 'stay_id' a column? -> {'stay_id' in df_master.columns}")

--- Inspecting df_master just before the error ---

First 5 rows:
    stay_id  time_bucket  heart_rate  systolic_bp  diastolic_bp  temperature  \
0  30057454           -1         0.0          0.0           0.0         0.00   
1  30057454            0       105.8          0.0           0.0        98.25   
2  30057454            1       108.0          0.0           0.0        98.70   
3  30057454            2       109.5          0.0           0.0        97.70   
4  30057454            3       113.0          0.0           0.0        97.70   

   lactate  norepinephrine  fluid_bolus  
0      0.9             0.0     0.000000  
1      0.9             0.0    17.916667  
2      0.9             0.0    17.916667  
3      0.9             0.0     7.000000  
4      0.9             0.0    31.416668  


DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3970 entries, 0 to 3969
Data columns (total 9 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          -----------

In [7]:
# --- 6. Create Final S_t, A_t, S_{t+1} Table ---
print("Creating final S_t, A_t, S_{t+1} table...")
state_cols = ['heart_rate', 'systolic_bp', 'diastolic_bp', 'temperature', 'lactate']
action_cols = ['norepinephrine', 'fluid_bolus']

# Create the "next state" columns by shifting the state columns up by one
for col in state_cols:
    df_master[f'{col}_next'] = df_master.groupby('stay_id')[col].shift(-1)

# Drop the last row for each patient, as it has no "next state"
df_final = df_master.dropna().reset_index(drop=True)

print("Cell 6: Preprocessing complete!")
print(f"Created a final dataset with {len(df_final)} samples.")
df_final.head()

Creating final S_t, A_t, S_{t+1} table...
Cell 6: Preprocessing complete!
Created a final dataset with 3830 samples.


Unnamed: 0,stay_id,time_bucket,heart_rate,systolic_bp,diastolic_bp,temperature,lactate,norepinephrine,fluid_bolus,heart_rate_next,systolic_bp_next,diastolic_bp_next,temperature_next,lactate_next
0,30057454,-1,0.0,0.0,0.0,0.0,0.9,0.0,0.0,105.8,0.0,0.0,98.25,0.9
1,30057454,0,105.8,0.0,0.0,98.25,0.9,0.0,17.916667,108.0,0.0,0.0,98.7,0.9
2,30057454,1,108.0,0.0,0.0,98.7,0.9,0.0,17.916667,109.5,0.0,0.0,97.7,0.9
3,30057454,2,109.5,0.0,0.0,97.7,0.9,0.0,7.0,113.0,0.0,0.0,97.7,0.9
4,30057454,3,113.0,0.0,0.0,97.7,0.9,0.0,31.416668,110.75,82.5,56.0,97.7,0.9


In [8]:
# --- 7. Build and Train the World Model ---
# X includes the state and action at time 't'
X = df_final[state_cols + action_cols]
# Y is the state at time 't+1'
Y = df_final[[f'{col}_next' for col in state_cols]]

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

# Initialize the XGBoost regressor
xgbr = xgb.XGBRegressor(objective='reg:squarederror', n_estimators=100, learning_rate=0.1, max_depth=5, random_state=42)

# Wrap it with MultiOutputRegressor to handle multiple targets
multi_output_model = MultiOutputRegressor(estimator=xgbr)

print("Training the world model...")
multi_output_model.fit(X_train, y_train)
print("Cell 7: Training complete.")

Training the world model...
Cell 7: Training complete.


In [9]:
# --- 8. Evaluate the World Model ---
print("Evaluating the world model...")
y_pred = multi_output_model.predict(X_test)

# Calculate overall Mean Squared Error
mse = mean_squared_error(y_test, y_pred)
print(f"\nOverall World Model MSE on Test Set: {mse:.4f}")

# You can also look at the error for each variable individually
print("\n--- MSE for each predicted variable ---")
for i, col in enumerate(state_cols):
    var_mse = mean_squared_error(y_test.iloc[:, i], y_pred[:, i])
    print(f"  - MSE for {col}_next: {var_mse:.4f}")

Evaluating the world model...

Overall World Model MSE on Test Set: 219.7467

--- MSE for each predicted variable ---
  - MSE for heart_rate_next: 289.2407
  - MSE for systolic_bp_next: 403.8036
  - MSE for diastolic_bp_next: 153.8443
  - MSE for temperature_next: 251.3340
  - MSE for lactate_next: 0.5111


In [10]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class PatientSimulatorEnv(gym.Env):
    """A custom Gym environment for simulating patient treatment."""
    
    def __init__(self, world_model, initial_states_df, state_cols, action_cols):
        super(PatientSimulatorEnv, self).__init__()
        
        self.world_model = world_model
        self.initial_states = initial_states_df[state_cols].values
        self.state_cols = state_cols
        self.action_cols = action_cols
        
        # Define action and observation space
        # They must be gym.spaces objects
        # Example: 5 discrete levels for norepinephrine, 5 for fluids = 25 actions
        self.action_space = spaces.Discrete(25) 
        
        # The state space is continuous
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 
                                            shape=(len(state_cols),), dtype=np.float32)
        
        self.current_state = None
        self.episode_length = 0

    def reset(self, seed=None, options=None):
        # Randomly select an initial state from the real data
        initial_state_idx = np.random.randint(0, len(self.initial_states))
        self.current_state = self.initial_states[initial_state_idx]
        self.episode_length = 0
        
        return self.current_state, {} # Return state and an empty info dict

    def step(self, action):
        # 1. Decode the discrete action into medication dosages
        # Example: action #15 -> level 3 norepi, level 3 fluids
        # This requires you to define a mapping.
        # For now, let's assume a placeholder for the action vector.
        action_vector = np.zeros(len(self.action_cols)) # Placeholder
        
        # 2. Prepare the input for the world model
        model_input = np.concatenate([self.current_state, action_vector]).reshape(1, -1)
        
        # 3. Use the world model to predict the next state
        predicted_next_state = self.world_model.predict(model_input)[0]
        self.current_state = predicted_next_state
        self.episode_length += 1
        
        # 4. Calculate the reward and determine if the episode is done
        lactate_level = self.current_state[self.state_cols.index('lactate')]
        
        reward = 0
        terminated = False
        
        # Simple reward function: Penalize high lactate
        if lactate_level > 4.0: # High lactate is a sign of severe sepsis
            reward = -100
            terminated = True
        elif lactate_level < 1.0: # Healthy lactate level
            reward = 10
        else:
            reward = -1 # Small penalty for each time step in the ICU
            
        if self.episode_length >= 50: # End episode after 50 steps (e.g., 200 hours)
            terminated = True
            
        return self.current_state, reward, terminated, False, {} # next_state, reward, terminated, truncated, info

In [11]:
from stable_baselines3 import DQN

# 1. Create an instance of your environment
# We use the 'X' DataFrame from the previous step to get initial states
env = PatientSimulatorEnv(multi_output_model, X, state_cols, action_cols)

# 2. Instantiate the DQN agent
model = DQN("MlpPolicy", env, verbose=1)

# 3. Train the agent
# This will take some time. The agent will play out thousands of episodes in the simulator.
print("Training RL agent...")
model.learn(total_timesteps=10000, progress_bar=True)
print("Training complete.")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


Output()

Training RL agent...


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 50       |
|    ep_rew_mean      | -50      |
|    exploration_rate | 0.81     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 136      |
|    time_elapsed     | 1        |
|    total_timesteps  | 200      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 8.7      |
|    n_updates        | 24       |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 50       |
|    ep_rew_mean      | -50      |
|    exploration_rate | 0.62     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 148      |
|    time_elapsed     | 2        |
|    total_timesteps  | 400      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 5        |
|    n_updates        | 74       |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 50       |
|    ep_rew_mean      | -4.17    |
|    exploration_rate | 0.43     |
| time/               |          |
|    episodes         | 12       |
|    fps              | 151      |
|    time_elapsed     | 3        |
|    total_timesteps  | 600      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 4.48     |
|    n_updates        | 124      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 50       |
|    ep_rew_mean      | -15.6    |
|    exploration_rate | 0.24     |
| time/               |          |
|    episodes         | 16       |
|    fps              | 150      |
|    time_elapsed     | 5        |
|    total_timesteps  | 800      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.71     |
|    n_updates        | 174      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.5     |
|    ep_rew_mean      | -25      |
|    exploration_rate | 0.0966   |
| time/               |          |
|    episodes         | 20       |
|    fps              | 149      |
|    time_elapsed     | 6        |
|    total_timesteps  | 951      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.34     |
|    n_updates        | 212      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48       |
|    ep_rew_mean      | -29.2    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 24       |
|    fps              | 147      |
|    time_elapsed     | 7        |
|    total_timesteps  | 1151     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.03     |
|    n_updates        | 262      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 46.6     |
|    ep_rew_mean      | -34.1    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 28       |
|    fps              | 146      |
|    time_elapsed     | 8        |
|    total_timesteps  | 1306     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.75     |
|    n_updates        | 301      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.1     |
|    ep_rew_mean      | -18.9    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 32       |
|    fps              | 144      |
|    time_elapsed     | 10       |
|    total_timesteps  | 1506     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.769    |
|    n_updates        | 351      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.4     |
|    ep_rew_mean      | -7.06    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 36       |
|    fps              | 144      |
|    time_elapsed     | 11       |
|    total_timesteps  | 1706     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.803    |
|    n_updates        | 401      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 46.5     |
|    ep_rew_mean      | -12.6    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 40       |
|    fps              | 144      |
|    time_elapsed     | 12       |
|    total_timesteps  | 1858     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.697    |
|    n_updates        | 439      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 46.8     |
|    ep_rew_mean      | 9.23     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 44       |
|    fps              | 145      |
|    time_elapsed     | 14       |
|    total_timesteps  | 2058     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.16     |
|    n_updates        | 489      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47       |
|    ep_rew_mean      | 4.52     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 48       |
|    fps              | 145      |
|    time_elapsed     | 15       |
|    total_timesteps  | 2258     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.3      |
|    n_updates        | 539      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.3     |
|    ep_rew_mean      | 10.9     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 52       |
|    fps              | 145      |
|    time_elapsed     | 16       |
|    total_timesteps  | 2458     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.841    |
|    n_updates        | 589      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.5     |
|    ep_rew_mean      | 16.8     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 56       |
|    fps              | 145      |
|    time_elapsed     | 18       |
|    total_timesteps  | 2658     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.92     |
|    n_updates        | 639      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 46.8     |
|    ep_rew_mean      | 11.5     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 60       |
|    fps              | 145      |
|    time_elapsed     | 19       |
|    total_timesteps  | 2809     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.38     |
|    n_updates        | 677      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47       |
|    ep_rew_mean      | 7.64     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 64       |
|    fps              | 145      |
|    time_elapsed     | 20       |
|    total_timesteps  | 3009     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.91     |
|    n_updates        | 727      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.2     |
|    ep_rew_mean      | 12.3     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 68       |
|    fps              | 144      |
|    time_elapsed     | 22       |
|    total_timesteps  | 3209     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 3.79     |
|    n_updates        | 777      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.3     |
|    ep_rew_mean      | 8.88     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 72       |
|    fps              | 144      |
|    time_elapsed     | 23       |
|    total_timesteps  | 3409     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.03     |
|    n_updates        | 827      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.5     |
|    ep_rew_mean      | 13.2     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 76       |
|    fps              | 144      |
|    time_elapsed     | 24       |
|    total_timesteps  | 3609     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.8      |
|    n_updates        | 877      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.6     |
|    ep_rew_mean      | 10       |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 80       |
|    fps              | 145      |
|    time_elapsed     | 26       |
|    total_timesteps  | 3809     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.8      |
|    n_updates        | 927      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.7     |
|    ep_rew_mean      | 26.9     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 84       |
|    fps              | 145      |
|    time_elapsed     | 27       |
|    total_timesteps  | 4009     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.55     |
|    n_updates        | 977      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.8     |
|    ep_rew_mean      | 23.4     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 88       |
|    fps              | 145      |
|    time_elapsed     | 28       |
|    total_timesteps  | 4209     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 3.11     |
|    n_updates        | 1027     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.9     |
|    ep_rew_mean      | 26.2     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 92       |
|    fps              | 145      |
|    time_elapsed     | 30       |
|    total_timesteps  | 4409     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.18     |
|    n_updates        | 1077     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48       |
|    ep_rew_mean      | 23       |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 96       |
|    fps              | 145      |
|    time_elapsed     | 31       |
|    total_timesteps  | 4609     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.59     |
|    n_updates        | 1127     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.1     |
|    ep_rew_mean      | 20.1     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 100      |
|    fps              | 146      |
|    time_elapsed     | 32       |
|    total_timesteps  | 4809     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.942    |
|    n_updates        | 1177     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.1     |
|    ep_rew_mean      | 20.1     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 104      |
|    fps              | 146      |
|    time_elapsed     | 34       |
|    total_timesteps  | 5009     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.96     |
|    n_updates        | 1227     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.1     |
|    ep_rew_mean      | 25.6     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 145      |
|    time_elapsed     | 35       |
|    total_timesteps  | 5209     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.07     |
|    n_updates        | 1277     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.1     |
|    ep_rew_mean      | 25.6     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 145      |
|    time_elapsed     | 37       |
|    total_timesteps  | 5409     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.39     |
|    n_updates        | 1327     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 47.6     |
|    ep_rew_mean      | 25.1     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 145      |
|    time_elapsed     | 38       |
|    total_timesteps  | 5560     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.28     |
|    n_updates        | 1364     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.1     |
|    ep_rew_mean      | 25.6     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 120      |
|    fps              | 145      |
|    time_elapsed     | 39       |
|    total_timesteps  | 5760     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.53     |
|    n_updates        | 1414     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.1     |
|    ep_rew_mean      | 25.6     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 124      |
|    fps              | 144      |
|    time_elapsed     | 41       |
|    total_timesteps  | 5960     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.947    |
|    n_updates        | 1464     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.5     |
|    ep_rew_mean      | 26.1     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 128      |
|    fps              | 144      |
|    time_elapsed     | 42       |
|    total_timesteps  | 6160     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.874    |
|    n_updates        | 1514     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.5     |
|    ep_rew_mean      | 20.6     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 132      |
|    fps              | 144      |
|    time_elapsed     | 44       |
|    total_timesteps  | 6360     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.47     |
|    n_updates        | 1564     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.5     |
|    ep_rew_mean      | 15.2     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 136      |
|    fps              | 144      |
|    time_elapsed     | 45       |
|    total_timesteps  | 6560     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.48     |
|    n_updates        | 1614     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | 15.7     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 140      |
|    fps              | 144      |
|    time_elapsed     | 46       |
|    total_timesteps  | 6760     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.06     |
|    n_updates        | 1664     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | 4.55     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 144      |
|    fps              | 144      |
|    time_elapsed     | 48       |
|    total_timesteps  | 6960     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.979    |
|    n_updates        | 1714     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | 4.44     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 148      |
|    fps              | 145      |
|    time_elapsed     | 49       |
|    total_timesteps  | 7160     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.34     |
|    n_updates        | 1764     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | 4.55     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 152      |
|    fps              | 145      |
|    time_elapsed     | 50       |
|    total_timesteps  | 7360     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.44     |
|    n_updates        | 1814     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | -1.17    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 156      |
|    fps              | 145      |
|    time_elapsed     | 52       |
|    total_timesteps  | 7560     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.36     |
|    n_updates        | 1864     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49.5     |
|    ep_rew_mean      | -0.67    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 160      |
|    fps              | 145      |
|    time_elapsed     | 53       |
|    total_timesteps  | 7760     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.14     |
|    n_updates        | 1914     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49.5     |
|    ep_rew_mean      | 10.3     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 164      |
|    fps              | 145      |
|    time_elapsed     | 54       |
|    total_timesteps  | 7960     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.27     |
|    n_updates        | 1964     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49.5     |
|    ep_rew_mean      | 4.83     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 168      |
|    fps              | 145      |
|    time_elapsed     | 55       |
|    total_timesteps  | 8160     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.75     |
|    n_updates        | 2014     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49.5     |
|    ep_rew_mean      | 4.94     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 172      |
|    fps              | 145      |
|    time_elapsed     | 57       |
|    total_timesteps  | 8360     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.972    |
|    n_updates        | 2064     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49.5     |
|    ep_rew_mean      | -0.67    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 176      |
|    fps              | 146      |
|    time_elapsed     | 58       |
|    total_timesteps  | 8560     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.968    |
|    n_updates        | 2114     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49.5     |
|    ep_rew_mean      | -0.67    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 180      |
|    fps              | 146      |
|    time_elapsed     | 59       |
|    total_timesteps  | 8760     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.3      |
|    n_updates        | 2164     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | -17.8    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 184      |
|    fps              | 146      |
|    time_elapsed     | 60       |
|    total_timesteps  | 8911     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.24     |
|    n_updates        | 2202     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | -17.8    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 188      |
|    fps              | 146      |
|    time_elapsed     | 62       |
|    total_timesteps  | 9111     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.51     |
|    n_updates        | 2252     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | -23.3    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 192      |
|    fps              | 146      |
|    time_elapsed     | 63       |
|    total_timesteps  | 9311     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.688    |
|    n_updates        | 2302     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 49       |
|    ep_rew_mean      | -17.8    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 196      |
|    fps              | 146      |
|    time_elapsed     | 64       |
|    total_timesteps  | 9511     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.07     |
|    n_updates        | 2352     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48.5     |
|    ep_rew_mean      | -18.3    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 200      |
|    fps              | 146      |
|    time_elapsed     | 65       |
|    total_timesteps  | 9662     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.16     |
|    n_updates        | 2390     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 48       |
|    ep_rew_mean      | -13.1    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 204      |
|    fps              | 146      |
|    time_elapsed     | 66       |
|    total_timesteps  | 9813     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.626    |
|    n_updates        | 2428     |
----------------------------------


Training complete.


In [12]:
# --- Evaluate the trained agent ---
episodes = 10
total_reward = 0

for ep in range(episodes):
    obs, info = env.reset()
    done = False
    ep_reward = 0
    while not done:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        ep_reward += reward
    
    total_reward += ep_reward
    print(f"Episode {ep + 1}: Total Reward = {ep_reward}")

avg_reward = total_reward / episodes
print(f"\nAverage reward over {episodes} episodes: {avg_reward:.2f}")

Episode 1: Total Reward = -39
Episode 2: Total Reward = -50
Episode 3: Total Reward = -50
Episode 4: Total Reward = -50
Episode 5: Total Reward = -50
Episode 6: Total Reward = -50
Episode 7: Total Reward = -50
Episode 8: Total Reward = -50
Episode 9: Total Reward = -50
Episode 10: Total Reward = -50

Average reward over 10 episodes: -48.90
