In [1]:
import numpy as np
import networkx as nx
import gymnasium as gym
from gymnasium import spaces

from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import A2C, SAC, PPO, DQN
from stable_baselines3.common.env_util import make_vec_env

import random
import pickle

import os

In [2]:
class FluShotRollout(gym.Env):

    def __init__(self, graph, initial_i=1, infect_prob=0.1, cure_prob=.05, budget=2, seed=False,
                 file_path='fixed_india'):
    # === INPUTS ===
    # graph = networkx.Graph() object
    # initial_i = int, number of nodes that are initially infected
    # infect_prob = float, probability of infection transferring over an edge
    # budget = int, number of nodes that can be tested at each time step
    
        super(FluShotRollout, self).__init__()
        self.graph = graph
        self.n = len(self.graph)
        self.all_nodes = list(range(self.n))
        self.initial_i = initial_i
        self.infect_prob = infect_prob
        self.cure_prob = cure_prob
        self.A = nx.to_numpy_array(self.graph)

        self.file_path = file_path
        
        # define the observation vectors here #
        self.infection_vector = np.array([0]* self.n)
        self.morbidity_vector = np.random.default_rng().uniform(.001, .05, self.n)
        self.vax_vector = np.array([1.0] * self.n, dtype=np.float64) # doing this as 1s so that we can multiply the arrays together to
                                                 # be the likelihood of someone getting infected
        
        self.observation_space = spaces.Box(low=0, high=1, shape=(self.n, self.n+3), dtype=np.float64)
            #spaces.Box(low=0, high=1, shape=(self.n,), dtype=np.float64)
        
        #Defining action space for just budget 1 right now, can change to a box later
        #self.action_space = spaces.Discrete(self.n)
        self.action_space = spaces.MultiDiscrete(np.array([self.n] * budget))
        
        self.sum_deaths = 0
        

        self.seed = seed
    
    # TODO Need to do reset
    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        #if selected_cascade:
        #    self.fised_graph_num = selected_cascade
        #else:
        #    self.fixed_graph_num += 1
        
        # Need to reset this since we are changing the adjacency matrix when someone dies
        self.A = nx.to_numpy_array(self.graph)

        self.infection_vector = np.array([0]* self.n)
        self.vax_vector = np.array([1.0] * self.n, dtype=np.float64)
        
        initial_infections = random.sample(list(range(self.n)), self.initial_i)
        for i in initial_infections:
            self.infection_vector[i] = 1

        # Should this be implemented in a wrapper? Maybe? See tutorial Notebook 2
        self.t = 1

        return np.column_stack((self.infection_vector, self.morbidity_vector, self.vax_vector, self.A)), {} #empty info dict
    
    # TODO define step
    def step(self, action): 
        # action is a vector of length B, so choose those values and vax those individuals
        info = {}
        #vax = np.nonzero(action == 1)
        for i in action:
            self.vax_vector[i] = .02
        
        
        # - Progress Infection (matrix multiplication)
        indegree_prob = self.infect_prob * self.infection_vector #
        uninfected_Prob = np.array([np.prod([(1 - indegree_prob[u]) for u in self.A[:,v].nonzero()[0]]) \
                         for v in self.all_nodes]) 
        infect_prob = np.ones(self.n) - uninfected_Prob
        
        #Calculates a list of newly infected nodes
        new_infect = [v for v in self.all_nodes if random.uniform(0,1) < (infect_prob[v] * self.vax_vector[v])]
                
        # - Calculate Deaths (using morbidity vector)
        new_deaths = [v for v in self.all_nodes if self.infection_vector[v] == 1 and random.uniform(0,1) < self.morbidity_vector[v]]
        
        # - Calculate Cures
        new_cure = [v for v in self.all_nodes if self.infection_vector[v] == 1 and random.uniform(0,1) < self.cure_prob]

        # - reduce vax efficacy (multiply self.vax_vector * 1.02)
        new_vax = self.vax_vector * 1.02
        self.vax_vector = np.array([min(i, 1) for i in new_vax], dtype=np.float64)
        
        
        # - Remove dead people from the adjacency matrix - can't spread/be infected anymore
        for i in new_deaths:
            self.A[i,:] = 0
            self.A[:,i] = 0
        # - change infection status of dead people to 0 just avoid possible double counting
            self.infection_vector[i] = 0
        
        # - change self.infection_vector to account for new infections, cures.
        for i in new_infect:
            self.infection_vector[i] = 1
        for i in new_cure:
            self.infection_vector[i] = 0
        
        #REWARD IS THE NEGATIVE SUM DEATHS
        self.sum_deaths -= len(new_deaths)
        
        #Still to do:
        # - Figure out how to disallow incorrect actions
        # - Generate the full return
        self.t += 1
        if self.t >= 180:
            return (np.column_stack((self.infection_vector, self.morbidity_vector, self.vax_vector, self.A)),
                self.sum_deaths,
                True,
                True, #One of these booleans needs to be adjusted for terminating after a certain number of steps
                info)
        
        return (np.column_stack((self.infection_vector, self.morbidity_vector, self.vax_vector, self.A)),
                self.sum_deaths,
                False,
                False, #One of these booleans needs to be adjusted for terminating after a certain number of steps
                info)


Graph_List = ['test_graph','Hospital','India','Exhibition','Flu','irvine','Escorts','Epinions']
Graph_index = 2
Graph_name = Graph_List[Graph_index]
path = Graph_name + '.txt'
G = nx.read_edgelist(path, nodetype=int)
mapping = dict(zip(G.nodes(),range(len(G))))
g = nx.relabel_nodes(G,mapping)


env = FluShotRollout(g)
check_env(env, warn=True)


In [None]:
## NOTE - not implemented yet

Graph_List = ['test_graph','Hospital','India','Exhibition','Flu','irvine','Escorts','Epinions']
Graph_index = 2
Graph_name = Graph_List[Graph_index]
path = Graph_name + '.txt'
G = nx.read_edgelist(path, nodetype=int)
mapping = dict(zip(G.nodes(),range(len(G))))
g = nx.relabel_nodes(G,mapping)


env = FluShotRollout(g)
check_env(env, warn=True)

model = A2C('MlpPolicy', env, verbose=1).learn(400_000)

rews = []
env = FluShotRollout(g)
check_env(env, warn=True)
episodes = 300

for ep in range(episodes):
    run_reward = 0
    obs, _info = env.reset()
    done = False
    while not done:
        action, _state = model.predict(obs)
        obs, rewards, done, terminated, info = env.step(action)
    rews.append(rewards)

print(f'Average Reward: {np.mean(rews)}, STD: {np.std(rews)}')



Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 179       |
|    ep_rew_mean        | -9.01e+03 |
| time/                 |           |
|    fps                | 301       |
|    iterations         | 100       |
|    time_elapsed       | 1         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -9.38     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -2.99e+03 |
|    value_loss         | 1.64e+05  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 179       |
|    ep_rew_mean        | -2.06e+04 |
| time/                 |           |
|    fps                | 309       |
|    iterations         | 200   

In [6]:
print(ep)

0
