# Init the model

In [4]:
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
from torch.autograd import Variable
import torch.optim
import math
import numpy as np
device = 'cuda'

import time

In [16]:
#Redefine Conv4 here :

def init_layer(L):
    # Initialization using fan-in
    if isinstance(L, nn.Conv2d):
        n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
        L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
    elif isinstance(L, nn.BatchNorm2d):
        L.weight.data.fill_(1)
        L.bias.data.fill_(0)

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)    
    
class ConvNetNoBN(nn.Module):
    maml = False # Default

    def __init__(self, depth, n_way=-1, flatten=True, padding=1):
        super(ConvNetNoBN, self).__init__()
        layers = []
        for i in range(depth):
            indim = 3 if i == 0 else 64
            outdim = 64
            if self.maml:
                conv_layer = Conv2d_fw(indim, outdim, 3, padding=padding)
                # BN     = BatchNorm2d_fw(outdim)
            else:
                conv_layer = nn.Conv2d(indim, outdim, 3, stride=1, padding=padding, bias=False)
                # BN     = nn.BatchNorm2d(outdim)
            
            relu = nn.ReLU(inplace=True)
            layers.append(conv_layer)
            # layers.append(BN)
            layers.append(relu)

            if i < 4:  # Pooling only for the first 4 layers
                pool = nn.MaxPool2d(2)
                layers.append(pool)

            # Initialize the layers
            init_layer(conv_layer)
            # init_layer(BN)

        if flatten:
            layers.append(Flatten())
        
        if n_way>0:
            layers.append(nn.Linear(1600,n_way))
            self.final_feat_dim = n_way
        else:
            self.final_feat_dim = 1600
            
        self.trunk = nn.Sequential(*layers)
        

    def forward(self, x):
        # print(self.trunk[0].weight)
        for i, layer in enumerate(self.trunk):
            print(f"Input shape before layer {i}: {x.shape}")
            print(layer.weight)
            x = layer(x)
            print(f"Output shape after layer {i}: {x.shape}")
        return x

def Conv4NoBN():
    print("Conv4 No Batch Normalization")
    return ConvNetNoBN(4)

def Conv4NoBN_class(n_way=5):
    print("Conv4 No Batch Normalization with final classifier layer of 5 way")
    return ConvNetNoBN(4, n_way=n_way)

# Global variables

In [17]:
n_way = 5

n_support = 7
n_query = 10

n_inner_upd = 3
n_task = 4

eps = 1e-4
lr_in = 1e-2
lr_out = 1e-1

In [18]:
x_support = torch.randn(n_way * n_support, 3, 84, 84, device=device)
y_support = Variable(torch.from_numpy(np.repeat(range(n_way), n_support)).cuda())
x_query = torch.randn(n_way * n_query, 3, 84, 84, device=device)
y_query = Variable(torch.from_numpy(np.repeat(range(n_way), n_query)).cuda())


net = Conv4NoBN_class().to(device)

Conv4 No Batch Normalization with final classifier layer of 5 way


In [19]:
params = {k: v for k, v in net.named_parameters()}
s = {k: torch.ones_like(v) for (k, v) in params.items()}

print("params dict is equal to the net's params \n")
print(net.trunk[0].weight[0][0])
print(list(params.values())[0][0][0])

params dict is equal to the net's params 

tensor([[ 0.1177,  0.0025,  0.0485],
        [ 0.0203, -0.0412,  0.0553],
        [ 0.0211,  0.0470,  0.1086]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[ 0.1177,  0.0025,  0.0485],
        [ 0.0203, -0.0412,  0.0553],
        [ 0.0211,  0.0470,  0.1086]], device='cuda:0',
       grad_fn=<SelectBackward0>)


## INNER LOOP

## NTK computation

Note :
USE FAST WEIGHT AS THEY DID FOR MAML (nothing is supposed to pose a problem except for batch norm layers as usual)

In [20]:
new_params = {k : nn.Parameter(v) for k, v in new_params.items()}

output = functional_call(net.to(device), new_params, (x_support,))

