In [1]:
from validphys.api import API
import sys

# Add the path to the library folder
sys.path.append('./lib')

from utils import XGRID, build_fk_matrix, regularize_matrix
from model import PDFmodel, generate_mse_loss
from gen_dicts import generate_dicts
from plot_utils import plot_eigvals
from validphys.api import API

import numpy as np
import pandas as pd

Using Keras backend


# Utility functions for the null space

In [2]:
import scipy as sp
from typing import Any, Tuple
import numpy.typing as npt
def null_space_eig(eigvals: npt.ArrayLike, eigvecs: npt.NDArray[np.float64], tol: float = None) -> Tuple[npt.NDArray[np.float64],npt.NDArray[np.float64]]:
  """
  Compute the kernel and its orthogonal space given as set
  of eigenvalues and eigenvectors.

  The kernel space is constructed out of the eigenvectors whose eigenvalue
  is zero. The eigenvalues are compared to tolerance. If the value is greater
  than the tolerance, then it is considered non-zero.

  The tolerance is a parameter of this function. If `tol` is not provided,
  then it is defined as the product of the largest eigenvalue with the
  smallest precision number given the type the of the eigenvalues.

  Parameters
  ----------
  eigvals: array
    List of eigenvalues that are compared to the tolerance.
  eigvecs: NDArray
    Matrix where the second index select the i-th eigenvector relative
    to the i-th eigenvalue, and the first index runs over the components
    of each eigenvector.
  tol: float
    The tolerance for the zero-value veto.

  Return
  ------
  The two sets of basis vectors for the kernel and its orthogonal space. These
  are subspaces of the original eigenspace provided as an argument. The indexing
  follows the same as `eigvecs`.
  """
  if tol is None:
    tol = np.amax(eigvals, initial=0.) * np.finfo(eigvecs.dtype).eps
  num = np.sum(eigvals > tol, dtype=int) # Number of non-zero eigenvalues
  ker = eigvecs[:,num:]
  orth = eigvecs[:,:num]
  return ker, orth
  
def project_matrix(matrix, basis1, basis2=None):
  """
  Project the matrix into a given basis. If two bases are given,
  then the first basis specifies the projection of the matrix on
  the right space, while the second basis for the left space.

  In particular, the projection is computed as follows $\delta$

  ..math::

    M_{i_{B_1} j_{B_2}} = \sum_{i=1}^{dim(B_1)} \sum_{j=1}^{dim(B_2)}
    \mathbf{v}_{B_2}^{(j)T} \cdot M \cdot \mathbf{v}_{B_1}^{(i)}

  where :math:`\mathbf{v}_{B}^{(i)}` is the i-th vector of the basis B.

  If `basis2` is not provided, then basis2 is taken to be standard basis
  (i.e. the one specified by the identity matrix).

  Note that the matrix is not required to be squared.

  Parameters
  ----------
  matrix: NDArray
    The matrix that is projected.
  basis1: NDArray
    The (right) basis on which the matrix is projected.
  basis2: NDArray
    The (left) basis on which the matrix is projected.

  Returns
  -------
  The projection of the matrix into the bases `basis1` and, if 
  given `basis2`.
  """
  if basis2 is None:
    basis2 = np.eye(matrix.shape[0])

  emb_space1 = basis1.shape[0] # Embedding space base 1
  emb_space2 = basis2.shape[0] # Embedding space base 2

  # Check if the bases are compatible with the matrix
  if matrix.shape[0] != emb_space2 or matrix.shape[1] != emb_space1:
    raise ValueError ('The matrix cannot be projected into the two bases.')

  M_orth = basis2.T @ matrix @ basis1

  return M_orth

def project_vector(vector, basis):
  """
  Project a vector into a given basis.
  """
  basis_dim = basis.shape[1]
  space_dim = basis.shape[0]
  if space_dim != vector.shape[0]:
    raise ValueError ('The matrix cannot be projected into the basis')
  
  res = [np.dot(vector, basis[:,i]) for i in range(basis_dim)]
  return res

In [3]:
seed = 14132124

In [4]:
# List of DIS dataset
dataset_inputs = [
  #{'dataset': 'NMC_NC_NOTFIXED_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NMC_NC_NOTFIXED_P_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'SLAC_NC_NOTFIXED_P_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'SLAC_NC_NOTFIXED_D_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'BCDMS_NC_NOTFIXED_P_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'BCDMS_NC_NOTFIXED_D_DW_EM-F2', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'CHORUS_CC_NOTFIXED_PB_DW_NU-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'CHORUS_CC_NOTFIXED_PB_DW_NB-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NUTEV_CC_NOTFIXED_FE_DW_NU-SIGMARED', 'cfac': ['MAS'], 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'NUTEV_CC_NOTFIXED_FE_DW_NB-SIGMARED', 'cfac': ['MAS'], 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_225GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_251GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_300GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_CC_318GEV_EM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_CC_318GEV_EP-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EAVG_CHARM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
  {'dataset': 'HERA_NC_318GEV_EAVG_BOTTOM-SIGMARED', 'frac': 0.75, 'variant': 'legacy'},
]

