In [1]:
import sys
import logging
import time
import itertools
import os
import datetime
import functools


"""
Some runtime optimizations for CPU (using tur nodes)
os.environ["OMP_NUM_THREADS"] = 32
tf.config.threading.set_inter_op_parallelism_threads(32)
tf.config.threading.set_intra_op_parallelism_threads(16)
tf.config.set_soft_device_placement(enabled)
tf.config.optimizer.set_jit(
    True
)
"""

import pymc4 as pm
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import time
import os
import matplotlib.pyplot as plt


sys.path.append("../../")
import covid19_npis


In [2]:
""" # Debugging and other snippets
"""


# Logs setup
log = logging.getLogger()
# Needed to set logging level before importing other modules
# log.setLevel(logging.DEBUG)
covid19_npis.utils.setup_colored_logs()
logging.getLogger("parso.python.diff").disabled = True
# Mute Tensorflow warnings ...
# logging.getLogger("tensorflow").setLevel(logging.ERROR)


if tf.executing_eagerly():
    log.warning("Running in eager mode!")



In [3]:
""" # 1. Data Retrieval
    Load data for different countries/regions, for now we have to define every
    country by hand maybe we want to automatize that at some point.

    TODO: maybe we want to outsource that to a different file at some point
"""

# Load our data from csv files into our own custom data classes

countries = [
    "Germany",
    "Belgium",
    #    "Czechia",
    # "Denmark",
    # "Finland",
    # "Greece",
    # "Italy",
    # "Netherlands",
    "Portugal",
    # "Romania",
    # "Spain",
    # "Sweden",
    # "Switzerland",
]

c = [
    covid19_npis.data.Country(f"../../data/coverage_db/{country}",)
    for country in countries
]


2021-06-03 08:58:08 wollex covid19_npis.data[9753] INFO Loaded data for Germany.
2021-06-03 08:58:09 wollex covid19_npis.data[9753] INFO Loaded data for Belgium.
2021-06-03 08:58:09 wollex covid19_npis.data[9753] INFO Loaded data for Portugal.


In [4]:
# Construct our modelParams from the data.
modelParams = covid19_npis.ModelParams(countries=c, minimal_daily_deaths=1)
# modelParams = covid19_npis.ModelParams.from_folder("../data/Germany_bundesländer/")

# Define our model
this_model = covid19_npis.model.model.main_model(modelParams)

# Test shapes, should be all 3:
def print_dist_shapes(st):
    for name, dist in itertools.chain(
        st.discrete_distributions.items(), st.continuous_distributions.items(),
    ):
        if dist.log_prob(st.all_values[name]).shape != (3,):
            log.warning(
                f"False shape: {dist.log_prob(st.all_values[name]).shape}, {name}"
            )
    for p in st.potentials:
        if p.value.shape != (3,):
            log.warning(f"False shape: {p.value.shape} {p.name}")


_, sample_state = pm.evaluate_model_transformed(this_model, sample_shape=(3,))
print_dist_shapes(sample_state)

In [None]:
"""  # 2. MCMC Sampling
"""
num_chains = 6
use_VI = True
jit_compile = True

