In [None]:
import torch
import numpy as np
from sbi_particle_physics.objects.model import Model
from sbi_particle_physics.managers.plotter import Plotter
from sbi_particle_physics.managers.backup import Backup
from sbi_particle_physics.managers.predictions import Predictions
from sbi_particle_physics.config import MODELS_DIR

In [None]:
model = Backup.load_model_for_inference_basic(directory=MODELS_DIR / "training_11", device=torch.device("cpu"))

In [None]:
n_sampled_parameters = 100
n_points = model.n_points
print("n_points:", n_points)
true_parameter, observed_sample, sampled_parameters = model.get_true_parameters_simulations_and_sampled_parameters(n_true=1, n_points=n_points, n_sampled_parameters=n_sampled_parameters)
Plotter.plot_a_posterior(sampled_parameters, true_parameter)

In [None]:
estimator, uncertainty = Predictions.calculate_estimator(sampled_parameters)
print(f"Estimator {estimator.item()}, uncertainty {uncertainty.item()}")

In [None]:
mean, median, q5, q16, q84, q95, std, width_68 = Predictions.calculate_estimator_summary(sampled_parameters)
print("Mean", mean.item())
print("Median", median.item())
print("q5", q5.item())
print("q16", q16.item())
print("q84", q84.item())
print("q95", q95.item())
print("std", std.item())
print("width_68", width_68.item())

In [None]:
prior_samples = model.draw_parameters_from_prior(n_sampled_parameters) # normalized and formated

In [None]:
log_contraction = Predictions.log_contraction(prior_samples, sampled_parameters).item()
print("Log contraction", log_contraction)

In [None]:
information_gain = Predictions.information_gain(prior_samples, sampled_parameters).item()
print("Information gain", information_gain)

In [None]:
n_true = 2
true_parameters, observed_samples, many_sampled_parameters = model.get_true_parameters_simulations_and_sampled_parameters(n_true=n_true, n_points=n_points, n_sampled_parameters=n_sampled_parameters)

In [None]:
estimator, uncertainty = Predictions.calculate_estimator(many_sampled_parameters)
print(f"Estimator {estimator}, uncertainty {uncertainty}")

In [None]:
mean, median, q5, q16, q84, q95, std, width_68 = Predictions.calculate_estimator_summary(many_sampled_parameters)
print("Mean", mean)
print("Median", median)
print("q5", q5)
print("q16", q16)
print("q84", q84)
print("q95", q95)
print("std", std)
print("width_68", width_68)

In [None]:
log_contraction = Predictions.log_contraction(prior_samples, many_sampled_parameters)
print("Log contraction", log_contraction)

In [None]:
information_gain = Predictions.information_gain(prior_samples, many_sampled_parameters)
print("Information gain", information_gain)

In [None]:
average_uncertainty = Predictions.average_uncertainty(many_sampled_parameters)
print("Average uncertainty", average_uncertainty)