# Dictionary for validphys API
common_dict = dict(
    dataset_inputs=dataset_inputs,
    metadata_group="nnpdf31_process",
    use_cuts='internal',
    datacuts={'q2min': 3.49, 'w2min': 12.5},
    theoryid=708,
    t0pdfset='NNPDF40_nnlo_as_01180',
    use_t0=True
)

In [5]:
# Retrieve data from NNPDF
groups_data = API.procs_data(**common_dict)
tuple_of_dicts = generate_dicts(groups_data)
fk_table_dict = tuple_of_dicts.fk_tables
central_data_dict = tuple_of_dicts.central_data
FK = build_fk_matrix(fk_table_dict)

In [6]:
C = API.groups_covmat_no_table(**common_dict)

# Serialize covmat
C_index = C.index
C_col = C.columns
Cinv = np.linalg.inv(C)
Cinv = pd.DataFrame(Cinv, index=C_index, columns=C_col)

# Diagonalize covariance matric
eigvals_Cinv, R_Y = np.linalg.eigh(Cinv)
if eigvals_Cinv[-1] > eigvals_Cinv[0]:
    eigvals_Cinv = eigvals_Cinv[::-1]
    R_Y = R_Y[:,::-1]
D_Y = np.zeros_like(R_Y)
np.fill_diagonal(D_Y, eigvals_Cinv)

LHAPDF 6.5.4 loading /opt/homebrew/Caskroom/miniconda/base/envs/nnpdf/share/LHAPDF/NNPDF40_nnlo_as_01180/NNPDF40_nnlo_as_01180_0000.dat
NNPDF40_nnlo_as_01180 PDF set, member #0, version 1; LHAPDF ID = 331100


In [7]:
# Construct dataframe for predictions
Y = pd.DataFrame(np.zeros(Cinv.shape[0]), index=Cinv.index)
for exp_name, data in central_data_dict.items():
  if data.size == Y.loc[(slice(None), [exp_name], slice(None)), :].size:
    Y.loc[(slice(None), [exp_name], slice(None)), :] = data
  else:
    raise ValueError

In [8]:
nnpdf_model = PDFmodel(input=XGRID,
                       outputs=9,
                       architecture=[28,20],
                       activations=['tanh', 'tanh'],
                       kernel_initializer='RandomNormal',
                       user_ki_args={'mean': 0.0, 'stddev': 1.0},
                       seed=seed,
                       dtype='float64')
NTK = nnpdf_model.compute_ntk()

# Flatten NTK
prod = 1
oldshape = NTK.shape
for k in oldshape[2:]:
    prod *= k
NTK_flat = np.array(NTK).reshape(prod,-1)

# Compute predictions at initialization
f0 = nnpdf_model.predict(squeeze=True)

Load data from GD training

In [9]:
import pickle
with open('training.pkl', 'rb') as file:
    results = pickle.load(file)

pred_in_time = results[1]
pdfs_in_time = results[2]
learning_rate_gd = 0.00000001

# Computing matrices from notes
-------------------------------

# $M = (FK)^T C_Y^{-1} (FK) = RDR^T$

In [11]:
M = FK.T @ Cinv.to_numpy() @ FK
M, (eigvals_M, R) = regularize_matrix(M, tol=None)

# Construct diagonal matrix
D = np.zeros_like(R)
np.fill_diagonal(D, eigvals_M)

# $\tilde{H} = D^{1/2} R^T \Theta R D^{1/2}$

In [None]:
ntk, (eigvals_ntk, R_ntk) = regularize_matrix(NTK_flat)
H_tilde = np.sqrt(D) @ R.T @ ntk @ R @ np.sqrt(D)
H_tilde, (eigvals_H_tilde, eigvecs_H_tilde) = regularize_matrix(H_tilde, tol=None)

# Check if symmetric
print(f'Is symmetric: {np.allclose(H_tilde, H_tilde.T)}')

# Null space of $M$ and $\tilde{H}$

In this part, we compute the $\textrm{ker}(M)$ and its orthogonal space $R(M)$. We then project the matrix $M$ into the these two bases. In particular, we apply
$$
M_{i j} = (\mathbf{v^{(i)}_{B_2}})^T \cdot \mathbf{M} \cdot \mathbf{v^{(j)}_{B_1}} \,,
$$
where $\mathbf{v^{i}_{B_1}}$ and $\mathbf{v^{j}_{B_2}}$ are the i-th and j-th vectors of the respective bases.

When we compute the null space, we need to choose the threshold for the smallest distinguishable eigenvalue. If no tolerance is provided to the `null_space_eig` function, the default option is used (see above). Otherwise, we can specify a custom tolerance. We could choose such value by looking at the eigenvalues of the matrix $M$:

