In [7]:
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from epidemic_simulation_environment import EpidemicSimulation
from settings import DATA_DIR, BASE_DIR
import stable_baselines3
from stable_baselines3.common.env_checker import check_env
import gymnasium as gym
import numpy as np
from gymnasium import spaces

To make Stable-Baselines3 work with gymnasium, install this:

```
pip install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support
```

In [8]:
class SB3Observation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)

    def observation(self, obs):
        return np.array(obs)


def create_env():
    env = EpidemicSimulation(
        data_path=f"{DATA_DIR}/Updated Data/epidemiological_model_data/",
        state_name="new_york",
        state_population=19_453_734,
        start_date="11/01/2021",
    )
    env = SB3Observation(env)
    env = Monitor(env)
    return env

In [9]:
check_env(create_env())

In [10]:
env = DummyVecEnv([create_env] * 4)
env = VecFrameStack(env, n_stack=2)

In [11]:
env.reset().shape

(4, 8)

In [12]:
model = stable_baselines3.PPO('MlpPolicy', env, verbose=1, tensorboard_log=f'{BASE_DIR}/rl_algorithms/pytorch/logs')
model.learn(total_timesteps=1e4)

Using cpu device
Logging to /Users/akhildevarashetti/code/covid_research/nitin22/src/rl_algorithms/pytorch/logs/PPO_3
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 181       |
|    ep_rew_mean     | -4.37e+05 |
| time/              |           |
|    fps             | 4680      |
|    iterations      | 1         |
|    time_elapsed    | 1         |
|    total_timesteps | 8192      |
----------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 181          |
|    ep_rew_mean          | -4.5e+05     |
| time/                   |              |
|    fps                  | 3476         |
|    iterations           | 2            |
|    time_elapsed         | 4            |
|    total_timesteps      | 16384        |
| train/                  |              |
|    approx_kl            | 8.219104e-06 |
|    clip_fraction        | 0            |
|    clip_range 

<stable_baselines3.ppo.ppo.PPO at 0x14ccec550>