# Principal kernel recursion

We want to explicitly check the validity of Eq. 5.4 of arXiv:2106.10165. To this end we write down a NN with 5 hidden layers, each with nL neurons, with user-specified activation function. Then we extract the value of $K_{00}$ at the 3-rd and 4-th layer and check if they are indeed related by Eq. 5.4 (within error-bars)

In [195]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.optimize import curve_fit

Hyperparameters of the network. n_inputs is the number of inputs, not to be confused with n_samples, which is the number of times we initialize the network, i.e. the dimension of the ensemble of networks we use to measure statistics (like means and variances)

In [196]:
nL = 200
cw = 10
cb = 0.
n_inputs = 100
n_samples = 100
in_size = 200

Network class. Here we have 1 input layer with "in_size" inputs and "nL" outputs, 5 hidden layers with "nL" inputs and "nL" outputs, and one output layer with "nL" inputs and 1 output. We pick the preactivations z3 and z4 from the 3rd and 4th hidden layers

In [197]:
class Simple_NN(nn.Module):
    def __init__(self):
        super(Simple_NN, self).__init__()
        self.firstlayer = nn.Linear(in_features=in_size, out_features=nL)
        self.layer1 = nn.Linear(nL, nL)
        self.layer2 = nn.Linear(nL, nL)
        self.layer3 = nn.Linear(nL, nL)
        self.layer4 = nn.Linear(nL, nL)
        self.layer5 = nn.Linear(nL, nL)
        self.finallayer = nn.Linear(in_features=nL, out_features=1)
        
        # Initialize weights from a Gaussian distribution
        init.normal_(self.firstlayer.weight, mean=0.0, std=np.sqrt(cw/n_inputs))
        init.normal_(self.layer1.weight, mean=0.0, std=np.sqrt(cw/nL))
        init.normal_(self.layer2.weight, mean=0.0, std=np.sqrt(cw/nL))
        init.normal_(self.layer3.weight, mean=0.0, std=np.sqrt(cw/nL))
        init.normal_(self.layer4.weight, mean=0.0, std=np.sqrt(cw/nL))
        init.normal_(self.layer5.weight, mean=0.0, std=np.sqrt(cw/nL))
        init.normal_(self.finallayer.weight, mean=0.0, std=np.sqrt(cw/nL))
        
        # Initialize biases from a Gaussian distribution
        init.normal_(self.firstlayer.bias, mean=0.0, std=cb)
        init.normal_(self.layer1.bias, mean=0.0, std=cb)
        init.normal_(self.layer2.bias, mean=0.0, std=cb)
        init.normal_(self.layer3.bias, mean=0.0, std=cb)
        init.normal_(self.layer4.bias, mean=0.0, std=cb)
        init.normal_(self.layer5.bias, mean=0.0, std=cb)
        init.normal_(self.finallayer.bias, mean=0.0, std=cb)
        
    def forward(self, x):
        z0 = x
        z1 = self.firstlayer(x)
        x = F.relu(z1)
        z2 = self.layer1(x)
        x = F.relu(z2)
        z3 = self.layer2(x)
        x = F.relu(z3)
        z4 = self.layer3(x)
        x = F.relu(z4)
        z5 = self.layer4(x)
        x = F.relu(z5)
        z6 = self.layer5(x)
        x = F.relu(z6)
        x = self.finallayer(x) #the final layer has no Activation function applied to it
        return x, z3, z4

In [198]:
x = torch.randn(n_inputs,in_size)

In [199]:
def g(k):
    x = torch.linspace(-100,100,1000)
    norm = torch.sqrt(torch.tensor(2*np.pi*k))
    return ((sum((F.relu(x)**2)*np.exp(-x**2/(2*k)))/norm)/1000)*200

Since nL is pretty large, as a first approximation we take the kernel K, i.e. the first order approximation of the metric, to be the full metric. This is correct up to O(1/nL) corrections

Without loss of generality we consider the 0-th input, compute the average of the kernel at the 3rd and 4th layer over n_samples initializations of the network, and check if it they are related by eq 5.4

In [200]:
k3vec = []
k4vec = []
for _ in range(n_samples):
    model = Simple_NN()
    xout, z3, z4 = model(x)
    k3vec.append(sum((z3[0].detach())**2)/nL)
    k4vec.append(sum((z4[0].detach())**2)/nL)

k3 = np.mean(k3vec)
k4 = np.mean(k4vec)
k4exp = cb + cw * g(k3)
print(k4/k4exp)

tensor(0.9603, dtype=torch.float64)