if use_VI:
    begin_time = time.time()
    from tensorflow_probability import distributions as tfd
    from tensorflow_probability import bijectors as tfb
    
    
    _, state = pm.evaluate_model_transformed(this_model)
    state, _ = state.as_sampling_state()

    """
    Retrieve the name of all transformed distributions 
    """
    values_dict = dict(state.all_unobserved_values)
    transformed_names = list(values_dict.keys())


    """
    Construct joined distribution from a sample of all prior distributions. 
    (not taking noise into respect)
    # Note: Does this correspond to the variational parameters Phi, in the sticking the landing paper?
    """
    # Note: Why Normal distribution as base? Shouldn't that depend on the underlying distribution?
    normal_base = tfd.JointDistributionNamed(
        {
            name: tfd.Sample(tfd.Normal(loc=0.0, scale=1.0), sample_shape=tensor.shape)
            for name, tensor in values_dict.items()
        },
        validate_args=False,
        name="normal_base",
    )

    # Note: What is this abomination? Can we apply some make-up please?
    order_list = ["left-to-right", "right-to-left", "left-to-right"]

    bijector = covid19_npis.model.build_iaf(values_dict, order_list)

    """We transform our joined distribution with the previously created bijector.
    """
    posterior_approx = tfd.TransformedDistribution(normal_base, bijector=bijector)

    sample_size = 4

    (
        logpfn,
        init_random,
        _deterministics_callback,
        deterministic_names,
        state_,
    ) = pm.mcmc.samplers.build_logp_and_deterministic_functions(
        this_model, num_chains=sample_size, collect_reduced_log_prob=False
    )

    trace_loss = lambda traceable_quantities: tf.debugging.check_numerics(
        traceable_quantities.loss, f"loss not finite: {traceable_quantities.loss}"
    )

    # For eventual debugging:
    # tf.config.run_functions_eagerly(True)
    # tf.debugging.enable_check_numerics(stack_height_limit=50, path_length_limit=50)

    begin = time.time()
        
    if jit_compile:
        # Run the entire minimization inside a jit-compiled function. 
        @tf.function(autograph=False, experimental_compile=True)
        def fit_surrogate_posterior(*args, **kwargs):
            return tfp.vi.fit_surrogate_posterior(*args, **kwargs)
    else: 
        fit_surrogate_posterior = tfp.vi.fit_surrogate_posterior
    
    
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.003,
    decay_steps=5_000,
    decay_rate=0.3,
    staircase=False)

    
    loss_arr = fit_surrogate_posterior(
        logpfn,
        posterior_approx,
        tf.optimizers.Adam(
            learning_rate=lr_schedule, epsilon=1e-7, beta_1=0.9, beta_2=0.999, clipvalue=10.0
        ),
        100_000,
        convergence_criterion=tfp.optimizer.convergence_criteria.LossNotDecreasing(
            atol=0.02, rtol=None, window_size=2_000, min_num_steps=10_000),
        sample_size=sample_size,
        trainable_variables=None,
        #jit_compile=True,
        variational_loss_fn=functools.partial(
            tfp.vi.monte_carlo_variational_loss,
            discrepancy_fn=tfp.vi.kl_reverse,
            use_reparameterization=True,
        ),
        trace_fn=trace_loss,
    )
    loss_arr = loss_arr[np.array(loss_arr!=loss_arr[-1])]
    log.info(f"Likelihood: {np.mean(loss_arr[-1000:])}")

    _, st = pm.evaluate_model_posterior_predictive(
        this_model, values=posterior_approx.sample(100)
    )
    var_names = list(st.all_values.keys()) + list(st.deterministics_values.keys())
    samples = {
        k: (
            st.untransformed_values[k]
            if k in st.untransformed_values
            else (
                st.deterministics_values[k]
                if k in st.deterministics_values
                else st.transformed_values[k]
            )
        )
        for k in var_names
    }


    init_state = posterior_approx.sample(num_chains)
    init_state = [init_state[name] for name in transformed_names]
    bijector_to_list = tfb.Restructure(
        [name for name in transformed_names], {name: name for name in transformed_names}
    )
    bijector_list = tfb.Chain(
        [bijector_to_list, bijector, tfb.Invert(bijector_to_list)]
    )
    end_time = time.time()
    log.info("running time: {:.1f}s".format(end_time - begin_time))

    plt.figure()
    plt.plot(loss_arr)
    plt.title("Tuning bijector")
    plt.ylabel('Likelihood')
    plt.xlabel('optimization step')
    plt.show()
    
else:
    bijector_list = None
    init_state = None

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [None]:
begin_time = time.time()
log.info("start")


trace_tuning, trace = pm.sample(
    this_model,
    num_samples=1000,
    num_samples_binning=10,
    burn_in_min=10,
    burn_in=2000,
    use_auto_batching=False,
    num_chains=num_chains,
    xla=False,
    initial_step_size=0.001,
    ratio_tuning_epochs=1.3,
    max_tree_depth=4,
    decay_rate=0.75,
    target_accept_prob=0.75,
    step_size_adaption_per_chain=False,
    bijector=bijector_list,
    init_state=init_state
    # num_steps_between_results = 9,
    #    state=pm.evaluate_model_transformed(this_model)[1]
    # sampler_type="nuts",
)

end_time = time.time()
log.info("running time: {:.1f}s".format(end_time - begin_time))

In [None]:
plt.figure()
plt.plot(trace_tuning.sample_stats["step_size"][0])
plt.figure()
plt.plot(trace_tuning.sample_stats["lp"].T)
plt.show(block=False)

In [None]:
# We also Sample the prior for the kde in the plots (optional)
trace_prior = pm.sample_prior_predictive(
    this_model, sample_shape=(500,), use_auto_batching=False
)

In [None]:
fpath = f'./../traces/{datetime.datetime.now().strftime("%y_%m_%d_%H")}'

# Save our traces for the plotting script
store = covid19_npis.utils.save_trace_zarr(
#store = covid19_npis.utils.save_trace(
    trace, modelParams, store=fpath, trace_prior=trace_prior,
)

In [23]:
fpath='./traces/21_05_27_09/21_05_27_09_21'

In [22]:
cd notebooks/

/home/wollex/Data/Science/Corona/covid19_npis_europe/scripts/notebooks


In [None]:
os.system(f"python ./../plot_trace.py {fpath}")