# This is resnet NTK trial

In [1]:
from backbone import ResNet101, ResNet34
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
import time
device = 'cuda'
if device=='cuda':
    torch.cuda.empty_cache()
    
    
class simple_netC_0hl(nn.Module):
    def __init__(self):
        super(simple_netC_0hl, self).__init__()
        self.layer1 = nn.Linear(512, 5)
        
    def forward(self, x):
        out = self.layer1(x)
        return out

print(type(simple_netC_0hl))
print(type(simple_netC_0hl()))

<class 'type'>
<class '__main__.simple_netC_0hl'>


In [10]:
import backbone

combined_Conv3 = backbone.CombinedNetwork(backbone.Conv3(), nn.Linear(1764,1))
print([p.numel() for p in combined_Conv3.parameters()])

x1 = torch.randn(5, 3, 84, 84)

def compute_jacobian(net, inputs):   # i is the class label, and corresponds to the output targeted
    """
    Return the jacobian of a batch of inputs, thanks to the vmap functionality
    """
    net.zero_grad()
    params = {k: v for k, v in net.named_parameters()}

    def fnet_single(params, x):
        # Make sure output has the right dimensions
        return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)

    jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
    jac_values = jac.values()

    reshaped_tensors = []
    for j in jac_values:
        if len(j.shape) == 3:  # For layers with weights
            # Flatten parameters dimensions and then reshape
            flattened = j.flatten(start_dim=1)  # Flattens to [batch, params]
            reshaped = flattened.T  # Transpose to align dimensions as [params, batch]
            reshaped_tensors.append(reshaped)
        elif len(j.shape) == 2:  # For biases or single parameter components
            reshaped_tensors.append(j.T)  # Simply transpose

    # Concatenate all the reshaped tensors into one large matrix
    return torch.cat(reshaped_tensors, dim=0).T




[972, 36, 11664, 36, 11664, 36, 8820, 5]


In [5]:
resnet = ResNet34()

print(f"Number of params in network : {sum(p.numel() for p in resnet.parameters() if p.requires_grad_)}")
resnet.eval()
print(f"Number of params in network : {sum(p.numel() for p in resnet.parameters() if p.requires_grad_)}")
resnet.train()
print(f"Final dimension of the network : {resnet.final_feat_dim}")
# print(resnet.trunk)

Number of params in network : 21284672
Number of params in network : 21284672
Final dimension of the network : 512


In [3]:
class CombinedNetwork(nn.Module):
    def __init__(self, net1, net2):
        super(CombinedNetwork, self).__init__()
        self.networks = nn.Sequential(
            net1,
            net2
        )
    
    def forward(self, x):
        return self.networks(x)

combined_net = CombinedNetwork(resnet, simple_netC_0hl())
print(type(combined_net) == CombinedNetwork)
print(f"Number of params in network : {sum(p.numel() for p in combined_net.parameters())}")

True
Number of params in network : 21287237


In [4]:
start_time = time.time()

def compute_ntk(model, x1, x2):
    # Ensure gradients are enabled
    x1.requires_grad_(True)
    x2.requires_grad_(True)

    # Forward pass
    y1 = model(x1)
    y2 = model(x2)

        # Initialize NTK value
    ntk = 0.0

    # Compute the NTK for each output element
    # print(y1.shape)
    for j in range(y1.shape[1]):  # Assuming y1 and y2 are of shape (batch_size, num_classes)
        grad_y1 = torch.autograd.grad(y1[:, j].sum(), model.parameters(), retain_graph=True, create_graph=True)
        grad_y2 = torch.autograd.grad(y2[:, j].sum(), model.parameters(), retain_graph=True, create_graph=True)
        # print(sum([g1.numel() for g1 in grad_y1]))
        # print(grad_y1[0].shape)
        # print(grad_y2[0].shape)
        # print(grad_y1[1].shape)
        # print(grad_y2[1].shape)
        # Compute the NTK (dot product of gradients) for the current output element
        ntk += sum((g1 * g2).sum() for g1, g2 in zip(grad_y1, grad_y2))
    print(ntk, ntk.item())
    return ntk.item()

# Example usage
input1 = torch.randn([5*17, 3, 224, 224])
input2 = torch.randn([5*17, 3, 224, 224])

ntk_value = compute_ntk(combined_net, input1, input2)
print("NTK value:", ntk_value)
print(f"Total time : {time.time()-start_time}")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (85x2048 and 512x5)

In [5]:
def print_memory_usage():
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
    print(f"Cached: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB\n")

In [None]:
start_time = time.time()

