# Extend N3FIT to integrate nDIS/nDY

The following notebook replicates the main parts of the **n3fit** fitting code. Its main purpose is to serve as a playground to implement various (subtle) features such as the inclusion of nuclear fits. Similar to the main fitting code, we rely on the `n3fit.backends` to perform various operations. As an important information, this notebook also relies on a modified version of some parts of the main code, these changes mainly affect the `backends` and the `layers` modules. Incrementally, we are going to complete the **nDIS** part first and then the **nDY**.

## 1. Short introduction

Most of the nuclear (Neutral Current, or in short NC) DIS datasets are given as ration of structure functions with different nuclei:
$$ \mathcal{O} (x, A_1, A_2, Q^2) = \frac{F_2 (x, A_1, Q^2)}{F_2 (x, A_2, Q^2)} \quad \mathrm{with} \quad F_2 (x, A, Q^2) = \sum^{n_f}_{i} \sum^{n_x}_{\alpha} \mathrm{FK}_{ij} (x, x_\alpha, Q^2, Q^2_0) f_i^A (x, Q^2_0) $$
where $A$ denotes the atomic mass number, $f^A$ denotes the bound-nucleon PDF for a nucleus with atomic number $A$, and the rest carries the usual meaning. It is important to emphasize that the **FK** tables that appear in the numerator and in the denominator are the same. In turns, the bound-nucleon PDFs $f_i^A$ at a momentum fraction $x$ and scale $Q^2_0$ are expressed in terms of the bound-proton PDFs $f_i^{p/A}$ and bound-neutron PDFs $f_i^{n/A}$ as follows:
$$ f_i^A (x, Q^2_0) = Z f_i^{p/A} (x, Q^2_0) + (A-Z) f_i^{n/A} (x, Q^2_0). $$
The bound-proton and bound-neutron PDFs are related by **isospin asymmetry** via the following relations:
$$ u^{p/A}(x, Q^2_0) = d^{n/A}(x, Q^2_0), \: d^{p/A}(x, Q^2_0) = u^{n/A}(x, Q^2_0), \: \bar{u}^{p/A}(x, Q^2_0) = \bar{d}^{n/A}(x, Q^2_0), \: \bar{d}^{p/A}(x, Q^2_0) = \bar{u}^{n/A}(x, Q^2_0) $$
and $f_i^{p/A} = f_i^{n/A}$ for other PDF flavours. In practice, one fits the bound-proton PDFs in which constraints such as **sum rules** can be imposed.

#### Evolution basis version:

## 2. Import modules

In [1]:
import random
import pickle
import numpy as np
import tensorflow as tf
import logging

from abc import abstractmethod, ABC
from dataclasses import dataclass
from rich.console import Console
from rich.table import Table

# Use n3fit backends which are wrappers around
# tf.keras backends.
from n3fit.backends import Input
from n3fit.backends import base_layer_selector
from n3fit.backends import MetaModel
from n3fit.backends import callbacks
from n3fit.backends import MetaLayer
from n3fit.backends import operations as op
from n3fit.backends import clear_backend_state

from n3fit.stopping import Stopping
from n3fit.msr import msr_impose

from n3fit.layers import DIS, DY
from n3fit.layers import ObsRotation
from n3fit.layers import losses
from n3fit.layers import Preprocessing
from n3fit.layers import FkRotation
from n3fit.layers import FlavourToEvolution
from n3fit.backends import MetaLayer, Lambda
from n3fit.backends import base_layer_selector
from n3fit.backends import regularizer_selector

# Define seeds
random.seed(123)
np.random.seed(456)
console = Console()

from n3fit.vpinterface import N3PDF
log = logging.getLogger(__name__)

Using Keras backend


## 3. Load toy-datasets and Add $A$-dependence

For the sake of simplicity, we use as inputs saved files (`toyexpinfo.pkl` for experimental datasets, `posdatasets.pkl` for positivity datasets, and `integdatasets.pkl` for integrability datasets) generated from a **n3fit** run using the `toy-runcard.yaml`.

In [2]:
exp_pkl_file = open("toyexpinfo.pkl", "rb")
toyexpinfo = pickle.load(exp_pkl_file)

In [3]:
pos_pkl_file = open("posdatasets.pkl", "rb")
toyposdatasets = pickle.load(pos_pkl_file)

In [4]:
integ_pkl_file = open("integdatasets.pkl", "rb")
toyintegdatasets = pickle.load(integ_pkl_file)

Now, we need to **add** the $A$-dependence to the various datasets to be passed along the FK tables for training. One still needs to think about how to include such information in **n3fit/validphys**.

In [5]:
A_dicts = {"PB": 208, "BE": 9, "C": 12}

In [6]:
def add_A_dependence(expinfo: list) -> list:
    """Takes the usual inputs for n3fit (raw datasets from validphys/NNPDF)
    and add the A-values to the dictionaries. The following should be added
    to validphys/n3fit somehow, for examples, through the input run card.
    
    This is just for testing purposes, ie. not physical meanings at all.
    """
    for dataset_group in expinfo:
        for dataset in dataset_group["datasets"]:
            name_split = dataset["name"].split("_")
            if len(name_split) <= 2:
                A1, A2 = 1, 1
            elif len(name_split) == 3:
                A1 = A_dicts[name_split[1]]
                A2 = A_dicts[name_split[2]]
            else:
                raise ValueError("Unappropriate Dataset")
            dataset["A1"] = A1
            dataset["A2"] = A2
    return expinfo

In [7]:
def list_active_A(expinfo: list) -> list:
    """Take the new list from `add_A_dependence` in which the information
    on the atomic mass number A is included and returns an order list of
    the A included in the fit.
    """
    A_lists = []
    for dataset_group in expinfo:
        for dataset in dataset_group["datasets"]:
            A_lists.append(dataset["A1"])
            A_lists.append(dataset["A2"])
    return sorted(list(set(A_lists)))

In [8]:
def summarize_expinfo(expinfo: list) -> None:
    """Summarize the information concerning the atomic mass number
    for the datasets included in the actual fitting playgrounds.
    """
    table = Table(show_header=True, header_style="bold magenta")
    table.add_column("Dataset", justify="left", width=24)
    table.add_column("A1", justify="left", width=24)
    table.add_column("A2", justify="left", width=24)
    for dataset_group in expinfo:
        for dataset in dataset_group["datasets"]:
            table.add_row(
                f"{dataset['name']}", 
                f"{dataset['A1']}", 
                f"{dataset['A2']}"
            )
    console.print(table)

