In [1]:
import sys
from pathlib import Path
repo_root = Path.cwd().parent
sys.path.insert(0, str(repo_root / "src"))

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecMonitor, DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import CheckpointCallback
import torch as th
import os
from GurobiParamEnv import InexactGBDEnv

In [None]:
# make a logs directory
log_dir = "../tb_logs"
os.makedirs(log_dir, exist_ok=True)
save_dir = "../checkpoints"
os.makedirs(save_dir, exist_ok=True)

# 1. Wrap your env in Monitor before vectorizing
def make_env():
    # the Monitor wrapper writes out a monitor.csv under log_dir
    return Monitor(InexactGBDEnv(), filename=os.path.join(log_dir, "monitor.csv"))

# 2. Create a DummyVecEnv of monitored envs
env = DummyVecEnv([make_env])

# 3. Normalize observations (no reward normalization)
env = VecNormalize(env, norm_obs=True, norm_reward=True)

# 4. Wrap the VecNormalize env in a VecMonitor to aggregate per‚Äêenv episode stats
env = VecMonitor(env, filename=os.path.join(log_dir, "vecmonitor.csv"))

policy_kwargs = dict(activation_fn=th.nn.ReLU,
                     net_arch=dict(pi=[64, 64, 64], vf=[64, 64]))

# 5. Initialize PPO with tensorboard logging pointed at the same directory
model = PPO(
    "MlpPolicy",
    env,
    policy_kwargs=policy_kwargs, 
    gamma=0.99,
    learning_rate=5e-4,
    n_steps=512,
    batch_size=128,
    verbose=1,
    tensorboard_log=log_dir, 
)

# checkpoint_callback = CheckpointCallback(
#     save_freq=4000, 
#     save_path=save_dir, 
#     name_prefix="ppo_gbd"
# )

# 6. Train
model.learn(total_timesteps=20_000, 
            # callback=checkpoint_callback, 
            tb_log_name="benders_test")

In [None]:
# Save both model and normalization statistics
model.save("ppo_benders_model_retrain")
env.save("vecnormalize_benders.pkl")

print("Training complete. Model saved to ppo_benders_model.zip and stats to vecnormalize_benders.pkl")