In [38]:
from wmin.api import API
from validphys.fkparser import load_fktable
from super_net.theory_predictions import make_dis_prediction, make_had_prediction, OP

from super_net.covmats import sqrt_covmat_jax

import jax.scipy.linalg as jla
import jax
import jax.numpy as jnp
import itertools

import time
import timeit, functools

In [2]:
inp = {
    "dataset_inputs": [
{'dataset':'NMC'},
{'dataset': 'SLACP_dwsh'},
{'dataset': 'SLACD_dw_ite'},
{'dataset': 'BCDMSP_dwsh'},
{'dataset': 'BCDMSD_dw_ite'},
{'dataset': 'CHORUSNUPb_dw_ite'},
{'dataset': 'CHORUSNBPb_dw_ite'},
# {'dataset': 'NTVNUDMNFe_dw_ite', 'cfac': ['MAS']},
# {'dataset': 'NTVNBDMNFe_dw_ite', 'cfac': ['MAS']},
# {'dataset': 'HERACOMBNCEM'} ,
# {'dataset': 'HERACOMBNCEP575'},
# {'dataset': 'HERACOMBNCEP820'},
# {'dataset': 'HERACOMBNCEP920'},
# {'dataset': 'HERACOMBNCEP460'},
# {'dataset': 'HERACOMBCCEP'},
# {'dataset': 'HERACOMBCCEM'} ,
# {'dataset':'HERACOMB_SIGMARED_B'},
# {'dataset': 'HERACOMB_SIGMARED_C'}
                      ],

    "positivity":{
  "posdatasets":[
  {"dataset": "POSF2U", "maxlambda": 1e6},      # Positivity Lagrange Multiplier
  {"dataset": "POSF2DW", "maxlambda": 1e6},
  {"dataset": "POSF2S", "maxlambda": 1e6},
  {"dataset": "POSFLL", "maxlambda": 1e6},
  {"dataset": "POSF2C", "maxlambda": 1e6},
  {"dataset": "POSXUQ", "maxlambda": 1e6},       # Positivity of MSbar PDFs
  {"dataset": "POSXUB", "maxlambda": 1e6},
  {"dataset": "POSXDQ", "maxlambda": 1e6},
  {"dataset": "POSXDB", "maxlambda": 1e6},
  {"dataset": "POSXSQ", "maxlambda": 1e6},
  {"dataset": "POSXSB", "maxlambda": 1e6},
  {"dataset": "POSXGL", "maxlambda": 1e6},
  {"dataset": "POSDYU", "maxlambda": 1e10},      
  {"dataset": "POSDYD", "maxlambda": 1e10},
  {"dataset": "POSDYS", "maxlambda": 1e10},
    ]
    },
    "theoryid": 400,
    "use_cuts": "internal",
    
    # wmin basis specs
    "wminpdfset": "210623_mnc_disonly_linear_1000",
    "n_replicas_wmin": 100,
    
    # Level 0 closure test
    # "fakedata": True,
    # "pseudodata": False, 
    # "closure_test_pdf": "210623_mnc_disonly_linear_1000",
    
    # fit specs
    "use_t0": True,
    "t0pdfset": "210623_mnc_disonly_linear_1000",
    
    # "bayesian_fit": True,
    "wmin_grid_index": 1, # random seed used for random parametrisation of wmin pdf
#     "replica_index": 1,  # random seed used for random noise to central values
    "trval_index": 1
}

In [3]:
weight_minimization_grid = API.weight_minimization_grid(**inp)
data_values = API.make_data_values(**inp)
pred_data = API.make_pred_data(**inp)

LHAPDF 6.5.0 loading all 985 PDFs in set 210623_mnc_disonly_linear_1000
210623_mnc_disonly_linear_1000, version 1; 985 PDF members


2023-11-21 14:03:41.649634: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none
2023-11-21 14:03:41.847687: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none
2023-11-21 14:03:41.878020: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none
2023-11-21 14:03:42.081876: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none
2023-11-21 14:03:42.228008: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none
2023-11-21 14:03:42.228566: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none


LHAPDF 6.5.0 loading /Users/luca/opt/miniconda3/envs/supernet/share/LHAPDF/210623_mnc_disonly_linear_1000/210623_mnc_disonly_linear_1000_0000.dat
210623_mnc_disonly_linear_1000 PDF set, member #0, version 1


2023-11-21 14:03:53.517587: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none
2023-11-21 14:03:53.645712: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:606] fastpath_data is none


