In [1]:
import numpy as np
import networkx as nx
import gymnasium as gym
from gymnasium import spaces

from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import A2C, SAC, PPO, DQN
from stable_baselines3.common.env_util import make_vec_env

import random
import pickle

import os

In [2]:
class PeriodicSensor(gym.Env):

    def __init__(self, graph, initial_i=1, infect_prob=0.1, cure_prob=0, budget=2, seed=False,
                 file_path='fixed_india'):
    # === INPUTS ===
    # graph = networkx.Graph() object
    # initial_i = int, number of nodes that are initially infected
    # infect_prob = float, probability of infection transferring over an edge
    # budget = int, number of nodes that can be tested at each time step
    
        super(PeriodicSensor, self).__init__()
        self.graph = graph
        self.n = len(self.graph)
        self.all_nodes = list(range(self.n))
        self.initial_i = initial_i
        self.infect_prob = infect_prob
        self.cure_prob = cure_prob
        self.A = nx.to_numpy_array(graph)

        self.file_path = file_path
        
        # define the observation vectors here #
        self.infection_vector = np.array([0]* self.n)
        self.morbidity_vector = np.random.rand(self.n)
        self.vax_vector = np.array([1] * self.n)
        
        self.observation_space = spaces.Box(low=0, high=1, shape=(self.n+3, self.n), dtype=np.float64)
            #spaces.Box(low=0, high=1, shape=(self.n,), dtype=np.float64)
        
        #Defining action space for just budget 1 right now, can change to a box later
        #self.action_space = spaces.Discrete(self.n)
        self.action_space = spaces.MultiBinary(np.array([self.n]))
        

        self.seed = seed
    
    # TODO Need to do reset
    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        #if selected_cascade:
        #    self.fised_graph_num = selected_cascade
        #else:
        #    self.fixed_graph_num += 1


        #self.cascade = nx.Graph()
        #self.cascade.add_nodes_from(self.graph)
        
        self.infection_vector = np.array([0]* self.n)
        self.vax_vector = np.array([1] * self.n)
        
        initial_infections = random.sample(list(range(self.n)), self.initial_i)
        for i in initial_infections:
            self.infection_vector[i] = 1

        # Should this be implemented in a wrapper? Maybe? See tutorial Notebook 2
        self.t = 1

        return np.stack(self.infection_vector, self.morbidity_vector, self.vax_vector, self.A), {} #empty info dict
    
    # TODO define step
    def step(self, action):
        info = {}
        
        
        
        

SyntaxError: '(' was never closed (2539595165.py, line 28)

In [None]:
## NOTE - not implemented yet

Graph_List = ['test_graph','Hospital','India','Exhibition','Flu','irvine','Escorts','Epinions']
Graph_index = 2
Graph_name = Graph_List[Graph_index]
path = 'graph_data/' + Graph_name + '.txt'
G = nx.read_edgelist(path, nodetype=int)
mapping = dict(zip(G.nodes(),range(len(G))))
g = nx.relabel_nodes(G,mapping)


env = PeriodicSensor(g)
check_env(env, warn=True)

model = A2C('MlpPolicy', env, verbose=1).learn(400_000)

rews = []
env = PeriodicSensor(g)
check_env(env, warn=True)
episodes = 300

for ep in range(10000, 10000+episodes):
    run_reward = 0
    obs, _info = env.reset(selected_cascade=ep)
    done = False
    while not done:
        action, _state = model.predict(obs)
        obs, rewards, done, terminated, info = env.step(action)
        run_reward += rewards
    rews.append(run_reward)

print(f'Average Reward: {np.mean(rews)}, STD: {np.std(rews)}')