Skip to content

Commit

Permalink
Added some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
semohr committed May 19, 2021
1 parent f6108a7 commit 0179ce1
Showing 1 changed file with 58 additions and 8 deletions.
66 changes: 58 additions & 8 deletions covid19_npis/model/approximate_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def build_iaf(values_iaf_dict, order_list, values_exclude_dict=None):
size_iaf = sum([int(np.prod(tensor.shape)) for tensor in values_iaf_dict.values()])

size_splits = [int(np.prod(v.shape)) for v in values_iaf_dict.values()]

# Create a list of tensors
iaf_split = tfb.Split(size_splits, axis=-1)

iaf_restructure = tfb.Restructure(init_iaf_struct)
Expand All @@ -82,6 +84,9 @@ def build_iaf(values_iaf_dict, order_list, values_exclude_dict=None):
hidden_units=[size_iaf, size_iaf],
input_order=order,
activation="elu",
kernel_initializer=tf.keras.initializers.GlorotNormal(
seed=None
),
)
)
)
Expand Down Expand Up @@ -131,34 +136,75 @@ def build_iaf(values_iaf_dict, order_list, values_exclude_dict=None):


def build_approximate_posterior(model):
order_list = ["left-to-right", "right-to-left", "left-to-right"]
order_list_short = ["right-to-left", "left-to-right"]
"""
Parameters
----------
model : pymc4.model
TODO
----
Description
"""

"""
Get sampling state and the names of all variables i.e. all distributions from our model
"""
_, state = pm.evaluate_model_transformed(model)
state, _ = state.as_sampling_state()

"""
Retrieve the name of all transformed distributions
"""
values_dict = dict(state.all_unobserved_values)
transformed_names = list(values_dict.keys())

noise_vars = ("main_model|noise_R", "main_model|noise_R_age")
values_without_noise = {k: v for k, v in values_dict.items() if k not in noise_vars}
values_with_noise = {k: v for k, v in values_dict.items() if k in noise_vars}
"""
Filter all noise distributions we filter by name:
At the moment these 4 names get filtered:
main_model|noise_R,
main_model|noise_R_age,
main_model|__Exp-SinhTanh_noise_R_sigma_age,
main_model|__Exp-SinhTanh_noise_R_sigma",
"""
values_without_noise = {
key: value for key, value in values_dict.items() if "noise" not in key
}
values_with_noise = {
key: value for key, value in values_dict.items() if "noise" in key
}

# Note: Why are we doing this split here? Corresponds to the note below
values_except_noise_age = {
k: v for k, v in values_dict.items() if k not in ("main_model|noise_R_age",)
k: v
for k, v in values_dict.items()
if k
not in ("main_model|noise_R_age", "main_model|__Exp-SinhTanh_noise_R_sigma_age")
}
values_noise_age = {
k: v for k, v in values_dict.items() if k in ("main_model|noise_R_age",)
k: v
for k, v in values_dict.items()
if k
in ("main_model|noise_R_age", "main_model|__Exp-SinhTanh_noise_R_sigma_age")
}

"""
Construct joined distribution from a sample of all prior distributions.
(not taking noise into respect)
# Note: Does this correspond to the variational parameters Phi, in the sticking the landing paper?
"""
# Note: Why Normal distribution as base? Shouldn't that depend on the underlying distribution?
normal_base = tfd.JointDistributionNamed(
{
name: tfd.Sample(tfd.Normal(loc=0.0, scale=1.0), sample_shape=tensor.shape)
for name, tensor in values_dict.items()
},
validate_args=False,
name=None,
name="normal_base",
)

# Note: What is this abomination? Can we apply some make-up please?
order_list = ["left-to-right", "right-to-left", "left-to-right"]
order_list_short = ["right-to-left", "left-to-right"]
bijectors_list = []
for vals, orders, vals_exclude in [
(values_without_noise, order_list, values_with_noise),
Expand All @@ -171,5 +217,9 @@ def build_approximate_posterior(model):
bijectors_list.append(build_iaf(vals, orders, vals_exclude))

bijector = tfb.Chain(bijectors_list)

"""We transform our joined distribution with the previously created bijector.
"""
posterior_approx = tfd.TransformedDistribution(normal_base, bijector=bijector)

return posterior_approx, bijector, transformed_names

0 comments on commit 0179ce1

Please sign in to comment.