# Init the model

In [1]:
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
from torch.autograd import Variable
import torch.optim
import math
import numpy as np
import torch.nn.functional as F
device = 'cuda'

import time

In [70]:
#Redefine Conv4 here :

def init_layer(L):
    # Initialization using fan-in
    if isinstance(L, nn.Conv2d):
        n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
        L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
    elif isinstance(L, nn.BatchNorm2d):
        L.weight.data.fill_(1)
        L.bias.data.fill_(0)

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)    
    
class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True):
        super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)
        self.weight.fast = None
        if not self.bias is None:
            self.bias.fast = None

    def forward(self, x):
        if self.bias is None:
            if self.weight.fast is not None:
                out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding)
            else:
                out = super(Conv2d_fw, self).forward(x)
        else:
            print(self.weight)
            if self.weight.fast is not None and self.bias.fast is not None:
                out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding)
            else:
                out = super(Conv2d_fw, self).forward(x)

        return out
    
class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
    def __init__(self, in_features, out_features):
        super(Linear_fw, self).__init__(in_features, out_features)
        self.weight.fast = None #Lazy hack to add fast weight link
        self.bias.fast = None

    def forward(self, x):
        if self.weight.fast is not None and self.bias.fast is not None:
            out = F.linear(x, self.weight.fast, self.bias.fast) #weight.fast (fast weight) is the temporaily adapted weight
        else:
            out = super(Linear_fw, self).forward(x)
        return out
    
class ConvNetNoBN(nn.Module):
    maml = True # Default

    def __init__(self, depth, n_way=-1, flatten=True, padding=1):
        super(ConvNetNoBN, self).__init__()
        layers = []
        for i in range(depth):
            indim = 3 if i == 0 else 64
            outdim = 64
            if self.maml:
                conv_layer = Conv2d_fw(indim, outdim, 3, padding=padding)
                # BN     = BatchNorm2d_fw(outdim)
            else:
                conv_layer = nn.Conv2d(indim, outdim, 3, stride=1, padding=padding, bias=False)
                # BN     = nn.BatchNorm2d(outdim)
            
            relu = nn.ReLU(inplace=True)
            layers.append(conv_layer)
            # layers.append(BN)
            layers.append(relu)

            if i < 4:  # Pooling only for the first 4 layers
                pool = nn.MaxPool2d(2)
                layers.append(pool)

            # Initialize the layers
            init_layer(conv_layer)
            # init_layer(BN)

        if flatten:
            layers.append(Flatten())
        
        if n_way>0:
            if self.maml:
                classifier = Linear_fw(1600, n_way)
                classifier.bias.data.fill_(0)
                layers.append(classifier)
            else:
                classifier = nn.linear(1600, n_way)
                layers.append(classifier)
            self.final_feat_dim = n_way
        else:
            self.final_feat_dim = 1600
          
        self.trunk = nn.Sequential(*layers)
        

    def forward(self, x):
        out = self.trunk(x)
        return out

def Conv4NoBN():
    print("Conv4 No Batch Normalization")
    return ConvNetNoBN(4)

def Conv4NoBN_class(n_way=5):
    print("Conv4 No Batch Normalization with final classifier layer of 5 way")
    return ConvNetNoBN(4, n_way=n_way)

# Global variables

In [71]:
n_way = 5

n_support = 7
n_query = 10

eps = 1e-4
lr_in = 1e-2
lr_out = 1e-1

In [72]:
x_support = torch.randn(n_way * n_support, 3, 84, 84, device=device)
y_support = Variable(torch.from_numpy(np.repeat(range(n_way), n_support)).cuda())
x_query = torch.randn(n_way * n_query, 3, 84, 84, device=device)
y_query = Variable(torch.from_numpy(np.repeat(range(n_way), n_query)).cuda())


net = Conv4NoBN_class().to(device)

# def net_forward(x):
#     out  = net.forward(x)
#     scores  = classifier.forward(out)
#     return scores

Conv4 No Batch Normalization with final classifier layer of 5 way


In [73]:
params = {k: v for k, v in net.named_parameters()}
s = {k: torch.ones_like(v) for (k, v) in params.items()}

In [74]:
# Initiate fast parameters

fast_parameters = list(net.parameters()) #the first gradient calcuated in line 45 is based on original weight
fast_params_dict = {k: v for k, v in net.named_parameters()}
for weight in net.parameters():
    weight.fast = None
net.zero_grad()

## INNER LOOP

## NTK computation

Note :
USE FAST WEIGHT AS THEY DID FOR MAML (nothing is supposed to pose a problem except for batch norm layers as usual)

In [75]:
def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

start_time = time.time()

# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(s, x_support)
s_jac = {k : s[k]*j for (k, j) in jac1.items()}   # Useful for later
ntk_jac1 = [s_j.flatten(2) for s_j in s_jac.values()]   # Useful for the NTK computation
    