In [4]:
def make_chi2_original(
    make_data_values,
    make_pred_data,
):
    """
    Returns a jax.jit compiled function that computes the chi2
    of a pdf grid on a dataset.

    Notes:
        - This function is designed for Bayesian like PDF fits.

    Parameters
    ----------
    make_data_values: training_validation.MakeDataValues
        dataclass containing data for training and validation.

    make_pred_data: theory_predictions.make_pred_data
        super_net provider for (fktable) theory predictions.

    make_posdata_split: training_validation.PosdataTrainValidationSplit
        dataclass inheriting from monte_carlo_utils.TrainValidationSplit

    make_penalty_posdata: theory_predictions.make_penalty_posdata
        super_net provider used to compute positivity penalty.

    alpha: float

    lambda_positivity: float

    Returns
    -------
    @jax.jit Callable
        function to compute chi2 of a pdf grid.

    """
    training_data = make_data_values.training_data
    central_values = training_data.central_values
    covmat = training_data.covmat
    central_values_idx = training_data.central_values_idx

    # sqrt_covmat = jnp.array(sqrt_covmat_jax(covmat))
    inv_covmat = jla.inv(covmat)

    @jax.jit
    def chi2(pdf):
        """ """
        diff = make_pred_data(pdf)[central_values_idx] - central_values

        # sqrt_covmat = jnp.array(sqrt_covmat_jax(covmat))

        # solve_triangular: solve the equation a x = b for x, assuming a is a triangular matrix.
        # chi2_vec = jla.solve_triangular(sqrt_covmat, diff, lower=True)
        # loss = jnp.sum(chi2_vec**2)
        loss = jnp.einsum("i,ij,j", diff, inv_covmat, diff)

        return loss

    return chi2

chi2_original = make_chi2_original(data_values, pred_data)

In [5]:
@jax.jit
def log_likelihood(weights):
    """
    TODO
    """
    wmin_weights = jnp.concatenate((jnp.array([1.0]), weights))
    pdf = jnp.einsum(
        "i,ijk", wmin_weights, weight_minimization_grid.wmin_INPUT_GRID
    )
    return -0.5 * chi2_original(pdf)

In [6]:
def make_chi2_opt(make_data_values, make_pred_data):
    """
    TODO
    """
    training_data = make_data_values.training_data
    central_values = training_data.central_values
    covmat = training_data.covmat
    central_values_idx = training_data.central_values_idx

    predictions = []
    for i in range(weight_minimization_grid.wmin_INPUT_GRID.shape[0]):
        predictions.append(make_pred_data(weight_minimization_grid.wmin_INPUT_GRID[i]))
    predictions = jnp.array(predictions)

    sqrt_covmat = jnp.array(sqrt_covmat_jax(covmat))

    inv_covmat = jla.inv(covmat)

    @jax.jit
    def chi2(weights):
        """
        TODO
        """
        wmin_weights = jnp.concatenate((jnp.array([1.0]), weights))
        theory = jnp.einsum(
            "i,ij", wmin_weights, predictions
        )
        diff = (
            theory[central_values_idx]
            - central_values
        )

        # solve_triangular: solve the equation a x = b for x, assuming a is a triangular matrix.
        # chi2_vec = jla.solve_triangular(sqrt_covmat, diff, lower=True)
        # loss = jnp.sum(chi2_vec**2)

        loss = jnp.einsum("i,ij,j", diff, inv_covmat, diff)
        
        return loss

    return chi2

chi2_opt = make_chi2_opt(data_values, pred_data)

In [7]:
@jax.jit
def log_likelihood_opt(weights):
    """
    TODO
    """
    return -0.5 * chi2_opt(weights)

In [163]:
weights = jax.random.uniform(jax.random.PRNGKey(758493), shape=(10000, 99))

In [165]:
chi2s_or = []
t0 = time.time()
for weight in weights:
    chi2s_or.append(log_likelihood(weight))

t1 = time.time()

total = t1-t0

print("Time for evalutation:", total)

Time for evalutation: 3.0119669437408447


In [14]:
chi2s_opt = []
t0 = time.time()
for weight in weights:
    chi2s_opt.append(log_likelihood_opt(weight))

t1 = time.time()

total = t1-t0

print("Time for evalutation:", total)

Time for evalutation: 0.020735979080200195


## Test positivity impact

