## (1) Packages and Settings

In [None]:
import sys, os, math, importlib
import pickle
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from scipy.stats import norm 
from scipy import linalg
from scipy.interpolate import interp1d
from sklearn.covariance import GraphicalLassoCV, LedoitWolf, EmpiricalCovariance, MinCovDet, ShrunkCovariance


import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.distributions.multivariate_normal import MultivariateNormal as Normal
from torch import tensor, as_tensor, Tensor, eye, zeros, ones, float32


import pyccl as ccl
import powerbox as pbox

from sbi.analysis import run_sbc, sbc_rank_plot

sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data

from tqdm.notebook import tqdm
from nflows import distributions as distributions_
from nflows import flows, transforms
from nflows.nn import nets

In [None]:
# settings for plots
font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : 10}
matplotlib.rc('font', **font)
rcnew = {"mathtext.fontset" : "cm", 
         "xtick.labelsize" : 10,
         "ytick.labelsize" : 10,
         "axes.titlesize" : 26, 
         "axes.labelsize" : 14,
         "xtick.major.size" : 8,      
         "xtick.minor.size" : 4,      
         "ytick.major.size" : 8,      
         "ytick.minor.size" : 4,      
         "legend.fontsize" : 22,
         'figure.titlesize' : 30,
         'errorbar.capsize' : 4,
         'axes.xmargin': 0.05,
          'axes.ymargin': 0.05,
        }
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')

In [None]:
# select device; use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

## (2) Define functions to compute coverage

Code adopted from https://github.com/mackelab/tsnpe_neurips

In [None]:
def compute_rank(val, vec):
    c = torch.cat([vec, val])
    s = torch.argsort(c)
    ind = torch.where(s == len(c) - 1)
    return ind[0]  # .where returns a tuple

# Compute coverage for an arbitrary posterior
def compute_coverage(posterior, theta_proposal, x_proposal, num_monte_carlo= 1_000, alpha=torch.linspace(0, 1, 20)):
    gt_is_covered = zeros(alpha.shape)
    counter = 0
    for params, summstats in zip(theta_proposal, x_proposal):
        xo = as_tensor(np.asarray([summstats]), dtype=float32)
        posterior.set_default_x(xo)
        lprobs = posterior.log_prob(
                    posterior.sample((num_monte_carlo,), show_progress_bars=False)
                )
        gt_log_prob = posterior.log_prob(
                    as_tensor(np.asarray([params]), dtype=float32)
                )
        rank_of_gt = compute_rank(gt_log_prob, lprobs)
        norm_rank = rank_of_gt / lprobs.shape[0]
        covered_in_alpha_quantile = norm_rank > alpha
        gt_is_covered += covered_in_alpha_quantile.float()
        counter += 1
        if counter % 100 == 0: 
            print('# of samples examined: ', counter)
    gt_is_covered /= x_proposal.shape[0]
    return torch.flip(gt_is_covered, dims=[0])
    
# Compute coverage for the multivariate normal distribution 
# which is predicted by the inference network trained on VICReg summaries
def compute_coverage_normal(theta_true, theta_pred, Cov_pred, 
                            num_monte_carlo= 1_000, 
                            alpha=torch.linspace(0, 1, 20)):
    gt_is_covered = zeros(alpha.shape).to(device=device)
    counter = 0
    for params_true, params_pred, cov_pred in zip(theta_true, theta_pred, Cov_pred):
        # get predicted posterior (multivariate normal)
        posterior = Normal(loc=params_pred, covariance_matrix=cov_pred)
        
        # sample the posterior
        posterior_samples = posterior.sample((num_monte_carlo,))
        
        # get log_probabilities of the samples
        lprobs = posterior.log_prob(posterior_samples)
        
        # get log_probabilities of the true value
        gt_log_prob = torch.atleast_1d(posterior.log_prob(params_true))


        rank_of_gt = compute_rank(gt_log_prob, lprobs)
        norm_rank = rank_of_gt / lprobs.shape[0]
        #print(norm_rank.device, alpha.device)
        covered_in_alpha_quantile = norm_rank > alpha
        gt_is_covered += covered_in_alpha_quantile.float()
        counter += 1
        if counter % 100 == 0: 
            print('# of samples examined: ', counter)
    gt_is_covered /= theta_true.shape[0]
    return torch.flip(gt_is_covered, dims=[0])
    

## (2) VICReg 
### (2.1) Load VICReg model

In [None]:
fmodel = ...
fout   = ...

hidden     = 8
last_layer = 2*hidden

n_params   = 2
n_tril     = int(n_params * (n_params + 1) / 2)  # Number of parameters in lower triangular matrix, for symmetric matrix
n_out      = n_params + n_tril  

# load the encoder model
model = SummaryNet(hidden = hidden, last_layer = last_layer).to(device)
model.load_state_dict(torch.load(fmodel))
model.eval(); 

# output files
fmodel_lr = ...
fout_lr   = ...

# define the network model for the downstream task
mlp_lr_units = [16*last_layer, 16*last_layer, n_out]
lr_net = Expander(mlp_lr_units, last_layer, bn = True).to(device)
# load the inference network model
lr_net.load_state_dict(torch.load(fmodel_lr))
lr_net.eval(); 

### (2.2) Load VICReg test data