# Compute J(x1) @ J(x2).T
ntk = torch.stack([torch.einsum('Naf,Maf->aNM', j1, j2) for j1, j2 in zip(ntk_jac1, ntk_jac1)])
ntk = ntk.sum(0)

print(f"Jacobian contraction time {time.time()-start_time}")
print(ntk.shape)

GradTrackingTensor(lvl=2, value=
    tensor([[[[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]]],


            [[[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]]],


            [[[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]]],


            ...,


            [[[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]

AttributeError: 'Tensor' object has no attribute 'fast'

## Computation of $\text{Sol}_c = (NTK_c + \epsilon I_k )^{-1} (Y - \phi_{\theta, c} (X))$

In [7]:
# Creation of Y

target_list = list()
samples_per_model = int(len(y_support) / n_way) #25 / 5 = 5
for way in range(n_way):
    target = torch.ones(len(y_support), dtype=torch.float32) * -1.0
    start_index = way * samples_per_model
    stop_index = start_index+samples_per_model
    target[start_index:stop_index] = 1.0
    target_list.append(target.cuda())
    
print(target_list)
print(target_list[0].shape)

[tensor([ 1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -

In [8]:
# Forward pass

phi = net.forward(x_support)
print(phi.shape)

torch.Size([35, 5])


In [9]:
# Do the actual computation

start_time = time.time()

sols = []

for c in range(n_way):
    inverse_term = ntk[c] + eps * torch.eye(n_way * n_support, device=device)
    residual = target_list[c] - phi[:, c]  # phi is of shape [n_way*n_support, n_way]

    # Solve the system (NTK_c + epsilon I_k) * result = residual
    sols.append(torch.linalg.solve(inverse_term, residual))
    
print(f"Total inversions time {time.time()-start_time}")

Total inversions time 0.08805727958679199


## Computation of $\theta - \eta_{in} \sum_{c \leq C} (s \cdot \nabla_\theta \phi_c) \times \text{Sol}_c$

In [10]:
# We already have the first term computed as s_jac
start_time = time.time()

for c in range(n_way):
    for (k, param) in net.named_parameters():
        param.data = param - lr_in * torch.tensordot(s_jac[k], sols[c], dims=([0], [0]))

print(f"Inner update time time {time.time()-start_time}")
# print(list( net.named_parameters() )[0])

Inner update time time 0.002782583236694336


# Define inner_loop function

In [11]:
def contruct_target_list(N, n_way):  # N = n_support or n_query
    target_list = list()
    samples_per_model = int(N)
    for c in range(n_way):
        target = torch.ones(N * n_way, dtype=torch.float32) * -1.0
        start_index = c * samples_per_model
        stop_index = start_index+samples_per_model
        target[start_index:stop_index] = 1.0
        target_list.append(target.cuda())
    return target_list


def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

def inner_loop(x_support, target_list_support):
    # Forward pass
    phi = net.forward(x_support)
    
    # Create a param dict
    params = {k: v for k, v in net.named_parameters()}
    
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x_support)
    s_jac = {k : s[k]*j for (k, j) in jac1.items()}   # Useful for later
    ntk_jac1 = [s_j.flatten(2) for s_j in s_jac.values()]   # Useful for the NTK computation

    # Compute J(x1) @ J(x2).T
    ntk = torch.stack([torch.einsum('Naf,Maf->aNM', j1, j2) for j1, j2 in zip(ntk_jac1, ntk_jac1)])
    ntk = ntk.sum(0)
    
    # Compute solutions to (NTK_c + eps I)^-1 (Y_c - phi_c)
    sols = []
    for c in range(n_way):
        inverse_term = ntk[c] + eps * torch.eye(n_way * n_support, device=device)
        residual = target_list_support[c] - phi[:, c]  # phi is of shape [n_way*n_support, n_way]

        # Solve the system (NTK_c + epsilon I_k) * result = residual
        sols.append(torch.linalg.solve(inverse_term, residual))
        
        # Update parameters 
        for (k, param) in net.named_parameters():
            param.data = params[k] - lr_in * torch.tensordot(s_jac[k], sols[c], dims=([0], [0]))

# OUTER LOOP

## Outer tools

In [12]:
n_inner_upd = 3
n_task = 4

In [13]:
# Those two are the same for every task and iteration, we compute them in advance

target_list_support = contruct_target_list(n_support, n_way)
target_list_query = contruct_target_list(n_query, n_way)

loss = nn.CrossEntropyLoss()

In [14]:
print_freq = 10
avg_loss = 0
task_count = 0
loss_all = []
optimizer = torch.optim.Adam([{'params': net.parameters(), 'lr': lr_out},
                              {'params': s.values(), 'lr': lr_out}])

for i_task in range(n_task):
    # New batch corresponding to a task
    x_support = torch.randn(n_way * n_support, 3, 84, 84, device=device)
    x_query = torch.randn(n_way * n_query, 3, 84, 84, device=device)
    
    # Inner updates
    for i_inner_upd in range(n_inner_upd):
        print(i_inner_upd)
        inner_loop(x_support, target_list_support)
    
    

0


RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1]