In [9]:
new_toyexpinfo = add_A_dependence(toyexpinfo)
list_fitted_A = list_active_A(new_toyexpinfo)

In [10]:
summarize_expinfo(new_toyexpinfo)
console.print(f"List of fitted As: {list_fitted_A}", style="bold cyan")

## 4. Create the NN architectures

### 4.1 Construct the hidden layers

In [11]:
def generate_dense_network(
    nodes_in,
    nodes,
    activations,
    initializer_name="glorot_normal",
    As_number=1,
    seed=0,
    dropout_rate=0.0,
    regularizer=None,
):
    """This function generate the different `tf.keras.layers.Dense` layers and add
    them to a list. The number of layers depends on the length of the `node_per_layer`.
    This function mimicks the behaviour of the `generate_dense_network` function in 
    the `n3fit.model_gen` module.
    
    Modifications:
    --------------
    For fits in the evolution basis, the output of the NNs are 8 PDFs, all for the
    proton PDFs. When including the nuclear PDFs, this number is multiplied by the
    number of active A's involved in the fit.
    
        Parameters:
        -----------
        As_number: int
            Number of active A's (nb of A involved in the fit.)
    
        Returns:
        --------
        list:
            List of `tf.keras.layers.Dense` layers of dimension `nodes[-1]*As_number`
    """
    list_of_pdf_layers = []
    # Modifications: Multiply the number of nodes in the last layer
    # with the number of active As involved in the fitting procedure.
    nodes[-1] *= As_number
    number_of_layers = len(nodes)
    if dropout_rate > 0:
        dropout_layer = number_of_layers - 2
    else:
        dropout_layer = -1
    for i, (nodes_out, activation) in enumerate(zip(nodes, activations)):
        if dropout_rate > 0 and i == dropout_layer:
            list_of_pdf_layers.append(base_layer_selector("dropout", rate=dropout_rate))
        init = MetaLayer.select_initializer(initializer_name, seed=seed + i)
        arguments = {
            "kernel_initializer": init,
            "units": int(nodes_out),
            "activation": activation,
            "input_shape": (nodes_in,),
            "kernel_regularizer": regularizer,
        }
        layer = base_layer_selector("dense", **arguments)
        list_of_pdf_layers.append(layer)
        nodes_in = int(nodes_out)

    return list_of_pdf_layers

### 4.3 Construct the complete Model

In [12]:
def pdfNN_layer_generator(
    inp=2,
    nodes=None,
    activations=None,
    initializer_name="glorot_normal",
    As_number=1,
    layer_type="dense",
    flav_info=None,
    fitbasis="NN31IC",
    out=14,
    seed=None,
    dropout=0.0,
    regularizer=None,
    regularizer_args=None,
    impose_sumrule=None,
    scaler=None,
    parallel_models=1,
):
    """In case of proton fit, this function acts in the standard way. In case A!=1,
    further extensions had to be implemented. Recall that the output of the
    `generate_dens_network` has a dimension (FITTING_BASIS_SIZE*As_number). Similar
    operations as in the proton fit therefore applies to each individual A involved 
    in the fit. As a result, various custom layers had to be extended to take this
    into consideration.
    
        * Modified layers so far: FKRotation, FlavourToEvolution, msr_impose
        * Stil needs to be modified: preprocessing
    """
    if seed is None:
        seed = parallel_models * [None]
    elif isinstance(seed, int):
        seed = parallel_models * [seed]

    if nodes is None:
        nodes = [15, 8]
    ln = len(nodes)

    if impose_sumrule is None:
        impose_sumrule = "All"

    if scaler:
        inp = 1

    if activations is None:
        activations = ["tanh", "linear"]
    elif callable(activations):
        activations = activations(ln)

    if regularizer_args is None:
        regularizer_args = dict()

    number_of_layers = len(nodes)
    last_layer_nodes = nodes[-1]  # (== len(flav_info))

    # First prepare the input for the PDF model and any scaling if needed
    placeholder_input = Input(shape=(None, 1), batch_size=1)

    subtract_one = False
    process_input = Lambda(lambda x: x)
    input_x_eq_1 = [1.0]
    if scaler:
        # change the input domain [0,1] -> [-1,1]
        process_input = Lambda(lambda x: 2 * x - 1)
        subtract_one = True
        input_x_eq_1 = scaler([1.0])[0]
        placeholder_input = Input(shape=(None, 2), batch_size=1)
    elif inp == 2:
        # create a x --> (x, logx) layer to preppend to everything
        process_input = Lambda(lambda x: op.concatenate([x, op.op_log(x)], axis=-1))

    model_input = [placeholder_input]
    if subtract_one:
        layer_x_eq_1 = op.numpy_to_input(np.array(input_x_eq_1).reshape(1, 1))
        model_input.append(layer_x_eq_1)

    # Evolution layer
    layer_evln = FkRotation(input_shape=(last_layer_nodes,), output_dim=out)
    # Basis rotation
    basis_rotation = FlavourToEvolution(flav_info=flav_info, fitbasis=fitbasis)
    # Normalization and sum rules
    if impose_sumrule:
        sumrule_layer, integrator_input = msr_impose(mode=impose_sumrule, scaler=scaler)
        model_input.append(integrator_input)
    else:
        sumrule_layer = lambda x: x

    pdf_models = []
    for i, layer_seed in enumerate(seed):
        if layer_type == "dense":
            reg = regularizer_selector(regularizer, **regularizer_args)
            list_of_pdf_layers = generate_dense_network(
                inp,
                nodes,
                activations,
                initializer_name,
                As_number=As_number,
                seed=layer_seed,
                dropout_rate=dropout,
                regularizer=reg,
            )
        elif layer_type == "dense_per_flavour":
            list_of_pdf_layers = generate_dense_per_flavour_network(
                inp,
                nodes,
                activations,
                initializer_name,
                seed=layer_seed,
                basis_size=last_layer_nodes,
            )

        def dense_me(x):
            """Takes an input tensor `x` and applies all layers
            from the `list_of_pdf_layers` in order"""
            processed_x = process_input(x)
            curr_fun = list_of_pdf_layers[0](processed_x)

            for dense_layer in list_of_pdf_layers[1:]:
                curr_fun = dense_layer(curr_fun)
            return curr_fun

        preproseed = layer_seed + number_of_layers
        layer_preproc = Preprocessing(
            flav_info=flav_info,
            input_shape=(1,),
            name=f"pdf_prepro_{i}",
            seed=preproseed,
            large_x=not subtract_one,
        )

        # Apply preprocessing and basis
        def layer_fitbasis(x):
            x_scaled = op.op_gather_keep_dims(x, 0, axis=-1)
            x_original = op.op_gather_keep_dims(x, -1, axis=-1)

            nn_output = dense_me(x_scaled)
            if subtract_one:
                nn_at_one = dense_me(layer_x_eq_1)
                nn_output = op.op_subtract([nn_output, nn_at_one])

            # Ignore Preprocessing for the Time Being. Still thinking of
            # The best way to take preprocessing into account, whether a
            # same flavour for different A's should be the same.
            # ret = op.op_multiply([nn_output, layer_preproc(x_original)])
            ret = nn_output
            if basis_rotation.is_identity():
                return ret
            return basis_rotation(ret)

        # Rotation layer, changes from the 8-basis to the 14-basis
        def layer_pdf(x):
            return layer_evln(layer_fitbasis(x))

        # Final PDF (apply normalization)
        final_pdf = sumrule_layer(layer_pdf)

        # Create the model
        pdf_model = MetaModel(
            model_input, final_pdf(placeholder_input), name=f"PDF_{i}", scaler=scaler
        )
        pdf_models.append(pdf_model)
        pdf_model.summary()
    return pdf_models

