# Mini-Project 1: Deep Q-Learning for Epidemic Mitigation

**Goal**: Train an artificial agent using deep Q-learning to find a decision-making policy regarding the mitigation of an epidemic process. Compare the performance of different methods with various action and observation spaces.

**Context**: An epidemic of a new virus named MARVIN23 has just started propagating in Switzerland's neighbor Listenburg. We'll use a predictive model designed by epidemiologists to train a reinforcement learning agent for epidemic mitigation. The model takes the form of a few python classes that we call a reinforcement learning environment.

**Environment**: The environment simulates epidemic dynamics on a simplified map of Switzerland using a set of stochastic differential equations with 5 simulation variables: susceptible (s), exposed (e), infected (i), recovered (r), and dead (d) individuals. The variables evolve in time following a set of differential equations. Time in the simulation is measured in days, with the onset of the epidemic happening on day $d_0 = 0$. We denote time by the variable $d \in \mathbb{N}$.

$$s^{(d)}_{total} = \sum_{city\in map} s^{(d)}_{city}$$

**Agent**: The agent is allowed to observe the number of infected (i) and dead (d) people in each city. It has an action space that consists of 4 binary decisions: confinement, isolation, adding additional hospital beds, and vaccination.

$$a^{[w]} = a^{[w]}_{conf} \cup a^{[w]}_{isol} \cup a^{[w]}_{hosp} \in a^{[w]}_{vacc} = \bigcup_{\mathfrak{d}\in actions}a^{[w]}_{\mathfrak{d}},$$

**Reward**: The agent receives a reward, which is a combination of a constant reward term, action cost, and death cost. The agent's goal is to maximize the total reward over the course of an episode.

$$R^{[w]} = R_c - \mathcal{C}(a^{[w]}) - D \cdot \Delta d^{[w]}_{total}$$

The action cost and announcement costs are:

$$\mathcal{C}(a^{[w]}) = \mathcal{A}(a^{[w]}) + \mathbf{1}_{vac}\cdot V + \mathbf{1}_{hosp}\cdot H + \mathbf{1}_{conf}\cdot C + \mathbf{1}_{isol}\cdot I$$

$$\mathcal{A}(a^{[w]}) = A\cdot( \mathbf{1}^{+}_{vac} +  \mathbf{1}^{+}_{conf} +  \mathbf{1}^{+}_{isol}) $$

This notebook will provide a framework for implementing the environment and training a deep Q-learning agent to learn an optimal policy for epidemic mitigation.


In [1]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from gym import spaces

"""Environment imports"""
from epidemic_env.env       import Env, Log
from epidemic_env.dynamics  import ModelDynamics, Observation
from epidemic_env.visualize import Visualize
from epidemic_env.agent     import Agent

"""Pytorch and numpy imports"""
import numpy as np
import torch
from torch import nn

%matplotlib inline

In [2]:
dyn = ModelDynamics('config/switzerland.yaml')   # load the switzerland map

"""Loading the environment"""
env = Env(dyn,                      # Pass the dynamical model to the environment  
          action_space=None,        # Can be passed an openai gym action space
          observation_space=None,   # Can be passed an openai gym observation space
          )

""" Resetting the environment """
obs, info = env.reset(
    seed=1                          # Seed for the random number generator (ensuring reproducibility)
    ) 

Define the action space and observation space of the agent. 

* The action space is a set of binary decisions: confinement, isolation, adding additional hospital beds, and vaccination. 

* The observation space is the number of infected (i) and dead (d) people in each city.

At each step the environment returns four variables :
1. An observation `obs` (to be used by the agent for decision making)
2. A reward `rew` (to be used for training)
3. A boolean variable `done` (indicates when an episode is finished)
4. An information object `info` (to be used for policy interpretation and debugging)


In [3]:
# Define the action space:
# Don't do anything
action = {
        'confinement': False, 
        'isolation': False, 
        'hospital': False, 
        'vaccinate': False,
    }

"""One step amounts to one week in the simulation environment"""
obs, reward, done, info = env.step(action)


The `obs` variable is an `Observation` class is a container that holds information about an environment observation. It has three attributes:

1. `pop`: A dictionary that contains the original population of each city.
2. `city`: A dictionary that maps city names to an Observables object that contains observables for that city.
3. `total`: An Observables object that contains observables for the entire environment.

The `Observables` class is a container that holds observable variables from a city. It has two attributes:

1. `infected`: A list of integers that contains the number of infected individuals.
2. `dead`: A list of integers that contains the number of dead individuals.


In [4]:
obs

Observation(pop={'Lausanne': 295000, 'Geneva': 900000, 'Sion': 34978, 'Neuchâtel': 44531, 'Basel': 830000, 'Bern': 133115, 'Lücern': 82000, 'St-Gallen': 76213, 'Zürich': 1354000}, city={'Lausanne': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'Geneva': Observables(infected=[8714, 8174, 19746, 14579, 16690, 15234, 38251], dead=[261, 389, 662, 1225, 1500, 1921, 2523]), 'Sion': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'Neuchâtel': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'Basel': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'Bern': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'Lücern': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'St-Gallen': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0]), 'Zürich': Observables(infected=[0, 0, 0, 0, 0, 0, 0], dead=[0, 0, 0, 0, 0, 0, 0])}, total=Observables(i

In [5]:
# Populations
print("Population (Initial):")
for city, pop in obs.pop.items():
    print(f"{city}: {pop}")
print()

# City observables
print("City observables:")
for city, obs_city in obs.city.items():
    print(city)
    print(f"  - Infected: {obs_city.infected}")
    print(f"  - Dead: {obs_city.dead}")
    print()

# Total observables
print("Total observables:")
print(f"  - Infected: {obs.total.infected}")
print(f"  - Dead: {obs.total.dead}")


Population (Initial):
Lausanne: 295000
Geneva: 900000
Sion: 34978
Neuchâtel: 44531
Basel: 830000
Bern: 133115
Lücern: 82000
St-Gallen: 76213
Zürich: 1354000

City observables:
Lausanne
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Geneva
  - Infected: [8714, 8174, 19746, 14579, 16690, 15234, 38251]
  - Dead: [261, 389, 662, 1225, 1500, 1921, 2523]

Sion
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Neuchâtel
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Basel
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Bern
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Lücern
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

St-Gallen
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Zürich
  - Infected: [0, 0, 0, 0, 0, 0, 0]
  - Dead: [0, 0, 0, 0, 0, 0, 0]

Total observables:
  - Infected: [8714, 8174, 19746, 14579, 16690, 15234, 38251]
  - Dead: [2

The `info` variable is a `Log` object that contains the following fields:

1. `total`: a Parameters object that contains the total epidemic parameters for the current day.
2. `city`: a dictionary that maps each city to a Parameters object that contains the city-level epidemic parameters for the current day.
3. `action`: a dictionary that maps each possible action (confinement, isolation, vaccination, and hospitalization) to a boolean value indicating whether the action was taken on the current day.

In [6]:
print(f"Day {env.day}:\n")
print("Total epidemic parameters:")
print(f"{info.total}\n")
print("Epidemic parameters by city:\n")

for city in info.city:
    print(f"{city}:")
    print(f"  Suceptible      : {info.city[city].suceptible}")
    print(f"  Exposed         : {info.city[city].exposed}")
    print(f"  Infected        : {info.city[city].infected}")
    print(f"  Recovered       : {info.city[city].recovered}")
    print(f"  Dead            : {info.city[city].dead}")
    print(f"  Initial Pop.    : {info.city[city].initial_population}\n")

print("Actions taken:")
for action, status in info.action.items():
    print(f"  {action.capitalize():<12} : {status}")


Day 7:

Total epidemic parameters:
Parameters(day=7, suceptible=17566379, exposed=315629, infected=38251, recovered=155643, dead=2523, initial_population=3749837)

Epidemic parameters by city:

Lausanne:
  Suceptible      : 14305193
  Exposed         : 0
  Infected        : 0
  Recovered       : 0
  Dead            : 0
  Initial Pop.    : 295000

Geneva:
  Suceptible      : 387952
  Exposed         : 315629
  Infected        : 38251
  Recovered       : 155643
  Dead            : 2523
  Initial Pop.    : 900000

Sion:
  Suceptible      : 34978
  Exposed         : 0
  Infected        : 0
  Recovered       : 0
  Dead            : 0
  Initial Pop.    : 34978

Neuchâtel:
  Suceptible      : 362928
  Exposed         : 0
  Infected        : 0
  Recovered       : 0
  Dead            : 0
  Initial Pop.    : 44531

Basel:
  Suceptible      : 830000
  Exposed         : 0
  Infected        : 0
  Recovered       : 0
  Dead            : 0
  Initial Pop.    : 830000

Bern:
  Suceptible      : 133115


In [7]:
# Define the action space
action_space = spaces.Discrete(5)  # 5 possible actions

# Define the observation space
observation_space = spaces.Box(
    low=0,                                          # Minimum value of the observation
    high=1,                                         # Maximum value of the observation
    shape=(2, dyn.n_cities, dyn.env_step_length),   # Shape of the observation (n_observables, n_cities, n_steps)
    dtype=np.float32                                # Data type of the observation
)

In [8]:
# Sample action
sampled_action = action_space.sample()
print(f"Sampled action: {sampled_action}")

# Sample observation
sampled_observation = observation_space.sample()

# Create a subplot with 1 row and 2 columns
fig = make_subplots(rows=1, cols=2, subplot_titles=("Infected", "Dead"))

# Add a heatmap for the first dimension of the sampled observation (Infected)
fig.add_trace(go.Heatmap(
    z=sampled_observation[0], colorscale="Viridis", showscale=False), row=1, col=1)

# Add a heatmap for the second dimension of the sampled observation (Dead)
fig.add_trace(go.Heatmap(
    z=sampled_observation[1], colorscale="Viridis", showscale=False), row=1, col=2)

# Update the layout
fig.update_layout(title="Sampled Observation Space")
fig.update_xaxes(title_text="City", row=1, col=1)
fig.update_yaxes(title_text="Time Step", row=1, col=1)
fig.update_xaxes(title_text="City", row=1, col=2)
fig.update_yaxes(title_text="Time Step", row=1, col=2)

# Show the plot
fig.show()


Sampled action: 1


In [9]:
SCALE = 100
ACTION_NULL = 0
ACTION_CONFINE = 1
ACTION_ISOLATE = 2
ACTION_HOSPITAL = 3
ACTION_VACCINATE = 4


def action_preprocessor(a: torch.Tensor, dyn: ModelDynamics):
    # Default action: do nothing
    action = { 
        'confinement': False,
        'isolation': False,
        'hospital': False,
        'vaccinate': False,
    }

    if a == ACTION_CONFINE:
        action['confinement'] = True
    elif a == ACTION_ISOLATE:
        action['isolation'] = True
    elif a == ACTION_VACCINATE:
        action['vaccinate'] = True
    elif a == ACTION_HOSPITAL:
        action['hospital'] = True

    return action


def observation_preprocessor(obs: Observation, dyn: ModelDynamics, scale: float = SCALE):
    infected = scale * \
        np.array([np.array(obs.city[c].infected)/obs.pop[c]
                 for c in dyn.cities])
    dead = scale * \
        np.array([np.array(obs.city[c].infected)/obs.pop[c]
                 for c in dyn.cities])
    confined = np.ones_like(dead)*int((dyn.get_action()['confinement']))
    return torch.Tensor(np.stack((infected, dead, confined))).unsqueeze(0)


In [10]:
env = Env(  
    dyn,
    action_space=action_space,
    observation_space=observation_space,
    action_preprocessor=action_preprocessor,
    observation_preprocessor=observation_preprocessor
)

In [11]:
class ExampleAgent(Agent):
    def __init__(self,  env: Env,
                 # Additionnal parameters to be added here
                 ):
        """
        Example agent implementation. Just picks a random action at each time step.
        """
        self.env = env

    def load_model(self, savepath):
        # This is where one would define the routine for loading a pre-trained model
        pass

    def save_model(self, savepath):
        # This is where one would define the routine for saving the weights for a trained model
        pass

    def optimize_model(self):
        # This is where one would define the optimization step of an RL algorithm
        return 0

    def reset(self,):
        # This should be called when the environment is reset
        pass

    def act(self, obs):
        # this takes an observation and returns an action
        # the action space can be directly sampled from the env
        return self.env.action_space.sample()


**Setup:**

In [12]:
SCALE = 100         # WHAT IS THE PURPOSE OF THIS SHIT??????
ACTION_NULL = 0
ACTION_CONFINE = 1
ACTION_ISOLATE = 2
ACTION_HOSPITAL = 3
ACTION_VACCINATE = 4


def action_preprocessor(a: torch.Tensor, dyn: ModelDynamics):
    # Default action: do nothing
    action = {
        'confinement': False,
        'isolation': False,
        'hospital': False,
        'vaccinate': False,
    }

    if a == ACTION_CONFINE:
        action['confinement'] = True
    elif a == ACTION_ISOLATE:
        action['isolation'] = True
    elif a == ACTION_VACCINATE:
        action['vaccinate'] = True
    elif a == ACTION_HOSPITAL:
        action['hospital'] = True

    return action


def observation_preprocessor(obs: Observation, dyn: ModelDynamics, scale: float):
    # print(f'Infected (Observed): {sum(obs.total.infected)}')
    infected = scale * \
        np.array([np.array(obs.city[c].infected)/obs.pop[c]
                 for c in dyn.cities])
    # print(f'Infected (Processed): {infected.sum(axis=0)}')
    dead = scale * \
        np.array([np.array(obs.city[c].infected)/obs.pop[c]
                 for c in dyn.cities])
    confined = np.ones_like(dead)*int((dyn.get_action()['confinement']))
    test = torch.Tensor(np.stack((infected, dead, confined))).unsqueeze(0)
    # print(f'Infected?: {test[0,0].sum(axis=0)}')
    return torch.Tensor(np.stack((infected, dead, confined))).unsqueeze(0)


In [13]:
dyn = ModelDynamics('config/switzerland.yaml')

# Define the action space
action_space = spaces.Discrete(5)  # 5 possible actions

# Define the observation space
observation_space = spaces.Box(
    low=0,                                          # Minimum value of the observation
    high=1,                                         # Maximum value of the observation
    # Shape of the observation (n_observables, n_cities, n_steps)
    shape=(2, dyn.n_cities, dyn.env_step_length),
    dtype=np.float32                                # Data type of the observation
)
    
env = Env(
    dyn,
    action_space=action_space,
    observation_space=observation_space,
    action_preprocessor=action_preprocessor,
    observation_preprocessor=observation_preprocessor
)


_________
### **Question 1.a** Study the behavior of the model when epidemics are unmitigated
Run the epidemic simulation for one episode (30 weeks), without epidemic mitigation (meaning no action is
taken, i.e. all values in the action dictionary are set to False) and produce three plots:
1. A plot of variables $\quad s^{[w]}_{total} \quad e^{[w]}_{total} \quad i^{[w]}_{total} \quad r^{[w]}_{total} \quad d^{[w]}_{total} \quad$ over time, where time is measured in weeks and all the
variables share the y axis scaling.
2. A plot of variables $\quad i^{[w]}_{total} \quad d^{[w]}_{total} \quad$ over time, where time is measured in weeks and all the variables share the
y axis scaling.
3. A set of plots of variables $\quad i^{[w]}_{city} \quad d^{[w]}_{city} \quad$ over time, where time is measured in weeks (one subplot per-city, variables share the y-scaling per-city).

**Discuss the evolution of the variables over time**

In [14]:
num_weeks = 30
history = []

env.reset(seed=1)
for _ in range(num_weeks):
    _, _, _, info = env.step(action=None)
    history.append(info)

In [15]:
# Arrange the data into lists
suceptible_total    = [data.total.suceptible for data in history]
exposed_total       = [data.total.exposed for data in history]
infected_total      = [data.total.infected for data in history]
recovered_total     = [data.total.recovered for data in history]
dead_total          = [data.total.dead for data in history]


In [16]:
weeks = list(range(1, num_weeks + 1))

fig1 = go.Figure()
fig1.add_trace(go.Scatter(x=weeks, y=suceptible_total, mode='lines', name='Suceptible (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=exposed_total, mode='lines', name='Exposed (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=infected_total, mode='lines', name='Infected (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=recovered_total, mode='lines', name='Recovered (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=dead_total, mode='lines', name='Dead (Total)'))
fig1.update_layout(title='Plot 1: SEIRD variables over time (Total)')
fig1.show()

In [17]:
fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=weeks, y=infected_total, mode='lines', name='Infected (Total)'))
fig2.add_trace(go.Scatter(x=weeks, y=dead_total, mode='lines', name='Dead (Total)'))
fig2.update_layout(title='Plot 2: Infected and Dead over time (Total)')
fig2.show()

In [18]:
num_cities = len(history[0].city)
cities = list(history[0].city.keys())

fig3 = make_subplots(rows=num_cities, cols=1, shared_xaxes=True, subplot_titles=[f'{cities[i]}' for i in range(num_cities)])

for city_index in range(num_cities):
    city_infected = [data.city[cities[city_index]].infected for data in history]
    city_dead = [data.city[cities[city_index]].dead for data in history]

    fig3.add_trace(go.Scatter(x=weeks, y=city_infected, mode='lines', name=f'{cities[city_index]} Infected'), row=city_index+1, col=1)
    fig3.add_trace(go.Scatter(x=weeks, y=city_dead, mode='lines', name=f'{cities[city_index]} Dead'), row=city_index+1, col=1)

fig3.update_layout(height=200 * num_cities, width=800, title_text='Plot 3: Infected and Deceased per City')
fig3.show()

____
### Question 2: Professor Russo’s Policy

Since the epidemic hit Listenburg before Switzerland, the listenburgish medical community has had time to study the epidemic behavior of the disease and one of the listenburgish experts, professor Russo, suggests the following mitigation policy:

**Algorithm 1:** Pr. Russo’s Policy ($\pi_{Russo}$)

_Input:_ x ← $i^{[w]}_{total}$ number of infected people at the end of week w

```python
if x > 200000:
    # Confine the entire country for 4 weeks
    pass
```

_The number of infected cases is not evaluated during a confinement period, i.e. if the policy declares a 4-week confinement starting at week w and the number of infected people is still > 200000 at week $w + 2$, a new 4-week confinement does not start at $w + 2$._


#### Question 2.a: Implement Pr. Russo’s Policy
Implement Pr. Russo’s Policy as a python class (we recommend that you subclass the Agent abstract class
provided with the project files, and as is demonstrated in the tutorial notebook). Run the epidemic simulation
for one episode using Pr. Russo’s Policy to pick actions and produce four plots:

1. A plot of variables 
$\quad s^{[w]}_{total} \quad e^{[w]}_{total} \quad i^{[w]}_{total} \quad r^{[w]}_{total} \quad d^{[w]}_{total} \quad$
over time, where time is measured in weeks and all the variables share the y axis scaling.

2. A plot of variables 
$\quad i^{[w]}_{total} \quad d^{[w]}_{total} \quad$
over time, where time is measured in weeks and all the variables share the y axis scaling.

3. A set of plots of variables 
$\quad i^{[w]}_{city} \quad d^{[w]}_{city} \quad$
over time, where time is measured in weeks (one subplot per-city, variables share the y-scaling per-city).

4. A plot of the action taken by the policy over time (whether the policy chooses to confine or not).

**Discuss how the epidemic simulation responds to Pr. Russo’s Policy (focus on how it differs from the
unmitigated scenario).**


In [19]:
# Implement Russo policy as a custom agent
class RussoPolicyAgent(Agent):
    def __init__(self, env, threshold=20000, confinement_duration=4, verbose=False):
        super().__init__(env)
        self.threshold = threshold
        self.confinement_duration = confinement_duration
        self.confinement_countdown = 0
        self.verbose = verbose

    def load_model(self, loadpath):
        pass

    def save_model(self, savepath):
        pass

    def optimize_model(self):
        return 0
            
    def act(self, obs: torch.Tensor, scale: float) -> int:
        total_infected = info.total.infected # Doesnt make sense with scalings done in the preprocessor 
        # total_infected = 
        if self.verbose:
            print(f'Confinement countdown: {self.confinement_countdown} weeks left')
            print(f'Infected: {total_infected}')
        if total_infected > self.threshold:
            if self.confinement_countdown == 0:
                self.confinement_countdown = self.confinement_duration
                if self.verbose:
                    print(f'Confinement triggered at {total_infected} infected')
                return ACTION_CONFINE
            else:
                if self.verbose:
                    print(f'Confinement in progress: {self.confinement_countdown} weeks left')
                self.confinement_countdown -= 1
                return ACTION_NULL
        if self.confinement_countdown > 0:
            self.confinement_countdown -= 1
        return ACTION_NULL

    def reset(self):
        self.confinement_countdown = 0       

In [20]:
verbose = True

env = Env(
    dyn,
    action_space=None,
    observation_space=None,
    action_preprocessor=action_preprocessor,
    observation_preprocessor=observation_preprocessor,
)

# Run the epidemic simulation with the Russo policy agent
obs, info = env.reset(seed=1)
agent = RussoPolicyAgent(env, verbose=verbose)
agent.reset()

num_weeks = 30
history = []
actions = []

for week in range(num_weeks):
    action = agent.act(obs, info)
    obs, _, _, info = env.step(action=action)
    if verbose:
        print(f'Shape of the observation (infected ppl for each city and day of the week): \n{np.squeeze(obs)[0,:,:]}')
        print(f'Shape of the observation (infected, dead, exposed for the first city for every day of the week): \n{np.squeeze(obs)[:,0,:]}')
        print(f'Shape of the observation (infected, dead, exposed for each city for the first day of the week): \n{np.squeeze(obs)[:,:,0]} ')
        print(f'Week {week+1}: {action}\n-------------------------------------------------')


    actions.append(action)
    history.append(info)    


Confinement countdown: 0 weeks left
Infected: 9309
Shape of the observation (infected ppl for each city and day of the week): 
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9682, 0.9082, 2.1940, 1.6199, 1.8544, 1.6927, 4.2501],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
Shape of the observation (infected, dead, exposed for the first city for every day of the week): 
tensor([[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]])
Shape of the observation (infected, dead, exposed for each city f

In [21]:
# Arrange the data into lists
suceptible_total = [data.total.suceptible for data in history]
exposed_total = [data.total.exposed for data in history]
infected_total = [data.total.infected for data in history]
recovered_total = [data.total.recovered for data in history]
dead_total = [data.total.dead for data in history]

weeks = list(range(1, num_weeks + 1))

fig1 = go.Figure()
fig1.add_trace(go.Scatter(x=weeks, y=suceptible_total, mode='lines', name='Suceptible (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=exposed_total, mode='lines', name='Exposed (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=infected_total, mode='lines', name='Infected (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=recovered_total, mode='lines', name='Recovered (Total)'))
fig1.add_trace(go.Scatter(x=weeks, y=dead_total, mode='lines', name='Dead (Total)'))
fig1.update_layout(title='Plot 1: SEIRD variables over time (Total)')
fig1.show()

In [22]:
fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=weeks, y=infected_total, mode='lines', name='Infected (Total)'))
fig2.add_trace(go.Scatter(x=weeks, y=dead_total, mode='lines', name='Dead (Total)'))
fig2.update_layout(title='Plot 2: Infected and Dead over time (Total)')
fig2.show()

In [23]:
num_cities = len(history[0].city)
cities = list(history[0].city.keys())

fig3 = make_subplots(rows=num_cities, cols=1, shared_xaxes=True, subplot_titles=[
                     f'{cities[i]}' for i in range(num_cities)])

for city_index in range(num_cities):
    city_infected = [
        data.city[cities[city_index]].infected for data in history]
    city_dead = [data.city[cities[city_index]].dead for data in history]

    fig3.add_trace(go.Scatter(x=weeks, y=city_infected, mode='lines',
                   name=f'{cities[city_index]} Infected'), row=city_index+1, col=1)
    fig3.add_trace(go.Scatter(x=weeks, y=city_dead, mode='lines',
                   name=f'{cities[city_index]} Dead'), row=city_index+1, col=1)

fig3.update_layout(height=200 * num_cities, width=800,
                   title_text='Plot 3: Infected and Deceased per City')
fig3.show()


In [24]:
# Plot the actions taken by the agent
fig4 = go.Figure()
fig4.add_trace(go.Scatter(x=weeks, y=actions, mode='lines', name='Actions'))
fig4.update_layout(title='Plot 4: Actions taken by the agent')
fig4.show()