In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
import numpy as np
from lhapdf import mkPDF, setVerbosity
from n3fit.model_gen import pdfNN_layer_generator
from validphys.api import API
from n3fit.layers.rotations import FkRotation

tf.keras.backend.clear_session()
setVerbosity(0)

2024-01-29 17:08:22.157862: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-29 17:08:22.197731: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-29 17:08:22.197762: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-29 17:08:22.198950: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-29 17:08:22.205480: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-29 17:08:22.205927: I tensorflow/core/platform/cpu_feature_guard.cc:1

Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead
Using Keras backend


In [2]:
pdf_set = "NNPDF40_nnlo_as_01180"
pdf_target = mkPDF(pdf_set)

def pid_to_latex(pid):
    """Function to translate the PID number into latex values, useful for plotting"""
    translate = {21: "g", 1: "d", 2: "u", 3: "s", 4: "c", 5: "b", 6: "t"}
    flav = translate[abs(pid)]
    if pid < 0:
        flav = rf"\bar{{{flav}}}"
    return flav

In [3]:
q0 = 1.65  # Reference scale
npoints = int(5e4)  # How many points to use for training
xgrid = np.concatenate([np.logspace(-5, -1, npoints // 2), np.linspace(0.1, 1, npoints // 2)])[::200]
pdf_grid_all = pdf_target.xfxQ2(xgrid, np.ones_like(xgrid) * q0**2)

In [4]:
fit_info = API.fit(fit="NNPDF40_nnlo_as_01180_1000").as_input()
basis_info = fit_info["fitting"]["basis"]

pdf_model = pdfNN_layer_generator(
    nodes=[25,20, 8],
    activations=['tanh','tanh','linear'],
    initializer_name="glorot_normal",
    layer_type="dense",
    flav_info=basis_info,
    fitbasis="EVOL",
    out=14,
    seed=np.random.randint(0, pow(2, 31)),
    dropout=0.0,
    regularizer=None,
    regularizer_args=None,
    impose_sumrule=False, # NOTE: imposing sumrules will break Gaussianity to a large extent
    scaler=None,
    num_replicas = 1,
    photons=None,
    replica_axis=True,
)

lossfn = tf.keras.losses.MeanSquaredError()
pdf_model.compile(optimizer_name='Nadam', learning_rate=2.621e-3, clipnorm=6.073e-6, loss=lossfn)


In [5]:
# rotate from flavor basis to evolution basis (9 flavor)
def flav_to_evol(flav_vector):
    cbar = flav_vector[0]
    sbar = flav_vector[1]
    ubar = flav_vector[2]
    dbar = flav_vector[3]
    gluon = flav_vector[4]
    d = flav_vector[5]
    u = flav_vector[6]
    s = flav_vector[7]
    c = flav_vector[8]

    cp = 2*c
    sigma = u + ubar + d + dbar + s + sbar + cp
    v = u - ubar + d - dbar + s - sbar + c - cbar
    v3 = u - ubar - d + dbar
    v8 = u - ubar + d - dbar - 2*s + 2*sbar
    t3 = u + ubar - d - dbar
    t8 = u + ubar + d + dbar - 2*s - 2*sbar
    g = gluon
    v15 = u-ubar+d-dbar+s-sbar-3*c+3*cbar

    evol_vector = np.array([sigma, g, v, v3, v8, t3, t8, cp ,v15])
    return evol_vector

In [6]:
# Create training output data in the 14 flavour FK table basis

output_basis = [-4, -3, -2, -1, 21, 1, 2, 3, 4]
output_size = len(output_basis)

output_data = np.zeros((len(pdf_grid_all), output_size))

for i, pdf_grid in enumerate(pdf_grid_all):
    for j, pid in enumerate(output_basis):
        output_data[i, j] = pdf_grid[pid]

output_data = np.array(output_data)

training_data = flav_to_evol(output_data.T)
training_data = training_data.T@FkRotation()._create_rotation_matrix() # 9 flav fitting basis to 14 flav fk basis

In [7]:
def compute_ntk(model, input):
    grad = []
    for x in tf.convert_to_tensor(input):
        with tf.GradientTape() as tape:
            # x = tf.reshape(x, shape=(-1,1))
            tape.watch(x)
            pred = model(x)

        # compute gradients df(x)/dtheta
        g = tape.gradient(pred, model.trainable_variables)
        # concatenate the gradients of all trainable variables,
        # not discriminating between weights and biases
        g = tf.concat([tf.reshape(i, shape=(-1,1)) for i in g], axis=0)
        grad.append(g)

    grad = tf.concat(grad,axis=1)
    ntk = tf.einsum('ij,ik->jk',grad,grad)
    return ntk


In [8]:
inputgrid = np.expand_dims(xgrid,axis=(0,2))
traininggrid = np.expand_dims(training_data,axis=(0,1))
ntks = []
for epochs in 10*[2]:
    pdf_model.fit(inputgrid, traininggrid, epochs=epochs)
    ntks.append(compute_ntk(pdf_model,np.swapaxes(inputgrid, axis1=1, axis2=2)))

Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2
Epoch 1/2
Epoch 2/2


In [9]:
rel_change = []
for ntk in ntks[1:]:
    rel_change.append((tf.norm(ntks[0] - ntk) / tf.norm(ntks[0])).numpy())

In [10]:
rel_change

[0.022952486,
 0.052123144,
 0.08709725,
 0.115537934,
 0.1352531,
 0.15246178,
 0.16642924,
 0.17687486,
 0.18540986]

In [11]:
tf.norm(ntks[0])

<tf.Tensor: shape=(), dtype=float32, numpy=27628380.0>

In [12]:
[tf.norm(ntk) for ntk in ntks]

[<tf.Tensor: shape=(), dtype=float32, numpy=27628380.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=26994240.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=26188302.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=25222024.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=24436254.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=23891556.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=23416108.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=23030210.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=22741614.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=22505806.0>]