In [14]:
for i, val in enumerate(eigvals_M):
  print(f'{i+1} : {val}')

1 : 4501962.482500519
2 : 1611794.2231508063
3 : 656336.8239098041
4 : 529986.6542886588
5 : 160906.35937601666
6 : 152653.25231474565
7 : 137431.8778275906
8 : 111497.0201272436
9 : 103530.19569058708
10 : 77534.6677476217
11 : 74051.62283799764
12 : 51334.52972635849
13 : 46276.75356787624
14 : 40624.587537446314
15 : 36797.59404387643
16 : 34856.831448954945
17 : 31118.33117213023
18 : 24162.090154794583
19 : 23009.658665213006
20 : 21393.045643644942
21 : 20771.815897040124
22 : 18579.258785449296
23 : 16966.64697764954
24 : 16151.38296418803
25 : 15709.19291735395
26 : 15452.571087635839
27 : 14805.017362400145
28 : 14424.193238841144
29 : 13299.732276637764
30 : 13045.068878934711
31 : 12391.612302926651
32 : 11969.216349270544
33 : 10837.493680749745
34 : 10399.045437910148
35 : 9984.69135816926
36 : 9619.885504185886
37 : 9023.136069034628
38 : 8732.156929857694
39 : 8471.1654950175
40 : 8042.005406821709
41 : 7794.2930522539145
42 : 7151.795793641496
43 : 6854.480170365759
44 

In [43]:
ker_M, orth_M = null_space_eig(eigvals_M, R, tol=1.e-5)

In [44]:
M_pp = project_matrix(M, orth_M, orth_M)
M_kk = project_matrix(M, ker_M, ker_M)
M_pk = project_matrix(M, ker_M, orth_M)
M_kp = project_matrix(M, orth_M, ker_M)

In [45]:
print(f"det(M_pp) = {np.linalg.det(M_pp)}")
print(f"cond(M_pp) = {np.linalg.cond(M_pp)}")

inf

Now the matrix $M$ is decomposed as follows
$$
\mathbf{M} = \left(\begin{matrix}
\mathbf{M}_{KK} & \mathbf{M}_{K \bot}\\
\mathbf{M}_{\bot K} & \mathbf{M}_{\bot \bot}\\
\end{matrix} \right) \,.
$$
Note that, by definition of null space, only $M_{\bot\bot} \neq 0$. The next cell checks if that is effectively true.

In [None]:
print(f'M_pp ?= 0 : {np.allclose(np.zeros_like(M_pp), M_pp)}')
print(f'M_kk ?= 0 : {np.allclose(np.zeros_like(M_kk), M_kk)}')
print(f'M_kp ?= 0 : {np.allclose(np.zeros_like(M_kp), M_kp)}')
print(f'M_pk ?= 0 : {np.allclose(np.zeros_like(M_pk), M_pk)}')

By constructing the matrix $M_{\bot\bot}$, we have projected out the null modes. Hence, this matrix must be invertible. We check that in the following cell:

In [None]:
M_pp_inv = np.linalg.inv(M_pp)
print(f'M_pp_inv @ M_pp ?= Id: {np.allclose(M_pp_inv @ M_pp, np.eye(M_pp.shape[0]))}')
print(f'M_pp @ M_pp_inv ?= Id: {np.allclose(M_pp @ M_pp_inv, np.eye(M_pp.shape[0]))}')

We also need to compute the null space and its orthogonal space of the evolution matrix $\tilde{H}$

In [52]:
ker_H_tilde, orth_H_tilde = null_space_eig(eigvals_H_tilde, eigvecs_H_tilde)

In [53]:
H_tilde_pp = project_matrix(H_tilde, orth_H_tilde, orth_H_tilde)
H_tilde_kk = project_matrix(H_tilde, ker_H_tilde, ker_H_tilde)
H_tilde_pk = project_matrix(H_tilde, ker_H_tilde, orth_H_tilde)
H_tilde_kp = project_matrix(H_tilde, orth_H_tilde, ker_H_tilde)

In [None]:
print(f'H_tilde_pp ?= 0 : {np.allclose(np.zeros_like(H_tilde_pp), H_tilde_pp)}')
print(f'H_tilde_kk ?= 0 : {np.allclose(np.zeros_like(H_tilde_kk), H_tilde_kk)}')
print(f'H_tilde_kp ?= 0 : {np.allclose(np.zeros_like(H_tilde_kp), H_tilde_kp)}')
print(f'H_tilde_pk ?= 0 : {np.allclose(np.zeros_like(H_tilde_pk), H_tilde_pk)}')

Note that the null space of $\tilde{H}$ and that of $M$ are not the same in general:

In [None]:
print(f'{ker_H_tilde.shape} != {ker_M.shape}')