In [None]:
# load maps and parameters used for training
maps      = np.load(...)[:, :, None, :, :]
dset_size = maps.shape[0] # data set size
splits    = maps.shape[1] # number of realizations per parameter set

params  = np.load(...)[:, None, :]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps

# pre-process the maps data set
rescale     = True
standardize = True

if rescale:
    maps = np.log(maps+1)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

In [None]:
def get_maps_arr(params, splits, 
                 maps_mean, maps_std,
                 BoxSize = 1000.0, Npixel = 100):
    
    OmegaM = params[0]
    sigma8 = params[1]

    OmegaB = 0.05
    OmegaC = OmegaM - OmegaB
    h    = 0.7
    ns   = 0.96
    
    cosmo_ccl = ccl.Cosmology(Omega_c=OmegaC, Omega_b=OmegaB, 
                          h=h, sigma8 = sigma8, n_s=ns, 
                          transfer_function='eisenstein_hu')
    
    dfs_2D_splits = []
    
    for j in range(splits):
        # generate a 2D Gaussian field
        pb = pbox.PowerBox(
            N=Npixel,                     
            dim=2,                        
            pk = lambda k_val: ccl.linear_matter_power(cosmo_ccl, k_val, 1.0)/BoxSize, 
            boxlength = BoxSize,           
            seed = j,                
        )
        
        # convert it to a lognormal field
        delta_g = pb.delta_x()
        var_g = np.var(delta_g)
        rho_ln = np.exp(delta_g - var_g/2)
        
        dfs_2D_splits.append(rho_ln - 1)
        
    dfs_2D = np.array(dfs_2D_splits)[:, None, :, :]
    dfs_2D = np.log(dfs_2D+1)
    dfs_2D = (dfs_2D - maps_mean)/maps_std
    return dfs_2D

In [None]:
# set fiducial cosmological parameters
params = np.array([0.3, 0.8])

# generate maps to estimate covariance matrix for VICReg summaries
dfs_2D = get_maps_arr(params, splits = 1_000, 
                      maps_mean = maps_mean, maps_std = maps_std)
# compute inferred parameters (means and covariance) from the maps
model_encoder.eval()
with torch.no_grad(): 
    x = torch.tensor(dfs_2D).float().to(device)
    representations = model_encoder(x) 
    inferred_params = inference_net(representations) 
    
    y_pred, cov_pred = inferred_params[:, :n_params], inferred_params[:, n_params:]
    Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
   

### (2.3) Compute and plot coverage

In [None]:
alpha = torch.linspace(0, 1, 20).to(device=device)

thetas = torch.tensor([params]).repeat(1000, 1).to(device=device)
coverage_arr = compute_coverage_normal(thetas, y_pred, Cov, num_monte_carlo= 1_000, alpha=alpha)

plt.plot(figsize=(10, 10))
plt.gca().set_aspect('equal')
plt.plot(alpha.cpu(), alpha.cpu(), c = 'k', ls = '--')
plt.plot(alpha.cpu(), coverage_arr.cpu(),  
         c = 'coral', label = 'VICReg', lw=2)
plt.xlabel('Confidence level')
plt.ylabel('Empirical coverage')
plt.legend(loc = 'best', fontsize=10)

## (3) Emulator + SBI

### (3.1) Load the emulator model

In [None]:
def build_maf(dim=1, num_transforms=8, context_features=None, hidden_features=128):
    transform = transforms.CompositeTransform(
        [
            transforms.CompositeTransform(
                [
                    transforms.MaskedAffineAutoregressiveTransform(
                        features=dim,
                        hidden_features=hidden_features,
                        context_features=context_features,
                        num_blocks=2,
                        use_residual_blocks=False,
                        random_mask=False,
                        activation=torch.tanh,
                        dropout_probability=0.0,
                        use_batch_norm=False,
                    ),
                    transforms.RandomPermutation(features=dim),
                ]
            )
            for _ in range(num_transforms)
        ]
    )

    distribution = distributions_.StandardNormal((dim,))
    neural_net = flows.Flow(transform, distribution)

    return neural_net


In [None]:
# load the emulator model
fmodel = ...

last_layer = 16
flow_net = build_maf(dim=last_layer, context_features=2).to(device=device)
flow_net.load_state_dict(torch.load(fmodel, map_location=torch.device(device)))
flow_net.eval(); 

### (3.2) Load posterior

In [None]:
params = torch.tensor([0.3, 0.8])
OmegaM, sigma8 = params[0], params[1]

posterior = pickle.load(open(.., "rb"))

### (3.3) Get training data, compute and plot coverage

In [None]:
thetas = torch.tensor([[OmegaM, sigma8]]).repeat(1000, 1)
xs  = flow_net.sample(num_samples=1, context=thetas.to(device=device)).cpu().detach()[:, 0, :]

alpha = torch.linspace(0, 1, 20)
coverage_arr = compute_coverage(posterior, thetas.numpy(), xs.numpy(), alpha=alpha)

plt.plot(figsize=(10, 10))
plt.gca().set_aspect('equal')
plt.plot(alpha, alpha, c = 'k', ls = '--')
plt.plot(alpha, coverage_arr,  
         c = 'teal', label = 'Emulator + SBI')
plt.xlabel('Confidence level')
plt.ylabel('Empirical coverage')
plt.legend(loc='best', fontsize=10)
