In [1]:
from gym import Env
import numpy as np
from gym.spaces import MultiDiscrete,Box
from graph_tool.all import *
from makegraph import *

In [22]:
def simulatepandemic(self,actions):
    action = np.divide(actions,np.sum(actions)) #relative availability of vaccine is calculated
    update_state(self.g,action)
    self.state = graph_to_matrix(self.g)
    self.timestep += 1
    return self.state

def initializepandemic(self):
    self.g = make_graph(self.size,self.distribution)
    self.state = graph_to_matrix(self.g)
    return self.state

In [104]:
class PanEnv(Env):
    def __init__(self,size,distribution):
        self.size = size #population size
        
        self.action_space = MultiDiscrete(nvec=[10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10])
        self.observation_space = Box(low=-np.inf,high=np.inf,shape=[self.size,6])
        
        
        self.distribution = distribution #age distribution, country name as string
        self.g = make_graph(self.size,self.distribution)
                
        #state observation as matrix
        self.state = graph_to_matrix(self.g)
        
        self.timestep = 1
    def step(self,actions):
        state = simulatepandemic(self,actions)
        #observation (object): agent's observation of the current environment
        reward = -np.sum(state[:,5])
        #reward (float) : amount of reward returned after previous action
        #negative reward: punishment > change weights a lot, push away from causing weights, positive rewards pull
        #do reward compared to reward from previous step
        #naive example: reward = -sum(infected) > we want a reward where the cumulative sum of infections until end
        #is minimized
        #exp. solution: store information in self, summed infections, normalized by time
        done = False
        if((self.timestep > 19) or np.sum(self.state[:,1])==0):
            done = True
        #done (bool): episode is done after 20 timesteps or when there are no longer infected agents
        info = {}
        #info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)'''
        return state, reward, done, info
    def reset(self):
        #returns initial state
        self.state = initializepandemic(self)
        return self.state

In [107]:
from stable_baselines3.common.env_checker import check_env
check_env(env)

In [4]:
#matrix representation:
#[S, I, R, Sv, Iv, D]

In [106]:
env = PanEnv(size=1000,distribution='Japan')
env.reset()

array([[1., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.],
       ...,
       [1., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.]])

In [81]:
actions = np.full(20,1)

In [102]:
env.step(actions)

(array([[0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1.],
        ...,
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0.]]),
 -657.0,
 True,
 {})

In [8]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
# Parallel environments

env = DummyVecEnv([lambda: PanEnv(size=1000,distribution='Japan')])

model = PPO("MlpPolicy", env, verbose=1) #multilayer
model.learn(total_timesteps=25000) #training loop
#model.save("ppo_cartpole")

#del model # remove to demonstrate saving and loading

#model = PPO.load("ppo_cartpole")

Using cpu device
-----------------------------
| time/              |      |
|    fps             | 9    |
|    iterations      | 1    |
|    time_elapsed    | 205  |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 10           |
|    iterations           | 2            |
|    time_elapsed         | 389          |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0002608819 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -46.1        |
|    explained_variance   | -5.25e-06    |
|    learning_rate        | 0.0003       |
|    loss                 | 6.39e+07     |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00242     |
|    value_loss           | 1.27e+08     |
------------------------------------------

-------------------------------------------
| time/                   |               |
|    fps                  | 10            |
|    iterations           | 13            |
|    time_elapsed         | 2594          |
|    total_timesteps      | 26624         |
| train/                  |               |
|    approx_kl            | 0.00029641594 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -46           |
|    explained_variance   | 0             |
|    learning_rate        | 0.0003        |
|    loss                 | 6.04e+07      |
|    n_updates            | 120           |
|    policy_gradient_loss | -0.00273      |
|    value_loss           | 1.26e+08      |
-------------------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x7fed78827c70>

In [10]:
model.save("ppo_1")

In [18]:
gym.spaces??

Object `gym.spaces` not found.


In [27]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
log_dir = "/gym/"

env = Monitor(env, log_dir)

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)

print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

AttributeError: 'DummyVecEnv' object has no attribute 'reward_range'

In [9]:
obs = env.reset()
dones = False
while not dones:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)

KeyboardInterrupt: 

In [None]:
#split into learning and testing
model.learn(total_timesteps = 5000)
# store/accumulate rewards