In particular, the null space of the evolution matrix $\tilde{H}$ is larger that that of $M$. This is kind of expected, as the evolution matrix accounts for the contribution of the NTK which is known to provide many small eigenvalues.

# When is $\tilde{H}$ invertible?
The matrix $\tilde{H}$ is not invertible on its own. We can check that directly in the following cell:

In [None]:
try:
  np.linalg.inv(H_tilde)
except np.linalg.LinAlgError:
  print('Error detected. The matrix is not ivertible.')

This was expected, as $\tilde{H}$ is constructed from the full matrix $M$. We can try to construct $\tilde{H}$ using $M_{\bot\bot}$, which should be equivalent to projecting $\tilde{H}$ in $M_{\bot\bot}$. Let's see:

In [57]:
H_tilde_p = project_matrix(H_tilde, orth_M, orth_M)
try:
  H_tilde_p_inv = np.linalg.inv(H_tilde_p)
except np.linalg.LinAlgError:
  print('Error detected. The matrix is not ivertible.')

It seems to work. However, if we try to compute $\tilde{H}_{\bot} \cdot \tilde{H}_{\bot}^{-1}$ we obtain

In [None]:
inv = np.linalg.inv(H_tilde_p)
print(f'Printing the first 10 diagonal entries: \n{(inv @ H_tilde_p).diagonal()[:10]}')

We see that this is not what we thought we computed, namely the inverse matrix of $\tilde{H}_{\bot}$. The reason is the high condition number of the matrix, which makes it highly unstable.

We then resort to $\tilde{H}_{\bot\bot}$ computed before. This takes in account the null modes of the NTK:

In [None]:
try:
  H_tilde_pp_inv = np.linalg.inv(H_tilde_pp)
except np.linalg.LinAlgError:
  print('Error detected. The matrix is not ivertible.')

print(f'Printing the first 10 diagonal entries: \n{(H_tilde_pp_inv @ H_tilde_pp).diagonal()[:10]}')

# Projection of $(FK)$: $(FK)_{\bot}$ and $(FK)_{K}$

We also need to project $(FK)$ into these two spaces. Remember that $(FK)$ is a linear map from the space of PDF to the space of the data
$$
(FK) : \mathbb{R}^{\textrm{PDF}} \longrightarrow \mathbb{R}^{\textrm{Data}} \,.
$$
The projection is then applied to the right-space only, which is the PDF space. Hence, after projection, the $(FK)$ table can be written as
$$
(FK) = \left( \, (FK)_{K} \hspace{5mm}  (FK)_{\bot} \,\right) \,,
$$
where each $(FK)_{B}$ is a linear map from PDF to data space. Note that $\textrm{ker}(M) = \textrm{ker}(FK)$. Thus, also in this case $(FK)_{K}$. We can check that in the following cell:

In [60]:
#FK, (s_FK, vh_FK) = regularize_matrix(FK, tol=np.finfo(FK.dtype).eps * np.amax(eigvals_M, initial=0.)/ np.amax(s_FK, initial=0.))
FK_p = project_matrix(FK, orth_M, np.eye(FK.shape[0]))
FK_k = project_matrix(FK, ker_M, np.eye(FK.shape[0]))

A naive comparison of the two components $(FK)_K$ and $(FK)_{\bot}$ would not lead to the desired result, as shown in the following cell:

In [None]:
print(f'FK_k ?= 0 : {np.allclose(np.zeros_like(FK_k), FK_k)}')
print(f'FK_p ?= 0 : {np.allclose(np.zeros_like(FK_p), FK_p)}')

The reason being that the tolerance used to construct $\textrm{Ker}(M)$ and $R(M)$ was defined using the largest eigenvalue of the the matrix $M$. When we move to $(FK)$, we then need to scale the tolerance appropriately so that we have consistent results. In that case, we have $M \sim (FK)^2$, and we could expect that $\textrm{tol}_{FK} = \sqrt{\textrm{tol}_M}$. Let's try and see if we get the expected result:


In [None]:
tol =  np.sqrt(np.amax(eigvals_M, initial=0.) * np.finfo(eigvals_M.dtype).eps)
print(f'Comparing with tolerance = {tol}.')
print(f'FK_k ?= 0 : {np.allclose(np.zeros_like(FK_k), FK_k, atol=tol)}')
print(f'FK_p ?= 0 : {np.allclose(np.zeros_like(FK_p), FK_p, atol=tol)}')

We can then set to zeros all those entries that are lower that `tol`.

In [63]:
# @TODO
# Should I do the same FK_p
FK_k[FK_k <= tol] = 0.0

We can reconstruct the full $(FK)$ using these `regularized` components. In particular, we can write
$$
  (FK) = \biggl( (FK)_{\bot} \;, (FK)_K \biggr)
$$

In [64]:
FK_proj = np.hstack((FK_p, FK_k))

