Skip to content

Commit

Permalink
Modify model to allow for simpler models (SIR, no age groups)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdehning committed Apr 21, 2021
1 parent 6570541 commit f6108a7
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 112 deletions.
3 changes: 2 additions & 1 deletion covid19_npis/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .disease_spread import (
InfectionModel,
InfectionModel_SIR,
construct_generation_interval,
InfectionModel_unrolled,
construct_E_0_t,
Expand All @@ -9,7 +10,7 @@

from .likelihood import studentT_likelihood

from .reproduction_number import construct_R_t, construct_R_0
from .reproduction_number import construct_R_t, construct_R_0, construct_lambda_0

from .utils import (
convolution_with_fixed_kernel,
Expand Down
246 changes: 188 additions & 58 deletions covid19_npis/model/disease_spread.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,66 +417,66 @@ def construct_C(
name, modelParams, mean_C=-2.0, sigma_C=1, sigma_country=0.5, sigma_age=0.5
):

# C_country_sigma = yield HalfStudentT(
# df=4,
# name=f"{name}_country_sigma",
# scale=sigma_country,
# conditionally_independent=True,
# event_stack=(1, 1),
# )
# C_age_sigma = yield HalfStudentT(
# df=4,
# name=f"{name}_age_sigma",
# scale=sigma_age,
# conditionally_independent=True,
# event_stack=(1, 1),
# )
#
# Delta_C_country = (
# yield Normal(
# name=f"Delta_{name}_country",
# loc=0,
# scale=1,
# conditionally_independent=True,
# event_stack=(modelParams.num_countries, 1),
# shape_label=("country", None),
# )
# ) * C_country_sigma
#
# Delta_C_age = (
# yield Normal(
# name=f"Delta_{name}_age",
# loc=0,
# scale=1,
# conditionally_independent=True,
# event_stack=(
# 1,
# modelParams.num_age_groups * (modelParams.num_age_groups - 1) // 2,
# ),
# shape_label=(None, "age groups cross terms"),
# )
# ) * C_age_sigma
#
# Base_C = (
# yield Normal(
# name=f"Base_{name}",
# loc=0,
# scale=sigma_C,
# conditionally_independent=True,
# event_stack=(1, 1),
# )
# ) + mean_C
#
# C_array = Base_C + Delta_C_age + Delta_C_country

C_array = mean_C * np.ones(
(
modelParams.num_countries,
modelParams.num_age_groups * (modelParams.num_age_groups - 1) // 2,
),
dtype="float32",
C_country_sigma = yield HalfStudentT(
df=4,
name=f"{name}_country_sigma",
scale=sigma_country,
conditionally_independent=True,
event_stack=(1, 1),
)
C_age_sigma = yield HalfStudentT(
df=4,
name=f"{name}_age_sigma",
scale=sigma_age,
conditionally_independent=True,
event_stack=(1, 1),
)

Delta_C_country = (
yield Normal(
name=f"Delta_{name}_country",
loc=0,
scale=1,
conditionally_independent=True,
event_stack=(modelParams.num_countries, 1),
shape_label=("country", None),
)
) * C_country_sigma

Delta_C_age = (
yield Normal(
name=f"Delta_{name}_age",
loc=0,
scale=1,
conditionally_independent=True,
event_stack=(
1,
modelParams.num_age_groups * (modelParams.num_age_groups - 1) // 2,
),
shape_label=(None, "age groups cross terms"),
)
) * C_age_sigma

Base_C = (
yield Normal(
name=f"Base_{name}",
loc=0,
scale=sigma_C,
conditionally_independent=True,
event_stack=(1, 1),
)
) + mean_C

C_array = Base_C + Delta_C_age + Delta_C_country

# C_array = mean_C * np.ones(
# (
# modelParams.num_countries,
# modelParams.num_age_groups * (modelParams.num_age_groups - 1) // 2,
# ),
# dtype="float32",
# )

C_array = tf.math.sigmoid(C_array)
C_array = tf.clip_by_value(
C_array, 0.001, 0.99
Expand Down Expand Up @@ -632,6 +632,136 @@ def loop_body(params, elems):
return daily_infections_final # batch_dims x time x country x age


def InfectionModel_SIR(N, I_0, lambda_t, C, recov_rate=1 / 8, h_t=None):
r"""
This function combines a variety of different steps:
#. Converts the given :math:`E_0` values to an exponential distributed initial :math:`E_{0_t}` with an
length of :math:`l` this can be seen in :py:func:`_construct_E_0_t`.
#. Calculates :math:`R_{eff}` for each time step using the given contact matrix :math:`C`:
.. math::
R_{diag} &= \text{diag}(\sqrt{R}) \\
R_{eff} &= R_{diag} \cdot C \cdot R_{diag}
#. Calculates the :math:`\tilde{I}` arrays i.e. new infectious for each age group and
country, with the efficient reproduction matrix :math:`R_{eff}`, the susceptible pool
:math:`S`, the population size :math:`N` and the generation interval :math:`g(\tau)`.
This is done recursive for every time step.
.. math::
\tilde{I}(t) &= \frac{S(t)}{N} \cdot R_{eff} \cdot \sum_{\tau=0}^{t} \tilde{I}(t-1-\tau) g(\tau) \\
S(t) &= S(t-1) - \tilde{I}(t-1)
Parameters
----------
N:
Total population per country
|shape| country, age_group
I_0_t:
Initial number of infectious.
|shape| batch_dims, country, age_group
lambda_t:
spreading rate matrix.
|shape| time, batch_dims, country, age_group
C:
inter-age-group Contact-Matrix (see 8)
|shape| country, age_group, age_group
recov_rate:
recovery rate
|shape| batch_dims, country, age_group
h_t:
eventual external input
|shape| time, batch_dims, country, age_group
Returns
-------
:
Sample from distribution of new, daily cases
"""

# @tf.function(autograph=False)

# For robustness of inference
# R_t = tf.clip_by_value(R_t, 0.5, 7)
# R_t = tf.clip_by_norm(R_t, 100, axes=0)

def loop_body(params, elems):
# Unpack a:
# Old E_next is E_lastv now
lambda_, h = elems
_, I_last, S_t = params # Internal state

# Internal state
f = S_t / N

# log.debug(f"I_t {I_t}")

# Calculate effective lambda_t [country,age_group] from Contact-Matrix C [country,age_group,age_group]
lambda_sqrt = tf.math.sqrt(lambda_ + 1e-7)
lambda_diag = tf.linalg.diag(lambda_sqrt)
lambda_eff = tf.einsum(
"...ij,...ik,...kl->...il", lambda_diag, C, lambda_diag
) # Effective growth number

# log.debug(f"infectious: {infectious}")
# log.debug(f"R_eff:\n{R_eff}")
# log.debug(f"f:\n{f}")
# log.debug(f"h:\n{h}")

# Calculate new infections
infections = tf.einsum("...ci,...cij,...cj->...cj", I_last, lambda_eff, f) + h
infections = tf.clip_by_value(infections, 0.01, 1e9) # for robustness

# Calculate new infectious pool
I_new = infections + I_last - recov_rate * I_last

return infections, I_new, S_t

# Number of days that we look into the past for our convolution

S_initial = N - I_0

len_added = 1

lambda_t_for_loop = lambda_t[len_added:]
if not h_t:
h_t_for_loop = tf.zeros_like(lambda_t_for_loop)
else:
h_t_for_loop = h_t[len_added:]
# Initial susceptible population = total - infected

""" Calculate time evolution of new, daily infections
as well as time evolution of susceptibles
as lists
"""

initial = (
tf.zeros_like(S_initial),
I_0,
S_initial,
)
out = tf.scan(
fn=loop_body, elems=(lambda_t_for_loop, h_t_for_loop), initializer=initial
)
daily_infections_final = out[0]
daily_infections_final = tf.concat(
[I_0[tf.newaxis, ...], daily_infections_final], axis=0
)

# Transpose tensor in order to have batch dim before time dim
daily_infections_final = tf.einsum("t...ca->...tca", daily_infections_final)

log.debug(f"daily_infections_final:\n{daily_infections_final}")
log.debug(
f"daily_infections_final sum:\n{tf.reduce_sum(daily_infections_final, axis=-3)}"
)
daily_infections_final = tf.clip_by_value(daily_infections_final, 1e-6, 1e6)

return daily_infections_final # batch_dims x time x country x age


def InfectionModel_unrolled(N, E_0, R_t, C, g_p):
r"""
This function unrolls the loop. It compiling time is slower (about 10 minutes)
Expand Down
2 changes: 1 addition & 1 deletion covid19_npis/model/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def other_init(self, *args, **kwargs):

def init_with_softplus_transform(self, *args, **kwargs):
kwargs["validate_args"] = False
if "transform" not in kwargs.keys() or True:
if "transform" not in kwargs.keys():
kwargs["transform"] = transformations.Exp_SinhArcsinh()
super(self.__class__, self).__init__(*args, **kwargs)

Expand Down
17 changes: 10 additions & 7 deletions covid19_npis/model/number_of_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,13 +807,16 @@ def construct_testing_state(

""" Correlate with cholsky and multivariant normal distribution
"""
Sigma = yield LKJCholesky(
name="Sigma_cholesky",
dimension=4,
concentration=2.0, # eta
# validate_args=True,
transform=transformations.CorrelationCholesky(),
conditionally_independent=True,
Sigma = (
yield LKJCholesky(
name="Sigma_cholesky",
dimension=4,
concentration=2.0, # eta
# validate_args=True,
transform=transformations.CorrelationCholesky(),
conditionally_independent=True,
)
+ 1e-5
)
Sigma = tf.einsum(
"...ij,...i->...ij", # todo look at i,j
Expand Down

0 comments on commit f6108a7

Please sign in to comment.