In [13]:
flav_info = [
    {'fl': 'sng', 'trainable': False, 'smallx': [1.094, 1.118], 'largex': [1.46, 3.003]}, 
    {'fl': 'g', 'trainable': False, 'smallx': [0.8189, 1.044], 'largex': [2.791, 5.697]}, 
    {'fl': 'v', 'trainable': False, 'smallx': [0.457, 0.7326], 'largex': [1.56, 3.431]}, 
    {'fl': 'v3', 'trainable': False, 'smallx': [0.1462, 0.4061], 'largex': [1.745, 3.452]}, 
    {'fl': 'v8', 'trainable': False, 'smallx': [0.5401, 0.7665], 'largex': [1.539, 3.393]}, 
    {'fl': 't3', 'trainable': False, 'smallx': [-0.4401, 0.9163], 'largex': [1.773, 3.333]}, 
    {'fl': 't8', 'trainable': False, 'smallx': [0.5852, 0.8537], 'largex': [1.533, 3.436]}, 
    {'fl': 't15', 'trainable': False, 'smallx':[1.082, 1.142], 'largex': [1.461, 3.1]}
]
fitbasis = 'EVOL'
seed = [1872583848]
As_number = 4
nodes_in = 2
nodes = [3, 6, 8]
activations = ['tanh', 'tanh', 'tanh']

In [14]:
pdf_gen = pdfNN_layer_generator(
    nodes=nodes, 
    activations=activations, 
    As_number=As_number,
    flav_info=flav_info,
    fitbasis=fitbasis,
    seed=seed,
    impose_sumrule=True
)

Model: "PDF_0"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(1, None, 1)]       0                                            
__________________________________________________________________________________________________
integration_grid (InputLayer)   [(1, 2000, 1)]       0                                            
__________________________________________________________________________________________________
lambda_6 (Lambda)               (1, None, 1)         0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_4 (Lambda)               (1, 2000, 1)         0           integration_grid[0][0]           
______________________________________________________________________________________________

### ⚠ Implementation below not working yet ⚠

### 4.4 Construct the Observable

Now, we can implement the part that computes the Observable ($\mathcal{O}^{\rm th}$) expressed in the equation above.

In [None]:
def _is_unique(list_of_arrays):
    """ Check whether the list of arrays more than one different arrays """
    the_first = list_of_arrays[0]
    for i in list_of_arrays[1:]:
        if not np.array_equal(the_first, i):
            return False
    return True


class Observable(MetaLayer, ABC):
    """
        This class is the parent of the DIS and DY convolutions.
        All backend-dependent code necessary for the convolutions
                                    is (must be) concentrated here

        The methods gen_mask and call must be overriden by the observables
        where
            - gen_mask: it is called by the initializer and generates the mask between
                        fktables and pdfs
            - call: this is what does the actual operation


        Parameters
        ----------
            fktable_dicts: list
                list of fktable_dicts which define basis and xgrid for the fktables in the list
            fktable_arr: list
                list of fktables for this observable
            operation_name: str
                string defining the name of the operation to be applied to the fktables
            nfl: int
                number of flavours in the pdf (default:14)
    """

    def __init__(self, fktable_dicts, fktable_arr, operation_name, nfl=14, **kwargs):

        """
        Descriptions:
        -------------
        fktable_dicts: 
            list of dictionaries where each element contains some specifictions
            about a given dataset. It has the following keys:
            - ndata{N_dpts}:  nb of datapoints
            - nbasis:
            - nonzero: nb of nonzero PDFs
            - basis: gives the indices of the non-zero PDFs
            - nx{N_x}: size of the z-grid
            - xgrid: array of x-points
            - fktable: Tensor of shape (N_dpts, nonzero_fl, N_x)

        fktable_arr:
        """
        super(MetaLayer, self).__init__(**kwargs)

        self.nfl = nfl

        basis = []
        xgrids = []
        self.fktables = []
        for fktable, fk in zip(fktable_dicts, fktable_arr):
            xgrids.append(fktable["xgrid"])
            basis.append(fktable["basis"])
            self.fktables.append(op.numpy_to_tensor(fk))

        # check how many xgrids this dataset needs
        if _is_unique(xgrids):
            self.splitting = None
        else:
            self.splitting = [i.shape[1] for i in xgrids]

        # check how many basis this dataset needs
        if _is_unique(basis) and _is_unique(xgrids):
            self.all_masks = [self.gen_mask(basis[0])]
            self.many_masks = False
        else:
            self.many_masks = True
            self.all_masks = [self.gen_mask(i) for i in basis]

        self.operation = op.c_to_py_fun(operation_name)
        self.output_dim = self.fktables[0].shape[0]

    def compute_output_shape(self, input_shape):
        return (self.output_dim, None)

    # Overridables
    @abstractmethod
    def gen_mask(self, basis):
        pass

