In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.func import functional_call, vmap, vjp, jvp, jacrev

class MyNetworkTester:
    def __init__(self):
        # A simple CNN for testing purposes
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 3),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 96 * 96, 1)  # Assuming input size is 100x100
        )
        
        # Initialize some dummy scaling parameters
        net_params = dict(self.net.named_parameters())
        self.scaling_params = {k: torch.ones_like(v) for k, v in net_params.items()}
        def create_jac_func(net):
            """
            Computes the functional call of a single input, to be differentiated in parallel using vmap
            """
            def fnet_single(params, x):
                return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

            jac_func = vmap(jacrev(fnet_single), (None, 0))
            return jac_func
        self.jac_func = create_jac_func(self.net)
        

    def test_function(self, x1, x2, diag=False):
        x1 = x1.reshape(x1.size(0), 3, 100, 100)
        x2 = x2.reshape(x2.size(0), 3, 100, 100)
        
        params = dict(self.net.named_parameters())
        
        # Calculate Jacobians for x1
        jac1 = self.jac_func(params, x1)
        sp_jac1 = [(self.scaling_params[k]*j).flatten(2) for (k, j) in jac1.items()]  

        if torch.equal(x1, x2):
            jac2=jac1
            sp_jac2=sp_jac1
        else:
            jac2 = self.jac_func(params, x2)
            sp_jac2 = [(self.scaling_params[k]*j).flatten(2) for (k, j) in jac2.items()]  
        ntk_list = []
        for j1, j2 in zip(sp_jac1, sp_jac2):
            ntk_list.append(torch.einsum('Naf,Maf->aNM', j1, j2))
        
        ntk = torch.sum(torch.stack(ntk_list), dim=0).squeeze(0)
        
        if diag:
            return ntk.diag()
        return ntk


# Now, we will instantiate the class and test the function

# Create an instance of the network tester
tester = MyNetworkTester()

# Create random inputs x1 and x2 (batch_size=4 for testing)
x1 = torch.randn(19, 3, 100, 100)
x2 = torch.randn(19, 3, 100, 100)

# Test the function (without diagonal)
ntk = tester.test_function(x1, x1, diag=False)
print(ntk.shape)
# Test the function (without diagonal)
ntk = tester.test_function(x1, x2, diag=False)
print(ntk.shape)
# Test the function (with diagonal)
ntk_diag = tester.test_function(x1, x2, diag=True)
print(ntk_diag.shape)


torch.Size([19, 19])
torch.Size([19, 19])
torch.Size([19])


In [3]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        # Initialize your params
        self.params = {'param1': torch.randn(3, 3), 'param2': torch.randn(3, 3)}
        
        # Make scaling_params a Parameter to allow optimization
        self.scaling_params = {k: nn.Parameter(torch.ones_like(v, device='cuda:0')) for k, v in self.params.items()}
        
    def forward(self, x):
        # Example forward pass, using scaling_params
        # Do something with scaling_params here
        return x


In [4]:
model=MyModel()
for name, param in model.named_parameters():
    print(name)