As a consistency check, we can verify that $(FK) \Theta (FK)^T = (FK_{\bot}) \Theta_{\bot\bot} (FK_{\bot})^T$:

In [None]:
ntk_pp = project_matrix(ntk, orth_M, orth_M)
ntk_kp = project_matrix(ntk, orth_M, ker_M)
ntk_pk = project_matrix(ntk, ker_M, orth_M)
ntk_kk = project_matrix(ntk, ker_M, ker_M)
ntk_proj = np.block([[ntk_pp, ntk_pk],
                     [ntk_kp, ntk_kk]])
test1 = FK_p @ ntk_pp @ FK_p.T
test2 = FK_proj @ ntk_proj @ FK_proj.T
np.allclose(test1, test2)

Another consistency check is the following
$$
M_{\bot \bot} = (FK)_{\bot}^T C_Y^{-1} (FK)_{\bot}
$$

In [None]:
M_pp_reconstructed = FK_p.T @ Cinv.to_numpy(dtype='float64') @ FK_p
M_pp_recon_inv = np.linalg.inv(M_pp_reconstructed)
#M_pp_reconstructed, _ = regularize_matrix(M_pp_reconstructed, tol=tol)
print(f"M_pp == M_pp_recon: {np.allclose(M_pp, M_pp_reconstructed)}")
print(f"M_pp_recon_inv @ M_pp_recon == Id: {np.allclose(M_pp_reconstructed @ M_pp_recon_inv, np.eye(M_pp.shape[0]))}")
print(f"M_pp_recon_inv == M_pp_inv: {np.allclose(M_pp_recon_inv, M_pp_inv)}")
print(f"M_pp_inv @ M_pp_recon == Id: {np.allclose(M_pp_inv @ M_pp_reconstructed, np.eye(M_pp.shape[0]))}")
print(f"M_pp_recon @ M_pp_inv == Id: {np.allclose(M_pp_reconstructed @ M_pp_inv, np.eye(M_pp.shape[0]))}")

In [None]:
np.linalg.cond(M_pp_reconstructed)

In [None]:
np.linalg.cond(M_pp)

In [None]:
import matplotlib.pyplot as plt
plt.matshow((M_pp_inv - M_pp_recon_inv)) 
plt.colorbar()

However, if we try to invert the reconstructed matrix, and compare the inversion against $M_{\bot\bot}$ previously computed, we get different results:

In [None]:
# I'm not completely sure about that given that we get ~1.36 in
# the diagonal.
M_pp_recon_inv = np.linalg.inv(M_pp_reconstructed)
test = M_pp_recon_inv @ M_pp
print(test.diagonal())

We can to the same with all the other combinations.

In [None]:
M_kp_reconstructed = FK_k.T @ Cinv.to_numpy(dtype='float64') @ FK_p
np.allclose(M_kp, M_kp_reconstructed)

# $\tilde{H_{\epsilon}} = D_Y^{1/2} R_Y^T (FK) \Theta (FK)^T R_Y D_Y^{1/2}$

In [None]:
H_eps_tilde = np.sqrt(D_Y) @ R_Y.T @ FK_proj @ ntk_proj @ FK_proj.T @ R_Y @ np.sqrt(D_Y)
H_eps_tilde, (eigvals_H_eps_tilde, eigvecs_H_eps_tilde) = regularize_matrix(H_eps_tilde, tol=tol)

# Check if symmetric
print(f'Is symmetric: {np.allclose(H_eps_tilde, H_eps_tilde.T)}')

# $b = \Theta (FK)^T C_Y^{-1} y \hspace{5mm} \textrm{and} \hspace{5mm} \tilde{b} = D^{1/2} R^T b$

In [94]:
b = ntk @ FK.T @ Cinv.to_numpy(dtype='float64') @ Y.to_numpy('float64')
b_tilde = np.sqrt(D) @ R.T @ b

# Plots of the eigenvalues

In [None]:
fig, axs = plot_eigvals(eigvals_H_eps_tilde, figsize=(10,8), title=r'$H_{\epsilon} = D^{1/2}_Y R^T_Y (FK) \Theta (FK)^T R_Y D^{1/2}_Y$')

In [None]:
fig, axs = plot_eigvals(eigvals_H_tilde, 
                        figsize=(10,8), 
                        title=r'$\tilde{H}= D^{1/2} R^T \Theta R D^{1/2}$,  $M = RDR^T$')
fig.savefig('../../../doc/figs/Htilde_eigvals.pdf')

In [None]:
fig, axs = plot_eigvals(eigvals_ntk, figsize=(10,8), title='')
axs.set_title(r'Eigenvalues of $\Theta$', fontsize=20)
fig.savefig('../../../doc/figs/ntk_eigvals.pdf')

In [None]:
fig, axs = plot_eigvals(eigvals_M, figsize=(10,8), title=r'Eigenvalues of $M = (FK)^T C_Y^{-1} (FK)$')
#axs.set_title(r'Eigenvalues of $\Theta$', fontsize=20)
#fig.savefig('../../../doc/figs/m_eigvals.pdf')

