In [9]:
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from epidemic_simulation_environment import EpidemicSimulation
from settings import DATA_DIR, BASE_DIR
import stable_baselines3
from stable_baselines3.common.env_checker import check_env
import gymnasium as gym
import numpy as np
from gymnasium import spaces

To make Stable-Baselines3 work with gymnasium, install this:

```
pip install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support
```

In [7]:
def create_env():
    env = EpidemicSimulation(
        data_path=f"{DATA_DIR}/Updated Data/epidemiological_model_data/",
        state_name="new_york",
        state_population=19_453_734,
        start_date="11/01/2021",
    )

    class SB3Observation(gym.ObservationWrapper):
        def __init__(self, env):
            super().__init__(env)
            self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)

        def observation(self, obs):
            return np.array(obs)

    env = SB3Observation(env)
    return env

In [22]:
check_env(create_env())

In [23]:
env = DummyVecEnv([create_env] * 4)
env = VecFrameStack(env, n_stack=2)

In [24]:
env.reset().shape

(4, 8)

In [None]:
model = stable_baselines3.PPO('MlpPolicy', env, verbose=1, tensorboard_log=f'{BASE_DIR}/rl_algorithms/pytorch/logs')
model.learn(total_timesteps=100_000)

Using cpu device
Logging to /Users/akhildevarashetti/code/covid_research/nitin22/src/rl_algorithms/pytorch/logs/PPO_1
-----------------------------
| time/              |      |
|    fps             | 4701 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 8192 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 3627          |
|    iterations           | 2             |
|    time_elapsed         | 4             |
|    total_timesteps      | 16384         |
| train/                  |               |
|    approx_kl            | 6.7793662e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -2.48         |
|    explained_variance   | -4.77e-07     |
|    learning_rate        | 0.0003        |
|    loss                 | 1.46e+09      |
|    n_updates            | 10            |
|    policy_