In [None]:
class DIS(Observable):

    def gen_mask(self, basis):

        if basis is None:
            self.basis = np.ones(self.nfl, dtype=bool)
        else:
            basis_mask = np.zeros(self.nfl, dtype=bool)
            for i in basis:
                basis_mask[i] = True
        return op.numpy_to_tensor(basis_mask, dtype=bool)

    def call(self, pdf):

        # DIS never needs splitting
        if self.splitting is not None:
            raise ValueError("DIS layer call with a dataset that needs more than one xgrid?")

        results = []
        # Separate the two possible paths this layer can take
        if self.many_masks:
            for mask, fktable in zip(self.all_masks, self.fktables):
                pdf_masked = op.boolean_mask(pdf, mask, axis=2)
                res = op.tensor_product(pdf_masked, fktable, axes=[(1, 2), (2, 1)])
                results.append(res)
        else:
            pdf_masked = op.boolean_mask(pdf, self.all_masks[0], axis=2)
            for fktable in self.fktables:
                res = op.tensor_product(pdf_masked, fktable, axes=[(1, 2), (2, 1)])
                results.append(res)
        return self.operation(results)

In [None]:
from n3fit.layers.losses import LossPositivity
from n3fit.layers.losses import LossInvcovmat
from n3fit.layers.losses import LossIntegrability

In [None]:
@dataclass
class ObservableWrapper:

    name: str
    observables: list
    dataset_xsizes: list
    invcovmat: np.array = None
    covmat: np.array = None
    multiplier: float = 1.0
    integrability: bool = False
    positivity: bool = False
    data: np.array = None
    rotation: ObsRotation = None  # only used for diagonal covmat

    def _generate_loss(self, mask=None):
        
        if self.invcovmat is not None:
            loss = losses.LossInvcovmat(
                self.invcovmat, self.data, mask, covmat=self.covmat, name=self.name
            )
        elif self.positivity:
            loss = losses.LossPositivity(name=self.name, c=self.multiplier)
        elif self.integrability:
            loss = losses.LossIntegrability(name=self.name, c=self.multiplier)
        return loss

    def _generate_experimental_layer(self, pdf):
        
        if len(self.dataset_xsizes) > 1:
            splitting_layer = op.as_layer(
                op.split,
                op_args=[self.dataset_xsizes],
                op_kwargs={"axis": 1},
                name=f"{self.name}_split",
            )
            split_pdf = splitting_layer(pdf)
        else:
            split_pdf = [pdf]
        output_layers = [obs(p_pdf) for p_pdf, obs in zip(split_pdf, self.observables)]
        ret = op.concatenate(output_layers, axis=2)
        if self.rotation is not None:
            ret = self.rotation(ret)
        return ret

    def __call__(self, pdf_layer, mask=None):
        loss_f = self._generate_loss(mask)
        experiment_prediction = self._generate_experimental_layer(pdf_layer)
        return loss_f(experiment_prediction)

In [None]:
def observable_generator(
    spec_dict, positivity_initial=1.0, integrability=False
):
    
    spec_name = spec_dict["name"]
    dataset_xsizes = []
    model_obs_tr = []
    model_obs_vl = []
    model_obs_ex = []
    model_inputs = []

    for dataset_dict in spec_dict["datasets"]:
        dataset_name = dataset_dict["name"]

        if dataset_dict["hadronic"]:
            Obs_Layer = DY
        else:
            Obs_Layer = DIS

        operation_name = dataset_dict["operation"]

        if spec_dict["positivity"]:
            obs_layer_tr = Obs_Layer(
                dataset_dict["fktables"],
                dataset_dict["tr_fktables"],
                operation_name,
                name=f"dat_{dataset_name}",
            )
            obs_layer_ex = obs_layer_vl = None
        elif spec_dict.get("data_transformation_tr") is not None:
            obs_layer_ex = Obs_Layer(
                dataset_dict["fktables"],
                dataset_dict["ex_fktables"],
                operation_name,
                name=f"exp_{dataset_name}",
            )
            obs_layer_tr = obs_layer_vl = obs_layer_ex
        else:
            obs_layer_tr = Obs_Layer(
                dataset_dict["fktables"],
                dataset_dict["tr_fktables"],
                operation_name,
                name=f"dat_{dataset_name}",
            )
            obs_layer_ex = Obs_Layer(
                dataset_dict["fktables"],
                dataset_dict["ex_fktables"],
                operation_name,
                name=f"exp_{dataset_name}",
            )
            obs_layer_vl = Obs_Layer(
                dataset_dict["fktables"],
                dataset_dict["vl_fktables"],
                operation_name,
                name=f"val_{dataset_name}",
            )

        if obs_layer_tr.splitting is None:
            xgrid = dataset_dict["fktables"][0]["xgrid"]
            model_inputs.append(xgrid)
            dataset_xsizes.append(xgrid.shape[1])
        else:
            xgrids = [i["xgrid"] for i in dataset_dict["fktables"]]
            model_inputs += xgrids
            dataset_xsizes.append(sum([i.shape[1] for i in xgrids]))

        model_obs_tr.append(obs_layer_tr)
        model_obs_vl.append(obs_layer_vl)
        model_obs_ex.append(obs_layer_ex)

    full_nx = sum(dataset_xsizes)
    if spec_dict["positivity"]:
        out_positivity = ObservableWrapper(
            spec_name,
            model_obs_tr,
            dataset_xsizes,
            multiplier=positivity_initial,
            positivity=not integrability,
            integrability=integrability,
        )

        layer_info = {
            "inputs": model_inputs,
            "output_tr": out_positivity,
            "experiment_xsize": full_nx,
        }
        return layer_info

    if spec_dict.get("data_transformation_tr") is not None:
        obsrot_tr = ObsRotation(spec_dict.get("data_transformation_tr"))
        obsrot_vl = ObsRotation(spec_dict.get("data_transformation_vl"))
    else:
        obsrot_tr = None
        obsrot_vl = None

    out_tr = ObservableWrapper(
        spec_name,
        model_obs_tr,
        dataset_xsizes,
        invcovmat=spec_dict["invcovmat"],
        data=spec_dict["expdata"],
        rotation=obsrot_tr,
    )
    out_vl = ObservableWrapper(
        f"{spec_name}_val",
        model_obs_vl,
        dataset_xsizes,
        invcovmat=spec_dict["invcovmat_vl"],
        data=spec_dict["expdata_vl"],
        rotation=obsrot_vl,
    )
    out_exp = ObservableWrapper(
        f"{spec_name}_exp",
        model_obs_ex,
        dataset_xsizes,
        invcovmat=spec_dict["invcovmat_true"],
        covmat=spec_dict["covmat"],
        data=spec_dict["expdata_true"],
        rotation=None,
    )

    layer_info = {
        "inputs": model_inputs,
        "output": out_exp,
        "output_tr": out_tr,
        "output_vl": out_vl,
        "experiment_xsize": full_nx,
    }
    return layer_info

