In [1]:
from validphys.api import API
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from validphys.fkparser import load_fktable
from collections import defaultdict

In [2]:
seed = 1341351341

In [291]:
def generate_sequential_model(outputs=1, 
                   input_layer=None, 
                   nlayers=2, 
                   units=[100,100],
                   seed=seed,
                   **kwargs):
  """
  Create a tensorflow sequential model where all intermediate layers have the same size
  This function accepts an already constructed layer as the input.

  All hidden layers will have the same number of nodes for simplicity

  Arguments:
      outputs: int (default=1)
          number of output nodes (how many flavours are we training)
      input_layer: KerasTensor (default=None)
          if given, sets the input layer of the sequential model
      nlayers: int
          number of hidden layers of the network
      units: int
          number of nodes of every hidden layer in the network
      activation: str
          activation function to be used by the hidden layers (ex: 'tanh', 'sigmoid', 'linear')
  """
  if len(units) != nlayers:
      raise Exception("The length of units must match the number of layers.")
  
  if kwargs.get('kernel_initializer'):
      kernel_initializer = kwargs['kernel_initializer']
  else:
      kernel_initializer = tf.keras.initializers.HeNormal

  if kwargs.get('activation_list'):
      activation_list = kwargs['activation_list']
      if len(units) != len(activation_list):
          raise Exception("The length of the activation list must match the number of layers.")
  else:
      activation_list = ['tanh', 'tanh']

  if kwargs.get('output_func'):
      output_func = kwargs['output_func']
  else:
      output_func = 'linear'
  
  if kwargs.get('name'):
      name = kwargs['name']
  else:
      name = 'pdf'
  
  model = tf.keras.models.Sequential(name=name)
  if input_layer is not None:
      model.add(input_layer)
  for layer in range(nlayers):
      model.add(tf.keras.layers.Dense(units[layer], 
                                      activation=activation_list[layer],
                                      kernel_initializer=kernel_initializer(seed=seed - layer),
                                      ),
      )
  model.add(tf.keras.layers.Dense(outputs, 
                                  activation=output_func, 
                                  kernel_initializer=tf.keras.initializers.HeNormal(seed=seed - nlayers)
                                  ))

  return model

def compute_ntk(model, input):
  grad = []
  for x in tf.convert_to_tensor(input):
    with tf.GradientTape() as tape:
      x = tf.reshape(x, shape=(-1,1))
      #tape.watch(x)
      pred = model(x)

    # compute gradients df(x)/dtheta
    g = tape.gradient(pred, model.trainable_variables)
    # concatenate the gradients of all trainable variables,
    # not discriminating between weights and biases
    size_g = len(g)
    g_minus_out = tf.concat([tf.reshape(g[i], shape=(-1,1)) for i in range(size_g - 2)], axis=0)
    g = np.array([
      np.concatenate([g_minus_out, 
                      tf.reshape(g[-2][:,i],  shape=(-1,1)), 
                      tf.reshape(g[-1][i],  shape=(-1,1))],
                      axis=0
                      )
      for i in range(pred.shape[1])
    ])
    grad.append(g)

  grad = np.array(grad)
  ntk = tf.einsum('aikl,bjkl->ijab', grad, grad)
  return ntk


def produce_ntk_Y(ntk, start_grid_by_exp, grid_size_by_exp, fk_table_list, total_ndata, start_proc_by_exp, index):
  # Constructing ntk_Y
  sub_mats = defaultdict(list)

  for exp_name_1, alpha in start_grid_by_exp.items():
    for exp_name_2, beta in start_grid_by_exp.items():
      # Take the submatrix of the NTK in data space
      ntk_red = ntk[:, :, alpha : alpha + grid_size_by_exp[exp_name_1], beta : beta + grid_size_by_exp[exp_name_2]].numpy()
      fk_I = fk_table_list[exp_name_1]
      fk_J = fk_table_list[exp_name_2]
      start_locs = (start_proc_by_exp[exp_name_1], start_proc_by_exp[exp_name_2]) # Wrong, this should have that index in the data space, not in the grid spaces
      #print("-----------------------------")
      #print(f"{fk_I.shape} - {ntk_red.shape} - {fk_J.shape}")
      fk_ntk = np.tensordot(fk_I, ntk_red, axes=[[1,2],[0,2]])
      fk_ntk_fk = np.tensordot(fk_ntk, fk_J, axes=[[1,2],[1,2]])
      sub_mats[start_locs] = fk_ntk_fk

  result = np.zeros((total_ndata, total_ndata), dtype=np.float32)
  for locs, mat in sub_mats.items():
      xsize, ysize = mat.shape
      print(mat.shape)
      print("------------------------")
      print(f"x| {locs[0]} : {locs[0] + xsize}")
      print(f"y| {locs[1]} : {locs[1] + ysize}")
      print(locs)
      print(mat)
      result[locs[0] : locs[0] + xsize, locs[1] : locs[1] + ysize] = mat

  return pd.DataFrame(result, index=index, columns=index), sub_mats