# Calculation of $f_{\infty}$, $\varepsilon_{\infty}$ and evolution
We now compute the limiting solution $f_{\infty}$, which is the value that minimizes the loss function.
$$
f_{\infty} = \mathbf{M}_{\bot\bot}^{-1} (FK)_{\bot}^T C_Y^{-1} y \,.
$$
We also compute
$$
\varepsilon_{\infty} = y - (FK)_{\bot}f_{\infty} = \biggl(1 - (FK)_{\bot} \mathbf{M}_{\bot\bot}^{-1} (FK)_{\bot}^T C_Y^{-1}\biggr) y\,,
$$
together with the minimum of the loss-function
$$
\mathcal{L}^{*} = \frac{1}{2} y^T C_Y^{-1} \varepsilon_{\infty}

In [99]:
f_inf = M_pp_inv @ FK_p.T @ Cinv.to_numpy(dtype='float64') @ Y.to_numpy(dtype='float64')[:,0]
eps_inf = Y.to_numpy(dtype='float64')[:,0] - FK_p @ f_inf
L_inf = 0.5 * Y.to_numpy(dtype='float64')[:,0].T @ Cinv.to_numpy(dtype='float64') @ eps_inf

We also compute the following quantity, which will be useful the description of the evolution
$$
\begin{align}
& \tilde{\varepsilon}_{\infty} = D_{Y}^{1/2} \, R_Y^T \varepsilon_{\infty} \hspace{5mm} \textrm{where} \hspace{5mm} C_Y^{-1} = R_Y D_Y R_Y^T \\
\end{align}
$$

In [100]:
eps_inf_tilde = np.sqrt(D_Y) @ R_Y.T @ eps_inf

# Check eq.(67)
First we verify that
$$
   (FK)^T C_Y^{-1} (FK)_{\bot} M_{\bot\bot}^{-1} = 1 \,.
$$
Note that the right-most $(FK)$ is the full matrix, while $(FK)_{\bot}$ is the one projected into the orthogonal space. To make things compatible, we need to use the same set of bases for both. If we don't do the same, we get:

In [None]:
wrong_prod = FK_p.T @ Cinv.to_numpy(dtype='float64') @ FK_p @ M_pp_inv
wrong_prod[:wrong_prod.shape[1], :].diagonal()

where we show only the square part of the product. If we instead project the $(FK)$ first, we get

In [None]:
correct_prod = FK_proj.T @ Cinv.to_numpy(dtype='float64') @ FK_p @ M_pp_inv
np.allclose(np.eye(correct_prod.shape[1]), correct_prod[:correct_prod.shape[1], :])

In [None]:
correct_prod[:correct_prod.shape[1], :]

In [None]:
correct_prod.shape[1]

In [None]:
correct_prod.min()

We can then check
$$
(FK)^T C_Y^{-1} - (FK)^T C_Y^{-1} (FK)_{\bot} M_{\bot\bot}^{-1} (FK)_{\bot}^T C_Y^{-1} = 0
$$

In [None]:
first_bit = FK_proj.T @ Cinv.to_numpy(dtype='float64')
second_bit = FK_proj.T @ Cinv.to_numpy(dtype='float64') @ FK_p @ M_pp_inv @ FK_p.T @ Cinv.to_numpy(dtype='float64')
test67 = second_bit - first_bit
np.allclose(np.zeros_like(test67), test67)

We now check eq.(67) of the paper, that is
$$
\tilde{H}_{\varepsilon} \, \tilde{\varepsilon}_{\infty} = 0
$$

In [115]:
res = H_eps_tilde @ eps_inf_tilde

In [None]:
res.min()

# DEPRECATED
-------------
This part needs to be updated...

# Evolution of the data

In [None]:
from functools import lru_cache

# Construct dataframe for predictions
Y = pd.DataFrame(np.zeros(Cinv.shape[0]), index=Cinv.index)
for exp_name, data in central_data_dict.items():
  if data.size == Y.loc[(slice(None), [exp_name], slice(None)), :].size:
    Y.loc[(slice(None), [exp_name], slice(None)), :] = data
  else:
    raise ValueError
  
eps_0 = Y.to_numpy()[:,0] - FK @ f0.flatten()
Ly = (L @ Y).to_numpy()[:,0]
L_eps0 = L @ eps_0

L_eps0_tilde = [np.dot(L_eps0, eigvecs[:,k]) for k in range(eigvecs.shape[1])]
pre_computed_coefficients = [Linv @ eigvecs[:,k] * L_eps0_tilde[k] for k in range(eigvals_reg.size)] 

