In [None]:
import lampe
import numpy as np
import os
import torch
import zuko

import matplotlib.pyplot as plt
import ipywidgets as widgets
import scipy.integrate as integrate

from IPython.display import display
import pickle
from typing import Tuple, Callable, List, Dict
from tqdm.notebook import tqdm_notebook as tqdm
from torch import Tensor

In [None]:
N = 100  # Number of simulations before taking a summary
theta_grid = np.linspace(-8, 8, 1000)
N_simu = 50_000  # Number of simulations
M = 100_000  # Number of samples from the MCMC algorithm
warm_up_steps = 20_000  # burn in period
sigma = 0.01  # In the error model, std of the spikes
tau = 0.25  # In the erorr model, parameter of the Cauchy distribution
rho = 1 / 2  # In the error model, parameter of the Bernoulli distribution

In [None]:
def posterior_summaries(thetas: Tuple, x1: float, x2: float, sigma_2_y: float = 1):
    r"""Compute the true posterior of :math:`p(\theta | x)`.

    Args:
        thetas: points where we want the density function to be computed.
        x1: First dimension of the observation.
        x2: Second dimension of the observation.
        sigma_2_y (optional): Corruption. Defaults to 1 (no corruption).
    """

    def improper_posterior_summaries(theta, x1=x1, x2=x2):
        likelihood_x1_part = -N * np.square(theta - x1) / (2 * sigma_2_y**2)
        likelihood_x2_part = (
            -0.5 * np.square(sigma_2_y - x2) / (2 * np.square(sigma_2_y) / N)
        )
        prior_part = -np.square(theta) / 50
        return (
            np.exp(likelihood_x1_part) * np.exp(likelihood_x2_part) * np.exp(prior_part)
        )

    z_summaries, eps = integrate.quad(improper_posterior_summaries, -25, 25)
    return [
        improper_posterior_summaries(theta, x1, x2) / z_summaries for theta in thetas
    ]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [None]:
class CS:
    r"""Cancer and Stromal cell development in 2D space, marked point process.
    Total number of cells :math:`N^c`, the number of unobserved parents :math:`N^p` and the nomber of daughters for each parent :math:`N_i^d`
    :math:`N^c \sim Poisson(\lambda^c)`
    :ath:`N^p \sim Poisson(\lambda^p)
    :math:`N^d_i \sim Poisson(\lamda^d), i = 1,\ldots,N^p
    where :math:`\lambda^c,\lambda^p,\lamda^d` are the parameters.
    :math:`The affected radius r_i` represents the Euclidian distance from parent :math:`p_i` to its :math:`N_i^d`th nearest cell. Cells within this radius are infected.

    Data are summaried in a summary statisics x = N_cancer, N stromal, Mean Min Dist, Max Min Dist (Mean min/max distance from stromal cell to their nearest cancer cell).



    """

    def __init__(
        self, N_simu: int, miss_specification_param: float = 0.75, N_obs: int = M
    ) -> None:
        """Instanciate the object, runs the simulation and the generate the actual observed data.

        Args:
            N_simu: Number of simulations
            miss_specification_param (optional): Parameter of the Bernoulli distribution. Defaults to 0.75.
            N_obs (optional): Number of real observation to generate. Defaults to M.
        """
        self.name = "CS"
        self.prior_c = torch.distributions.Uniform(low=200, high=1500)
        self.prior_p = torch.distributions.Uniform(low=3, high=20)
        self.prior_d = torch.distributions.Uniform(low=10, high=20)
        self.has_post = False
        rate_c_0, rate_p_0, rate_d_0 = self.prior_sample()
        self.theta0 = torch.tensor([rate_c_0, rate_p_0, rate_d_0]).to(device)
        self.miss_specification_param = miss_specification_param

        self.features_dim = 4
        self.parameters_dim = 3  # self.simulate_data_batch(64, N_simu)
        self.data, self.true_obs_data = self.simulate_data(N_simu)
        self.data_test, _ = self.simulate_data(N_simu // 10)
        # self.true_obs_data = self.gen_obs(N_obs, self.theta0)
        self.y0 = self.true_obs_data["y"][0]
        self.y0_l = self.true_obs_data["y"].squeeze()
        self.y0_scaled = (self.y0 - self.data["scale_parameters"][0]) / self.data[
            "scale_parameters"
        ][1]
        self.y0_l_scaled = (self.y0_l - self.data["scale_parameters"][0]) / self.data[
            "scale_parameters"
        ][1]

    def prior_sample(self, n: int = 1) -> Tuple[Tensor, Tensor, Tensor]:
        """Samples from the priors

        Args:
            n (optional): Number of samples. Defaults to 1.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Samples from the 3 priors
        """
        return (
            self.prior_c.sample((n,)).to(device),
            self.prior_p.sample((n,)).to(device),
            self.prior_d.sample((n,)).to(device),
        )

    def simulate_data(self, N_simu: int, theta0: int = None) -> Tuple[Dict, Dict]:
        """Simulates the data, and generates the true observed one.

        Args:
            N_simu: Number of simulations
            theta0 (optional): Theta used as reference for the observed data. Vector containing the values of :math:`\lambda^c, \lambda^p, \lambda^d`. Defaults to None.

        Returns:
            Tuple[Dict, Dict]: _description_
        """

        simu_data = {
            "theta": [],
            "x": [],
            "cell_positions": [],
            "parent_positions": [],
            "cell_types": [],
        }
        true_obs = {
            "theta": [],
            "y": [],
            "cell_positions_obs": [],
            "parent_positions_obs": [],
            "cell_types_obs": [],
        }

        pbar = tqdm(total=N_simu, desc="Simulating data", unit="data")  # Progress bar
        n_simulated = 0
        while n_simulated < N_simu:
            if theta0 is not None:
                rate_c, rate_p, rate_d = theta0
            else:
                # Sample from the prior #TODO Prendre le sef.prior
                rate_c, rate_p, rate_d = self.prior_sample()
                rate_c.to(device)
                rate_p.to(device)
                rate_d.to(device)

            N_c = int(
                torch.distributions.Poisson(rate=rate_c).sample().item()
            )  # Number of cells
            N_p = int(
                torch.distributions.Poisson(rate=rate_p).sample().item()
            )  # Number of parents
            N_d = (
                torch.distributions.Poisson(rate=rate_d).sample((N_c,)).to(device)
            )  # Number of daughters per parents

            c = (
                torch.distributions.Uniform(low=0, high=1).sample((N_c, 2)).to(device)
            )  # Cells
            c_obs = (
                torch.distributions.Uniform(low=0, high=1).sample((N_c, 2)).to(device)
            )
            p = (
                torch.distributions.Uniform(low=0, high=1).sample((N_p, 2)).to(device)
            )  # Parents

            cell_types = torch.zeros(
                N_c, dtype=torch.bool, device=device
            )  # Stromal or not
            cell_types_obs = torch.zeros(N_c, dtype=torch.bool, device=device)
            for parent_idx in range(N_p):
                parent = p[parent_idx]

                num_daughters = int(N_d[parent_idx].item())
                num_available_daughters = c.shape[0]

                if num_daughters > num_available_daughters:
                    num_daughters = num_available_daughters

                _, daughters_idx = torch.topk(
                    torch.cdist(c, parent.unsqueeze(0), p=2).squeeze(),
                    num_daughters,
                    largest=False,
                )
                daughters = c[daughters_idx]
                r = (
                    torch.cdist(daughters, parent.unsqueeze_(0), p=2).max().item()
                    if daughters.numel() != 0
                    else 0
                )

                distances = torch.cdist(c, parent.unsqueeze(0), p=2).squeeze()
                cell_types[distances <= r] = 1
                cell_types_obs = cell_types.clone()

            if torch.distributions.Bernoulli(self.miss_specification_param).sample():
                # Miss specification
                removal_radius = r * 0.8

                if (
                    cell_types.shape == distances.shape
                ):  # Check if the mask shape matches cell_types shape
                    mask = (cell_types == 1) & (distances <= removal_radius)
                    c_obs = c[~mask]
                    cell_types_obs = cell_types[~mask]

            n_cancer_cells = torch.sum(cell_types).item()
            n_stromal_cells = N_c - n_cancer_cells
            stromal_cells = c[cell_types == 0]
            cancer_cells = c[cell_types == 1]

            # Summarizing
            if len(stromal_cells) > 0 and len(cancer_cells) > 0:
                min_distances = (
                    torch.cdist(stromal_cells, cancer_cells, p=2).min(dim=1).values
                )
                mean_min_dist = torch.mean(min_distances).item()
                max_min_dist = torch.max(min_distances).item()

            n_cancer_cells_obs = torch.sum(cell_types_obs).item()
            n_stromal_cells_obs = N_c - n_cancer_cells_obs
            stromal_cells_obs = c_obs[cell_types_obs]
            cancer_cells_obs = c_obs[cell_types_obs]
            if len(stromal_cells_obs) > 0 and len(cancer_cells_obs) > 0:
                min_distances_obs = (
                    torch.cdist(stromal_cells_obs, cancer_cells_obs, p=2)
                    .min(dim=1)
                    .values
                )
                mean_min_dist_obs = torch.mean(min_distances_obs).item()
                max_min_dist_obs = torch.max(min_distances_obs).item()

            if (
                len(stromal_cells_obs) > 0
                and len(cancer_cells_obs) > 0
                and len(stromal_cells) > 0
                and len(cancer_cells) > 0
            ):
                theta = torch.tensor([rate_c, rate_p, rate_d])
                x = torch.tensor(
                    [n_cancer_cells, n_stromal_cells, mean_min_dist, max_min_dist]
                )
                simu_data["theta"].append(theta)
                true_obs["theta"].append(theta)
                simu_data["x"].append(x)
                simu_data["cell_positions"].append(c)
                simu_data["parent_positions"].append(p)
                simu_data["cell_types"].append(cell_types)

                y = torch.tensor(
                    [
                        n_cancer_cells_obs,
                        n_stromal_cells_obs,
                        mean_min_dist_obs,
                        max_min_dist_obs,
                    ]
                )
                true_obs["y"].append(y)
                true_obs["cell_positions_obs"].append(c)
                true_obs["parent_positions_obs"].append(p)
                true_obs["cell_types_obs"].append(cell_types_obs)

                n_simulated += 1
                pbar.update(1)
        true_obs["cell_positions_obs"] = true_obs["cell_positions_obs"][0]
        true_obs["parent_positions_obs"] = true_obs["parent_positions_obs"][0]
        true_obs["cell_types_obs"] = true_obs["cell_types_obs"][0]
        true_obs["theta"] = torch.stack(true_obs["theta"])
        simu_data["theta"] = torch.stack(simu_data["theta"])
        true_obs["y"] = torch.stack(true_obs["y"])
        simu_data["x"] = torch.stack(simu_data["x"])
        scale_mean, scale_std = simu_data["x"].mean(0), simu_data["x"].std(0)
        simu_data["scale_parameters"] = scale_mean, scale_std
        simu_data["scaled_x"] = (simu_data["x"] - scale_mean) / scale_std
        # simu_obs_data["y"]=  (simu_data["y"] - scale_mean) / scale_std
        # simu_data["scaled_y"] = (simu_data["y"] - scale_mean) / scale_std
        pbar.close()
        return (simu_data, true_obs)

In [None]:
class Gaussian:
    r"""Gaussian Task.
        - Simulation: :math:`z_i \sim \mathcal{N}(\theta, 1), i=1,...N_simu`
        - TDGP: :math:`z_i \sim \mathcal{N}(\theta, \sigma^2), i=1,...N_simu`
        - Feature space: Summary statistic :math:`x=\left(mean(z_1,...z_{N_simu}); var(z_1,...z_{N_simu)\right)`
    NB: We can see the TDGP such that :math:`y = x + \epsiolon, x \sim \mathcal{M}(\theta), \epsilon \sim \mathcal{N}(0, \sigma_y^2).
    """

    def __init__(
        self, N_simu: int, sigma_y_2: float, true_theta: Tensor = torch.zeros(1)
    ) -> None:
        r"""Init the object, including datasets.

        Args:
            N_simu: total number of simulations
            sigma_y_2: Corruption, Var of the observed data.
        """
        self.name = "Gaussian"
        self.prior = torch.distributions.Normal(0, 5)
        self.sigma_y_2 = sigma_y_2
        self.features_dim = 2
        self.parameters_dim = 1
        self.true_theta = true_theta
        self.has_post = True
        data = self.generate_data(N_simu, sigma_y_2)
        data_test = self.generate_data(N_simu // 10, sigma_y_2)
        true_obs_data = self.generate_data(10_000, sigma_y_2, true_theta)

        x0 = true_obs_data["x"][0]
        y0 = x0 + true_obs_data["eps"][0]
        x0_scaled = (x0 - data["scale_parameters"][0]) / data["scale_parameters"][1]
        y0_scaled = (y0 - data["scale_parameters"][0]) / data["scale_parameters"][1]
        self.y0_l = true_obs_data["x"] + true_obs_data["eps"]
        self.data = data
        self.data_test = data_test
        self.x0 = x0.squeeze()
        self.y0 = y0.squeeze()
        self.x0_scaled = x0_scaled
        self.y0_scaled = y0_scaled.squeeze()
        self.y0_l_scaled = (self.y0_l - data["scale_parameters"][0]) / data[
            "scale_parameters"
        ][1]

    def prior_sample(self, n):
        return self.prior.sample((n, 1))

    def generate_data(self, N: int, sigma_y_2: float = 1, thetas: Tensor = None):
        """Generates data according to #TODO Maths ici

        Args:
            N: Total number of simulations
            sigma_y_2 (optional): Corruption. Defaults to 1.
            theta (optional): Reference theta.

        Returns:
            res: dictionnary whose entries are parameters thetas, the raw simulations x, the corruption espilons as well as the scaled observations and the scale parameters.
        """
        if thetas is None:
            thetas = self.prior_sample(N)
        sigma_eps = (
            sigma_y_2 - 1
        )  # If the TDGP has a variance of \sigma^2, the corruption layer has \sigma-1
        res = {}
        res["theta"] = thetas
        means_and_vars = torch.zeros((N, 2), device=device)

        for i, theta in enumerate(thetas):
            z = torch.distributions.Normal(theta, 1).sample((100,)).to(device)
            mean = z.mean()
            var = z.var()
            means_and_vars[i] = torch.stack(
                [mean.unsqueeze_(0), var.unsqueeze_(0)]
            ).squeeze_()

        res["x"] = means_and_vars

        norm_max = -1
        for n_run in range(1):
            # We take the run where the corruption has more effect, to be able to compare.
            eps = (
                torch.distributions.Normal(
                    torch.zeros_like(res["x"]), np.sqrt(sigma_eps)
                )
                .sample()
                .to(device)
            )
            if torch.norm(eps) > norm_max:
                epsilons = eps
                norm_max = torch.norm(eps)
        res["eps"] = epsilons
        scale_mean, scale_std = res["x"].mean(0), res["x"].std(0)
        res["scale_parameters"] = scale_mean, scale_std
        res["scaled_x"] = (res["x"] - scale_mean) / scale_std
        return res

    def true_post(self, x_star: Tensor, sigma_2_y: float = None) -> Tuple:
        """Get the true posterior given an observation x_star.

        Args:
            x_star: An observation.
            sigma_2_y (optional): Corruption. Defaults to None.

        Returns:
            post: Posterior density evaluated from -8 to 8.
        """
        if sigma_2_y is None:
            sigma_2_y = self.sigma_y_2
        theta_grid_t = np.linspace(-8, 8, 1000)
        post = posterior_summaries(theta_grid_t, x_star[0], x_star[1], sigma_2_y)
        return post

In [None]:
sigma_2_y = 3

In [None]:
def build_nsf(features, context):
    """Callable to instantiate the NPE with NSFs"""
    return zuko.flows.NSF(features, context, bins=10, transforms=5).to(device)


def train_flow(
    flow: lampe.inference.NPE,
    loss: Callable[[Tensor, Tensor], float],
    theta: Tensor,
    x: Tensor,
    theta_test: Tensor,
    x_test: Tensor,
) -> lampe.inference.NPE:
    """Training procedure for the instantiated NPE.

    Args:
        flow (lampe.inference.NPE): Flow object to be trained.
        loss (Callable[[Tensor, Tensor], float]): Loss function used for training.
        theta (Tensor): Parameters.
        x (Tensor): Observations.
        theta_test (Tensor): Parameters used for early stopping.
        x_test (Tensor): Observations used for early stopping.

    Returns:
        lampe.inference.NPE: The trained flow.
    """
    optimizer = torch.optim.Adam(flow.parameters(), lr=5e-3)
    # theta_test = theta_test.unsqueeze(-1).to(device)
    data = lampe.data.JointDataset(theta.to(device), x.to(device))
    loader = lampe.data.DataLoader(data, batch_size=256)
    with torch.no_grad():
        min_loss = loss(theta_test.to(device), x_test.to(device))
    min_loss_list = [min_loss.item()]  # Convert min_loss to a scalar

    flow.train()

    for epoch in range(50):
        for theta_batch, x_batch in loader:
            theta_batch = theta_batch  # .unsqueeze(-1).to(device)
            x_batch = x_batch.to(device)
            losses = loss(theta_batch, x_batch)
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

        # Checking for early stopping
        with torch.no_grad():
            loss_test = loss(theta_test.to(device), x_test.to(device))
            min_loss_list.append(loss_test.item())  # Convert loss_test to a scalar
            if (
                len(min_loss_list) - np.argmin(min_loss_list) > 5
            ):  # No improvement in loss(test) for the last 5 iterations
                # Early stop
                break
    flow.eval()
    return flow


def create_train_flow(task) -> lampe.inference.NPE:
    """Creates a conditional flow (NSF) and trains it on the task data.

    Args:
        task: A task object.

    Returns:
        lampe.inference.NPE: The trained flow.
    """
    data = task.data
    data_test = task.data_test
    theta = data["theta"]
    theta_test = data_test["theta"]
    flow = lampe.inference.NPE(
        theta_dim=task.parameters_dim, x_dim=task.features_dim, build=build_nsf
    ).to(device)
    loss = lampe.inference.NPELoss(flow).to(device)
    x = data["scaled_x"].to(device)
    x_test = data_test["scaled_x"].to(device)
    flow = train_flow(flow, loss, theta, x, theta_test, x_test)
    return flow

In [None]:
def train_unconditional_flow(
    flow: zuko.flows, loss: Callable[[Tensor], float], x: Tensor, x_test: Tensor
) -> zuko.flows:
    """Trains a unconditional flow on x.

    Args:
        flow: Object flow to be trained.
        loss: Method to compute a loss.
        x: Training data.
        x_test: Test data, used for early stopping.

    Returns:
        flow: Trained flow object.
    """
    optimizer = torch.optim.Adam(flow.parameters(), lr=1e-2)
    loader = torch.utils.data.DataLoader(x.to(device), 256)

    with torch.no_grad():
        min_loss = loss(x_test.to(device))
    min_loss_list = [min_loss.item()]  # Convert min_loss to a scalar

    flow.train()

    for epoch in range(50):
        for x_batch in loader:
            x_batch = x_batch.to(device)
            losses = loss(x_batch)
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

        # Checking for early stopping
        with torch.no_grad():
            loss_test = loss(x_test.to(device))
            min_loss_list.append(loss_test.item())  # Convert loss_test to a scalar
            if (
                len(min_loss_list) - np.argmin(min_loss_list) > 5
            ):  # No improvement in loss(test) for the last 5 iterations
                # Early stop

                break
    flow.eval()
    return flow


def create_train_unconditional_flow(task) -> zuko.flows:
    """Instanciate and trains an unconditional flow (MAF) on the task data.

    Args:
        task: A task, containing train and test data.

    Returns:
        flow: The trained flow.
    """
    data = task.data
    data_test = task.data_test
    flow = zuko.flows.NAF(features=task.features_dim, context=0).to(device)  #!
    loss = lambda x: -flow().log_prob(x).mean()
    x = data["scaled_x"].to(device)
    x_test = data_test["scaled_x"].to(device)
    flow = train_unconditional_flow(flow, loss, x, x_test)
    return flow

In [None]:
class MyMCMC:
    r"""Class to sample from the posterior :math: `p(\theta \mid y)`"""

    def __init__(
        self, y0: Tensor, tau: float, sigma: float, rho: float, q_x_NF: zuko.flows
    ) -> None:
        """Init

        Args:
            y0: Observation.
            tau: Parameter of the slab (Cauchy) distribution.
            sigma: Parameter (variance) of the spike (Normal) distribution.
            rho: Parameter of the Bernoulli distribution.
            q_x_NF: Unconditional flow.
        """
        self.y0 = y0
        self.tau = tau
        self.sigma = sigma
        self.rho = rho
        self.q_x_NF = q_x_NF

    def f_density(self, y0: Tensor, x: Tensor) -> Tensor:
        """Error model

        Args:
            y0: Observation.
            x: x proposed for the sampler. (can be batched)

        Returns:
            res: Density evaluated at x (can be batched)
        """
        D = x.shape[-1]  # context dimension
        res = torch.zeros_like(x)
        for j in range(D):
            xj = x[:, j]
            yj = y0[j]  # TODO
            zj = (
                torch.distributions.Bernoulli(rho).sample().to(device)
            )  # Whetehr the model is miss specified
            if not zj:
                spike_dist = torch.distributions.Normal(
                    xj.detach(), torch.tensor(sigma).to(device)
                )
                res[:, j] = torch.exp(spike_dist.log_prob(yj))
            else:
                slab_dist = torch.distributions.Cauchy(
                    xj.detach(), torch.tensor(tau).to(device)
                )
                res[:, j] = torch.exp(slab_dist.log_prob(yj))
        return res  # * torch.exp(q_x_NF().log_prob(x))[:,None]

    def proposal(
        self, M: int, warm_up_steps: int, proposal_data: Tensor, n_chains
    ) -> Tensor:
        """Transition kernel. Samples from the unconditional kernel.

        Args:
            M: Total number of sample.
            warm_up_steps: Burn-in period.


        Returns:
            Proposed sample.
        """
        idx = torch.randint(M + warm_up_steps, (n_chains,))
        return proposal_data[idx]

    def sample(self, M: int, warm_up_steps: int, features_dim: int) -> Tensor:
        f_lampe = lambda x: self.f_density(self.y0, x)
        n_chains = 8  # TODO
        x_curr = torch.FloatTensor(n_chains, features_dim).uniform_(-1, 1).to(device)
        my_samples = torch.empty(
            (n_chains, warm_up_steps + M // n_chains, x_curr.shape[1])
        ).to(device)
        log_f_x_curr = f_lampe(x_curr).log()

        proposal_data = self.q_x_NF().sample((M + warm_up_steps,))
        with torch.no_grad():
            for i in range(int(warm_up_steps + M // n_chains)):
                x_star = self.proposal(M, warm_up_steps, proposal_data, n_chains)
                log_f_x_star = f_lampe(x_star).log()
                log_a = log_f_x_star - log_f_x_curr
                a = torch.exp(log_a)
                u = torch.FloatTensor(a.shape).uniform_().to(device)
                mask = u < a
                x_curr = torch.where(mask, x_star, x_curr)
                log_f_x_curr = torch.where(mask, log_f_x_star, log_f_x_curr)
                my_samples[:, i] = x_curr
            my_samples = my_samples[:, warm_up_steps:, :]
            samples_mcmc = my_samples  # [:,:M//n_chains,:]

            samples_mcmc_ = samples_mcmc.reshape(1, M, x_curr.shape[-1]).squeeze()
        return samples_mcmc_

In [None]:
# If you want to check if it works, this is the way I intended this to be used:
task = Gaussian(N_simu=N_simu, sigma_y_2=sigma_2_y)
q_NPE = create_train_flow(task)
q_x_NF = create_train_unconditional_flow(task)
sampler = MyMCMC(task.y0_scaled.squeeze(), tau, sigma, rho, q_x_NF)
xm = sampler.sample(M, warm_up_steps, task.features_dim)

rnpe_samples = q_NPE.flow(xm).sample()
true_post_y = posterior_summaries(
    theta_grid,
    task.y0.cpu().squeeze()[0].item(),
    task.y0.cpu().squeeze()[1].item(),
    sigma_2_y,
)
npe_samples = q_NPE.flow(task.y0_scaled).sample((20_000,))
true_post = task.true_post(task.y0.squeeze().cpu().numpy(), sigma_2_y)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8))


ax1.scatter(
    x=task.data["scaled_x"][:, 0].cpu(),
    y=task.data["scaled_x"][:, 1].cpu(),
    color="blue",
    label="Simulator outputs",
    s=1,
)
ax1.scatter(
    x=xm[:, 0].cpu(),
    y=xm[:, 1].cpu(),
    label="Denoised outputs from the MCMC sampler",
    color="yellow",
    s=1,
)
ax1.scatter(
    x=task.y0_scaled.cpu()[0],
    y=task.y0_scaled.cpu()[1],
    label="y0 scaled",
    color="green",
    s=20,
)
ax1.set_xlabel(r"$mean (z_1,...,z_n)$")
ax1.set_ylabel(r"$Var(z_1,...,z_n)$")
ax1.legend(bbox_to_anchor=(1.04, 1), loc="upper left")


ax2.hist(
    npe_samples.squeeze().cpu(),
    bins=200,
    density=True,
    color="green",
    alpha=0.5,
    label=r"NPE samples $q_{\Phi}(\theta|y_0)$",
)
ax2.hist(
    rnpe_samples.squeeze().cpu(),
    bins=200,
    density=True,
    color="red",
    alpha=0.5,
    label=r"RNPE samples $q_{\Phi}(\theta|y_0)$",
)
ax2.plot(
    theta_grid,
    true_post_y,
    color="yellow",
    label=r" $p(\theta|y_0)$ under the true DGP",
)

ax2.legend()

ax2.set_xlim(-5, 5)
plt.tight_layout()
plt.show()

In [None]:
# Portion de code repris de Schmitt et. al.
from functools import partial


def kl_latent_space(z, log_det_J):
    """Computes the Kullback-Leibler divergence (Maximum Likelihood Loss) between true and approximate
    posterior using simulated data and parameters.
    """
    loss = torch.mean(0.5 * torch.norm(z, dim=-1) ** 2 - log_det_J)
    return loss


def maximum_mean_discrepancy(
    source_samples,
    target_samples,
    kernel="gaussian",
    minimum=0.0,
    unbiased=False,
    squared=True,
):
    """This Maximum Mean Discrepancy (MMD) loss is calculated with a number of different Gaussian or Inverse-Multiquadratic kernels."""
    sigmas = torch.tensor(
        [  # Convert list to tensor
            1e-3,  #!
        ],
        dtype=torch.float32,
        device=source_samples.device,
    )

    if kernel == "gaussian":
        kernel = partial(_gaussian_kernel_matrix, sigmas=sigmas)
    elif kernel == "inverse_multiquadratic":
        kernel = partial(_inverse_multiquadratic_kernel_matrix, sigmas=sigmas)
    else:
        print("Invalid kernel specified. Falling back to default Gaussian.")
        kernel = partial(_gaussian_kernel_matrix, sigmas=sigmas)

    if unbiased:
        loss_value = _mmd_kernel_unbiased(source_samples, target_samples, kernel=kernel)
    else:
        loss_value = _mmd_kernel(source_samples, target_samples, kernel=kernel)

    loss_value = max(minimum, loss_value)

    if squared:
        return loss_value
    else:
        return torch.sqrt(loss_value)


def _gaussian_kernel_matrix(x, y, sigmas):
    norm = lambda v: torch.sum(v**2, dim=2)  # Update dimension to 2
    beta = 1.0 / (2.0 * (sigmas[:, None]))
    dist = norm(x[:, None, :] - y[None, :, :])  # Remove transpose
    s = torch.matmul(beta, torch.reshape(dist, (1, -1)))
    kernel = torch.reshape(torch.sum(torch.exp(-s), 0), dist.shape)
    return kernel


def _inverse_multiquadratic_kernel_matrix(x, y, sigmas):
    """Computes an inverse multiquadratic RBF between the samples of x and y.
    We create a sum of multiple IM-RBF kernels each having a width sigma_i.
    """
    dist = torch.unsqueeze(
        torch.sum((x[:, None, :] - y[None, :, :]) ** 2, dim=-1), dim=-1
    )
    sigmas = torch.unsqueeze(sigmas, dim=0)
    return torch.sum(sigmas / (dist + sigmas), dim=-1)


def _mmd_kernel(x, y, kernel=None):
    """Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
    Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of the distributions of x and y.
    """
    loss = torch.mean(kernel(x, x))
    loss += torch.mean(kernel(y, y))
    loss -= 2 * torch.mean(kernel(x, y))
    return loss


def _mmd_kernel_unbiased(x, y, kernel=None):
    """Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
    Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of the distributions of x and y.
    """
    m, n = x.shape[0], y.shape[0]
    print(kernel(x, x))
    loss = (1.0 / (m * (m - 1))) * torch.sum(kernel(x, x))
    loss += (1.0 / (n * (n - 1))) * torch.sum(kernel(y, y))
    loss -= (2.0 / (m * n)) * torch.sum(kernel(x, y))
    return loss


def MMD_bootstrap(
    x, x_o, N_BOOTSTRAP_ITERATIONS=10, n_samples_x=500, n_samples_x_o=500
):
    n_x = x.shape[0]
    n_x_o = x_o.shape[0]

    MMD_bootstrap = np.empty(N_BOOTSTRAP_ITERATIONS)
    for i in tqdm(range(N_BOOTSTRAP_ITERATIONS)):
        idx_x = np.random.randint(0, n_x, size=n_samples_x)
        idx_x_o = np.random.randint(0, n_x_o, size=n_samples_x_o)

        x_bootstrap = x[idx_x]
        x_o_bootstrap = x_o[idx_x_o]

        MMD_bootstrap[i] = float(
            maximum_mean_discrepancy(x_bootstrap, x_o_bootstrap, squared=False)
        )
    return MMD_bootstrap


def calculate_CI(x, ci_area=0.95):
    q_lower = round((1.0 - ci_area) / 2, 5)
    q_upper = round(1.0 - q_lower, 5)
    return np.quantile(x, q_lower), np.quantile(x, q_upper)

So that is what I did at first, I computed the estimated density of thetz using NPE & RNPE for 10 runs (the two following functions)

In [None]:
def compute_rnpe_npe(
    task_name: str, sigma_2: float
) -> Tuple[Tensor, Tensor, Tuple, Tensor]:
    """General pipeline to compute both NPE and RNPE samples for a given task and corruption.

    Args:
        task_name: Name of the task.
        sigma_2: Corruption.

    Returns:
        Tuple[Tensor, Tensor, Tuple, Tensor]: NPE samples given x, y, theroritical given y and RNPE samples given y.
    """

    n_runs = 10
    task_l = []
    xm_l = []
    norm_npe_rnpe = []
    npe_samples_y_l = []
    theoritical_post_y_l = []
    rnpe_samples_l = []

    for _ in tqdm(range(n_runs), leave=True):
        task = (
            Gaussian(N_simu=N_simu, sigma_y_2=sigma_2)
            if task_name == "Gaussian"
            else CS(N_simu=N_simu, miss_specification_param=sigma_2)
        )
        q_NPE = create_train_flow(task)
        q_x_NF = create_train_unconditional_flow(task)
        with torch.no_grad():
            mcmc_sampler = MyMCMC(task.y0_scaled.squeeze(), tau, sigma, rho, q_x_NF)
            xm = mcmc_sampler.sample(M, warm_up_steps, task.features_dim)
        npe_samples_y = q_NPE.flow(task.y0_scaled).sample((M,))
        theoritical_post_y = (
            task.true_post(task.y0.squeeze().cpu().numpy(), sigma_2)
            if task.has_post
            else []
        )
        rnpe_samples = q_NPE.flow(xm).sample()
        norm_npe_rnpe.append(npe_samples_y.mean() - rnpe_samples.mean())

        task_l.append(task)
        xm_l.append(xm)
        npe_samples_y_l.append(npe_samples_y)
        theoritical_post_y_l.append(theoritical_post_y)
        rnpe_samples_l.append(rnpe_samples)
    return {
        "task_l": task_l,
        "npe_samples_y_l": npe_samples_y_l,
        "theoritical_post_y_l": theoritical_post_y_l,
        "rnpe_samples_l": rnpe_samples_l,
        "xm_l": xm_l,
    }

In [None]:
def compute_metrics(
    task_l: List,
    npe_samples_y_l: List,
    theoritical_post_y_l: List,
    rnpe_samples_l: List,
    xm_l: List,
) -> Dict:
    npe_samples_y, theoritical_post_y, rnpe_samples, task, MMD_npe, MMD_rnpe = (
        [],
        [],
        [],
        [],
        [],
        [],
    )
    # Computiing MMD on each run
    MMD_npe = np.zeros((1, len(xm_l)))
    MMD_rnpe = np.zeros((1, len(xm_l)))
    norm_npe_rnpe = []
    for i in range(len(xm_l)):
        # Rescaling the xm
        xm_rescaled = (xm_l[i] + task_l[i].data["scale_parameters"][0]) * task_l[
            i
        ].data["scale_parameters"][1]

        # MMD[1, j] = float(maximum_mean_discrepancy(task_l[i].data["x"], task_l[i].y0.expand(40_000, -1), squared=False))
        idx_npe = np.random.randint(0, task_l[i].data["scaled_x"].shape[0], size=1_000)
        MMD_npe[0, i] = float(
            maximum_mean_discrepancy(
                task_l[i].y0_scaled.expand(10_000, -1),
                task_l[i].data["scaled_x"][idx_npe],
                squared=False,
            )
        )
        idx_rnpe = np.random.randint(0, xm_l[i].shape[0], size=1_000)
        MMD_rnpe[0, i] = float(
            maximum_mean_discrepancy(
                task_l[i].y0_scaled.expand(10_000, -1), xm_l[i][idx_rnpe], squared=False
            )
        )
        norm_npe_rnpe.append(
            np.square(npe_samples_y_l[i].mean() - rnpe_samples_l[i].mean())
            + np.square(np.log(npe_samples_y_l[i].std() / rnpe_samples_l[i].std()))
        )

    imax = np.argmax(norm_npe_rnpe)
    npe_samples_y = npe_samples_y_l[imax]
    theoritical_post_y = theoritical_post_y_l[imax]
    rnpe_samples = rnpe_samples_l[imax]
    task = task_l[imax]

    return {
        "npe_samples_y": npe_samples_y,
        "theoritical_post_y": theoritical_post_y,
        "rnpe_samples": rnpe_samples,
        "task": task,
        "MMD_npe": MMD_npe,
        "MMD_rnpe": MMD_rnpe,
    }

And we can plot the estimated densities

In [None]:
results = {}

sigmas_2 = [1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3]
with tqdm(sigmas_2, unit="Sigma²", desc="Corruption") as tq:
    for sigma_2 in tq:
        print(sigma_2)
        donnees = compute_rnpe_npe("Gaussian", sigma_2)
        results[sigma_2] = compute_metrics(**donnees)

In [None]:
@widgets.interact(sigma_2=results.keys())  # Seulement pour tâche gaussienne...
def plot_posteriors_widget(sigma_2=1):
    plt.close()
    fig, ax = plt.subplots()
    ax.hist(
        results[sigma_2]["npe_samples_y"].squeeze(),
        bins=400,
        color="green",
        density=True,
        label="Posterior samples from NPE",
        alpha=0.5,
    )
    ax.hist(
        results[sigma_2]["rnpe_samples"].squeeze(),
        bins=400,
        color="red",
        density=True,
        label="Posterior samples from RNPE",
        alpha=0.5,
    )
    ax.plot(theta_grid, results[sigma_2]["theoritical_post_y"])
    ax.set_xlim(-5, 5)
    ax.legend()
    ax.set_title(f"NPE and RNPE samples, sigma_y² = {sigma_2}")
    plt.show()

In [None]:
[1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3]


def create_compute(filename, task_name, miss_specification_params):
    results = {}
    n_sim = 100
    n_runs_per_pi = 10
    MMD_npe = np.zeros((len(miss_specification_params), n_runs_per_pi))
    MMD_rnpe = np.zeros((len(miss_specification_params), n_runs_per_pi))
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    for i, pi in tqdm(enumerate(miss_specification_params)):
        donnees = {}
        results = {}
        for j in range(n_runs_per_pi):
            task = (
                Gaussian(N_simu=N_simu, sigma_y_2=pi)
                if task_name == "Gaussian"
                else CS(N_simu=N_simu, miss_specification_param=pi)
            )
            q_NPE = create_train_flow(task)
            q_x_NF = create_train_unconditional_flow(task)
            with torch.no_grad():
                mcmc_sampler = MyMCMC(task.y0_scaled.squeeze(), tau, sigma, rho, q_x_NF)
                xm = mcmc_sampler.sample(M, warm_up_steps, task.features_dim)
            npe_samples_y = q_NPE.flow(task.y0_scaled).sample((M,))
            # theoritical_post_y = task.true_post(task.y0.squeeze().cpu().numpy(), sigma_2) if task.has_post else []
            rnpe_samples = q_NPE.flow(xm).sample()

            n_npe = task.data["x"].shape[0]
            n_rnpe = xm.shape[0]
            n_y0 = task.y0_l.shape[0]
            idx_npe = np.random.randint(0, n_npe, size=1_000)
            idx_rnpe = np.random.randint(0, n_rnpe, size=1_000)
            idx_y0 = np.random.randint(0, n_y0, size=1_000)
            MMD_npe[i, j] = float(
                maximum_mean_discrepancy(
                    task.y0_scaled.expand(10_000, -1),
                    task.data["scaled_x"][idx_npe],
                    squared=False,
                )
            )
            MMD_rnpe[i, j] = float(
                maximum_mean_discrepancy(
                    task.y0_scaled.expand(10_000, -1), xm[idx_rnpe], squared=False
                )
            )
        results["mmd_npe"] = MMD_npe
        results["mmd_rnpe"] = MMD_rnpe
        with open(filename + task.name + str(pi), "wb") as f:
            pickle.dump(results, f)

    return MMD_npe, MMD_rnpe


create_compute("./cache/", "Gaussian", sigmas_2)

In [None]:
results = {}
sigmas_2 = [1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3]
for sigma_2 in sigmas_2:
    with open("./cache/Gaussian" + str(sigma_2), "rb") as f:
        results[sigma_2] = pickle.load(f)

In [None]:
for sigma_2 in sigmas_2:
    plt.scatter(sigma_2, results[sigma_2]["mmd_rnpe"].mean(), c="r")
    plt.scatter(sigma_2, results[sigma_2]["mmd_npe"].mean(), c="b")