def compute_jacobian_autodiff(net, inputs, c):
    """
    Return the jacobian of a batch of inputs, using autodifferentiation
    """
    inputs.requires_grad_(True)
    outputs = net(inputs)
    N = sum(p.numel() for p in net.parameters())
    jac = torch.empty(outputs.size(0), N)
    for j in range(outputs.size(0)):
        # print(j)
        grad_y1 = torch.autograd.grad(outputs[j, c], net.parameters(), retain_graph=True, create_graph=True) # We need to retain every single graph for the gradient to be able to run through
        # print_memory_usage()
        flattened_tensors = [t.flatten() for t in grad_y1]
        jac[j] = torch.cat(flattened_tensors)
        # print_memory_usage()
        # if device == "cuda":
        #     torch.cuda.empty_cache()
        #     print_memory_usage()
    return jac

for_loop_jac = compute_jacobian_autodiff(combined_net.to(device), input1.to(device), 2)
print(for_loop_jac.shape)
print(f"Total time : {time.time()-start_time}")

In [6]:
# Define a function to get parameters as a single tensor
def get_params_tensor(net):
    params = []
    for param in net.parameters():
        params.append(param.view(-1))
    return torch.cat(params)

# Define the function for which we want the Jacobian
def net_with_params(params, input_tensor):
    # Set the parameters to the network
    start_idx = 0
    for param in combined_net.parameters():
        param_length = param.numel()
        param.data.copy_(params[start_idx:start_idx + param_length].view(param.size()))
        start_idx += param_length
    
    return combined_net(input_tensor)

# Example input
input_tensor = torch.randn(25,3,84,84, device = device, requires_grad=True)

# Get the current parameters as a single tensor
params = get_params_tensor(combined_net.to(device))

# Compute the Jacobian
jacobian = torch.autograd.functional.jacobian(lambda p: net_with_params(p, input_tensor), params)
print(jacobian)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0.,

In [8]:
start_time = time.time()

def compute_jacobian_autodiff(net, inputs):
    """
    Return the jacobian of a batch of inputs, using autodifferentiation
    Useful for when dealing with models using batch normalization or other kind of running statistics
    """
    inputs.requires_grad_(True)
    outputs = net(inputs)
    N = sum(p.numel() for p in net.parameters())
    jac = torch.empty(outputs.size(0), N)
    for j in range(outputs.size(0)):
        # print(j)
        grad_y1 = torch.autograd.grad(outputs[j, 2], net.parameters(), retain_graph=True, create_graph=True) # We need to create and retain every single graph for the gradient to be able to run through during backprop
        # print_memory_usage()
        flattened_tensors = [t.flatten() for t in grad_y1]
        jac[j] = torch.cat(flattened_tensors)
        # print_memory_usage()
        # if device == "cuda":
        #     torch.cuda.empty_cache()
        #     print_memory_usage()
    return jac

input1 = torch.randn(80, 3, 84, 84)

compute_jacobian_autodiff(combined_net.to(device), input1.to(device))
print(time.time() - start_time)

4.115832805633545


In [9]:
print_memory_usage()
print(for_loop_jac)

Allocated: 502.18 MB
Cached: 29484.00 MB



NameError: name 'for_loop_jac' is not defined

In [None]:
start_time = time.time()

def compute_jacobian_vmap_autodiff(net, inputs, c):
    """
    Return the jacobian of a batch of inputs, thanks to the vmap functionality
    """
    params_that_need_grad = []
    for param in net.parameters():
        if param.requires_grad:
            params_that_need_grad.append(param.requires_grad)
    
    inputs = inputs.to(device, non_blocking=True)
    inputs.requires_grad_(True)
    outputs = net(inputs)
    basis_vectors = torch.eye(len(inputs),device=device,dtype=torch.bool)
    J_layer = []
    for i,z in enumerate(net.named_parameters()):
        if not(params_that_need_grad[i]): #if it didnt need a grad, we can skip it.
            continue
        name, param = z
        outputsc = outputs[:, c]   
        #Seems like for retain_graph=False, you might need to do multiple forward passes.
        
        def torch_row_Jacobian(v): #y would have to be a single piece of the batch
            return torch.autograd.grad(outputsc,param,v, retain_graph=True, create_graph=True)[0].reshape(-1)
        J_layer.append(vmap(torch_row_Jacobian)(basis_vectors).detach())
        
        del outputsc
        if device=='cuda':
            torch.cuda.empty_cache()
        #print(name)
    #for layer in J_layer:
    #    print(layer.shape)
    del params_that_need_grad
    del inputs
    del outputs
    del basis_vectors
    if device=='cuda':
        torch.cuda.empty_cache()
    J_layer = torch.cat(J_layer, axis=1)
    return J_layer
    