@lru_cache(maxsize=None)
def preds_t(t, learning_rate = 0.00001, eig_range=None):
  if eig_range is None:
    eig_range = eigvals_reg.size
  predictions = [pre_computed_coefficients[k] * np.exp(-eigvals_reg[k] * learning_rate* t) for k in range(eig_range)] 
  predictions = np.sum(predictions, axis=0)

  predictions = pd.DataFrame(predictions, index=Y.index)
  predictions = Y - predictions
  return predictions

In [None]:
experiments = ['NMC_NC_NOTFIXED_P_EM-SIGMARED', 'SLAC_NC_NOTFIXED_P_EM-F2', 'BCDMS_NC_NOTFIXED_D_EM-F2', 'HERA_NC_318GEV_EM-SIGMARED']
exp_titles = ['NMC', 'SLAC NC P', 'BCDMS NC D', 'HERA NC 318GEV']
y_labels = [r'$\sigma$', r'$F_2$', r'$F_2$', r'$\sigma$']
t = 0.
fig_pred, axes_pred = plt.subplots(2, 2, figsize=(25, 25))  # Adjust figsize for desired plot size
preds = preds_t(t, learning_rate=learning_rate_gd)

scat_gf = []
scat_gd = []
text = []
for i, ax in enumerate(axes_pred.flat):
    y = Y.xs(level='dataset', key=experiments[i]).to_numpy()
    p = preds.xs(level='dataset', key=experiments[i]).to_numpy()
    trained_pred = pred_in_time[0][experiments[i]]
    ax.scatter(np.arange(y.size), y, color='green', label='Central data', marker='o', s=100, alpha=0.4)
    gf = ax.scatter(np.arange(y.size), p, color='orange', label='Analytical solution', marker='^', s=100)
    gd = ax.scatter(np.arange(y.size), trained_pred, color='red', label='Gradient descent', marker='v', s=100)
    scat_gf.append(gf)
    scat_gd.append(gd)
    #ax.set_xlabel(r'$x$')
    ax.set_ylabel(y_labels[i], fontsize=20)
    #ax.set_xscale('log')
    ax.set_title(exp_titles[i], x=0.8,fontsize=20, fontweight='bold')
    ax.legend(fontsize=20)
    text_t = ax.text(0.05, 1.01, f't = {t}, learning rate = {learning_rate_gd}', fontsize=20, transform=ax.transAxes)
    text.append(text_t)


plt.tight_layout()
#fig.savefig('data_evolution.pdf')

In [None]:
experiments = ['NMC_NC_NOTFIXED_P_EM-SIGMARED', 'SLAC_NC_NOTFIXED_P_EM-F2', 'BCDMS_NC_NOTFIXED_D_EM-F2', 'HERA_NC_318GEV_EM-SIGMARED']
exp_titles = ['NMC', 'SLAC NC P', 'BCDMS NC D', 'HERA NC 318GEV']
y_labels = [r'$\sigma$', r'$F_2$', r'$F_2$', r'$\sigma$']
t = 0.
fig_eps, axes_eps = plt.subplots(2, 2, figsize=(25, 25))  # Adjust figsize for desired plot size
preds = preds_t(t, learning_rate=learning_rate_gd)

scat_gf_eps = []
scat_gd_eps = []
text_eps = []
for i, ax in enumerate(axes_eps.flat):
    y = Y.xs(level='dataset', key=experiments[i]).to_numpy()
    p = y - preds.xs(level='dataset', key=experiments[i]).to_numpy()
    trained_pred = y[:,0] - pred_in_time[int(t)][experiments[i]].numpy()
    ax.scatter(np.arange(y.size), y, color='green', label='Central data', marker='o', s=100, alpha=0.4)
    gf = ax.scatter(np.arange(y.size), p, color='orange', label='Analytical solution', marker='^', s=100)
    gd = ax.scatter(np.arange(y.size), trained_pred, color='red', label='Gradient descent', marker='v', s=100)
    scat_gf_eps.append(gf)
    scat_gd_eps.append(gd)
    #ax.set_xlabel(r'$x$')
    ax.set_ylabel(r'$\epsilon$', fontsize=20)
    #ax.set_xscale('log')
    ax.set_title(exp_titles[i], x=0.8,fontsize=20, fontweight='bold')
    ax.legend(fontsize=20)
    text_t = ax.text(0.05, 1.01, f't = {t}, learning rate = {learning_rate_gd}', fontsize=20, transform=ax.transAxes)
    text_eps.append(text_t)


plt.tight_layout()
#fig.savefig('data_evolution.pdf')

In [None]:
def compute_loss_analytical(t, eig_range=None):
  preds = preds_t(t, learning_rate=learning_rate_gd, eig_range=eig_range)
  loss = 0
  ndata = 0
  for exp in Y.index.get_level_values('dataset').unique():
    y = Y.xs(level='dataset', key=exp).to_numpy()
    Cinv_exp = Cinv.xs(level="dataset", key=exp).T.xs(level="dataset", key=exp).to_numpy()
    p = preds.xs(level='dataset', key=exp).to_numpy()
    R = y[:,0] - p[:,0]
    loss += 0.5 * R.T @ Cinv_exp @ R
    ndata += Cinv_exp.shape[0]
  return float(loss) / ndata

