In [37]:
import numpy as np
from triangular_transport.flows.methods.utils import UnitGaussianNormalizer
from scipy.stats import norm, chi2
import h5py

In [8]:
def read_data_h5(path="data.h5"):
    with h5py.File(path, "r") as f:
        targets = f["/target"][...]
        data = f["/data"][...]
    return targets, data

In [None]:
ys = np.load("training_dataset/solutions_grid_delta.npy")
targets, yobs = read_data_h5()

In [51]:
ys_mean = np.mean(ys, axis=0)
ys_std = np.std(ys, axis=0)

ys_98 = norm.ppf(0.98, loc=ys_mean, scale=ys_std)
ys_68 = norm.ppf(0.68, loc=ys_mean, scale=ys_std)

cov = np.cov(ys, rowvar=False)
inv_cov = np.linalg.inv(cov)

def mahal_square(y):
    diff = y - ys_mean
    return diff.T @ inv_cov @ diff

mahal = mahal_square(yobs)

In [42]:
p = chi2.cdf(mahal, df=ys.shape[1])
p

np.float64(0.8501599448129048)

The variable $p$ is basically saying that $y_{\mathrm{obs}}$ lies on an ellipsoid that contains a fraction $p$ of the Gaussian mass. So it's really similar to a quantile, though it isn't quite the same. It is the exact Gaussian quantile of the Mahalanobis distance. The reason why this is hard to interpret is because the y's are multivariate so it's impossible to get a unique notion of quantile.

In [54]:
diff = ys - ys_mean
mahal_all = np.sum(diff @ inv_cov * diff, axis=1)

In [60]:
q98 = np.quantile(mahal_all, 0.98)
idx_closest = np.argmin(np.abs(mahal_all - q98))
y_98 = ys[idx_closest]
y_98

array([0.07550486, 0.07434839, 0.0733999 , 0.07435156, 0.0746158 ,
       0.07471459, 0.07420093, 0.07299112, 0.07272377, 0.07315236,
       0.21279691, 0.21187083, 0.21333539, 0.21486062, 0.2150751 ,
       0.21465419, 0.21412714, 0.21373279, 0.21415745, 0.2150088 ,
       0.33897259, 0.33998645, 0.3415815 , 0.34159089, 0.34174503,
       0.34092612, 0.34043084, 0.34131264, 0.34407856, 0.34519817,
       0.4563399 , 0.4571698 , 0.45755351, 0.45774189, 0.45874082,
       0.45815771, 0.45658872, 0.45693057, 0.45728721, 0.45833377,
       0.56420405, 0.56426266, 0.56385503, 0.5647712 , 0.56634108,
       0.56732625, 0.56586707, 0.56274715, 0.56005272, 0.55874132,
       0.65798134, 0.65760785, 0.65849768, 0.66102324, 0.66294856,
       0.66460024, 0.66412648, 0.66073191, 0.65670552, 0.65421733,
       0.74410571, 0.74362536, 0.74444961, 0.7467645 , 0.75009951,
       0.75290926, 0.75394374, 0.75240504, 0.74982908, 0.74827922,
       0.82486656, 0.82528317, 0.82619393, 0.82782884, 0.83066

Do the same for the median below, following the exact steps from the above