In [166]:
def make_penalty_posdataset(posdataset):
    """
    Given a PositivitySetSpec compute the positivity penalty
    as a lagrange multiplier times elu of minus the theory prediction

    Parameters
    ----------
    posdataset : validphys.core.PositivitySetSpec

    Returns
    -------
    @jax.jit CompiledFunction
        Compiled function taking pdf grid and alpha parameter
        of jax.nn.elu function in input and returning
        elu function evaluated on minus the theory prediction

        Note: this is needed in order to compute the positivity
        loss function. Elu function is used to avoid a big discontinuity
        in the derivative at 0 when the lagrange multiplier is very big.

        In practice this function can produce results in the range (-alpha, inf)

        see also nnpdf.n3fit.src.layers.losses.LossPositivity

    """

    pred_funcs = []

    for fkspec in posdataset.fkspecs:
        fk = load_fktable(fkspec).with_cuts(posdataset.cuts)
        if fk.hadronic:
            pred = make_had_prediction(fk)
        else:
            pred = make_dis_prediction(fk)
        pred_funcs.append(pred)

    @jax.jit
    def pos_penalty(pdf, alpha, lambda_positivity):
        return lambda_positivity * jax.nn.elu(
            -OP[posdataset.op](*[f(pdf) for f in pred_funcs]), alpha
        )

    return pos_penalty


def make_penalty_posdata(posdatasets):
    """
    Compute positivity penalty for list of PositivitySetSpec

    Parameters
    ----------
    posdatasets: list
            list of PositivitySetSpec

    Returns
    -------
    @jax.jit CompiledFunction

    """

    predictions = []

    for posdataset in posdatasets:
        predictions.append(make_penalty_posdataset(posdataset))

    @jax.jit
    def pos_penalties(pdf, alpha, lambda_positivity):
        # return predictions[1](pdf, alpha, lambda_positivity)
        return jnp.array(
            list(
                itertools.chain(
                    *[f(pdf, alpha, lambda_positivity) for f in predictions[12:]]
                )
            )
        )

    return pos_penalties

In [167]:
posdata_split = API.make_posdata_split(**inp)
penalty_posdata = make_penalty_posdata(API.posdatasets(**inp))

In [168]:
def make_chi2_with_positivity(
    make_data_values,
    make_pred_data,
    make_posdata_split,
    make_penalty_posdata,
    alpha=1e-7,
    lambda_positivity=1000,
):
    """
    Returns a jax.jit compiled function that computes the chi2
    of a pdf grid on a dataset.

    Notes:
        - This function is designed for Bayesian like PDF fits.

    Parameters
    ----------
    make_data_values: training_validation.MakeDataValues
        dataclass containing data for training and validation.

    make_pred_data: theory_predictions.make_pred_data
        super_net provider for (fktable) theory predictions.

    make_posdata_split: training_validation.PosdataTrainValidationSplit
        dataclass inheriting from monte_carlo_utils.TrainValidationSplit

    make_penalty_posdata: theory_predictions.make_penalty_posdata
        super_net provider used to compute positivity penalty.

    alpha: float

    lambda_positivity: float

    Returns
    -------
    @jax.jit Callable
        function to compute chi2 of a pdf grid.

    """
    training_data = make_data_values.training_data
    central_values = training_data.central_values
    covmat = training_data.covmat
    central_values_idx = training_data.central_values_idx

    # Invert the covmat
    inv_covmat = jla.inv(covmat)

    posdata_training_idx = make_posdata_split.training

    @jax.jit
    def chi2(pdf):
        """ """
        diff = make_pred_data(pdf)[central_values_idx] - central_values

        loss = jnp.einsum("i,ij,j", diff, inv_covmat, diff)

        # add penalty term due to positivity
        pos_penalty = make_penalty_posdata(pdf, alpha, lambda_positivity)[
            posdata_training_idx
        ]
        loss += jnp.sum(pos_penalty)

        return loss

    return chi2

chi2_withpos = make_chi2_with_positivity(data_values, pred_data, posdata_split, penalty_posdata)

In [169]:
@jax.jit
def log_likelihood_withpos(weights):
    """
    TODO
    """
    wmin_weights = jnp.concatenate((jnp.array([1.0]), weights))
    pdf = jnp.einsum(
        "i,ijk", wmin_weights, weight_minimization_grid.wmin_INPUT_GRID
    )
    return -0.5 * chi2_withpos(pdf)

In [171]:
chi2s_withpos = []
t0 = time.time()
for weight in weights:
    chi2s_withpos.append(log_likelihood_withpos(weight))

t1 = time.time()

total = t1-t0

print("Time for evalutation:", total)

Time for evalutation: 13.286328792572021