In [292]:
dataset_inputs = [
  #{'dataset': 'NMC_NC_NOTFIXED_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NMC_NC_NOTFIXED_P_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'SLAC_NC_NOTFIXED_P_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'SLAC_NC_NOTFIXED_D_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'BCDMS_NC_NOTFIXED_P_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'BCDMS_NC_NOTFIXED_D_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'CHORUS_CC_NOTFIXED_PB_DW_NU-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'CHORUS_CC_NOTFIXED_PB_DW_NB-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'NUTEV_CC_NOTFIXED_FE_DW_NU-SIGMARED', 'cfac': ['MAS'], 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'NUTEV_CC_NOTFIXED_FE_DW_NB-SIGMARED', 'cfac': ['MAS'], 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_318GEV_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_225GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_251GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_300GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_318GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_CC_318GEV_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_CC_318GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_318GEV_EAVG_CHARM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  #{'dataset': 'HERA_NC_318GEV_EAVG_BOTTOM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
]

In [293]:
common_dict = dict(
    dataset_inputs=dataset_inputs,
    metadata_group="nnpdf31_process",
    use_cuts='internal',
    datacuts={'t0pdfset': '240701-02-rs-nnpdf40-baseline', 'q2min': 3.49, 'w2min': 12.5},
    theoryid=40000000,
)

In [294]:
groups_data = API.procs_data(**common_dict)
groups_index = API.groups_index(**common_dict)

In [295]:
fk_table_list = defaultdict(list)
x_grid_list = defaultdict(list)
Y = []
total_ndata_wc = 0
total_grid_size = 0
start_grid_by_exp = defaultdict(list)
grid_size_by_exp = defaultdict(list)
start_proc_by_exp = defaultdict(list)

for idx_proc, group_proc in enumerate(groups_data):
  for idx_exp, exp_set in enumerate(group_proc.datasets):

    fkspecs = exp_set.fkspecs
    cuts = exp_set.cuts
    ndata = exp_set.load_commondata().ndata
    fk_table = load_fktable(fkspecs[0])
    fk_table_wc = fk_table.with_cuts(cuts)
    x_grid = fk_table_wc.xgrid

    Y.append(exp_set.load_commondata().central_values.to_numpy())
    fk_table_list[exp_set.name] = fk_table_wc.get_np_fktable()
    x_grid_list[exp_set.name] = x_grid
    start_proc_by_exp[exp_set.name] = total_ndata_wc
    start_grid_by_exp[exp_set.name] = total_grid_size
    grid_size_by_exp[exp_set.name] = x_grid.shape[0]
    total_grid_size += x_grid.shape[0]
    total_ndata_wc += ndata
#print(f"Total number of points after cuts: {total_ndata_wc}")

In [296]:
# Generate NNPDF model
nnpdf = generate_sequential_model(outputs=9, nlayers=2, units=[28, 20],seed=seed, name='NNPDF', kernel_initializer=tf.keras.initializers.GlorotNormal)

In [297]:
# Create index for NTK
grid_index_array = []
grid_index_id = []
counter = 0
for set, x_grid in x_grid_list.items():
  counter += x_grid.size
  for id, x in enumerate(x_grid):
    grid_index_array.append(set)
    grid_index_id.append(id)

grid_multi_index = pd.MultiIndex.from_arrays([grid_index_array, grid_index_id], names=('dataset','id'))

x_grid_total = np.concatenate([grid for grid in x_grid_list.values()])

