Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
jdehning committed May 20, 2021
1 parent 25af94a commit 5971c39
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 112 deletions.
8 changes: 4 additions & 4 deletions covid19_npis/model/deaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _calc_Phi_IFR(
shape_label=("country",),
)

n_batches = None if (len(beta.shape)==1) else beta.shape[0]
n_batches = None if (len(beta.shape) == 1) else beta.shape[0]

# for robustness, clip at about 5 sigmas
alpha = tf.clip_by_value(alpha, 0.1, 0.14)
Expand Down Expand Up @@ -266,8 +266,8 @@ def _calc_Phi_IFR(
for c, country in enumerate(modelParams.countries):
phi_c = []

if modelParams.data_summary['age_groups_summarized'][c]:
age_group_c = modelParams.data_summary['age_groups_ref']
if modelParams.data_summary["age_groups_summarized"][c]:
age_group_c = modelParams.data_summary["age_groups_ref"]
else:
age_group_c = country.age_groups

Expand All @@ -280,7 +280,7 @@ def _calc_Phi_IFR(
log.debug(f"phi_a\n{phi_a.shape}")
phi_c.append(phi_a)
# else:
# phi_c.append(tf.constant(np.NaN,shape=(n_batches,) if n_batches else ()))
# phi_c.append(tf.constant(np.NaN,shape=(n_batches,) if n_batches else ()))
phi.append(phi_c)
log.debug(f"phi\n{tf.convert_to_tensor(phi).shape}")

Expand Down
37 changes: 25 additions & 12 deletions covid19_npis/model/disease_spread.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def construct_E_0_t(

# # eigvals, _ = tf.linalg.eigh(R_t[..., i_data_begin, :, :])
# # largest_eigval = eigvals[-1]
R_t_rescaled = (R_t+1e-5) ** (1. / (modelParams._R_interval_time+1e-3))
R_t_rescaled = (R_t + 1e-5) ** (1.0 / (modelParams._R_interval_time + 1e-3))
R_inv = 1 / (R_t_rescaled + 1e-3)
R_inv = tf.clip_by_value(R_inv, clip_value_min=0.7, clip_value_max=1.2)
"""
Expand All @@ -81,7 +81,14 @@ def construct_E_0_t(
avg_cases_begin = []
for c in range(data.shape[1]):
avg_cases_begin.append(
tf.reduce_mean(data[i_data_begin_list[c] : i_data_begin_list[c] + modelParams._R_interval_time, c], axis=0)
tf.reduce_mean(
data[
i_data_begin_list[c] : i_data_begin_list[c]
+ modelParams._R_interval_time,
c,
],
axis=0,
)
)
# avg_cases_begin = np.array(avg_cases_begin)
E_t = tf.stack(avg_cases_begin)
Expand Down Expand Up @@ -422,7 +429,9 @@ def construct_C(
# len_batch_shape = len(pos_tests.shape) - 3
C_matrix = yield Deterministic(
name=f"{name}",
value=tf.eye(modelParams.num_age_groups,batch_shape=[modelParams.num_countries]),
value=tf.eye(
modelParams.num_age_groups, batch_shape=[modelParams.num_countries]
),
shape_label=("country", "age_group_i", "age_group_j"),
)
return C_matrix
Expand Down Expand Up @@ -461,10 +470,10 @@ def construct_C(
conditionally_independent=True,
event_stack=(
1,
modelParams.num_age_groups * (modelParams.num_age_groups - 1) // 2,
),
modelParams.num_age_groups * (modelParams.num_age_groups - 1) // 2,
),
shape_label=(None, "age groups cross terms"),
)
)
) * C_age_sigma

Base_C = (
Expand All @@ -480,12 +489,13 @@ def construct_C(
C_array = Base_C + Delta_C_age + Delta_C_country
C_array = tf.math.sigmoid(C_array)
C_array = tf.clip_by_value(
C_array, 0.001, 0.99
C_array, 0.001, 0.99
) # ensures off diagonal terms are smaller than diagonal terms

size = modelParams.num_age_groups
transf_array = lambda arr: normalize_matrix(
_subdiagonal_array_to_matrix(arr, size) + tf.linalg.eye(size, dtype=arr.dtype)
_subdiagonal_array_to_matrix(arr, size)
+ tf.linalg.eye(size, dtype=arr.dtype)
)

yield Deterministic(
Expand Down Expand Up @@ -570,9 +580,12 @@ def loop_body(params, elems):
# log.debug(f"E_t {E_t}")
# Calc "infectious" people, weighted by serial_p (country x age_group)

E_lastv_noNaN = tf.where(tf.math.is_nan(E_lastv),tf.zeros(E_lastv.shape),E_lastv) # replacing NaN values -> 0
infectious = tf.einsum("t...ca,...t->...ca", E_lastv_noNaN, gen_kernel) # Convolution

E_lastv_noNaN = tf.where(
tf.math.is_nan(E_lastv), tf.zeros(E_lastv.shape), E_lastv
) # replacing NaN values -> 0
infectious = tf.einsum(
"t...ca,...t->...ca", E_lastv_noNaN, gen_kernel
) # Convolution

# Calculate effective R_t [country,age_group] from Contact-Matrix C [country,age_group,age_group]
R_sqrt = tf.math.sqrt(R + 1e-7)
Expand Down Expand Up @@ -617,7 +630,7 @@ def loop_body(params, elems):
initial = (
tf.zeros(S_initial.shape, dtype=S_initial.dtype),
E_0_t[:len_gen_interv_kernel],
S_initial
S_initial,
)
out = tf.scan(fn=loop_body, elems=(R_t_for_loop, h_t_for_loop), initializer=initial)
daily_infections_final = out[0]
Expand Down
1 change: 0 additions & 1 deletion covid19_npis/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def main_model(modelParams):
C = yield construct_C(name="C", modelParams=modelParams)
log.debug(f"C:\n{C}")


""" # Create generation interval g:
"""
len_gen_interv_kernel = 12
Expand Down
3 changes: 2 additions & 1 deletion covid19_npis/model/number_of_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def weekly_modulation(name, modelParams, cases):

t = modelParams.get_weekdays() # get array with weekdays
f = (1 - weight) * (
1 - tf.math.abs(
1
- tf.math.abs(
tf.math.sin(tf.reshape(t, (-1, 1, 1)) / 7 * tf.constant(np.pi) + offset / 2)
)
)
Expand Down
2 changes: 1 addition & 1 deletion covid19_npis/model/reproduction_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def gamma(d_i_c_p, l_i_sign):

# Get the sigmoid and multiply it with our gamma tensor
sigmoid = tf.math.sigmoid(
tf.einsum("...i,...icpt->...icpt", 4.0 / (l_i_sign + 1e-3), (t - d_i_c_p))
tf.einsum("...i,...icpt->...icpt", 4.0 / (l_i_sign + 1e-3), (t - d_i_c_p))
)
gamma_i_c_p = tf.einsum(
"...icpt,icp->...icpt", sigmoid, modelParams.gamma_data_tensor,
Expand Down
86 changes: 51 additions & 35 deletions covid19_npis/modelParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ModelParams:
def __init__(
self,
countries,
const_contact=True, # set 'true' for a constant contact matrix (without age-group interaction)
const_contact=True, # set 'true' for a constant contact matrix (without age-group interaction)
R_interval_time=5, # time interval over which the reproduction number is calculated
offset_sim_data=20,
minimal_daily_cases=40,
Expand Down Expand Up @@ -114,8 +114,8 @@ def join_dataframes(key, check_dict, attribute_name):
else:
df = getattr(country, attribute_name)
# if accumulate:
# df_total[country.name] = getattr(country, attribute_name).sum(axis=1)
return df#, df_total
# df_total[country.name] = getattr(country, attribute_name).sum(axis=1)
return df # , df_total

# sort countries alphabetically to have consistent indexes
c_sort = np.argsort([c.name for c in countries])
Expand Down Expand Up @@ -156,7 +156,7 @@ def join_dataframes(key, check_dict, attribute_name):
self._dataframe_interventions = join_dataframes(
key="/interventions.csv",
check_dict=check,
attribute_name="data_interventions"
attribute_name="data_interventions",
)

""" Update data summary
Expand All @@ -166,7 +166,6 @@ def join_dataframes(key, check_dict, attribute_name):

# self._adjust_stratification(attributes=['_dataframe_new_cases'])


""" Calculate positive tests data tensor (tensorflow)
Set data tensor, replaces values smaller than 40 by nans.
"""
Expand All @@ -182,7 +181,7 @@ def join_dataframes(key, check_dict, attribute_name):
"""
self.deaths_data_tensor = self._dataframe_deaths # Uses setter below!

self._set_data_mask() # prepare data masks for use in likelihood computation
self._set_data_mask() # prepare data masks for use in likelihood computation

""" # Update intervetions data tensor
"""
Expand Down Expand Up @@ -213,19 +212,21 @@ def _update_data_summary(self):
country_name = country.name
# age_groups = count
data["countries"].append(country_name)
### added
### added
age_groups = self.pos_tests_dataframe[country_name].columns
if len(age_groups)>1:
if len(data['age_groups']):
assert len(data['age_groups'])==len(age_groups), "data with different number of age groups provided - please provide either similiar age stratification, or summarized data"
if len(age_groups) > 1:
if len(data["age_groups"]):
assert len(data["age_groups"]) == len(
age_groups
), "data with different number of age groups provided - please provide either similiar age stratification, or summarized data"
else:
data['age_groups'] = list(age_groups)
data['age_groups_ref'] = country.age_groups
data['age_groups_summarized'].append(False)
data["age_groups"] = list(age_groups)
data["age_groups_ref"] = country.age_groups
data["age_groups_summarized"].append(False)
else:
data['age_groups_summarized'].append(True)
data['age_groups_summarized'] = np.array(data['age_groups_summarized'])
# data["age_group_data"][country_name] = list(self.pos_tests_dataframe[country_name].columns)
data["age_groups_summarized"].append(True)
data["age_groups_summarized"] = np.array(data["age_groups_summarized"])
# data["age_group_data"][country_name] = list(self.pos_tests_dataframe[country_name].columns)
# Create age group list dynamic from data dataframe
# for age_group_name in self.pos_tests_dataframe.columns.get_level_values(
# level="age_group"
Expand Down Expand Up @@ -376,27 +377,36 @@ def pos_tests_data_tensor(self, df):
"""

# create tensor with stratified (provided or artificial) case data for all countries
new_cases = np.zeros((self.pos_tests_dataframe.shape[0],0,self.num_age_groups))
for i,c in enumerate(self.countries):
new_cases = np.zeros(
(self.pos_tests_dataframe.shape[0], 0, self.num_age_groups)
)
for i, c in enumerate(self.countries):
new_cases_tmp = self.pos_tests_dataframe[c.name].to_numpy()
if self.data_summary['age_groups_summarized'][i]: ## write existing data into array
new_cases_tmp = np.repeat(new_cases_tmp/self.num_age_groups,self.num_age_groups,axis=1) ## could be further refined according to demographics - but not suuuper important

new_cases = np.concatenate([new_cases,new_cases_tmp[:,np.newaxis,:]],axis=1)
if self.data_summary["age_groups_summarized"][
i
]: ## write existing data into array
new_cases_tmp = np.repeat(
new_cases_tmp / self.num_age_groups, self.num_age_groups, axis=1
) ## could be further refined according to demographics - but not suuuper important

new_cases = np.concatenate(
[new_cases, new_cases_tmp[:, np.newaxis, :]], axis=1
)

# prepend data with zeros for simulation offset
new_cases = np.concatenate(
[
np.zeros((self._offset_sim_data, self.num_countries, self.num_age_groups)),
np.zeros(
(self._offset_sim_data, self.num_countries, self.num_age_groups)
),
new_cases,
], axis=0
],
axis=0,
)

i_data_begin_list = []
for c in range(new_cases.shape[1]):
mask = (
np.nansum(new_cases[:, c, :], axis=-1) > self._minimal_daily_cases
)
mask = np.nansum(new_cases[:, c, :], axis=-1) > self._minimal_daily_cases
if mask.sum() == 0: # [False,False,False]
i_data_begin = len(mask) - 1
else:
Expand All @@ -409,8 +419,10 @@ def pos_tests_data_tensor(self, df):
for c, i in enumerate(self.indices_begin_data):
new_cases[:i, c, :] = np.nan

self._tensor_pos_tests = tf.constant(new_cases,dtype=self.dtype)
self._tensor_pos_tests_total = tf.constant(new_cases.sum(axis=-1),dtype=self.dtype)
self._tensor_pos_tests = tf.constant(new_cases, dtype=self.dtype)
self._tensor_pos_tests_total = tf.constant(
new_cases.sum(axis=-1), dtype=self.dtype
)

# self._tensor_pos_tests = tf.constant(new_cases_tensor, dtype=self.dtype)

Expand Down Expand Up @@ -548,18 +560,22 @@ def deaths_data_tensor(self, df):

self._tensor_deaths = tf.constant(deaths_tensor, dtype=self.dtype)



# ------------------------------------------------------------------------------ #
# masks
# ------------------------------------------------------------------------------ #
def _set_data_mask(self):

self.data_stratified_mask = np.argwhere(
(~np.isnan(self.pos_tests_data_tensor) & ~self.data_summary['age_groups_summarized'][np.newaxis,:,np.newaxis]).flatten()
(
~np.isnan(self.pos_tests_data_tensor)
& ~self.data_summary["age_groups_summarized"][np.newaxis, :, np.newaxis]
).flatten()
)
self.data_summarized_mask = np.argwhere(
(~np.isnan(self.pos_tests_total_data_tensor) & self.data_summary['age_groups_summarized'][np.newaxis,:]).flatten()
(
~np.isnan(self.pos_tests_total_data_tensor)
& self.data_summary["age_groups_summarized"][np.newaxis, :]
).flatten()
)

# ------------------------------------------------------------------------------ #
Expand All @@ -585,9 +601,9 @@ def N_data_tensor(self):
d_c = []

# Get age groups from country config
if self.data_summary['age_groups_summarized'][c]:
if self.data_summary["age_groups_summarized"][c]:
# If age group not present in this country data
age_groups_c = self.data_summary['age_groups_ref']
age_groups_c = self.data_summary["age_groups_ref"]
else:
age_groups_c = country.age_groups

Expand Down
22 changes: 8 additions & 14 deletions covid19_npis/plot/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,7 @@ def helper_plot(posterior, prior, name_str):
if plot_age_groups_together and ("age_group" in df.index.names):
unq_age = df.index.get_level_values("age_group").unique()
fig, ax = plt.subplots(
len(unq_age),
1,
figsize=(
2.2,
2.2 * len(unq_age),
),
squeeze=False,
len(unq_age), 1, figsize=(2.2, 2.2 * len(unq_age),), squeeze=False,
)
ax = ax[:, 0]
for i, ag in enumerate(unq_age):
Expand All @@ -119,7 +113,7 @@ def helper_plot(posterior, prior, name_str):
else:
prior_t = None

ax_now = ax[i] if len(unq_age)>1 else ax
ax_now = ax[i] if len(unq_age) > 1 else ax
# Plot
_distribution(
array_posterior=posterior_t,
Expand Down Expand Up @@ -319,13 +313,13 @@ def helper_plot(posterior, prior, name_str):
squeeze=False,
)
for x, row in enumerate(rows):
ax_row = ax[x] if len(rows)>1 else ax
ax_row = ax[x] if len(rows) > 1 else ax
if posterior is not None:
posterior_x = posterior.xs(row, level=posterior.index.names[-1])
if prior is not None:
prior_x = prior.xs(row, level=prior.index.names[-1])
for y, col in enumerate(cols):
ax_col = ax_row[y] if len(cols)>1 else ax_row
ax_col = ax_row[y] if len(cols) > 1 else ax_row
if posterior is not None:
posterior_xy = posterior_x.xs(col).to_numpy().flatten()
else:
Expand All @@ -347,12 +341,12 @@ def helper_plot(posterior, prior, name_str):

# Create titles for the axes
for x, row in enumerate(rows):
ax_row = ax[x] if len(rows)>1 else ax
ax_col = ax_row[0] if len(cols)>1 else ax_row
ax_row = ax[x] if len(rows) > 1 else ax
ax_col = ax_row[0] if len(cols) > 1 else ax_row
ax_col.set_ylabel(row)
for y, col in enumerate(cols):
ax_col = ax[x] if len(cols)>1 else ax
ax_row = ax_col[0] if len(rows)>1 else ax_col
ax_col = ax[x] if len(cols) > 1 else ax
ax_row = ax_col[0] if len(rows) > 1 else ax_col
ax_row.set_title(col)

# Suptitle
Expand Down

0 comments on commit 5971c39

Please sign in to comment.