In [1]:
from functools import partial
from pathlib import Path
from turtle import color
from setuptools_scm import meta
import torch
import numpy as np

from hnpe.misc import make_label
from hnpe.posterior import build_flow, IdentityJRNMM
from hnpe.simulator import prior_JRNMM, simulator_JRNMM
from hnpe.summary import summary_JRNMM
from hnpe.viz import get_posterior
from hnpe.eval_sim import plot_pairgrid_with_groundtruth

import matplotlib.pyplot as plt

In [2]:
PATH = 'saved_experiments/JR-NMM/Flows/'

In [4]:
t_rec = 8
single_rec = False
theta = [135.0, 220.0, 2000.0, 0.0]
# setup the parameters for the example
meta_parameters = {}
# how many extra observations to consider
meta_parameters["n_extra"] = 0
# what kind of summary features to use
meta_parameters["summary"] = 'Fourier'
# the parameters of the ground truth (observed data)
meta_parameters["theta"] = torch.tensor(theta)

# whether to do naive implementation
meta_parameters["naive"] = False

# which example case we are considering here
meta_parameters["case"] = PATH+"JRNMM_nextra_{:02}_trec_{}" \
                "naive_{}_" \
                "single_rec_{}_"\
                "C_{:.2f}_" \
                "mu_{:.2f}_" \
                "sigma_{:.2f}_" \
                "gain_{:.2f}".format(meta_parameters["n_extra"],
                                    t_rec,
                                    meta_parameters["naive"],
                                    single_rec,
                                    meta_parameters["theta"][0],
                                    meta_parameters["theta"][1],
                                    meta_parameters["theta"][2],
                                    meta_parameters["theta"][3])

# number of rounds to use in the SNPE procedure
meta_parameters["n_rd"] = 2
# number of simulations per round
meta_parameters["n_sr"] = 50_000
# number of summary features to consider
meta_parameters["n_sf"] = 33
# how many seconds the simulations should have (fs = 128 Hz)
meta_parameters["t_recording"] = t_rec
meta_parameters["n_ss"] = int(128 * meta_parameters["t_recording"])       
# label to attach to the SNPE procedure and use for saving files
meta_parameters["label"] = make_label(meta_parameters)
# run example with the chosen parameters
device = "cpu"

# set prior distribution for the parameters
input_parameters = ['C', 'mu', 'sigma', 'gain']
prior = prior_JRNMM(parameters=[('C', 10.0, 250.0),
                                ('mu', 50.0, 500.0),
                                ('sigma', 100.0, 5000.0),
                                ('gain', -20.0, +20.0)])

# choose how to setup the simulator for training
simulator = partial(simulator_JRNMM,
                    input_parameters=input_parameters,
                    t_recording=meta_parameters["t_recording"],
                    n_extra=meta_parameters["n_extra"],
                    p_gain=prior,
                    single_recording=False)  

# choose how to get the summary features
summary_extractor = summary_JRNMM(n_extra=meta_parameters["n_extra"],
                            d_embedding=meta_parameters["n_sf"],
                            n_time_samples=meta_parameters["n_ss"],
                            type_embedding=meta_parameters["summary"])

# let's use the log power spectral density instead
summary_extractor.embedding.net.logscale = True                        

# choose a function which creates a neural network density estimator
build_nn_posterior = partial(build_flow, 
                            embedding_net=IdentityJRNMM(),
                            naive=meta_parameters["naive"],
                            aggregate=True,
                            z_score_theta=True,
                            z_score_x=True)  

# get posterior 
posterior = get_posterior(
    simulator, prior, summary_extractor, build_nn_posterior,
    meta_parameters, round_=1
)


  warn(f"In one-dimensional output space, this flow is limited to Gaussians")


In [5]:
posterior

DirectPosterior(
               method_family=snpe,
               net=<a JRNMMFlow_nflows_factorized, see `.net` for details>,
               prior=prior_JRNMM(Uniform(low: torch.Size([4]), high: torch.Size([4])), 1),
               x_shape=torch.Size([1, 33, 1]))
               