What remains to do now is to combined everything and construct a class to perform a fit.

In [None]:
from itertools import zip_longest

def _pdf_injection(pdf_layers, observables, masks):
    """Takes as input a list of PDF layers and if needed applies masks."""
    return [f(x, mask=m) for f, x, m in zip_longest(observables, pdf_layers, masks)]

In [None]:
PUSH_POSITIVITY_EACH = 100
PUSH_INTEGRABILITY_EACH = 100
CHI2_THRESHOLD = 10.0

def _LM_initial_and_multiplier(input_initial, input_multiplier, max_lambda, steps):
    initial, multiplier = input_initial, input_multiplier
    if multiplier is None:
        if initial is None: initial = 1.0
        multiplier = pow(max_lambda / initial, 1 / max(steps, 1))
    elif initial is None:
        initial = max_lambda / pow(multiplier, steps)
    return initial, multiplier

class ModelTrainer:

    def __init__(
        self,
        exp_info,
        pos_info,
        integ_info,
        flavinfo,
        fitbasis,
        nnseeds,
        pass_status="ok",
        failed_status="fail",
        debug=False,
        kfold_parameters=None,
        max_cores=None,
        model_file=None,
        sum_rules=None,
        parallel_models=1,
    ):
        
        # Save all input information
        self.exp_info = exp_info
        self.pos_info = pos_info
        self.integ_info = integ_info
        if self.integ_info is not None:
            self.all_info = exp_info + pos_info + integ_info
        else:
            self.all_info = exp_info + pos_info
        self.flavinfo = flavinfo
        self.fitbasis = fitbasis
        self._nn_seeds = nnseeds
        self.pass_status = pass_status
        self.failed_status = failed_status
        self.debug = debug
        self.all_datasets = []
        self._scaler = None
        self._parallel_models = parallel_models

        # Initialise internal variables which define behaviour
        if debug:
            self.max_cores = 1
        else:
            self.max_cores = max_cores
        self.model_file = model_file
        self.print_summary = True
        self.mode_hyperopt = False
        self.impose_sumrule = sum_rules
        self._hyperkeys = None
        if kfold_parameters is None:
            self.kpartitions = [None]
            self.hyper_threshold = None
        else:
            self.kpartitions = kfold_parameters["partitions"]
            self.hyper_threshold = kfold_parameters.get("threshold", HYPER_THRESHOLD)
            # if there are penalties enabled, set them up
            penalties = kfold_parameters.get("penalties", [])
            self.hyper_penalties = []
            for penalty in penalties:
                pen_fun = getattr(n3fit.hyper_optimization.penalties, penalty)
                self.hyper_penalties.append(pen_fun)
                log.info("Adding penalty: %s", penalty)
            # Check what is the hyperoptimization target function
            hyper_loss = kfold_parameters.get("target", None)
            if hyper_loss is None:
                hyper_loss = "average"
                log.warning("No minimization target selected, defaulting to '%s'", hyper_loss)
            log.info("Using '%s' as the target for hyperoptimization", hyper_loss)
            self._hyper_loss = getattr(n3fit.hyper_optimization.rewards, hyper_loss)

        # Initialize the dictionaries which contain all fitting information
        self.input_list = []
        self.input_sizes = []
        self.training = {
            "output": [],
            "expdata": [],
            "ndata": 0,
            "model": None,
            "posdatasets": [],
            "posmultipliers": [],
            "posinitials": [],
            "integdatasets": [],
            "integmultipliers": [],
            "integinitials": [],
            "folds": [],
        }
        self.validation = {
            "output": [],
            "expdata": [],
            "ndata": 0,
            "model": None,
            "folds": [],
            "posdatasets": [],
        }
        self.experimental = {
            "output": [],
            "expdata": [],
            "ndata": 0,
            "model": None,
            "folds": [],
        }

        self._fill_the_dictionaries()

        if self.validation["ndata"] == 0:
            # If there is no validation, the validation chi2 = training chi2
            self.no_validation = True
            self.validation["expdata"] = self.training["expdata"]
        else:
            # Consider the validation only if there is validation (of course)
            self.no_validation = False

        self.callbacks = []
        if debug:
            self.callbacks.append(callbacks.TimerCallback())

    def set_hyperopt(self, hyperopt_on, keys=None, status_ok="ok"):
        """Set hyperopt options on and off (mostly suppresses some printing)"""
        self.pass_status = status_ok
        if keys is None:
            keys = []
        self._hyperkeys = keys
        if hyperopt_on:
            self.print_summary = False
            self.mode_hyperopt = True
        else:
            self.print_summary = True
            self.mode_hyperopt = False

            
    def _fill_the_dictionaries(self):
        """
        This function fills the following dictionaries
            -``training``: data for the fit
            -``validation``: data which for the stopping
            -``experimental``: 'true' data, only used for reporting purposes
        with fixed information.

        Fixed information: information which will not change between different runs of the code.
        This information does not depend on the parameters of the fit at any stage
        and so it will remain unchanged between different runs of the hyperoptimizer.

        The aforementioned information corresponds to:
            - ``expdata``: experimental data
            - ``name``: names of the experiment
            - ``ndata``: number of experimental points
        """
        for exp_dict in self.exp_info:
            self.training["expdata"].append(exp_dict["expdata"])
            self.validation["expdata"].append(exp_dict["expdata_vl"])
            self.experimental["expdata"].append(exp_dict["expdata_true"])

            self.training["folds"].append(exp_dict["folds"]["training"])
            self.validation["folds"].append(exp_dict["folds"]["validation"])
            self.experimental["folds"].append(exp_dict["folds"]["experimental"])

            nd_tr = exp_dict["ndata"]
            nd_vl = exp_dict["ndata_vl"]

            self.training["ndata"] += nd_tr
            self.validation["ndata"] += nd_vl
            self.experimental["ndata"] += nd_tr + nd_vl

            for dataset in exp_dict["datasets"]:
                self.all_datasets.append(dataset["name"])
        self.all_datasets = set(self.all_datasets)

        for pos_dict in self.pos_info:
            self.training["expdata"].append(pos_dict["expdata"])
            self.training["posdatasets"].append(pos_dict["name"])
            self.validation["expdata"].append(pos_dict["expdata"])
            self.validation["posdatasets"].append(pos_dict["name"])
        if self.integ_info is not None:
            for integ_dict in self.integ_info:
                self.training["expdata"].append(integ_dict["expdata"])
                self.training["integdatasets"].append(integ_dict["name"])

    def _model_generation(self, pdf_models, partition, partition_idx):
       
        log.info("Generating the Model")
        input_arr = np.concatenate(self.input_list, axis=1).T
        if self._scaler:
            input_arr = self._scaler(input_arr)
        input_layer = op.numpy_to_input(input_arr)

        all_replicas_pdf = []
        for pdf_model in pdf_models:
            full_model_input_dict, full_pdf = pdf_model.apply_as_layer([input_layer])

            all_replicas_pdf.append(full_pdf)

        full_pdf_per_replica = op.stack(all_replicas_pdf, axis=-1)

        sp_ar = [self.input_sizes]
        sp_kw = {"axis": 1}
        splitting_layer = op.as_layer(op.split, op_args=sp_ar, op_kwargs=sp_kw, name="pdf_split")
        splitted_pdf = splitting_layer(full_pdf_per_replica)

        training_mask = validation_mask = experimental_mask = [None]
        if partition and partition["datasets"]:
            if partition.get("overfit", False):
                training_mask = [i[partition_idx] for i in self.training["folds"]]
                validation_mask = [i[partition_idx] for i in self.validation["folds"]]
            experimental_mask = [i[partition_idx] for i in self.experimental["folds"]]
        output_tr = _pdf_injection(splitted_pdf, self.training["output"], training_mask)
        training = MetaModel(full_model_input_dict, output_tr)

        val_pdfs = []
        exp_pdfs = []
        for partial_pdf, obs in zip(splitted_pdf, self.training["output"]):
            if not obs.positivity and not obs.integrability:
                val_pdfs.append(partial_pdf)
                exp_pdfs.append(partial_pdf)
            elif not obs.integrability and obs.positivity:
                val_pdfs.append(partial_pdf)

        output_vl = _pdf_injection(val_pdfs, self.validation["output"], validation_mask)
        validation = MetaModel(full_model_input_dict, output_vl)

        output_ex = _pdf_injection(exp_pdfs, self.experimental["output"], experimental_mask)
        experimental = MetaModel(full_model_input_dict, output_ex)

        if self.print_summary:
            training.summary()

        models = {
            "training": training,
            "validation": validation,
            "experimental": experimental,
        }

        return models

    def _reset_observables(self):
        
        self.input_list = []
        self.input_sizes = []
        for key in ["output", "posmultipliers", "integmultipliers"]:
            self.training[key] = []
            self.validation[key] = []
            self.experimental[key] = []

    def _generate_observables(
        self,
        all_pos_multiplier,
        all_pos_initial,
        all_integ_multiplier,
        all_integ_initial,
        epochs,
        interpolation_points,
    ):
        
        self._reset_observables()
        log.info("Generating layers")
        for exp_dict in self.exp_info:
            if not self.mode_hyperopt:
                log.info("Generating layers for experiment %s", exp_dict["name"])

            exp_layer = observable_generator(exp_dict)

            self.input_list += exp_layer["inputs"]
            self.input_sizes.append(exp_layer["experiment_xsize"])

            self.training["output"].append(exp_layer["output_tr"])
            self.validation["output"].append(exp_layer["output_vl"])
            self.experimental["output"].append(exp_layer["output"])

        for pos_dict in self.pos_info:
            if not self.mode_hyperopt:
                log.info("Generating positivity penalty for %s", pos_dict["name"])

            positivity_steps = int(epochs / PUSH_POSITIVITY_EACH)
            max_lambda = pos_dict["lambda"]

            pos_initial, pos_multiplier = _LM_initial_and_multiplier(
                all_pos_initial, all_pos_multiplier, max_lambda, positivity_steps
            )

            pos_layer = observable_generator(pos_dict, positivity_initial=pos_initial)
            self.input_list += pos_layer["inputs"]
            self.input_sizes.append(pos_layer["experiment_xsize"])

            self.training["output"].append(pos_layer["output_tr"])
            self.validation["output"].append(pos_layer["output_tr"])

            self.training["posmultipliers"].append(pos_multiplier)
            self.training["posinitials"].append(pos_initial)

        # Finally generate the integrability penalty
        if self.integ_info is not None:
            for integ_dict in self.integ_info:
                if not self.mode_hyperopt:
                    log.info("Generating integrability penalty for %s", integ_dict["name"])

                integrability_steps = int(epochs / PUSH_INTEGRABILITY_EACH)
                max_lambda = integ_dict["lambda"]

                integ_initial, integ_multiplier = _LM_initial_and_multiplier(
                    all_integ_initial, all_integ_multiplier, max_lambda, integrability_steps
                )

                integ_layer = observable_generator(
                    integ_dict, positivity_initial=integ_initial, integrability=True
                )
                self.input_list += integ_layer["inputs"]
                self.input_sizes.append(integ_layer["experiment_xsize"])
                self.training["output"].append(integ_layer["output_tr"])
                self.training["integmultipliers"].append(integ_multiplier)
                self.training["integinitials"].append(integ_initial)

        if interpolation_points:
            input_arr = np.concatenate(self.input_list, axis=1)
            input_arr = np.sort(input_arr)
            input_arr_size = input_arr.size

            force_set_smallest = input_arr.min() > 1e-9
            if force_set_smallest:
                new_xgrid = np.linspace(
                    start=1/input_arr_size, stop=1.0, endpoint=False, num=input_arr_size
                )
            else:
                new_xgrid = np.linspace(start=0, stop=1.0, endpoint=False, num=input_arr_size)

            unique, counts = np.unique(input_arr, return_counts=True)
            map_to_complete = []
            for cumsum_ in np.cumsum(counts):
                map_to_complete.append(new_xgrid[cumsum_ - counts[0]])
            map_to_complete = np.array(map_to_complete)
            map_from_complete = unique

            if force_set_smallest:
                map_from_complete = np.insert(map_from_complete, 0, 1e-9)
                map_to_complete = np.insert(map_to_complete, 0, 0.0)

            onein = map_from_complete.size / (int(interpolation_points) - 1)
            selected_points = [round(i * onein - 1) for i in range(1, int(interpolation_points))]
            if selected_points[0] != 0:
                selected_points = [0] + selected_points
            map_from = map_from_complete[selected_points]
            map_from = np.log(map_from)
            map_to = map_to_complete[selected_points]

            try:
                scaler = PchipInterpolator(map_from, map_to)
            except ValueError:
                raise ValueError(
                    "interpolation_points is larger than the number of unique "
                                    "input x-values"
                )
            self._scaler = lambda x: np.concatenate([scaler(np.log(x)), x], axis=-1)

    def _generate_pdf(
        self,
        nodes_per_layer,
        activation_per_layer,
        initializer,
        layer_type,
        dropout,
        regularizer,
        regularizer_args,
        seed,
    ):
        log.info("Generating PDF models")

        pdf_models = pdfNN_layer_generator(
            nodes=nodes_per_layer,
            activations=activation_per_layer,
            layer_type=layer_type,
            flav_info=self.flavinfo,
            fitbasis=self.fitbasis,
            seed=seed,
            initializer_name=initializer,
            dropout=dropout,
            regularizer=regularizer,
            regularizer_args=regularizer_args,
            impose_sumrule=self.impose_sumrule,
            scaler=self._scaler,
            parallel_models=self._parallel_models,
        )
        return pdf_models

    def _prepare_reporting(self, partition):

        reported_keys = ["name", "count_chi2", "positivity", "integrability", "ndata", "ndata_vl"]
        reporting_list = []
        for exp_dict in self.all_info:
            reporting_dict = {k: exp_dict.get(k) for k in reported_keys}
            if partition:
                for dataset in exp_dict["datasets"]:
                    if dataset in partition["datasets"]:
                        ndata = dataset["ndata"]
                        frac = dataset["frac"]
                        reporting_dict["ndata"] -= int(ndata * frac)
                        reporting_dict["ndata_vl"] = int(ndata * (1 - frac))
            reporting_list.append(reporting_dict)
        return reporting_list

    def _train_and_fit(self, training_model, stopping_object, epochs=100):
        
        callback_st = callbacks.StoppingCallback(stopping_object)
        callback_pos = callbacks.LagrangeCallback(
            self.training["posdatasets"],
            self.training["posmultipliers"],
            update_freq=PUSH_POSITIVITY_EACH,
        )
        callback_integ = callbacks.LagrangeCallback(
            self.training["integdatasets"],
            self.training["integmultipliers"],
            update_freq=PUSH_INTEGRABILITY_EACH,
        )

        training_model.perform_fit(
            epochs=epochs,
            verbose=False,
            callbacks=self.callbacks + [callback_st, callback_pos, callback_integ],
        )

        if any(bool(i) for i in stopping_object.e_best_chi2):
            return self.pass_status
        return self.failed_status

    def _hyperopt_override(self, params):
        
        hyperparameters = params.get("parameters")
        if hyperparameters is not None:
            return hyperparameters
        for hyperkey in self._hyperkeys:
            item = params[hyperkey]
            if isinstance(item, dict):
                params.update(item)
        return params

    def enable_tensorboard(self, logdir, weight_freq=0, profiling=False):
        
        callback_tb = callbacks.gen_tensorboard_callback(
            logdir, profiling=profiling, histogram_freq=weight_freq
        )
        self.callbacks.append(callback_tb)

    def evaluate(self, stopping_object):
        
        if self.training["model"] is None:
            raise RuntimeError("Modeltrainer.evaluate was called before any training")
        train_chi2 = stopping_object.evaluate_training(self.training["model"])
        val_chi2 = stopping_object.vl_chi2
        exp_chi2 = self.experimental["model"].compute_losses()["loss"] / self.experimental["ndata"]
        return train_chi2, val_chi2, exp_chi2

    def hyperparametrizable(self, params):
        
        clear_backend_state()

        if self.mode_hyperopt:
            log.info("Performing hyperparameter scan")
            for key in self._hyperkeys:
                log.info(" > > Testing %s = %s", key, params[key])
            params = self._hyperopt_override(params)

        epochs = int(params["epochs"])
        stopping_patience = params["stopping_patience"]
        stopping_epochs = int(epochs * stopping_patience)

        positivity_dict = params.get("positivity", {})
        integrability_dict = params.get("integrability", {})
        self._generate_observables(
            positivity_dict.get("multiplier"),
            positivity_dict.get("initial"),
            integrability_dict.get("multiplier"),
            integrability_dict.get("initial"),
            epochs,
            params.get("interpolation_points"),
        )
        threshold_pos = positivity_dict.get("threshold", 1e-6)
        threshold_chi2 = params.get("threshold_chi2", CHI2_THRESHOLD)

        # Initialize the chi2 dictionaries
        l_valid = []
        l_exper = []
        l_hyper = []
        # And lists to save hyperopt utilities
        n3pdfs = []
        exp_models = []

        ### Training loop
        for k, partition in enumerate(self.kpartitions):
            seeds = self._nn_seeds
            if k > 0:
                seeds = [np.random.randint(0, pow(2, 31)) for _ in seeds]

            # Generate the pdf model
            pdf_models = self._generate_pdf(
                params["nodes_per_layer"],
                params["activation_per_layer"],
                params["initializer"],
                params["layer_type"],
                params["dropout"],
                params.get("regularizer", None),
                params.get("regularizer_args", None),
                seeds,
            )

            models = self._model_generation(pdf_models, partition, k)

            if self.model_file:
                log.info("Applying model file %s", self.model_file)
                for pdf_model in pdf_models:
                    pdf_model.load_weights(self.model_file)

            if k > 0:
                pos_and_int = self.training["posdatasets"] + self.training["integdatasets"]
                initial_values = self.training["posinitials"] + self.training["posinitials"]
                models["training"].reset_layer_weights_to(pos_and_int, initial_values)

            # Generate the list containing reporting info necessary for chi2
            reporting = self._prepare_reporting(partition)

            if self.no_validation:
                # Substitute the validation model with the training model
                models["validation"] = models["training"]
                validation_model = models["training"]
            else:
                validation_model = models["validation"]

            stopping_object = Stopping(
                validation_model,
                reporting,
                pdf_models,
                total_epochs=epochs,
                stopping_patience=stopping_epochs,
                threshold_positivity=threshold_pos,
                threshold_chi2=threshold_chi2,
            )

            for model in models.values():
                model.compile(**params["optimizer"])

            passed = self._train_and_fit(
                models["training"],
                stopping_object,
                epochs=epochs,
            )

            if self.mode_hyperopt:
                validation_loss = np.mean(stopping_object.vl_chi2)

                # Compute experimental loss
                exp_loss_raw = np.average(models["experimental"].compute_losses()["loss"])
                ndata = np.sum([np.count_nonzero(i[k]) for i in self.experimental["folds"]])
                if ndata == 0:
                    ndata = self.experimental["ndata"]
                experimental_loss = exp_loss_raw / ndata

                hyper_loss = experimental_loss
                if passed != self.pass_status:
                    log.info("Hyperparameter combination fail to find a good fit, breaking")
                    break
                for penalty in self.hyper_penalties:
                    hyper_loss += penalty(pdf_models=pdf_models, stopping_object=stopping_object)
                log.info("Fold %d finished, loss=%.1f, pass=%s", k + 1, hyper_loss, passed)

                l_hyper.append(hyper_loss)
                l_valid.append(validation_loss)
                l_exper.append(experimental_loss)
                n3pdfs.append(N3PDF(pdf_models, name=f"fold_{k}"))
                exp_models.append(models["experimental"])

                if hyper_loss > self.hyper_threshold:
                    log.info(
                        "Loss above threshold (%.1f > %.1f), breaking",
                        hyper_loss,
                        self.hyper_threshold,
                    )
                    pen_mul = len(self.kpartitions) - k
                    l_hyper = [i * pen_mul for i in l_hyper]
                    break

        if self.mode_hyperopt:
            dict_out = {
                "status": passed,
                "loss": self._hyper_loss(fold_losses=l_hyper, n3pdfs=n3pdfs, experimental_models=exp_models),
                "validation_loss": np.average(l_valid),
                "experimental_loss": np.average(l_exper),
                "kfold_meta": {
                    "validation_losses": l_valid,
                    "experimental_losses": l_exper,
                    "hyper_losses": l_hyper,
                },
            }
            return dict_out

        self.training["model"] = models["training"]
        self.experimental["model"] = models["experimental"]
        self.validation["model"] = models["validation"]
        dict_out = {"status": passed, "stopping_object": stopping_object, "pdf_models": pdf_models}
        return dict_out

