In [3]:
from validphys.api import API
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


from gen_dicts import generate_dicts
from utils import XGRID
from model_utils import *

Using Keras backend


In [4]:
seed = 14132124
DEBUG = False

In [5]:
# List of DIS dataset
dataset_inputs = [
  #{'dataset': 'NMC_NC_NOTFIXED_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NMC_NC_NOTFIXED_P_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'SLAC_NC_NOTFIXED_P_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'SLAC_NC_NOTFIXED_D_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'BCDMS_NC_NOTFIXED_P_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'BCDMS_NC_NOTFIXED_D_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'CHORUS_CC_NOTFIXED_PB_DW_NU-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'CHORUS_CC_NOTFIXED_PB_DW_NB-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NUTEV_CC_NOTFIXED_FE_DW_NU-SIGMARED', 'cfac': ['MAS'], 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NUTEV_CC_NOTFIXED_FE_DW_NB-SIGMARED', 'cfac': ['MAS'], 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_225GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_251GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_300GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_CC_318GEV_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_CC_318GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EAVG_CHARM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EAVG_BOTTOM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
]

# Dictionary for validphys API
common_dict = dict(
    dataset_inputs=dataset_inputs,
    metadata_group="nnpdf31_process",
    use_cuts='internal',
    datacuts={'q2min': 3.49, 'w2min': 12.5},
    theoryid=40000000,
    t0pdfset='NNPDF40_nnlo_as_01180',
    use_t0=True
)

In [6]:
# Retrieve data from NNPDF
groups_data = API.procs_data(**common_dict)
tuple_of_dicts = generate_dicts(groups_data)
fk_table_dict = tuple_of_dicts.fk_tables
central_data_dict = tuple_of_dicts.central_data

In [7]:
C_sys = API.dataset_inputs_t0_covmat_from_systematics(**common_dict)
C = API.groups_covmat_no_table(**common_dict)

# Serialize covmat
C.to_pickle(path='./serialised_data/covmat.pkl')
C_index = C.index
C_col = C.columns
#C = pd.DataFrame(C_sys, index=C_index, columns=C_col)
#C = pd.DataFrame(np.identity(C.shape[0]), index=C_index, columns=C_col)
Cinv = np.linalg.inv(C)
Cinv = pd.DataFrame(Cinv, index=C_index, columns=C_col)


L = np.linalg.cholesky(Cinv)
L = L.T
Linv = np.linalg.inv(L)

LHAPDF 6.5.4 loading /opt/homebrew/Caskroom/miniconda/base/envs/nnpdf/share/LHAPDF/NNPDF40_nnlo_as_01180/NNPDF40_nnlo_as_01180_0000.dat
NNPDF40_nnlo_as_01180 PDF set, member #0, version 1; LHAPDF ID = 331100


In [8]:
# Construct the big FK table matrix
ndata = 0
dataname_from_covmat = Cinv.index.get_level_values('dataset').unique()

vecs_for_stack = []
for i, (exp, fk) in enumerate(fk_table_dict.items()):
  ndata += fk.shape[0]
  if exp == dataname_from_covmat[i]:
    vecs_for_stack.append(fk.numpy().reshape((fk.shape[0], fk.shape[1] * fk.shape[2])))
  else:
    raise ValueError(f'Problem encountered {exp} != {dataname_from_covmat[i]}')

FK = np.vstack([fk.numpy().reshape((fk.shape[0], fk.shape[1] * fk.shape[2])) for fk in fk_table_dict.values()])
FK_plus = np.linalg.pinv(FK, rcond=1.e-3)
#FK = np.vstack(vecs_for_stack)

try:
  assert(FK.shape[0] == ndata)
except AssertionError:
  print('The number of points does not match.')

# Check that this FK is what we expect
try:
  test_matrix = np.random.rand(FK.shape[0], FK.shape[0])
  # Matrix product
  mat_prod = FK.T @ test_matrix

  # Tensor product
  shape = (*fk.shape[1:], *fk.shape[1:])
  result = np.zeros((fk.shape[1], fk.shape[2], FK.shape[0]))
  I = 0
  for fk in fk_table_dict.values():
    ndata = fk.shape[0]
    result += np.einsum('Iia, IJ -> iaJ',fk, test_matrix[I : I + ndata, :])
    I += ndata
  
  result_flatten = result.reshape((result.shape[0] * result.shape[1], result.shape[2]))
  assert(np.allclose(result_flatten, mat_prod))
  assert(np.allclose(result, mat_prod.reshape((result.shape[0], result.shape[1], result.shape[2]))))
except AssertionError:
  print('A problem occurred')

# Perform singular value decomposition
fk_l, fk_evals, fk_r = np.linalg.svd(FK)
print('Singular values of the FK tables')
print('--------------------------------')
for idx, val in enumerate(fk_evals):
  print(f'{idx} : {val}')

