Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions clt_toolkit/base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,8 @@ def __init__(self,

for model in self.subpop_models.values():
model.metapop_model = self

self.run_input_checks()

def __getattr__(self, name):
"""
Expand All @@ -1163,6 +1165,16 @@ def __getattr__(self, name):
return self.subpop_models[name]
else:
raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}")

def run_input_checks(self) -> None:
"""
Run input checks to ensure that the provided inputs are valid.
Subclasses can override this method to add additional checks.
If inputs don't make sense we raise a MetapopModelError, and
in some cases only a warning is issued.
"""

pass

def modify_simulation_settings(self,
updates_dict: dict):
Expand Down Expand Up @@ -1241,7 +1253,6 @@ def simulate_until_day(self,
self.apply_inter_subpop_updates()

for subpop_model in self.subpop_models.values():

save_daily_history = subpop_model.simulation_settings.save_daily_history
timesteps_per_day = subpop_model.simulation_settings.timesteps_per_day

Expand Down Expand Up @@ -1433,6 +1444,8 @@ def __init__(self,
self.state.schedules = self.schedules

self.params = updated_dataclass(self.params, {"total_pop_age_risk": self.compute_total_pop_age_risk()})

self.run_input_checks()

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -1512,6 +1525,19 @@ def get_start_real_date(self):
print("Error: The date format should be YYYY-MM-DD.")

return start_real_date

def run_input_checks(self) -> None:
"""
Run input checks to ensure that the provided inputs are valid.
Subclasses can override this method to add additional checks.
If inputs don't make sense we raise a SubpopModelError, and
in some cases only a warning is issued.
"""

# Check that all compartments have non-negative initial values
for compartment_name, compartment in self.compartments.items():
if np.any(compartment.init_val < 0):
raise SubpopModelError(f"Compartment '{compartment_name}' has negative initial values.")

@abstractmethod
def create_compartments(self) -> sc.objdict[str, Compartment]:
Expand Down Expand Up @@ -1805,6 +1831,8 @@ def update_compartments(self) -> None:

# By construction (using binomial/multinomial with or without taylor expansion),
# more individuals cannot leave the compartment than are in the compartment
## TODO check whether the following reason is still valid: a flooring function
# was added to transition variables when using Poisson distributed transitions
# However, for Poisson any for ANY deterministic version, it is possible
# to have more individuals leaving the compartment than are in the compartment,
# and hence negative-valued compartments
Expand All @@ -1813,7 +1841,8 @@ def update_compartments(self) -> None:
# allows us to take derivatives in the torch implementation)
# The syntax is janky here -- we want everything as an array, but
# we need to pass a tensor to the torch functional
if "deterministic" in self.simulation_settings.transition_type:
if ("deterministic" in self.simulation_settings.transition_type) and \
(self.simulation_settings.use_deterministic_softplus):
compartment.current_val = \
np.array(torch.nn.functional.softplus(torch.tensor(compartment.current_val)))

Expand Down
6 changes: 6 additions & 0 deletions clt_toolkit/base_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class SimulationSettings:
valid value must be from `TransitionTypes`, specifying
the probability distribution of transitions between
compartments.
use_deterministic_softplus (bool):
If the transition type used is deterministic this determines
whether we use a softplus function once compartment values are
updated. If true this matches the behavior of the torch
implementation of the model, if false true zeros are used.
start_real_date (str):
actual date in string format "YYYY-MM-DD" that aligns with the
beginning of the simulation.
Expand All @@ -58,6 +63,7 @@ class SimulationSettings:

timesteps_per_day: int = 7
transition_type: str = TransitionTypes.BINOM
use_deterministic_softplus: bool = False
start_real_date: str = "2024-10-31"
save_daily_history: bool = True
transition_variables_to_save: tuple = ()
Expand Down
138 changes: 138 additions & 0 deletions clt_toolkit/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,48 @@ def plot_metapop_epi_metrics(metapop_model: MetapopModel,
for ix, (subpop_name, subpop_model) in enumerate(metapop_model.subpop_models.items()):
plot_subpop_epi_metrics(subpop_model, axes[ix])

@plot_subpop_decorator
def plot_subpop_epi_metrics_justM(subpop_model: SubpopModel,
ax: matplotlib.axes.Axes = None):
"""
Plots EpiMetric history for a single subpopulation model on the given axis.

Args:
subpop_model (SubpopModel):
Subpopulation model containing compartments.
ax (matplotlib.axes.Axes):
Matplotlib axis to plot on.
"""

# Compute summed history values for each age-risk group
history_vals_list = [np.average(age_risk_group_entry) for
age_risk_group_entry in subpop_model.epi_metrics.M.history_vals_list]

# Plot data with a label
ax.plot(history_vals_list, label="M", alpha=0.6)

# Set axis title and labels
ax.set_title(f"M")
ax.set_xlabel("Days")
ax.set_ylabel("Epi Metric Value")
ax.legend()


@plot_metapop_decorator
def plot_metapop_epi_metrics_justM(metapop_model: MetapopModel,
axes: matplotlib.axes.Axes):
"""
Plots the EpiMetric data for a metapopulation model.

Args:
metapop_model (MetapopModel):
Metapopulation model containing compartments.
axes (matplotlib.axes.Axes):
Matplotlib axes to plot on.
"""

for ix, (subpop_name, subpop_model) in enumerate(metapop_model.subpop_models.items()):
plot_subpop_epi_metrics_justM(subpop_model, axes[ix])

@plot_subpop_decorator
def plot_subpop_total_infected_deaths(subpop_model: SubpopModel,
Expand Down Expand Up @@ -175,6 +217,51 @@ def plot_metapop_total_infected_deaths(metapop_model: MetapopModel,
plot_subpop_total_infected_deaths(subpop_model, axes[ix])


@plot_subpop_decorator
def plot_subpop_total_infected(subpop_model: SubpopModel,
ax: matplotlib.axes.Axes = None):
"""
Plots data for a single subpopulation model on the given axis.

Args:
subpop_model (SubpopModel):
Subpopulation model containing compartments.
ax (matplotlib.axes.Axes):
Matplotlib axis to plot on.
"""

infected_compartment_names = [name for name in subpop_model.compartments.keys() if
"I" in name or "H" in name]

infected_compartments_history = [subpop_model.compartments[compartment_name].history_vals_list
for compartment_name in infected_compartment_names]

total_infected = np.sum(np.asarray(infected_compartments_history), axis=(0, 2, 3))

ax.plot(total_infected, label="Total infected", alpha=0.6)

ax.set_title(f"{subpop_model.name}")
ax.set_xlabel("Days")
ax.set_ylabel("Number of individuals")
ax.legend()

@plot_metapop_decorator
def plot_metapop_total_infected(metapop_model: MetapopModel,
axes: matplotlib.axes.Axes):
"""
Plots the total infected (IP+IS+IA) data for a metapopulation model.

Args:
metapop_model (MetapopModel):
Metapopulation model containing compartments.
axes (matplotlib.axes.Axes):
Matplotlib axes to plot on.
"""

# Iterate over subpop models and plot
for ix, (subpop_name, subpop_model) in enumerate(metapop_model.subpop_models.items()):
plot_subpop_total_infected(subpop_model, axes[ix])

@plot_subpop_decorator
def plot_subpop_basic_compartment_history(subpop_model: SubpopModel,
ax: matplotlib.axes.Axes = None):
Expand Down Expand Up @@ -219,3 +306,54 @@ def plot_metapop_basic_compartment_history(metapop_model: MetapopModel,
# Iterate over subpop models and plot
for ix, (subpop_name, subpop_model) in enumerate(metapop_model.subpop_models.items()):
plot_subpop_basic_compartment_history(subpop_model, axes[ix])

@plot_subpop_decorator
def plot_subpop_TransitionVariable(subpop_model: SubpopModel,
ax: matplotlib.axes.Axes = None):
"""
Plots the values for a given transition variable for a subpopulation model.

Args:
subpop_model (SubpopModel):
Subpopulation model containing transition variables.
axes (matplotlib.axes.Axes):
Matplotlib axes to plot on.
"""
#transition_history = subpop_model.transition_variables.R_to_S.history_vals_list
transition_history = np.array(subpop_model.transition_variables.ISH_to_HR.history_vals_list) + \
np.array(subpop_model.transition_variables.ISH_to_HD.history_vals_list)

#transition_history is AxR matrix, so need to sum over all entries
#total_infected = np.sum(np.asarray(infected_compartments_history), axis=(0, 2, 3))
total = [np.sum(age_risk_group_entry)
for age_risk_group_entry
in transition_history]

# Aggregate to daily values if needed
timesteps_per_day = subpop_model.simulation_settings.timesteps_per_day
if timesteps_per_day > 1:
total = np.array(total).reshape(-1, timesteps_per_day).sum(axis=1)

#ax.plot(total, label="R to S", alpha=0.6)
ax.plot(total, label="ISH to HR and HD", alpha=0.6)

ax.set_title(f"{subpop_model.name}")
ax.set_xlabel("Days")
ax.set_ylabel("Number of individuals")
ax.legend()

@plot_metapop_decorator
def plot_metapop_TransitionVariable(metapop_model: MetapopModel,
axes: matplotlib.axes.Axes):
"""
Plots the TransitionVariable for a metapopulation model.

Args:
metapop_model (MetapopModel):
Metapopulation model containing compartments.
axes (matplotlib.axes.Axes):
Matplotlib axes to plot on.
"""

for ix, (subpop_name, subpop_model) in enumerate(metapop_model.subpop_models.items()):
plot_subpop_TransitionVariable(subpop_model, axes[ix])
13 changes: 13 additions & 0 deletions developer_notes.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
# Journal

## 2025 12 11
Added input checks for subpop and metapop models.
We check that humidity, vaccination, contact matrix, and initial compartment values are non-negative. All values at zero are possible but wouldn't make sense.
For transition rates we need strictly positive values.
For vaccination rates we check whether cumulative vaccination rates in each age-risk group are not exceeding 100% in any 365-day period. This only issues a warning.
The mobility matrix (or travel_proportions) should have rows that sum to 1: this ensures people either travel to another subpopulation or stay in their home location.

A new parameter called use_deterministic_softplus is added to the simulation settings. If the object oriented model is run with deterministic transitions this can be used to prevent softplus values instead of zeros in compartments, which leads to strange behaviors when epidemics occur in populations without any exposure.

Small fixes were made to the travel model equations in the file `flu_travel_functions.py`.

## 2025 11 17 - Adding ghost compartments
Updated website notation and made code updates in a lot of places.


# For future developers (from LP)

Technical notes
Expand Down
Loading