Skip to content

Commit

Permalink
Added forgotten delta d_p
Browse files Browse the repository at this point in the history
Fixed small plot bugs
  • Loading branch information
semohr committed May 21, 2021
1 parent afddc90 commit de9a2f7
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 34 deletions.
29 changes: 26 additions & 3 deletions covid19_npis/model/reproduction_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ def _create_distributions(modelParams):
transform=transformations.SoftPlus(scale=3),
conditionally_independent=True,
)
d_sigma_change_point = HalfStudentT(
df=4,
name="d_sigma_change_point",
scale=0.3,
transform=transformations.SoftPlus(scale=3),
conditionally_independent=True,
)
delta_d_i = Normal(
name="delta_d_i",
loc=0.0,
Expand All @@ -209,6 +216,14 @@ def _create_distributions(modelParams):
shape_label=(None, "country", None),
conditionally_independent=True,
)
delta_d_p = Normal(
name="delta_d_p",
loc=0.0,
scale=1.0,
event_stack=(1, 1, modelParams.gamma_data_tensor.shape[-1]),
shape_label=(None, None, "change_point"),
conditionally_independent=True,
)

# We create a dict here to pass all distributions to another function
distributions = {}
Expand All @@ -221,8 +236,10 @@ def _create_distributions(modelParams):
distributions["delta_l_cross_i"] = delta_l_cross_i
distributions["d_sigma_interv"] = d_sigma_interv
distributions["d_sigma_country"] = d_sigma_country
distributions["d_sigma_change_point"] = d_sigma_change_point
distributions["delta_d_i"] = delta_d_i
distributions["delta_d_c"] = delta_d_c
distributions["delta_d_p"] = delta_d_p
if modelParams.num_age_groups > 1:
distributions["delta_alpha_cross_a"] = delta_alpha_cross_a
distributions["alpha_sigma_a"] = alpha_sigma_a
Expand Down Expand Up @@ -352,20 +369,26 @@ def date():
delta_d_i = yield distributions["delta_d_i"]
d_sigma_interv = yield distributions["d_sigma_interv"]
delta_d_i = tf.einsum( # Multiply distribution by hyperprior
"...ica,...->...ica", delta_d_i, d_sigma_interv
"...icp,...->...icp", delta_d_i, d_sigma_interv
)

delta_d_c = yield distributions["delta_d_c"]
d_sigma_country = yield distributions["d_sigma_country"]
delta_d_c = tf.einsum( # Multiply distribution by hyperprior
"...ica,...->...ica", delta_d_c, d_sigma_country
"...icp,...->...icp", delta_d_c, d_sigma_country
)

delta_d_p = yield distributions["delta_d_p"]
d_sigma_change_point = yield distributions["d_sigma_change_point"]
delta_d_p = tf.einsum( # Multiply distribution by hyperprior
"...icp,...->...icp", delta_d_p, d_sigma_change_point
)
# Get data tensor padded with 0 if the cp does not exist for an intervention/country combo
d_data = (
modelParams.date_data_tensor
) # shape intervention, country, change_points

d_return = d_data + delta_d_i + delta_d_c
d_return = d_data + delta_d_i + delta_d_c + delta_d_p
# Clip by value should be in range of our simulation
d_return = tf.clip_by_value(
d_return, -modelParams.length_sim, modelParams.length_sim
Expand Down
10 changes: 2 additions & 8 deletions covid19_npis/plot/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def helper_plot(posterior, prior, name_str):
if plot_age_groups_together and ("age_group" in df.index.names):
unq_age = df.index.get_level_values("age_group").unique()
fig, ax = plt.subplots(
len(unq_age),
1,
figsize=(2.2, 2.2 * len(unq_age),),
squeeze=False
len(unq_age), 1, figsize=(2.2, 2.2 * len(unq_age),), squeeze=False
)
for i, ag in enumerate(unq_age):
# Create pivot table i.e. time on index and draw on columns
Expand Down Expand Up @@ -143,10 +140,7 @@ def helper_plot(posterior, prior, name_str):
return

fig, ax = plt.subplots(
num_rows,
1,
figsize=(2.2, 2.2 * num_rows,),
squeeze=False,
num_rows, 1, figsize=(2.2, 2.2 * num_rows,), squeeze=False,
)
for i, ag in enumerate(
df.index.get_level_values(df.index.names[0]).unique()
Expand Down
39 changes: 19 additions & 20 deletions covid19_npis/plot/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,39 +147,38 @@ def recursive_plot(df, name_str, observed=None):

# Remove "_" from name
name_str = name_str[1:]

if plot_age_groups_together and "age_group" in df.index.names:
unq_age = df.index.get_level_values("age_group").unique()

if observed is None or len(observed.columns) > 1:
if (observed is None) or (len(observed.columns) > 1):
fig, a_axes = plt.subplots(
len(unq_age),
1,
figsize=(4, 1.5 + 1.5 * len(unq_age)),
squeeze=False,
)
a_axes = a_axes[:, 0]
for i, ag in enumerate(unq_age):
temp = df.xs(ag, level="age_group")
a_axes = a_axes[:, 0]
for i, ag in enumerate(unq_age):
temp = df.xs(ag, level="age_group")

# Create pivot table i.e. time on index and draw on columns
temp = temp.reset_index().pivot_table(index="time", columns="draw")
# Create pivot table i.e. time on index and draw on columns
temp = temp.reset_index().pivot_table(index="time", columns="draw")

ax_now = a_axes[i] if len(unq_age) > 1 else a_axes
# Plot data
_timeseries(temp.index, temp.to_numpy(), what="model", ax=ax_now)
ax_now = a_axes[i] if len(unq_age) > 1 else a_axes
# Plot data
_timeseries(temp.index, temp.to_numpy(), what="model", ax=ax_now)

# Plot observed
if observed is not None:
_timeseries(
observed[ag].index,
observed[ag].to_numpy(),
what="data",
ax=ax_now,
)
# Plot observed
if observed is not None:
_timeseries(
observed[ag].index,
observed[ag].to_numpy(),
what="data",
ax=ax_now,
)

# Set title for axis
ax_now.set_title(ag)
# Set title for axis
ax_now.set_title(ag)
else:
# plot summarized data
fig, a_axes = plt.subplots(2, 1, figsize=(4, 1.5 * 2),)
Expand Down
9 changes: 6 additions & 3 deletions scripts/plot_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
# @Author: Sebastian B. Mohr
# @Created: 2020-12-18 14:40:45
# @Last Modified: 2021-01-27 12:54:39
# @Last Modified: 2021-05-21 16:01:36
# ------------------------------------------------------------------------------ #

# Get trace fp
Expand Down Expand Up @@ -101,8 +101,11 @@ def dir_path(string):
# ------------------------------------------------------------------------------ #
# Load pickled trace
# ------------------------------------------------------------------------------ #
# modelParams, trace = covid19_npis.utils.load_trace_zarr(args.file)
modelParams, trace = covid19_npis.utils.load_trace(args.file)
try:
modelParams, trace = covid19_npis.utils.load_trace(args.file)
except Exception as e:
modelParams, trace = covid19_npis.utils.load_trace_zarr(args.file)

modelParams._R_interval_time = 5
modelParams._const_contact = False

Expand Down

0 comments on commit de9a2f7

Please sign in to comment.