In [298]:
ntk = compute_ntk(nnpdf, x_grid_total)

In [299]:
ntk_y, submat = produce_ntk_Y(ntk, start_grid_by_exp, grid_size_by_exp, fk_table_list, total_ndata_wc, start_proc_by_exp, groups_index)

(204, 204)
------------------------
x| 204 : 408
y| 204 : 408
(204, 204)
[[7.32335075 7.47860843 7.5413762  ... 6.48483865 6.4516079  6.3897897 ]
 [7.47860843 7.63940526 7.70472913 ... 6.67542857 6.64252079 6.58051269]
 [7.5413762  7.70472913 7.77125908 ... 6.75995643 6.72732461 6.66539844]
 ...
 [6.48483865 6.67542857 6.75995643 ... 6.996914   6.9917984  6.96357986]
 [6.4516079  6.64252079 6.72732461 ... 6.9917984  6.98730612 6.95988385]
 [6.3897897  6.58051269 6.66539844 ... 6.96357986 6.95988385 6.93354636]]


ValueError: could not broadcast input array from shape (204,204) into shape (33,33)

In [234]:
ntk_y.xs(level='dataset', key='SLAC_NC_NOTFIXED_P_DW_EM-F2').T.xs(level='dataset', key='SLAC_NC_NOTFIXED_P_DW_EM-F2').shape

(33, 33)

In [256]:
submat[(0,0)]

array([[7.32335075, 7.47860843, 7.5413762 , ..., 6.48483865, 6.4516079 ,
        6.3897897 ],
       [7.47860843, 7.63940526, 7.70472913, ..., 6.67542857, 6.64252079,
        6.58051269],
       [7.5413762 , 7.70472913, 7.77125908, ..., 6.75995643, 6.72732461,
        6.66539844],
       ...,
       [6.48483865, 6.67542857, 6.75995643, ..., 6.996914  , 6.9917984 ,
        6.96357986],
       [6.4516079 , 6.64252079, 6.72732461, ..., 6.9917984 , 6.98730612,
        6.95988385],
       [6.3897897 , 6.58051269, 6.66539844, ..., 6.96357986, 6.95988385,
        6.93354636]])

In [259]:
result = np.zeros((total_ndata_wc, total_ndata_wc), dtype=np.float32)
for locs, mat in submat.items():
    xsize, ysize = mat.shape
    print("------------------------")
    print(f"x| {locs[0]} : {locs[0] + xsize}")
    print(f"y| {locs[1]} : {locs[1] + ysize}")
    print(locs)
    print(mat)
    result[locs[0] : locs[0] + xsize, locs[1] : locs[1] + ysize] = mat

------------------------
x| 0 : 204
y| 0 : 204
(0, 0)
[[7.32335075 7.47860843 7.5413762  ... 6.48483865 6.4516079  6.3897897 ]
 [7.47860843 7.63940526 7.70472913 ... 6.67542857 6.64252079 6.58051269]
 [7.5413762  7.70472913 7.77125908 ... 6.75995643 6.72732461 6.66539844]
 ...
 [6.48483865 6.67542857 6.75995643 ... 6.996914   6.9917984  6.96357986]
 [6.4516079  6.64252079 6.72732461 ... 6.9917984  6.98730612 6.95988385]
 [6.3897897  6.58051269 6.66539844 ... 6.96357986 6.95988385 6.93354636]]
------------------------
x| 0 : 204
y| 34 : 67
(0, 34)
[[7.29964799 7.35032745 7.18358689 ... 6.13826102 6.1390513  6.13896548]
 [7.46805088 7.52045945 7.3539357  ... 6.32310934 6.32420883 6.32446495]
 [7.53805164 7.59125137 7.42534217 ... 6.40554322 6.40680911 6.40725213]
 ...
 [6.78608867 6.84643024 6.78736996 ... 6.72801395 6.73563438 6.74369334]
 [6.75920143 6.81961271 6.76302421 ... 6.72520995 6.73296133 6.74117884]
 [6.70438522 6.76469565 6.71138565 ... 6.70072358 6.70861513 6.717006  ]]
---

In [240]:
result[0,0]

7.323351

In [245]:
total_ndata_wc

237