In [None]:
def _summarizes_values(modeltrain, value_name):
    """Gives a brief Summary."""
    
    value = getattr(modeltrain, value_name)
    table = Table(show_header=True, header_style="bold magenta")
    table.add_column(f"Keys ({value_name})", justify="left", width=24)
    table.add_column("Description", justify="left", width=24)
    table.add_column("Value", justify="left", width=24)
    
    table.add_row("model", "Model", f"{value['model'] if value['model'] is not None else 'None'}")
    table.add_row("expdata", "Nb. Experiments", f"{len(value['expdata'])}")
    table.add_row("output", "Nb. Outputs", f"{len(value['output'])}")
    table.add_row("ndata", "Nb. Datapoints", f"{value['ndata']}")
    
    console.print(table)

In [None]:
params = {
    'nodes_per_layer': [15, 10, 8], 
    'activation_per_layer': ['sigmoid', 'sigmoid', 'linear'], 
    'initializer': 'glorot_normal', 
    'optimizer': {'optimizer_name': 'RMSprop', 'learning_rate': 0.01, 'clipnorm': 1.0}, 
    'epochs': 900, 'positivity': {'multiplier': 1.05, 'initial': None, 'threshold': 1e-05}, 
    'stopping_patience': 0.3, 'layer_type': 'dense', 'dropout': 0.0, 'threshold_chi2': 5.0
}

