In [11]:
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
import math
device = 'cuda'

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, (3, 3))
        self.conv2 = nn.Conv2d(32, 32, (3, 3))
        self.conv3 = nn.Conv2d(32, 32, (3, 3))
        self.fc = nn.Linear(194688, 5)
        
    def forward(self, x):
        x = self.conv1(x)
        x = x.relu()
        x = self.conv2(x)
        x = x.relu()
        x = self.conv3(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x

In [12]:
#Redefine Conv4 here :

def init_layer(L):
    # Initialization using fan-in
    if isinstance(L, nn.Conv2d):
        n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
        L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
    elif isinstance(L, nn.BatchNorm2d):
        L.weight.data.fill_(1)
        L.bias.data.fill_(0)

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)    
    
class ConvNetNoBN(nn.Module):
    maml = False # Default

    def __init__(self, depth, n_way=-1, flatten=True, padding=1):
        super(ConvNetNoBN, self).__init__()
        layers = []
        for i in range(depth):
            indim = 3 if i == 0 else 64
            outdim = 64
            if self.maml:
                conv_layer = Conv2d_fw(indim, outdim, 3, padding=padding)
                # BN     = BatchNorm2d_fw(outdim)
            else:
                conv_layer = nn.Conv2d(indim, outdim, 3, stride=1, padding=padding, bias=False)
                # BN     = nn.BatchNorm2d(outdim)
            
            relu = nn.ReLU(inplace=True)
            layers.append(conv_layer)
            # layers.append(BN)
            layers.append(relu)

            if i < 4:  # Pooling only for the first 4 layers
                pool = nn.MaxPool2d(2)
                layers.append(pool)

            # Initialize the layers
            init_layer(conv_layer)
            # init_layer(BN)

        if flatten:
            layers.append(Flatten())
        
        if n_way>0:
            layers.append(nn.Linear(1600,n_way))
            self.final_feat_dim = n_way
        else:
            self.final_feat_dim = 1600
            
        self.trunk = nn.Sequential(*layers)
        

    def forward(self, x):
        out = self.trunk(x)
        return out

def Conv4NoBN():
    print("Conv4 No Batch Normalization")
    return ConvNetNoBN(4)

def Conv4NoBN_class(n_way=5):
    print("Conv4 No Batch Normalization with final classifier layer of 5 way")
    return ConvNetNoBN(4, n_way=n_way)

def combined_conv3():
    print("Conv3")
    return backbone.CombinedNetwork(backbone.Conv3(), nn.Linear(1764,5))


# Redefine Resnet here:

class SimpleBlockNoBN(nn.Module):
    maml = False #Default
    def __init__(self, indim, outdim, half_res):
        super(SimpleBlockNoBN, self).__init__()
        self.indim = indim
        self.outdim = outdim
        
        self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
        # self.BN1 = nn.BatchNorm2d(outdim)
        self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False)
        # self.BN2 = nn.BatchNorm2d(outdim)
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

        self.parametrized_layers = [self.C1, self.C2]

        self.half_res = half_res

        # if the input number of channels is not equal to the output, then need a 1x1 convolution
        if indim!=outdim:
            self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
            # self.BNshortcut = nn.BatchNorm2d(outdim)

            self.parametrized_layers.append(self.shortcut)
            # self.parametrized_layers.append(self.BNshortcut)
            self.shortcut_type = '1x1'
        else:
            self.shortcut_type = 'identity'

        for layer in self.parametrized_layers:
            init_layer(layer)

    def forward(self, x):
        out = self.C1(x)
        # out = self.BN1(out)
        out = self.relu1(out)
        out = self.C2(out)
        # out = self.BN2(out)
        short_out = x if self.shortcut_type == 'identity' else self.shortcut(x)
        out = out + short_out
        out = self.relu2(out)
        return out
    
class ResNetNoBN(nn.Module):
    maml = False #Default
    def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten = True, n_way = -1):
        # list_of_num_layers specifies number of layers in each stage
        # list_of_out_dims specifies number of output channel for each stage
        super(ResNetNoBN,self).__init__()
        assert len(list_of_num_layers)==4, 'Can have only four stages'
        conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # bn1 = nn.BatchNorm2d(64)

        relu = nn.ReLU()
        pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        init_layer(conv1)
        # init_layer(bn1)


        trunk = [conv1, relu, pool1]

        indim = 64
        for i in range(4):

            for j in range(list_of_num_layers[i]):
                half_res = (i>=1) and (j==0)  # If the start of a new type of layers
                B = block(indim, list_of_out_dims[i], half_res)
                trunk.append(B)
                indim = list_of_out_dims[i]

        if flatten:
            avgpool = nn.AvgPool2d(3) # originally avgpool = nn.AvgPool2d(7)
            trunk.append(avgpool)
            trunk.append(Flatten())
            self.final_feat_dim = indim
        else:
            self.final_feat_dim = [ indim, 7, 7]
            
        if n_way>0:
            trunk.append(nn.Linear(2048,n_way))
            self.final_feat_dim = n_way
        else:
            self.final_feat_dim = 2048

        self.trunk = nn.Sequential(*trunk)

    def forward(self,x):
        out = self.trunk(x)
        return out
    
    