def compute_loss_gd(t):
  preds = pred_in_time[int(t)]
  loss = 0
  ndata = 0
  for exp, pred in preds.items():
    y = Y.xs(level='dataset', key=exp).to_numpy()
    Cinv_exp = tf.convert_to_tensor(Cinv.xs(level="dataset", key=exp).T.xs(level="dataset", key=exp).to_numpy(), name=f'Cinv_{exp}', dtype='float32')
    R = tf.convert_to_tensor(y[:,0] - pred, name=f'residue_{exp}', dtype='float32')
    Cinv_R = tf.linalg.matvec(Cinv_exp, R)
    loss += 0.5 * tf.reduce_sum(tf.multiply(R, Cinv_R))
    ndata += Cinv_exp.shape[0]
  return float(loss) / ndata


In [None]:
time_steps_high = np.arange(1000,len(pred_in_time),1000)
time_steps_low = np.arange(0,1000,2)
time_steps = np.concatenate([time_steps_low, time_steps_high])
aloss = [compute_loss_analytical(t, eig_range=100) for t in time_steps]
gd_loss = [compute_loss_gd(t) for t in time_steps]

In [None]:
fig_loss, ax_loss = plt.subplots(figsize=(10, 7))  # Adjust figsize for desired plot size

ax_loss.scatter(time_steps, aloss, label='Analytical solution')
ax_loss.scatter(time_steps, gd_loss, label='Gradient descent')
ax_loss.set_xlabel(r'$t$')
ax_loss.set_ylabel(r'Loss function', fontsize=20)
ax_loss.set_xscale('symlog')
ax_loss.set_title('MSE in function of training time', x=0.5, fontsize=20, fontweight='bold')
ax_loss.legend(fontsize=20)
#text_t = ax.text(0.05, 1.01, f't = {t}, learning rate = {learning_rate_gd}', fontsize=20, transform=ax.transAxes)
#text.append(text_t)


plt.tight_layout()
fig_loss.savefig('Loss_function_time.pdf')

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**128

# Animation function
# Update function for predicitons
def update_preds(t):
    preds = preds_t(t, learning_rate=learning_rate_gd)
    for i, (gf, gd, text_t) in enumerate(zip(scat_gf, scat_gd, text)):
        # Update the y-data for each subplot's line
        y = Y.xs(level='dataset', key=experiments[i]).to_numpy()
        p = preds.xs(level='dataset', key=experiments[i]).to_numpy()
        trained_pred = pred_in_time[int(t)][experiments[i]]
        data_gf = np.hstack(( np.arange(y.size)[:, np.newaxis] , p))
        data_gd = np.hstack(( np.arange(y.size)[:, np.newaxis] , trained_pred[:,np.newaxis]))
        gf.set_offsets(data_gf)  # Example: Add phase shift based on t and subplot index
        gd.set_offsets(data_gd)  # Example: Add phase shift based on t and subplot index
        text_t.set_text(f't = {t}, learning rate = {learning_rate_gd}')
    return scat_gf + scat_gd + text

# Update function for epsilon
def update_eps(t):
    preds = preds_t(t, learning_rate=learning_rate_gd)
    for i, (gf, gd, text_t) in enumerate(zip(scat_gf_eps, scat_gd_eps, text_eps)):
        # Update the y-data for each subplot's line
        y = Y.xs(level='dataset', key=experiments[i]).to_numpy()
        p = preds.xs(level='dataset', key=experiments[i]).to_numpy()
        trained_pred = pred_in_time[int(t)][experiments[i]]
        data_gf = np.hstack(( np.arange(y.size)[:, np.newaxis] , p))
        data_gd = np.hstack(( np.arange(y.size)[:, np.newaxis] , trained_pred[:,np.newaxis]))
        gf.set_offsets(data_gf)  # Example: Add phase shift based on t and subplot index
        gd.set_offsets(data_gd)  # Example: Add phase shift based on t and subplot index
        text_t.set_text(f't = {t}, learning rate = {learning_rate_gd}')
    return scat_gf + scat_gd + text

In [None]:
ani_pred = FuncAnimation(fig_pred, update_preds, frames=np.arange(0, len(pred_in_time), 1000), interval=10, blit=True, cache_frame_data=False)
ani_eps = FuncAnimation(fig_eps, update_eps, frames=np.arange(0, len(pred_in_time), 1000), interval=10, blit=True, cache_frame_data=False)

# Save the animation in the background
ani_pred.save('prediction_evolution.mp4', writer='ffmpeg', fps=20)
ani_eps.save('epsilon_evolution.mp4', writer='ffmpeg', fps=20)