vmap_jac = compute_jacobian_vmap_autodiff(combined_net.to(device), input1.to(device), 2)
print(vmap_jac.shape)
#print(vmap_jac.cpu() == for_loop_jac)
print(f"Total time : {time.time()-start_time}")

In [9]:
# Print total memory allocated and cached on the GPU
print(f"Allocated: {torch.cuda.memory_allocated() / 1024 ** 2} MB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024 ** 2} MB")

# For more detailed information, use:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

Allocated: 520.861328125 MB
Cached: 30972.0 MB
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 8            |        cudaMalloc retries: 8         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 533362 KiB |  29307 MiB | 551117 MiB | 550596 MiB |
|       from large pool | 527032 KiB |  29288 MiB | 551040 MiB | 550526 MiB |
|       from small pool |   6330 KiB |     19 MiB |     76 MiB |     70 MiB |
|---------------------------------------------------------------------------|
| Active memory         | 533362 KiB |  29307 MiB | 551117 MiB | 550596 MiB |
|       from large pool | 527032 KiB |  29288 MiB | 551040 MiB | 550526 MiB |
|       from small pool |   6330 KiB |     19 MiB |     76 MiB |     70 MiB |
|----------------

In [None]:
x = torch.randn(2, 3)
print(x)

In [12]:
def compute_jacobian(net, inputs, c):
    """
    Return the jacobian of a batch of inputs, thanks to the vmap functionality
    """
    print(sum(p.numel() for p in net.parameters() if p.requires_grad_))
    params = {k: v for k, v in net.named_parameters() if v.requires_grad_}

    def fnet_single(params, x):
        # Make sure output has the right dimensions
        return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)[c]

    jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
    jac_values = jac.values()

    reshaped_tensors = []
    for j in jac_values:
        if len(j.shape) >= 3:  # For layers with weights
            # Flatten parameters dimensions and then reshape
            flattened = j.flatten(start_dim=1)  # Flattens to [batch, params]
            reshaped = flattened.T  # Transpose to align dimensions as [params, batch]
            reshaped_tensors.append(reshaped)
        elif len(j.shape) == 2:  # For biases or single parameter components
            reshaped_tensors.append(j.T)  # Simply transpose

    # Concatenate all the reshaped tensors into one large matrix
    return torch.cat(reshaped_tensors, dim=0).T


def compute_ntk2(model, x1, x2):
    # Forward pass
    model.eval()
    j1 = []
    j2 = []
    for c in range(5):
        j1c = compute_jacobian(model, x1, c)
        j2c = compute_jacobian(model, x2, c)
        print(j1c.shape)
        j1.append(j1c)
        j2.append(j2c)
    j1 = torch.cat(j1, dim=0)
    j2 = torch.cat(j2, dim=0)
    model.train()
    
    print(j1.shape)
    return j1@j2.T

# Example usage
input1 = torch.ones([5*17, 3, 84, 84])
input2 = torch.ones([5*17, 3, 84, 84])

ntk_value = compute_ntk2(combined_net, input1, input2)
print("NTK value:", ntk_value)
print("NTK sum value:", sum(sum(ntk_value)))

21287237
21287237
torch.Size([85, 21287237])
21287237
21287237
torch.Size([85, 21287237])
21287237
21287237
torch.Size([85, 21287237])
21287237
21287237
torch.Size([85, 21287237])
21287237
21287237
torch.Size([85, 21287237])
torch.Size([425, 21287237])
NTK value: tensor([[311973.9375, 311973.9375, 311973.9375,  ...,  -6109.9062,
          -6109.9062,  -6109.9062],
        [311973.9375, 311973.9375, 311973.9375,  ...,  -6109.9062,
          -6109.9062,  -6109.9062],
        [311973.9375, 311973.9375, 311973.9375,  ...,  -6109.9062,
          -6109.9062,  -6109.9062],
        ...,
        [ -6109.9062,  -6109.9062,  -6109.9062,  ..., 327817.5000,
         327817.5000, 327817.5000],
        [ -6109.9062,  -6109.9062,  -6109.9062,  ..., 327817.5000,
         327817.5000, 327817.5000],
        [ -6109.9062,  -6109.9062,  -6109.9062,  ..., 327817.5000,
         327817.5000, 327817.5000]], grad_fn=<MmBackward0>)
NTK sum value: tensor(1.3333e+10, grad_fn=<AddBackward0>)