Input shape before layer 0: torch.Size([35, 3, 84, 84])
Parameter containing:
tensor([[[[[ 0.0405, -0.0636,  0.0121],
           [ 0.0634, -0.0317,  0.0556],
           [-0.0851, -0.0781,  0.0151]],

          [[-0.0455,  0.0024, -0.0177],
           [-0.0468,  0.0569, -0.0198],
           [ 0.1005,  0.0518, -0.0946]],

          [[-0.0023,  0.0403,  0.0110],
           [-0.0639,  0.1388,  0.0300],
           [ 0.1657,  0.0036, -0.0303]]],


         [[[-0.0527, -0.0552, -0.0272],
           [ 0.0181,  0.0323, -0.1041],
           [-0.0070,  0.0066, -0.1254]],

          [[-0.1010, -0.0516, -0.0507],
           [-0.0045, -0.0910,  0.1314],
           [-0.0130,  0.0905, -0.1148]],

          [[ 0.0230,  0.0835,  0.0607],
           [ 0.1031,  0.0074,  0.0239],
           [-0.1148, -0.0089, -0.0562]]],


         [[[ 0.0020, -0.0318,  0.0139],
           [-0.0036,  0.0868, -0.0372],
           [ 0.0691, -0.0390,  0.0578]],

          [[ 0.0474, -0.0188, -0.0784],
           [ 0.0341,  0.

RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1]

In [9]:
def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

start_time = time.time()

# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x_support)

# print(net.trunk[0].weight)
s_jac = {k : s[k]*j for (k, j) in jac1.items()}   # Useful for later
ntk_jac1 = [s_j.flatten(2) for s_j in s_jac.values()]   # Useful for the NTK computation
    
# Compute J(x1) @ J(x2).T
ntk = torch.stack([torch.einsum('Naf,Maf->aNM', j1, j2) for j1, j2 in zip(ntk_jac1, ntk_jac1)])
ntk = ntk.sum(0)

print(f"Jacobian contraction time {time.time()-start_time}")
print(ntk.shape)

Input shape before layer 0: torch.Size([1, 3, 84, 84])
Output shape after layer 0: torch.Size([1, 64, 84, 84])
Input shape before layer 1: torch.Size([1, 64, 84, 84])
Output shape after layer 1: torch.Size([1, 64, 84, 84])
Input shape before layer 2: torch.Size([1, 64, 84, 84])
Output shape after layer 2: torch.Size([1, 64, 42, 42])
Input shape before layer 3: torch.Size([1, 64, 42, 42])
Output shape after layer 3: torch.Size([1, 64, 42, 42])
Input shape before layer 4: torch.Size([1, 64, 42, 42])
Output shape after layer 4: torch.Size([1, 64, 42, 42])
Input shape before layer 5: torch.Size([1, 64, 42, 42])
Output shape after layer 5: torch.Size([1, 64, 21, 21])
Input shape before layer 6: torch.Size([1, 64, 21, 21])
Output shape after layer 6: torch.Size([1, 64, 21, 21])
Input shape before layer 7: torch.Size([1, 64, 21, 21])
Output shape after layer 7: torch.Size([1, 64, 21, 21])
Input shape before layer 8: torch.Size([1, 64, 21, 21])
Output shape after layer 8: torch.Size([1, 64, 10

## Computation of $\text{Sol}_c = (NTK_c + \epsilon I_k )^{-1} (Y - \phi_{\theta, c} (X))$

In [10]:
# Creation of Y

target_list = list()
samples_per_model = int(len(y_support) / n_way) #25 / 5 = 5
for way in range(n_way):
    target = torch.ones(len(y_support), dtype=torch.float32) * -1.0
    start_index = way * samples_per_model
    stop_index = start_index+samples_per_model
    target[start_index:stop_index] = 1.0
    target_list.append(target.cuda())
    
print(target_list)
print(target_list[0].shape)

[tensor([ 1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -

In [11]:
# Forward pass

phi = net.forward(x_support)
print(phi.shape)

Input shape before layer 0: torch.Size([35, 3, 84, 84])
Output shape after layer 0: torch.Size([35, 64, 84, 84])
Input shape before layer 1: torch.Size([35, 64, 84, 84])
Output shape after layer 1: torch.Size([35, 64, 84, 84])
Input shape before layer 2: torch.Size([35, 64, 84, 84])
Output shape after layer 2: torch.Size([35, 64, 42, 42])
Input shape before layer 3: torch.Size([35, 64, 42, 42])
Output shape after layer 3: torch.Size([35, 64, 42, 42])
Input shape before layer 4: torch.Size([35, 64, 42, 42])
Output shape after layer 4: torch.Size([35, 64, 42, 42])
Input shape before layer 5: torch.Size([35, 64, 42, 42])
Output shape after layer 5: torch.Size([35, 64, 21, 21])
Input shape before layer 6: torch.Size([35, 64, 21, 21])
Output shape after layer 6: torch.Size([35, 64, 21, 21])
Input shape before layer 7: torch.Size([35, 64, 21, 21])
Output shape after layer 7: torch.Size([35, 64, 21, 21])
Input shape before layer 8: torch.Size([35, 64, 21, 21])
Output shape after layer 8: torc

In [12]:
# Do the actual computation

start_time = time.time()

sols = []

for c in range(n_way):
    inverse_term = ntk[c] + eps * torch.eye(n_way * n_support, device=device)
    residual = target_list[c] - phi[:, c]  # phi is of shape [n_way*n_support, n_way]

    # Solve the system (NTK_c + epsilon I_k) * result = residual
    sols.append(torch.linalg.solve(inverse_term, residual))
    
print(f"Total inversions time {time.time()-start_time}")

Total inversions time 0.09073948860168457


## Computation of $\theta - \eta_{in} \sum_{c \leq C} (s \cdot \nabla_\theta \phi_c) \times \text{Sol}_c$

In [13]:
params = {k: v for k, v in net.named_parameters()}

In [14]:
# We already have the first term computed as s_jac
start_time = time.time()
print(type(params))
# print(params["trunk.0.weight"])

# for c in range(n_way):
    # for (k, param) in params.items():
    #     params[k] = param - lr_in * torch.tensordot(s_jac[k], sols[c], dims=([0], [0]))

tensor_update = {k : sum(torch.tensordot(s_jac[k], sols[c], dims=([0], [0])) for c in range(n_way)) for k in params.keys()}
new_params = {k: param - lr_in * tensor_update[k] for k, param in params.items()}
    
print(net.trunk[0].weight[0][0])
print(list(params.values())[0][0][0][0])
print(f"Inner update time time {time.time()-start_time}")
print(type(params))
# print(params["trunk.0.weight"])
# print(list( net.named_parameters() )[0])

<class 'dict'>
tensor([[ 0.0405, -0.0636,  0.0120],
        [ 0.0635, -0.0317,  0.0555],
        [-0.0851, -0.0781,  0.0151]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([ 0.0405, -0.0636,  0.0120], device='cuda:0', grad_fn=<SelectBackward0>)
Inner update time time 0.004956483840942383
<class 'dict'>


# Define inner_loop function

In [14]:
def contruct_target_list(N, n_way):  # N = n_support or n_query
    target_list = list()
    samples_per_model = int(N)
    for c in range(n_way):
        target = torch.ones(N * n_way, dtype=torch.float32) * -1.0
        start_index = c * samples_per_model
        stop_index = start_index+samples_per_model
        target[start_index:stop_index] = 1.0
        target_list.append(target.cuda())
    return target_list


def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

def inner_loop(x_support, target_list_support):
    
    # Create a param dict
    params = {k: v for k, v in net.named_parameters()}

    for inner_epoch in range(n_inner_upd):
        # Forward pass
        phi = net.forward(x_support)

        # Compute J(x1)
        jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x_support)
        s_jac = {k : s[k]*j for (k, j) in jac1.items()}   # Useful for later
        ntk_jac1 = [s_j.flatten(2) for s_j in s_jac.values()]   # Useful for the NTK computation

        # Compute J(x1) @ J(x2).T
        ntk = torch.stack([torch.einsum('Naf,Maf->aNM', j1, j2) for j1, j2 in zip(ntk_jac1, ntk_jac1)])
        ntk = ntk.sum(0)

        # Compute solutions to (NTK_c + eps I)^-1 (Y_c - phi_c)
        sols = []
        for c in range(n_way):
            inverse_term = ntk[c] + eps * torch.eye(n_way * n_support, device=device)
            residual = target_list_support[c] - phi[:, c]  # phi is of shape [n_way*n_support, n_way]

            # Solve the system (NTK_c + epsilon I_k) * result = residual
            sols.append(torch.linalg.solve(inverse_term, residual))

            # Update parameters 
            for (k, param) in params.items():
                    params[k] = param - lr_in * torch.tensordot(s_jac[k], sols[c], dims=([0], [0]))
                    
    return params

# OUTER LOOP

## Outer tools

In [12]:
# Those two are the same for every task and iteration, we compute them in advance

target_list_support = contruct_target_list(n_support, n_way)
target_list_query = contruct_target_list(n_query, n_way)

loss = nn.CrossEntropyLoss()

In [15]:
print_freq = 10
avg_loss = 0
task_count = 0
loss_all = []
optimizer = torch.optim.Adam([{'params': net.parameters(), 'lr': lr_out},
                              {'params': s.values(), 'lr': lr_out}])

for i_task in range(n_task):
    # New batch corresponding to a task
    x_support = torch.randn(n_way * n_support, 3, 84, 84, device=device)
    x_query = torch.randn(n_way * n_query, 3, 84, 84, device=device)
    
    # Inner updates
    inner_params = inner_loop(x_support, target_list_support)
    
    

RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1]