In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
import numpy as np
from lhapdf import mkPDF, setVerbosity
from n3fit.model_gen import pdfNN_layer_generator
from validphys.api import API
from n3fit.layers.rotations import FkRotation
from matplotlib import pyplot as plt

tf.keras.backend.clear_session()
setVerbosity(0)

In [None]:
pdf_set = "NNPDF40_nnlo_as_01180"
pdf_target = mkPDF(pdf_set)

def pid_to_latex(pid):
    """Function to translate the PID number into latex values, useful for plotting"""
    translate = {21: "g", 1: "d", 2: "u", 3: "s", 4: "c", 5: "b", 6: "t"}
    flav = translate[abs(pid)]
    if pid < 0:
        flav = rf"\bar{{{flav}}}"
    return flav

In [None]:
q0 = 1.65  # Reference scale
npoints = int(5e4)  # How many points to use for training
xgrid = np.concatenate([np.logspace(-5, -1, npoints // 2), np.linspace(0.1, 1, npoints // 2)])[::200]
pdf_grid_all = pdf_target.xfxQ2(xgrid, np.ones_like(xgrid) * q0**2)

In [None]:
fit_info = API.fit(fit="NNPDF40_nnlo_as_01180_1000").as_input()
basis_info = fit_info["fitting"]["basis"]

pdf_model = pdfNN_layer_generator(
    nodes=[25,20, 8],
    activations=['tanh','tanh','linear'],
    initializer_name="glorot_normal",
    layer_type="dense",
    flav_info=basis_info,
    fitbasis="EVOL",
    out=14,
    seed=np.random.randint(0, pow(2, 31)),
    dropout=0.0,
    regularizer=None,
    regularizer_args=None,
    impose_sumrule=False, # NOTE: imposing sumrules will break Gaussianity to a large extent
    scaler=None,
    num_replicas = 1,
    photons=None,
    replica_axis=True,
)

lossfn = tf.keras.losses.MeanSquaredError()
clipnorm = 6.073e-6
pdf_model.compile(optimizer_name='Nadam', learning_rate=2.621e-3, clipnorm=None, loss=lossfn)


In [None]:
# rotate from flavor basis to evolution basis (9 flavor)
def flav_to_evol(flav_vector):
    cbar = flav_vector[0]
    sbar = flav_vector[1]
    ubar = flav_vector[2]
    dbar = flav_vector[3]
    gluon = flav_vector[4]
    d = flav_vector[5]
    u = flav_vector[6]
    s = flav_vector[7]
    c = flav_vector[8]

    cp = 2*c
    sigma = u + ubar + d + dbar + s + sbar + cp
    v = u - ubar + d - dbar + s - sbar + c - cbar
    v3 = u - ubar - d + dbar
    v8 = u - ubar + d - dbar - 2*s + 2*sbar
    t3 = u + ubar - d - dbar
    t8 = u + ubar + d + dbar - 2*s - 2*sbar
    g = gluon
    v15 = u-ubar+d-dbar+s-sbar-3*c+3*cbar

    evol_vector = np.array([sigma, g, v, v3, v8, t3, t8, cp ,v15])
    return evol_vector

In [None]:
# Create training output data in the 14 flavour FK table basis

output_basis = [-4, -3, -2, -1, 21, 1, 2, 3, 4]
output_size = len(output_basis)

output_data = np.zeros((len(pdf_grid_all), output_size))

for i, pdf_grid in enumerate(pdf_grid_all):
    for j, pid in enumerate(output_basis):
        output_data[i, j] = pdf_grid[pid]

output_data = np.array(output_data)

training_data = flav_to_evol(output_data.T)
training_data = training_data.T@FkRotation()._create_rotation_matrix() # 9 flav fitting basis to 14 flav fk basis

In [None]:
def compute_ntk(model, input):
    grad = []
    for data_index in range(input.size):
        x = tf.convert_to_tensor(input[:,[data_index],:])
        with tf.GradientTape() as tape:
            tape.watch(x)
            pred = model(x)

        # compute gradients df(x)/dtheta
        g = tape.gradient(pred, model.trainable_variables)
        # concatenate the gradients of all trainable variables,
        # not discriminating between weights and biases
        g = tf.concat([tf.reshape(i, shape=(-1,1)) for i in g], axis=0)
        grad.append(g)

    grad = tf.concat(grad,axis=1)
    ntk = tf.einsum('ij,ik->jk',grad,grad)
    return ntk


In [None]:
inputgrid = np.expand_dims(xgrid,axis=(0,2))
traininggrid = np.expand_dims(training_data,axis=(0,1))
ntks = []
for epochs in 1*[1000]:
    pdf_model.fit(inputgrid, traininggrid, epochs=epochs)
    ntks.append(compute_ntk(pdf_model,inputgrid))

In [None]:
rel_change = []
for ntk in ntks[1:]:
    rel_change.append((tf.norm(ntks[0] - ntk) / tf.norm(ntks[0])).numpy())

In [None]:
rel_change

In [None]:
def compare_pdfs(xgrid, traininggrid, pdf_model):
    preds = pdf_model(xgrid)[0,0]
    training_data = traininggrid[0,0,:]
    x = xgrid[0,:,0]
    fig, ax = plt.subplots(4, 2, figsize=(12, 6), sharex=True)
    flavs = ["sigma", 'g', 'v', 'v3', 'v8', 't3', 't8', 't15']
    for i, fl in enumerate([1,2,3,4,5,9,10,11]): # 8 flavor basis: sigma, g, v, v3, v8, t3, t8, t15
        plt.subplot(4, 2, i+1)
        plt.ylabel(fr'$x${flavs[i]}')
        plt.plot(x,preds[:,fl], label="NN")
        plt.plot(x,training_data[:,fl], label="NNPDF4.0")
        plt.legend()
        plt.xscale('log')

In [None]:
compare_pdfs(inputgrid, traininggrid, pdf_model)