In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import corner
import random
from sbi.analysis import plot_summary
from sbi.inference import NPE
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
from sbi.neural_nets import posterior_nn
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
from sbi.diagnostics.tarp import _run_tarp, get_tarp_references
from sbi.analysis import plot_tarp
from Custom_prior import *
from embedding_net import *
import gc
import pickle

In [None]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


set_random_seed(42)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
prior_path = "/home/jupyter/datasphere/project/NN-Clusters/data/efeds_data/prior.txt"
params_dict, param_names, lower_bound, upper_bound = get_prior(prior_path)
print(param_names)

In [None]:
nuisance_params = ['ob', 'h0', 'ns', 'r_cr_l', 'cr50', 'scr', 'gz', 'a_b', 'b_b', 'delta_b', 'gamma_b']
interest_params = []
interest_lower = []
interest_upper = []
for i in range(len(param_names)):
    if param_names[i] not in nuisance_params:
        interest_params.append(param_names[i])
        interest_lower.append(lower_bound[i])
        interest_upper.append(upper_bound[i])
print(interest_params)
targets = interest_params
prior = CustomPrior(params_dict, interest_params, return_numpy=False, device=device)

In [None]:
low_bound = interest_lower
high_bound = interest_upper

mean_prior = prior.mean
std_prior = torch.sqrt(prior.variance)


low_bound = torch.tensor(low_bound)
high_bound = torch.tensor(high_bound)

In [None]:
path = "/home/jupyter/datasphere/project/NN-Clusters/data/efeds_data/samples/"
targets = interest_params

X = []
y = []
for i in range(10):
    X.append(np.load(path + f"x_{i}.npz")['arr_0'])
    y.append(np.load(path + f"thetas_{i}.npy"))
    
X = np.vstack(X)
y = np.vstack(y)

In [None]:
y = pd.DataFrame(y, columns=param_names)
y = y[interest_params].to_numpy()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
del X, y
gc.collect();

In [None]:
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

In [None]:
X_train.shape

In [None]:
prior, num_parameters, prior_returns_numpy = process_prior(prior,
                      custom_prior_wrapper_kwargs=dict(
                          lower_bound=torch.tensor(interest_lower).to(device), 
                          upper_bound=torch.tensor(interest_upper).to(device)
                      )
)

In [None]:
set_random_seed(42)
embedding = EMBEDDING_NET(len(targets))
neural_posterior = posterior_nn(model="nsf", device=device, embedding_net=embedding, z_score_x='none')
inference = NPE(prior=prior, device=device, density_estimator=neural_posterior)
inference = inference.append_simulations(y_train.float(), X_train.float(), data_device="cpu")

In [None]:
density_estimator = inference.train(
    show_train_summary=True, 
    training_batch_size=4096, 
    stop_after_epochs=50,
    max_num_epochs=1800,
    learning_rate=1e-4
)

In [None]:
plot_summary(inference, tags=['training_loss', 'validation_loss'], disable_tensorboard_prompt=True)

In [None]:
posterior = inference.build_posterior()

In [None]:
posterior_samples =[]

for i in tqdm(range(len(X_test))):
    x = X_test[i]
    samples = posterior.sample((2000,), x=x.unsqueeze(0).to(device), show_progress_bars=False).cpu()
    posterior_samples.append(samples)

posterior_samples = torch.swapaxes(torch.stack(posterior_samples), 0, 1)

In [None]:
references = get_tarp_references(
    y_stack,
).to(device)

expected_coverage, ideal_coverage = _run_tarp(
    posterior_samples.cpu(),
    y_stack.cpu(),
    references.cpu(),
    z_score_theta=True,
)
fix, axes = plot_tarp(expected_coverage, ideal_coverage)
plt.show()

In [None]:
x_id = np.load("/home/jupyter/datasphere/project/NN-Clusters/data/efeds_data/efeds_hist.npy")
print(x_id.sum())
x_id = torch.tensor(x_id, dtype=torch.float32)
x_id = x_id.unsqueeze(0).to(device)

In [None]:
samples = posterior.sample((50_000,), x=x_id).cpu()

In [None]:
figure = corner.corner(samples.numpy(), 
    quantiles=[0.16, 0.5, 0.84], 
    labels=targets, 
    show_titles=True, 
    levels=(0.68, 0.95),
    smooth=True
);

In [None]:
with open("/home/jupyter/datasphere/project/NN-Clusters/models/efeds_posterior.pkl", "wb") as handle:
    pickle.dump(posterior, handle)