Skip to content

Commit

Permalink
prepare for pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
woxxel committed May 5, 2021
1 parent eaaef95 commit 78be8f0
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 79 deletions.
44 changes: 1 addition & 43 deletions covid19_npis/model/disease_spread.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def construct_E_0_t(
"""
E_0_t_mean = [None for _ in range(len_gen_interv_kernel - 1, -1, -1)]
R_inv_transposed = tf.transpose(R_inv, perm=perm_forw)
# log.info(f'mean_test_delay: {mean_test_delay}')
for i in range(len_gen_interv_kernel - 1, -1, -1):
# R = tf.gather(R_t_rescaled, i_sim_begin_list + i, axis=-3, batch_dims=1,))
# E_t = tf.linalg.matvec(R_eff_inv, E_t)
Expand All @@ -134,16 +133,12 @@ def construct_E_0_t(
0
] # A little complicated expression, because tensorflow doesn't allow advanced numpy indexing

# log.info(f'R_current: {1/R_current}')
E_t = R_current * E_t
log.debug(f"i, E_t:{i}\n{E_t}")
E_0_t_mean[i] = E_t
E_0_t_mean = tf.stack(E_0_t_mean, axis=-3)
E_0_t_mean = tf.clip_by_value(E_0_t_mean, 1e-5, 1e6)
log.debug(f"E_0_t_mean:\n{E_0_t_mean}")
# log.info(f"E_0_t_mean:\n{E_0_t_mean.shape}")
# log.info(f"E_0_t_mean (young):\n{E_0_t_mean[:30,0,0]}")
# log.info(f"E_0_t_mean (old):\n{E_0_t_mean[:30,0,3]}")

E_0_diff_base = yield Normal(
name="E_0_diff_base",
Expand All @@ -156,7 +151,6 @@ def construct_E_0_t(
E_0_base = E_0_t_mean[..., 0:1, :, :] * tf.exp(E_0_diff_base)
E_0_mean_diff = E_0_t_mean[..., 1:, :, :] - E_0_t_mean[..., :-1, :, :]

# log.info(f"E_0_mean_diff shape:\n{E_0_mean_diff.shape}")
E_0_diff_add = yield Normal(
name="E_0_diff_add",
loc=0.0,
Expand All @@ -165,7 +159,6 @@ def construct_E_0_t(
event_stack=tuple(E_0_mean_diff.shape[-3:]),
)
E_0_base_add = E_0_mean_diff * tf.exp(E_0_diff_add)
# log.info(f"E_0_base_add shape:\n{E_0_base_add.shape}")
log.debug(f"E_0_base:\n{E_0_base}")
log.debug(f"E_0_base_add:\n{E_0_base_add}")
log.debug(f"R_t:\n{R_t.shape}")
Expand All @@ -178,8 +171,6 @@ def construct_E_0_t(
"...kca->k...ca", E_0_t_rand
) # Now: shape: len_gen_interv_kernel x batch_dims x countries x age_groups

# log.info(f"E_0_t_rand:\n{E_0_t_rand}")
# log.info(f"E_0_t_rand shape:\n{E_0_t_rand.shape}")
log.debug(f"E_0_t_rand:\n{E_0_t_rand}")
E_0_t = []
batch_shape = R_t.shape[1:-2]
Expand All @@ -202,10 +193,7 @@ def construct_E_0_t(
axis=0,
)
)
# log.info(f"E_0_t_rand shape:\n{E_0_t_rand.shape}")
E_0_t = tf.concat(E_0_t, axis=-2)
# log.info(f'E_0_t:\n{tf.reduce_sum(E_0_t,axis=0)}')
# log.info(f'E_0_t shape:\n{E_0_t.shape}')

return E_0_t

Expand Down Expand Up @@ -432,13 +420,8 @@ def construct_C(
value=tf.eye(modelParams.num_age_groups,batch_shape=[modelParams.num_countries]),
shape_label=("country", "age_group_i", "age_group_j"),
)
log.info(f'C_matrix (shape): \n{C_matrix.shape}')
return C_matrix
# log.info('only one age group identified - not using contact matrix')
else:
log.info(sigma_C)
log.info(sigma_country)
log.info(sigma_age)
C_country_sigma = yield HalfNormal(
name=f"{name}_country_sigma",
scale=sigma_country,
Expand Down Expand Up @@ -486,9 +469,7 @@ def construct_C(
event_stack=(1, 1),
)
) + mean_C
log.info(f'C country sigma(shape): \n{C_country_sigma.shape}')
log.info(f'Delta C country(shape): \n{Delta_C_country.shape}')
log.info(f'Base C (shape): \n{Base_C.shape}')

C_array = Base_C + Delta_C_age + Delta_C_country
C_array = tf.math.sigmoid(C_array)
C_array = tf.clip_by_value(
Expand All @@ -499,8 +480,6 @@ def construct_C(
transf_array = lambda arr: normalize_matrix(
_subdiagonal_array_to_matrix(arr, size) + tf.linalg.eye(size, dtype=arr.dtype)
)
# log.info(f'C array: \n{C_array}')
log.info(f'C array (shape): \n{C_array.shape}')

yield Deterministic(
name=f"{name}_mean",
Expand All @@ -515,8 +494,6 @@ def construct_C(
value=C_matrix,
shape_label=("country", "age_group_i", "age_group_j"),
)
# log.info(f'C_matrix: \n{C_matrix}')
log.info(f'C_matrix (shape): \n{C_matrix.shape}')
return C_matrix


Expand Down Expand Up @@ -566,10 +543,6 @@ def InfectionModel(N, E_0_t, R_t, C, gen_kernel):
Sample from distribution of new, daily cases
"""

# log.info(f'E_0_t\n{E_0_t}')
# log.info(f'R_t\n{R_t}')
# log.info(f'C\n{C}')

# @tf.function(autograph=False)

# For robustness of inference
Expand All @@ -585,16 +558,13 @@ def loop_body(params, elems):
# Internal state
f = S_t / N

# log.info(f'i:\n{i}')
# Convolution:

# log.debug(f"E_t {E_t}")
# Calc "infectious" people, weighted by serial_p (country x age_group)

# log.info(f'E_lastv:\n{E_lastv}')
E_lastv_noNaN = tf.where(tf.math.is_nan(E_lastv),tf.zeros(E_lastv.shape),E_lastv) # replacing NaN values -> 0
infectious = tf.einsum("t...ca,...t->...ca", E_lastv_noNaN, gen_kernel) # Convolution
# log.info(f'infectious:\n{infectious}')


# Calculate effective R_t [country,age_group] from Contact-Matrix C [country,age_group,age_group]
Expand All @@ -610,7 +580,6 @@ def loop_body(params, elems):
# log.debug(f"h:\n{h}")

# Calculate new infections
# log.info(f'R_eff:\n{R_eff}')
new = tf.einsum("...ci,...cij,...cj->...cj", infectious, R_eff, f) + h
new = tf.clip_by_value(new, 0, 1e9)
# log.debug(f"new:\n{new}") # kernel_time,batch,country,age_group
Expand All @@ -624,20 +593,12 @@ def loop_body(params, elems):

# Number of days that we look into the past for our convolution
len_gen_interv_kernel = gen_kernel.shape[-1]
# log.info(f'sum E\n{tf.reduce_sum(E_0_t, axis=0)}')
# log.info(f'N\n{N}')
# log.info(f'len gen interval\n{len_gen_interv_kernel}')
# log.info(f'isnan\n{tf.math.is_nan(E_0_t)}')
# log.info(f'isnan shape\n{tf.math.is_nan(E_0_t).shape}')
# log.info(f'zeros\n{tf.zeros(E_0_t).shape}')
# E_0_t_noNaN = tf.where(tf.math.is_nan(E_0_t),tf.zeros(E_0_t),E_0_t)
# S_initial = N - tf.reduce_sum(E_0_t_noNaN, axis=0)
S_initial = N - tf.reduce_sum(E_0_t, axis=0)
# log.info(f'S_initial\n{S_initial}')

R_t_for_loop = R_t[len_gen_interv_kernel:]
h_t_for_loop = E_0_t[len_gen_interv_kernel:]
# log.info(f'h_t_for_loop\n{h_t_for_loop[:30,...]}')
# Initial susceptible population = total - infected

""" Calculate time evolution of new, daily infections
Expand Down Expand Up @@ -665,9 +626,6 @@ def loop_body(params, elems):
)
daily_infections_final = tf.clip_by_value(daily_infections_final, 1e-6, 1e6)

# log.info(f'daily infections\n{daily_infections_final}')
# log.info(f'daily infections [shape]\n{daily_infections_final.shape}')

return daily_infections_final # batch_dims x time x country x age


Expand Down
12 changes: 0 additions & 12 deletions covid19_npis/model/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,17 @@ def studentT_likelihood(modelParams, pos_tests, total_tests, deaths):
"""

# log.info(f'pos_tests:\n{pos_tests}')
# log.info(f'total_tests:\n{total_tests}')
# log.info(f'deaths_tests:\n{deaths}')

likelihood = yield _studentT_positive_tests(modelParams, pos_tests)
# log.info(f'likelihood\n{likelihood.shape}')

if modelParams.data_summary["files"]["/tests.csv"]:
likelihood_total_tests = yield _studentT_total_tests(modelParams, total_tests)
# log.info(f'likelihood total tests\n{likelihood_total_tests.shape}')
likelihood = tf.concat([likelihood, likelihood_total_tests], axis=-1)

if modelParams.data_summary["files"]["/deaths.csv"]:
likelihood_deaths = yield _studentT_deaths(modelParams, deaths)
# log.info(f'likelihood deaths\n{likelihood_total_tests.shape}')
likelihood = tf.concat([likelihood, likelihood_total_tests], axis=-1)

log.debug(f"likelihood:\n{likelihood}")
# log.info(f"likelihood (all):\n{likelihood}")
return likelihood


Expand Down Expand Up @@ -112,8 +104,6 @@ def _studentT_positive_tests(modelParams, pos_tests):
observed_str = index_mask(modelParams.pos_tests_data_tensor, modelParams.data_stratified_mask) # could also be done in mP
observed_sum = index_mask(modelParams.pos_tests_total_data_tensor, modelParams.data_summarized_mask) # could also be done in mP
observed_masked = tf.concat([observed_str,observed_sum],axis=-1)
# log.info(f'loc masked\n{loc_masked}')
# log.info(f'observed masked\n{observed_masked}')

likelihood = yield StudentT(
name="likelihood_pos_tests",
Expand Down Expand Up @@ -160,9 +150,7 @@ def _studentT_total_tests(modelParams, total_tests):
# Sadly we do not have age strata for the total performed test. We sum over the
# age groups to get a value for all ages. We can add an exception later if we find
# data for that.
# log.info(f'total tests\n{total_tests.shape}')
total_tests_without_age = tf.reduce_sum(total_tests, axis=-1)
# log.info(f"total tests w/o age\n{total_tests_without_age.shape}")

# Scale of the likelihood sigma for each country
sigma = yield HalfCauchy(
Expand Down
4 changes: 0 additions & 4 deletions covid19_npis/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,6 @@ def convolution_with_map(data, kernel, modelParams):
)
# Transpose to get the time dimension to the back

log.info(data_shift.shape)
log.info(kernel.shape)

if len(data.shape) == 4:
output_shape = (
data.shape[0],
Expand All @@ -532,7 +529,6 @@ def convolution_with_map(data, kernel, modelParams):
dtype="float32",
fn_output_signature=tf.TensorSpec(shape=output_shape, dtype="float32"),
)
log.info(convolution.shape)
convolution = tf.einsum("t...ca->...tca", convolution)
return convolution

Expand Down
4 changes: 2 additions & 2 deletions covid19_npis/modelParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class ModelParams:
def __init__(
self,
countries,
const_contact=True,
R_interval_time=5,
const_contact=True, # set 'true' for a constant contact matrix (without age-group interaction)
R_interval_time=5, # time interval over which the reproduction number is calculated
offset_sim_data=20,
minimal_daily_cases=40,
min_offset_sim_death_data=40,
Expand Down

0 comments on commit 78be8f0

Please sign in to comment.