# NTK Derivation and Analysis

In this Notebook, we will derive a closed form for the NTK for 1 hidden layer ReLU networks.  We will then present experiments to show that the NTK can be used to describe the behavior of large width neural networks.  

We begin with a derivation of the NTK below (this is basically the solution to Section 2 Problem 2 of the worksheet shared in the github).  

## Derivation of the NTK 

Suppose we are given a dataset $\{(x^{(i)}, y^{(i)}\}_{i=1}^{n} \subset \mathbb{R}^{d} \times \mathbb{R}$ (also written as $X \in \mathbb{R}^{d \times n}, y \in \mathbb{R}^{1 \times n}$).  Let $f$ denote a 1 hidden layer neural network with parameters $\mathbf{W}$.  To train the neural network $f$ to fit the data $(X, y)$, we typically use gradient descent to minimize the following loss: 

\begin{align*}
\mathcal{L}(\mathbf{W}) = \sum_{i=1}^{n} (y^{(i)} - f(\mathbf{W} ; x^{(i)}))^2 
\end{align*}

**Important:** Note that the network $f$ is written as a function of parameters $\mathbf{W}$ and data $x^{(i)}$, as opposed to just data.  For the neural tangent kernel derivation, we consider the cross section of $f$ given by fixing the data component and writing the neural network as a function of parameters, i.e. consider $f_x(\mathbf{W}): \mathbb{R}^{dk + k} \to \mathbb{R}$.  

### Linearization around Initialization
Before training the network as usual, let us consider the following alternative.  Viewing the neural network as only a function of parameters, we train the linear approximation for $f_x(\mathbf{W})$, which is given as follows: 

\begin{align*}
\tilde{f_x}(\mathbf{W}) = f_x(\mathbf{W}^{(0)}) + \nabla f_x(\mathbf{W}^{(0)})^T (\mathbf{W} - \mathbf{W}^{(0)}) ~;
\end{align*}
where $\mathbf{W}^{(0)} \in \mathbb{R}^{dk + k}$ denotes the parameters at initialization and $\nabla f_x(\mathbf{W}^{(0)})^T \in \mathbb{R}^{1 \times (dk + k)}$ denotes the gradient of $f_x(\mathbf{W})$.  Instead of minimizing the loss for the model $f(\mathbf{W} ; x^{(i)})$ given above, we instead minimize the following loss: 

\begin{align*}
    \tilde{\mathcal{L}}(\mathbf{W}) = \sum_{i=1}^{n} (y^{(i)} - \tilde{f}_{x^{(i)}}(\mathbf{W}))^2 = \sum_{i=1}^{n} (y^{(i)} - f_{x^{(i)}}(\mathbf{W}^{(0)}) - \nabla f_{x^{(i)}}(\mathbf{W}^{(0)})^T (\mathbf{W} - \mathbf{W}^{(0)}))^2
\end{align*}

Minimizing this loss naively can be computationally expensive since the vector $\mathbf{W} \in \mathbb{R}^{kd + k}$  depends on $k$, which can be arbitrarily large.  To remedy this, we let $\mathbf{W} = \mathbf{W}^{(0)} + \sum_{i=1}^{n} \nabla f_{x^{(i)}}(\mathbf{W}^{(0)})\alpha_i$. 


**Remark:** At this point, you should be asking why this is a reasonable step to take.  The rationale for this step is that we can use this to find the minimum norm minizimer, which lies in the span of the training data.  If you haven't seen this trick before, I encourage you to review the Representer theorem.  

Using the new form for $\mathbf{W}$, we can simplify our loss $\tilde{\mathcal{L}}(\mathbf{W})$ as follows: 
\begin{align*}
\tilde{\mathcal{L}}(\mathbf{W}) = \sum_{i=1}^{n} (y^{(i)} - f_{x^{(i)}}(\mathbf{W}^{(0)}) - \alpha k(x^{(i)}) )^2 ~;
\end{align*}
where $\alpha \in \mathbb{R}^{1 \times n}$ and $$k(x) = \begin{bmatrix} \langle \nabla f_{x}(\mathbf{W}^{(0)}),  \nabla f_{x^{(1)}}(\mathbf{W}^{(0)}) \rangle \\ \langle \nabla f_{x}(\mathbf{W}^{(0)}),  \nabla f_{x^{(2)}}(\mathbf{W}^{(0)}) \rangle \\ \vdots \\ \langle \nabla f_{x}(\mathbf{W}^{(0)}),  \nabla f_{x^{(n)}}(\mathbf{W}^{(0)}) \rangle  \end{bmatrix} \in \mathbb{R}^{n}$$

We can now recognize minimizing the loss $\tilde{\mathcal{L}}(\mathbf{W})$ as solving the following system of equations: 
\begin{align*}
 \alpha K = y - f_X(\mathbf{W}^{(0)}) ~;
\end{align*}
where $K \in \mathbb{R}^{n \times n}$ with $K_{i,j} = \langle \nabla f_{x^{(i)}}(\mathbf{W}^{(0)}),  \nabla f_{x^{(j)}}(\mathbf{W}^{(0)}) \rangle$ and $f_X(\mathbf{W}^{(0)}) \in \mathbb{R}^{1 \times n}$ with $f_X(\mathbf{W}^{(0)})_i = f_{x^{(i)}}(\mathbf{W}^{(0)})$.  

**Definition [NTK]:** The function $K_{i,j}$ above is written generally as the following Neural Tangent Kernel:
$$ K(x, x') = \langle \nabla f_{x}(\mathbf{W}^{(0)}), \nabla f_{x'}(\mathbf{W}^{(0)}) \rangle $$.  

**Remarks:** This kernel can of course be evaluated using any auto-differentition software (e.g. PyTorch, Tensorflow, Jax, etc.).  This is generally memory (and runtime) expensive since neural networks can have millions or billions of parameters.  On the other hand, we can actually analytically compute the kernel $K$ when the width of neural networks approaches infinity.  We do this below. 


### Analytical Evaluation of the NTK (1 Hidden Layer,)
Thus far, we have defined the NTK without explicitly computing it for a given architecture.  We now write a closed form for the NTK given a specific archticture.  In particular, let $f$ denote a 1 hidden layer network defined as follows: 
\begin{align*}
    f(\mathbf{W} ; x) = a \frac{\sqrt{c}}{\sqrt{k}} \phi(Bx) ~;
\end{align*}
where $a \in \mathbb{R}^{1 \times k}, B \in \mathbb{R}^{k \times d}$ are the trainable parameters ($\mathbf{W} = [a_1, a_2, \ldots a_k, B_{1,1}, B_{1,2}, \ldots B_{k, d}]^T \in \mathbb{R}^{k + dk}$ denotes the vector containing all trainable parameters), $c \in \mathbb{R}$ is an absolute constant, and $\phi: \mathbb{R} \to \mathbb{R}$ is an elementwise nonlinearity.  

Let us now compute the NTK $K(x, x') = \langle \nabla f_{x}(\mathbf{W}^{(0)}), \nabla f_{x'}(\mathbf{W}^{(0)}) \rangle$ as $k \to \infty$ assuming that $\mathbf{W}_j^{(0)} \overset{i.i.d.}{\sim} \mathcal{N}(0, 1)$.  Letting $\mathbf{W} = [a_1, a_2, \ldots a_k, B_{1,1}, B_{1,2}, \ldots B_{k, d}] $, we compute $\nabla f_{x}(\mathbf{W}^{(0)})$ as follows: 

\begin{align*}
    \nabla f_{x}(\mathbf{W}) = \begin{bmatrix}\frac{\partial f_{x}}{\partial a_1}  \\ \frac{\partial f_{x}}{\partial a_2} \\ \vdots \\ \frac{\partial f_{x}}{\partial a_k} \\ \frac{\partial f_{x}}{\partial B_{1,1}} \\ \vdots \\ \frac{\partial f_{x}}{\partial B_{k, d}}
   \end{bmatrix}
\end{align*}

We thus first calculate $\frac{\partial f_{x}}{\partial a_j}$ and $\frac{\partial f_{x}}{\partial B_{j, \ell}}$: 
\begin{align*}
    \frac{\partial f_{x}}{\partial a_j} = \frac{\sqrt{c}}{\sqrt{k}} \phi(B_{j, :}  x) \\
    \frac{\partial f_{x}}{\partial B_{j, \ell}} = a_j \frac{\sqrt{c}}{\sqrt{k}} \phi'(B_{j,:}x) x_{\ell}    
\end{align*}

Now that we have all the relevant terms to compute $\nabla f_x(\mathbf{W}^{(0)})$, we can compute $K(x, x')$ as follows: 
\begin{align*}
    K(x, x') &= \langle \nabla f_{x}(\mathbf{W}^{(0)}), \nabla f_{x'}(\mathbf{W}^{(0)}) \rangle \\
    &= \sum_{j=1}^{k} \frac{\partial f_x(\mathbf{W}^{(0)})}{\partial a_j} \frac{\partial f_{x'}(\mathbf{W}^{(0)})}{\partial a_j} + \sum_{j=1}^{k} \sum_{\ell = 1}^{d} \frac{\partial f_x(\mathbf{W}^{(0)})}{\partial B_{j, \ell}} \frac{\partial f_{x'}(\mathbf{W}^{(0)})}{\partial B_{j, \ell}} \\
    &= \color{red}{\text{$\frac{c}{k} \sum_{j=1}^{k}  \phi(B_{j, :}  x) \phi(B_{j, :}  x')$}} ~ + ~ \color{blue}{\text{$\frac{c}{k} \sum_{j=1}^{k} \sum_{\ell=1}^{d} a_j^2  \phi'(B_{j, :}  x) \phi'(B_{j, :}  x') x_{\ell} x'_{\ell}$}}  \\
    &= \color{red}{\text{$\frac{c}{k} \sum_{j=1}^{k}  \phi(B_{j, :}  x) \phi(B_{j, :}  x')$}} ~ + ~ \color{blue}{\text{$\frac{c}{k} \sum_{j=1}^{k}  a_j^2 \phi'(B_{j, :}  x) \phi'(B_{j, :}  x')  \sum_{\ell=1}^{d} x_{\ell} x'_{\ell}$}}  \\
    &= \color{red}{\text{$\frac{c}{k} \sum_{j=1}^{k}  \phi(B_{j, :}  x) \phi(B_{j, :}  x')$}} ~ + ~ \langle x, x' \rangle \color{blue}{\text{$\frac{c}{k} \sum_{j=1}^{k}  a_j^2 \phi'(B_{j, :}  x) \phi'(B_{j, :}  x') $}} 
\end{align*}

**Remark:** Do the red and blue terms look familiar? If you worked through the notebook *DoubleDescentTutorial*, they should.  Indeed, as $k \to \infty$, the terms in the red and blue correspond to the NNGP kernel for a network with activation $\phi$ and $\phi'$ respectively.  We know how to evaluate these using dual activations.  Namely, we have: 
\begin{align*}
    \color{red}{\text{$\frac{c}{k} \sum_{j=1}^{k}  \phi(B_{j, :}  x) \phi(B_{j, :}  x')$}} &\to c \mathbb{E}_{(u, v) \sim \mathcal{N}(\mathbf{0}, \Lambda)} [\phi(u) \phi(v) ] \\
    \color{blue}{\text{$\frac{c}{k} \sum_{j=1}^{k}  a_j^2 \phi'(B_{j, :}  x) \phi'(B_{j, :}  x') $}} &\to c \mathbb{E}_{(u, v) \sim \mathcal{N}(\mathbf{0}, \Lambda)} [\phi'(u) \phi'(v)] \\
    \Lambda &= \begin{bmatrix} \|x\|_2^2  & x^T x' \\ x^T x' & \|x'\|_2^2 \end{bmatrix}
\end{align*}

Let $\xi = x^T x'$ and $\check{\phi}$ denote the dual of $\phi$.  Assuming $\phi$ is homogeneous of degree 1 and that $\|x\|_2 = \|x'\|_2 = 1$ we conclude: 
\begin{align*}
    K(x, x') = \check{\phi}(\xi) + \xi \check{\phi'}(\xi)
\end{align*}

Recalling that the dual activation is computed in closed form for a number of nonlinearities including ReLU, we now have a closed form for the NTK.  Next, let's try training some simple neural networks to verify that the NTK does describe the training dynamics of large neural networks. 

## Training Neural Nets vs. Using the NTK

In [170]:
# Loading high dimensional linear data
import dataloader as dl
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
%matplotlib inline

SEED = 2134

np.random.seed(SEED)
d = 100
n = 32
n_test = 100

X = np.random.randn(n, d)
X = X / norm(X, axis=-1).reshape(-1, 1)
X_test = np.random.randn(n_test, d)
X_test = X_test / norm(X_test, axis=-1).reshape(-1, 1)
w = np.random.randn(1, d)
y = (w @ X.T).T
y_test = (w @ X_test.T).T
print(X.shape, y.shape, X_test.shape, y_test.shape)

(32, 100) (32, 1) (100, 100) (100, 1)


In [171]:
## We now need to define and train a neural network to map x^{(i)} to y^{(i)}
import torch
import torch.nn as nn
import torch.nn.functional as F

# Abstraction for nonlinearity 
class Nonlinearity(torch.nn.Module):
    
    def __init__(self):
        super(Nonlinearity, self).__init__()

    def forward(self, x):
        # return F.leaky_relu(x)
        return F.relu(x)
    
class Net(nn.Module):

    def __init__(self, width, f_in):
        super(Net, self).__init__()

        self.k = width
        self.first = nn.Sequential(nn.Linear(f_in, self.k, bias=True), 
                                   Nonlinearity())
        self.sec = nn.Linear(self.k, 1, bias=False)

    def forward(self, x):
        #C = np.sqrt(2/(.01**2 + 1)) * 1/np.sqrt(self.k)
        C = np.sqrt(2/self.k)
        o = self.first(x) * C
        return self.sec(o)

In [185]:
import torch.optim as optim
from copy import deepcopy
from auto_tqdm import tqdm

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

eps = 1e-10  # Threshold for stopping training

# Moving the data to GPU
X_t = torch.from_numpy(X).cuda()
y_t = torch.from_numpy(y).cuda()

X_test_t = torch.from_numpy(X_test).cuda()
y_test_t = torch.from_numpy(y_test).cuda()

n, d = X.shape

widths = [16000]
test_errors = []
best_test_errors = []
networks = []

for width in widths:

    # Create our network
    net = Net(width, d)
    
    # Initialize the parameters i.i.d. from a standard normal 
    for idx, param in enumerate(net.parameters()):
        print(param.size())
        init = torch.Tensor(param.size()).normal_()
        param.data = init
    net.double()
    net_0 = deepcopy(net)
    net.cuda()

    # Training neural network with GD
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-2)

    epochs = 100000

    best_test_error = float("inf")
    for i in tqdm(range(epochs), total=epochs):
        net.zero_grad()
        preds = net(X_t)
        loss = torch.mean(torch.pow(preds - y_t, 2))
        loss.backward()
        optimizer.step()
        loss = loss.cpu().data.numpy()
        if loss < 1e-10: 
            break 
        if i % 1000 == 0:
            print("Epoch ", i, " Loss: ", loss)
        test_error = torch.mean(torch.pow(net(X_test_t) - y_test_t, 2)).cpu().data.numpy()
        best_test_error = min(best_test_error, test_error)

    print("Finished Width: ", width, "End Train Error: ", loss, "End Test Error: ", test_error, 
          "Best Test Error: ", best_test_error)
    best_test_errors.append(best_test_error)
    test_errors.append(test_error)
    networks.append(deepcopy(net.cpu()))
print(test_errors)

torch.Size([16000, 100])
torch.Size([16000])
torch.Size([1, 16000])


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch  0  Loss:  2.73050878549734
Epoch  1000  Loss:  0.09766805617471178
Epoch  2000  Loss:  0.012308375464602797
Epoch  3000  Loss:  0.0021917035476749203
Epoch  4000  Loss:  0.0004582773470279824
Epoch  5000  Loss:  0.00010388931426187457
Epoch  6000  Loss:  2.4736101571105667e-05
Epoch  7000  Loss:  6.091727017858584e-06
Epoch  8000  Loss:  1.5372611678632235e-06
Epoch  9000  Loss:  3.950678705827788e-07
Epoch  10000  Loss:  1.0296731246989766e-07
Epoch  11000  Loss:  2.7142207903909124e-08
Epoch  12000  Loss:  7.2215891427776104e-09
Epoch  13000  Loss:  1.9368270254195775e-09
Epoch  14000  Loss:  5.231585017011714e-10
Epoch  15000  Loss:  1.422313385276015e-10
Finished Width:  16000 End Train Error:  9.990888168618902e-11 End Test Error:  1.0644498390149257 Best Test Error:  1.0644498390149257
[array(1.06444984)]


In [186]:
## Compute Empirical NTK
from auto_tqdm import tqdm

preds = net_0(X_t.cpu())
n, d = X.shape

l = sum(p.numel() for p in net.parameters())

K = np.zeros((n, n))
all_grads = np.zeros((n, l))

preds = net_0(X_t.cpu())    

for idx in tqdm(range(n), total=n):
    net_0.zero_grad()
    p = net_0(X_t.cpu())[idx]
    p.backward()
    grads = [q.grad.flatten() for q in net_0.parameters()]
    grads = torch.cat(grads)
    all_grads[idx,:] = deepcopy(grads.numpy())


for i in tqdm(range(n), total=n):
    for j in range(n):
        K[i, j] = np.sum(all_grads[i, :] * all_grads[j, :])
K_empirical = K


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

In [187]:
from numpy.linalg import pinv

K_inv = pinv(K)
sol = y.T @ K_inv

n_test, d = X_test.shape
K_test = np.zeros((n, n_test))
for j in tqdm(range(n_test), total=n_test):
    net_0.zero_grad()
    p = net_0(X_test_t.cpu())[j]
    p.backward()
    grads = [q.grad.flatten() for q in net_0.parameters()]
    grads = torch.cat(grads)
    for i, g in enumerate(all_grads):
        K_test[i, j] = np.sum(g * grads.numpy())

f_X = net_0(X_t.cpu()).data.numpy()
f_x = net_0(X_test_t.cpu()).data.numpy()

empirical_pred_test_corrected = (y.T - f_X.T) @ (K_inv @ K_test) + f_x.T
empirical_pred_test = y.T @ K_inv @ K_test

  0%|          | 0/100 [00:00<?, ?it/s]

In [189]:
from numpy.linalg import solve

def mse(preds, labels): 
    return np.mean(np.abs(np.power(preds - labels, 2)))

# Infinite Width Random ReLU Feature Regression
def ntk(pair1, pair2):

    out = pair1 @ pair2.transpose(1, 0) + 1
    N1 = np.sum(np.power(pair1, 2), axis=-1).reshape(-1, 1) + 1
    N2 = np.sum(np.power(pair2, 2), axis=-1).reshape(-1, 1) + 1

    XX = np.sqrt(N1 @ N2.transpose(1, 0))
    out = out / XX

    out = np.clip(out, a_min=-1, a_max=1)

    first = 1/np.pi * (out * (np.pi - np.arccos(out)) \
                           + np.sqrt(1. - np.power(out, 2))) * XX
    sec = 1/np.pi * out * (np.pi - np.arccos(out)) * XX
    out = first + sec
    return out 

# Build kernel matrix for train & test data
K_train = ntk(X, X) #+ np.eye(X.shape[0])*.0001
K_test = ntk(X, X_test)

# Solve kernel regression
K_inv = pinv(K_train) 
a_hat = y.T @ K_inv 

print("Theoretical NTK: \n", K_train)

print("Empirical NTK: \n", K_empirical)

theoretical_pred_test = a_hat @ K_test

# Get error on train & test data
train_error = mse(a_hat @ ntk(X,X), y.T)
test_error =mse(a_hat @ K_test, y_test.T)
print(train_error, test_error)

Theoretical NTK: 
 [[4.         1.79555843 1.64980051 ... 1.94070056 1.84797068 2.06689513]
 [1.79555843 4.         1.81279459 ... 1.90982142 2.0879102  1.97895885]
 [1.64980051 1.81279459 4.         ... 2.05568052 1.9516431  2.14718829]
 ...
 [1.94070056 1.90982142 2.05568052 ... 4.         1.87065453 2.24250602]
 [1.84797068 2.0879102  1.9516431  ... 1.87065453 4.         1.65249391]
 [2.06689513 1.97895885 2.14718829 ... 2.24250602 1.65249391 4.        ]]
Empirical NTK: 
 [[4.0561158  1.84060637 1.68579264 ... 1.98192256 1.8430595  2.13493269]
 [1.84060637 4.06273434 1.86152534 ... 1.94122551 2.10153732 2.05116363]
 [1.68579264 1.86152534 4.03538226 ... 2.10823912 1.9714356  2.21346735]
 ...
 [1.98192256 1.94122551 2.10823912 ... 4.07341524 1.86327991 2.31460805]
 [1.8430595  2.10153732 1.9714356  ... 1.86327991 3.97119779 1.67430362]
 [2.13493269 2.05116363 2.21346735 ... 2.31460805 1.67430362 4.12117518]]
1.2598217952903903e-29 0.48437464815921794


In [190]:
num = 5
print(theoretical_pred_test[0,:num])
print(empirical_pred_test[0, :num])
print(mse(empirical_pred_test, theoretical_pred_test))

[-0.1682423  -0.05881645  0.02397388  0.1342176  -0.34087719]
[-0.1874467  -0.06801235 -0.01552141  0.18810081 -0.3450094 ]
0.0005990180311231504


In [191]:
gt_preds = net(X_test_t.cpu()).data.numpy().T
f_X = net_0(X_t.cpu()).data.numpy()


def theoretical_pred(x):
    k_x = ntk(X, x)
    f_x = net_0(torch.from_numpy(x)).data.numpy().T
    return (y.T  - f_X.T) @ (K_inv @ k_x) + f_x

theoretical_pred_test_corrected = theoretical_pred(X_test)
num = 5

print("Infinite width: \n", theoretical_pred_test_corrected[0, :num])
print("Empirical NTK Corrected: \n", empirical_pred_test_corrected[0, :num])
print("Ground Truth: \n", gt_preds[0, :num])
print("Error between trained net and NTK: ", 
      mse(theoretical_pred_test_corrected, gt_preds))
print("Error between trained net and empirical NTK", 
      mse(empirical_pred_test_corrected,gt_preds))

Infinite width: 
 [-0.81182529  0.79999013  0.11211074 -0.46587294  0.49494158]
Empirical NTK Corrected: 
 [-0.80058685  0.79097463  0.03352433 -0.40223674  0.48549821]
Ground Truth: 
 [-0.80659439  0.78833474  0.03661104 -0.41004607  0.48631202]
Error between trained net and NTK:  0.0011547295683387625
Error between trained net and empirical NTK 1.4939568621191908e-05
