In [30]:
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
device = 'cuda'

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, (3, 3))
        self.conv2 = nn.Conv2d(32, 32, (3, 3))
        self.conv3 = nn.Conv2d(32, 32, (3, 3))
        self.fc = nn.Linear(21632, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = x.relu()
        x = self.conv2(x)
        x = x.relu()
        x = self.conv3(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x

In [31]:
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)

In [32]:
net = CNN().to(device)
params = {k: v for k, v in net.named_parameters()}

def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

In [33]:
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1.values()]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2.values()]
    
    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

In [34]:
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_train)
print(result.shape)

torch.Size([20, 20, 10, 10])


In [38]:
print(result[:, :, 0, 0]) 

tensor([[179.8362,  72.3862,  70.1874,  74.3949,  78.1520,  74.7145,  70.2342,
          75.4002,  80.2883,  73.1794,  75.9609,  75.0629,  71.5015,  70.7151,
          76.9346,  76.7645,  76.9869,  74.9296,  72.3473,  76.0309],
        [ 72.3862, 167.8784,  73.1395,  73.4909,  77.0863,  76.3766,  72.3521,
          73.5727,  75.2864,  76.4589,  72.2477,  77.4665,  69.6422,  70.1698,
          74.4881,  71.4884,  71.9003,  74.5763,  74.6571,  72.6125],
        [ 70.1874,  73.1395, 168.1223,  71.7968,  72.9266,  72.2910,  72.7452,
          73.5812,  74.1065,  72.1098,  71.9331,  73.6503,  73.7546,  72.1332,
          75.8976,  70.5649,  70.9312,  71.9232,  70.1148,  73.4586],
        [ 74.3949,  73.4909,  71.7968, 175.6412,  72.5392,  73.9358,  72.4234,
          71.8310,  75.7885,  73.3708,  73.5974,  78.3094,  70.6159,  74.5057,
          72.3343,  75.8209,  74.1976,  74.9399,  73.1699,  72.0072],
        [ 78.1520,  77.0863,  72.9266,  72.5392, 179.3170,  75.4155,  72.6700,
         