In [None]:
flav_info = [
    {'fl': 'sng', 'trainable': False, 'smallx': [1.094, 1.118], 'largex': [1.46, 3.003]}, 
    {'fl': 'g', 'trainable': False, 'smallx': [0.8189, 1.044], 'largex': [2.791, 5.697]}, 
    {'fl': 'v', 'trainable': False, 'smallx': [0.457, 0.7326], 'largex': [1.56, 3.431]}, 
    {'fl': 'v3', 'trainable': False, 'smallx': [0.1462, 0.4061], 'largex': [1.745, 3.452]}, 
    {'fl': 'v8', 'trainable': False, 'smallx': [0.5401, 0.7665], 'largex': [1.539, 3.393]}, 
    {'fl': 't3', 'trainable': False, 'smallx': [-0.4401, 0.9163], 'largex': [1.773, 3.333]}, 
    {'fl': 't8', 'trainable': False, 'smallx': [0.5852, 0.8537], 'largex': [1.533, 3.436]}, 
    {'fl': 't15', 'trainable': False, 'smallx': [1.082, 1.142], 'largex': [1.461, 3.1]}
]

nnseed = [1872583848]
fitbasis = 'EVOL'
debug = False
max_cores = 8
model_file = None
sum_rules = False
paralle_models = 1

In [None]:
ModelTraining = ModelTrainer(
    toyexpinfo,
    toyposdatasets,
    toyintegdatasets,
    flav_info,
    fitbasis,
    nnseed,
    debug=debug,
    kfold_parameters=None,
    max_cores=max_cores,
    sum_rules=sum_rules,
    parallel_models=1,
)

pdf_gen_and_train_function = ModelTraining.hyperparametrizable

ModelTraining.set_hyperopt(False)

pdf_gen_and_train_function(params)