Singular values of the FK tables
--------------------------------
0 : 108.40599822998047
1 : 104.26300811767578
2 : 99.8062515258789
3 : 96.13958740234375
4 : 93.28633117675781
5 : 89.17032623291016
6 : 85.46747589111328
7 : 82.43325805664062
8 : 78.22171783447266
9 : 73.54110717773438
10 : 61.730709075927734
11 : 55.10063934326172
12 : 49.780094146728516
13 : 45.8709831237793
14 : 39.89619064331055
15 : 32.68745803833008
16 : 26.800718307495117
17 : 25.57478904724121
18 : 18.776426315307617
19 : 17.752408981323242
20 : 15.906761169433594
21 : 13.151390075683594
22 : 12.943547248840332
23 : 12.117106437683105
24 : 10.707345962524414
25 : 10.439668655395508
26 : 10.418147087097168
27 : 10.03699779510498
28 : 7.709445953369141
29 : 7.311756610870361
30 : 6.263463973999023
31 : 6.017333030700684
32 : 5.891967296600342
33 : 5.685206890106201
34 : 5.381591320037842
35 : 5.045845031738281
36 : 5.0239105224609375
37 : 4.767819881439209
38 : 4.578884124755859
39 : 4.268607139587402
40 : 4.1912

In [133]:
from jax import jit
from jax import grad
from jax import random

import jax.numpy as jnp
from jax.nn import log_softmax
from jax.example_libraries import optimizers

import neural_tangents as nt
from neural_tangents import stax

In [164]:
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (50, 224))

init_fn, f, _ = stax.serial(
    stax.Dense(224),
    stax.Dense(1000, 1., 0.05),
    stax.Erf(),
    stax.Dense(1000, 1., 0.05),
    stax.Erf(),
    stax.Dense(9, 1., 0.05))

key = random.PRNGKey(0)
_, params = init_fn(key, x1.shape)

In [165]:
ntk = nt.empirical_ntk_fn(f, vmap_axes=0, trace_axes=(), implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)

#g_dd = ntk(XGRID.reshape(50,1), None, params)
#g_td = ntk(test['image'], train['image'], params)

In [166]:
learning_rate = 0.0000001
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

In [167]:
def loss(fx, y_dict):
  # Perform the contraction: sum over the second and third axes
  predictions = {}
  for exp, fk in fk_table_dict.items():
    fk = jnp.array(fk.numpy())
    predictions[exp] = jnp.einsum('Iia, ia -> I ', fk, fx)

  loss = 0

  for exp, pred in predictions.items():
    Cinv_exp = jnp.array(Cinv.xs(level="dataset", key=exp).T.xs(level="dataset", key=exp).to_numpy())
    R = jnp.array(pred) - jnp.array(y_dict[exp])
    Cinv_R = Cinv_exp @ R
    loss += 0.5 * jnp.mean(R @ Cinv_R)
  return loss

grad_loss = jit(grad(lambda params, x, y_dict: loss(f(params, x), y_dict)))

In [168]:
ntk_steps = []
for i in range(10000):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, x1, central_data_dict), state)
  ntk_steps.append(ntk(x1, None, params))
  if i % 100 == 0:
    exact_loss = loss(f(params, x1), central_data_dict)
    print('{}\t{:.4f}'.format(i, exact_loss))

0	687695.3125
100	187076.7656
200	150806.9688
300	128955.4922
400	112942.0156
500	100266.7422
600	89850.2422
700	81103.1250
800	73649.2734
900	67225.5078
1000	61638.6172
1100	56741.7148
1200	52421.1211
1300	48587.0508
1400	45167.6836
1500	42104.7148
1600	39350.2852
1700	36864.7773
1800	34614.9141
1900	32572.7324
2000	30714.2969
2100	29019.1836
2200	27469.7227
2300	26050.5723
2400	24748.4238
2500	23551.5605
2600	22449.6914
2700	21433.7129
2800	20495.5566
2900	19628.0938
3000	18824.8711
3100	18080.2051
3200	17388.9453
3300	16746.4922
3400	16148.6904
3500	15591.8037
3600	15072.4424
3700	14587.5518
3800	14134.3711
3900	13710.3457
4000	13313.2295
4100	12940.8779
4200	12591.4170
4300	12263.0957
4400	11954.3418
4500	11663.6758
4600	11389.7656
4700	11131.4473
4800	10887.5508
4900	10657.0762
5000	10439.0586
5100	10232.6455
5200	10037.0547
5300	9851.5322
5400	9675.4102
5500	9508.0352
5600	9348.8672
5700	9197.3555
5800	9053.0059
5900	8915.3662
6000	8784.0117
6100	8658.5557
6200	8538.6494
6300	842

KeyboardInterrupt: 

In [None]:
np.linalg.norm(ntk_steps[0] - ntk_steps[-1])

10.58436