<a href="https://colab.research.google.com/github/FabianoVeglianti/temporaryCodeRepository/blob/main/kernel_fn_stacking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
#import seaborn as sns
from scipy.stats import norm
import numpy as np
from scipy import stats

In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.6/149.6 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.2/86.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m80.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m68.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for neural-tangents (setup.py) ... [?25l[?25hdone
[0m

In [None]:
import jax.numpy as jnp
import math
from jax import random
from jax.example_libraries import optimizers
from jax import jit, grad, vmap
import itertools

import functools

import neural_tangents as nt
from neural_tangents import stax

In [None]:
# Initialization of PRNG's keys
seed=1
rootKey = random.PRNGKey(seed)
keys = random.split(rootKey, 3)
rootKey = keys[0]
initKey = keys[1]
dataKey = keys[2]

In [None]:
def mlp_neural_tangents(
    Cw_sqrt,
    Cb_sqrt,
    n_list,
    L,
    parameterization
):
    """Construct fully connected NN model and infinite width NTK & NNGP kernel
       function.

    Args:
        Cw_sqrt è la std dei pesi W
        Cb_sqrt è la std dei bias b
        L è il numero di hidden layers
        n_list è una lista di L interi. n_list[i] è il numero di neuroni allo strato (i+1)-esimo - i neuroni allo strato 0 non vanno specificati
        parameterization (string): Parameterization string, 'ntk' or 'standard'.
        Dettagli riguardo la parametrizzazione in: https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Dense.html#neural_tangents.stax.Dense

    Returns:
        `(init_fn, apply_fn, kernel_fn)`
    """
    print(n_list)
    layers_list = [stax.Dense(n_list[0], Cw_sqrt, Cb_sqrt, parameterization=parameterization)]

    def layer_block(width):
        return stax.serial(stax.Relu(), stax.Dense(width, Cw_sqrt, Cb_sqrt, parameterization=parameterization))
 
    for i in range(1, L+1):
        layers_list += [stax.Relu(), stax.Dense(n_list[i], Cw_sqrt, Cb_sqrt, parameterization=parameterization)]

    # for i in range(L-1):
    #     print(i)
    #     print(n_list)
    #     layers_list += [layer_block(n_list[i+1])]

    # layers_list += [stax.Relu(), stax.Dense(n_list[L], Cw_sqrt, Cb_sqrt, parameterization=parameterization)]

    # print (f"---- layer list is {layers_list} ------")

    init_fn, apply_fn, kernel_fn = stax.serial(*layers_list)

    apply_fn = jit(apply_fn)

    return init_fn, apply_fn, kernel_fn, layers_list  #apply_fn is f_hat

In [None]:
keys = random.split(rootKey, 2)
rootKey = keys[0]
[x1, x2] = random.uniform(keys[1], minval=-1, maxval = 1, shape=(2,16), dtype=jnp.float64)

In [None]:
def get_nn_layers_functions(layer_list):
  layers_functions = {}
  t = list(zip(*layer_list))
  layers_functions["init_fns"] = t[0]
  layers_functions["apply_fns"] = t[1]
  layers_functions["kernel_fns"] = t[2]
  return layers_functions

In [None]:
def split_keys(rootKey, n):
  if(n < 1 or n%1!=0):
    print("n deve essere >= 1 e intero")
  keys = random.split(rootKey, n+1)
  return keys[0], keys[1:]

In [None]:
def reinitialize_MLP_layer_l(rootKey, l, layers_functions, params_dict, x_alpha, width):
  rootKey, key = split_keys(rootKey,2)
  shape = 0
  if l == 0:
    shape = x_alpha.shape
  else:
    shape = width
  _, params_layerl = layers_functions["init_fns"][2*l+0](key[0], shape)
  _, params_layerl_plus_1 = layers_functions["init_fns"][2*l+1](key[1], shape)
  params_dict[2*l] = params_layerl
  params_dict[2*l+1] = params_layerl_plus_1

  return params_layers

In [None]:
def create_k_0_list(layers_functions, x_alpha):
  k_0_list = []
  for l in range(len(layers_functions)):
    if l == 0:
      k_0_list.append((layers_functions["kernel_fns"][l](x_alpha.reshape(1,16), x_alpha.reshape(1,16)).nngp)[0])
    else:
      k_0_list

In [24]:
### Test Cell
L = 3
rootKey, key = split_keys(rootKey,3)
n_list = [8, 4, 2, 1]
init_fn, apply_fn, kernel_fn, layer_list = mlp_neural_tangents(Cw_sqrt = 1, Cb_sqrt = 1, n_list = n_list, L = L, parameterization = "standard")
layers_functions = get_nn_layers_functions(layer_list)
print((kernel_fn(x1.reshape((1,16)), x1.reshape((1,16))).nngp)[0])
k_0_list = []
kernel_0_ouptut = layers_functions["kernel_fns"][0](x1.reshape(1,16), x1.reshape(1,16)).nngp
print(kernel_0_ouptut[0])

kernel_1_ouptut = layers_functions["kernel_fns"][1](kernel_0_ouptut, kernel_0_ouptut).nngp
print(kernel_1_ouptut[0])




[8, 4, 2, 1]
[1.9153962]
[1.3231703]


ValueError: ignored