def ResNetNoBN10( flatten = True):
    return ResNetNoBN(SimpleBlockNoBN, [1,1,1,1],[64,128,256,512], flatten)

def ResNetNoBN10_classifier( flatten = True):
    return ResNetNoBN(SimpleBlockNoBN, [1,1,1,1],[64,128,256,512], flatten, n_way = 5)

In [27]:
# x_train = torch.randn(85, 3, 224, 224, device=device)
# x_test = torch.randn(85, 3, 224, 224, device=device)

x_train = torch.randn(85, 3, 84, 84, device=device)
x_test = torch.randn(85, 3, 84, 84, device=device)

In [30]:
# net = ResNetNoBN10_classifier().to(device)
net = Conv4NoBN_class().to(device)
params = {k: v for k, v in net.named_parameters()}

s = {k: torch.ones_like(v) for (k, v) in params.items()}

def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

Conv4 No Batch Normalization with final classifier layer of 5 way


In [31]:
def empirical_ntk_s(fnet_single, params, x1, s):
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [(s[k] * j).flatten(2) for k, j in jac1.items()] 
    # jac1 = [j.flatten(2) for j in jac1.values()]

    # jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    # jac2 = [(s[k] * j).flatten(2) for k, j in jac2.items()] 
    # jac2 = [j.flatten(2) for j in jac2.values()]

    result = torch.stack([torch.einsum('Naf,Maf->NMa', j1, j2) for j1, j2 in zip(jac1, jac1)])
    result = result.sum(0)
    return result.permute(2,0,1)


In [32]:
empirical_ntk_s(net, params, x_train, s)

TypeError: forward() takes 2 positional arguments but 3 were given

In [6]:
s = {k: torch.ones_like(v).to(device) for k, v in params.items()}
ntk_result = empirical_ntk_s(fnet_single, params, x_train, x_test, s)

print(ntk_result.shape)  

(0, 0)
(0, 1)
(0, 2)
(0, 3)
(0, 4)
(0, 5)
(0, 6)
(0, 7)
(0, 8)
(0, 9)
(0, 10)
(0, 11)
(0, 12)
(0, 13)
(0, 14)
(0, 15)
(0, 16)
(0, 17)
(0, 18)
(0, 19)
(0, 20)
(0, 21)
(0, 22)
(0, 23)
(0, 24)
(0, 25)
(0, 26)
(0, 27)
(0, 28)
(0, 29)
(0, 30)
(0, 31)
(0, 32)
(0, 33)
(0, 34)
(0, 35)
(0, 36)
(0, 37)
(0, 38)
(0, 39)
(0, 40)
(0, 41)
(0, 42)
(0, 43)
(0, 44)
(0, 45)
(0, 46)
(0, 47)
(0, 48)
(0, 49)
(0, 50)
(0, 51)
(0, 52)
(0, 53)
(0, 54)
(0, 55)
(0, 56)
(0, 57)
(0, 58)
(0, 59)
(0, 60)
(0, 61)
(0, 62)
(0, 63)
(0, 64)
(0, 65)
(0, 66)
(0, 67)
(0, 68)
(0, 69)
(0, 70)
(0, 71)
(0, 72)
(0, 73)
(0, 74)
(0, 75)
(0, 76)
(0, 77)
(0, 78)
(0, 79)
(0, 80)
(0, 81)
(0, 82)
(0, 83)
(0, 84)
(1, 0)
(1, 1)
(1, 2)
(1, 3)
(1, 4)
(1, 5)
(1, 6)
(1, 7)
(1, 8)
(1, 9)
(1, 10)
(1, 11)
(1, 12)
(1, 13)
(1, 14)
(1, 15)
(1, 16)
(1, 17)
(1, 18)
(1, 19)
(1, 20)
(1, 21)
(1, 22)
(1, 23)


OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 31.74 GiB total capacity; 30.66 GiB already allocated; 15.12 MiB free; 31.11 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
    def get_ntk(x1, x2):
        def func_x1(params):
            return func(params, x1)

        def func_x2(params):
            return func(params, x2)

        output, vjp_fn = vjp(func_x1, params)

        def get_ntk_slice(vec):
            # This computes vec @ J(x2).T
            # `vec` is some unit vector (a single slice of the Identity matrix)
            vjps = vjp_fn(vec)
            # This computes J(X1) @ vjps
            _, jvps = jvp(func_x2, (params,), vjps)
            return jvps

        # Here's our identity matrix
        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
        return vmap(get_ntk_slice)(basis)
        
    # get_ntk(x1, x2) computes the NTK for a single data point x1, x2
    # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,
    # we actually wish to compute the NTK between every pair of data points
    # between {x1} and {x2}. That's what the vmaps here do.
    result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
    
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK', result)

In [21]:
# result_from_jacobian_contraction = empirical_ntk_s(fnet_single, params, x_test, x_train, s)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)

OutOfMemoryError: CUDA out of memory. Tried to allocate 28.96 GiB (GPU 0; 31.74 GiB total capacity; 19.15 GiB already allocated; 11.66 GiB free; 19.47 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF