# Init the model

In [2]:
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
from torch.autograd import Variable
import torch.optim
import math
import numpy as np
device = 'cuda'

import time

In [3]:
#Redefine Conv4 here :

def init_layer(L):
    # Initialization using fan-in
    if isinstance(L, nn.Conv2d):
        n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
        L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
    elif isinstance(L, nn.BatchNorm2d):
        L.weight.data.fill_(1)
        L.bias.data.fill_(0)

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)    
    
class ConvNetNoBN(nn.Module):
    maml = False # Default

    def __init__(self, depth, n_way=-1, flatten=True, padding=1):
        super(ConvNetNoBN, self).__init__()
        layers = []
        for i in range(depth):
            indim = 3 if i == 0 else 64
            outdim = 64
            if self.maml:
                conv_layer = Conv2d_fw(indim, outdim, 3, padding=padding)
                # BN     = BatchNorm2d_fw(outdim)
            else:
                conv_layer = nn.Conv2d(indim, outdim, 3, padding=padding, bias=False)
                # BN     = nn.BatchNorm2d(outdim)
            
            relu = nn.ReLU(inplace=True)
            layers.append(conv_layer)
            # layers.append(BN)
            layers.append(relu)

            if i < 4:  # Pooling only for the first 4 layers
                pool = nn.MaxPool2d(2)
                layers.append(pool)

            # Initialize the layers
            init_layer(conv_layer)
            # init_layer(BN)

        if flatten:
            layers.append(Flatten())
        
        if n_way>0:
            layers.append(nn.Linear(1600,n_way))
            self.final_feat_dim = n_way
        else:
            self.final_feat_dim = 1600
            
        self.trunk = nn.Sequential(*layers)
        

    def forward(self, x):
        # print(self.trunk[0].weight)
        # for i, layer in enumerate(self.trunk):
            # print(f"Input shape before layer {i}: {x.shape}")
            # print(layer.weight)
            # x = layer(x)
            # print(f"Output shape after layer {i}: {x.shape}")
        x = self.trunk(x)
        return x

def Conv4NoBN():
    print("Conv4 No Batch Normalization")
    return ConvNetNoBN(4)

def Conv4NoBN_class(n_way=5):
    print("Conv4 No Batch Normalization with final classifier layer of 5 way")
    return ConvNetNoBN(4, n_way=n_way)

# Global variables

In [4]:
n_way = 5

n_support = 7
n_query = 10

n_inner_upd = 3
n_task = 4

eps = 1e-4
lr_in = 1e-2
lr_out = 1e-1

In [5]:
x_support = torch.randn(n_way * n_support, 3, 84, 84, device=device)
y_support = Variable(torch.from_numpy(np.repeat(range(n_way), n_support)).cuda())
x_query = torch.randn(n_way * n_query, 3, 84, 84, device=device)
y_query = Variable(torch.from_numpy(np.repeat(range(n_way), n_query)).cuda())


net = Conv4NoBN_class().to(device)

Conv4 No Batch Normalization with final classifier layer of 5 way


In [6]:
params = dict(net.named_parameters())
s = {k: torch.ones_like(v) for (k, v) in params.items()}

print("params dict is equal to the net's params \n")
print(net.trunk[0].weight[0][0])
print(list(params.values())[0][0][0])

params dict is equal to the net's params 

tensor([[-0.0031,  0.0121, -0.0709],
        [ 0.0624,  0.0210, -0.0120],
        [ 0.0653, -0.0405, -0.0496]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[-0.0031,  0.0121, -0.0709],
        [ 0.0624,  0.0210, -0.0120],
        [ 0.0653, -0.0405, -0.0496]], device='cuda:0',
       grad_fn=<SelectBackward0>)


In [7]:
print(net.trunk[0].stride)

(1, 1)


In [8]:
output = functional_call(net.to(device), params, (x_support,))

## INNER LOOP

## NTK computation

Note :
USE FAST WEIGHT AS THEY DID FOR MAML (nothing is supposed to pose a problem except for batch norm layers as usual)

In [9]:
def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

start_time = time.time()

# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x_support)

# print(net.trunk[0].weight)
s_jac = {k : s[k]*j for (k, j) in jac1.items()}   # Useful for later
ntk_jac1 = [s_j.flatten(2) for s_j in s_jac.values()]   # Useful for the NTK computation
    
# Compute J(x1) @ J(x2).T
ntk = torch.stack([torch.einsum('Naf,Maf->aNM', j1, j2) for j1, j2 in zip(ntk_jac1, ntk_jac1)])
ntk = ntk.sum(0)

print(f"Jacobian contraction time {time.time()-start_time}")
print(ntk.shape)

Jacobian contraction time 0.16933584213256836
torch.Size([5, 35, 35])


## Computation of $\text{Sol}_c = (NTK_c + \epsilon I_k )^{-1} (Y - \phi_{\theta, c} (X))$

In [10]:
# Creation of Y

target_list = list()
samples_per_model = int(len(y_support) / n_way) #25 / 5 = 5
for way in range(n_way):
    target = torch.ones(len(y_support), dtype=torch.float32) * -1.0
    start_index = way * samples_per_model
    stop_index = start_index+samples_per_model
    target[start_index:stop_index] = 1.0
    target_list.append(target.cuda())
    
print(target_list)
print(target_list[0].shape)

[tensor([ 1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1., -1., -1.], device='cuda:0'), tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -

In [11]:
# Forward pass

phi = functional_call(net, params, (x_support,))
print(phi.shape)

torch.Size([35, 5])


In [12]:
# Do the actual computation

start_time = time.time()

sols = []

for c in range(n_way):
    inverse_term = ntk[c] + eps * torch.eye(n_way * n_support, device=device)
    residual = target_list[c] - phi[:, c]  # phi is of shape [n_way*n_support, n_way]

    # Solve the system (NTK_c + epsilon I_k) * result = residual
    sols.append(torch.linalg.solve(inverse_term, residual))
    
print(f"Total inversions time {time.time()-start_time}")

Total inversions time 0.13020730018615723


## Computation of $\theta - \eta_{in} \sum_{c \leq C} (s \cdot \nabla_\theta \phi_c) \times \text{Sol}_c$

In [13]:
params = {k: v for k, v in net.named_parameters()}

In [14]:
# We already have the first term computed as s_jac
start_time = time.time()
print(type(params))
# print(params["trunk.0.weight"])

# for c in range(n_way):
    # for (k, param) in params.items():
    #     params[k] = param - lr_in * torch.tensordot(s_jac[k], sols[c], dims=([0], [0]))

    
print(sols[0].shape)
print(s_jac["trunk.0.weight"][:, 0].shape)
print(torch.tensordot(s_jac["trunk.0.weight"][:,0], sols[0], dims=([0], [0])).shape)
print(sum(torch.tensordot(s_jac["trunk.0.weight"][:,c], sols[c], dims=([0], [0])) for c in range(n_way)).shape)

tensor_update = {k : s[k] * sum(torch.tensordot(s_jac[k][:,c], sols[c], dims=([0], [0])) for c in range(n_way)) for k in params.keys()}
params = {k: param + lr_in * tensor_update[k] for k, param in params.items()}
    
print(net.trunk[0].weight[0][0])
# print(list(new_params.values())[0].shape)
print(list(params.values())[0].shape)
print(f"Inner update time time {time.time()-start_time}")
print(type(params))
# print(params["trunk.0.weight"])
# print(list( net.named_parameters() )[0])

<class 'dict'>
torch.Size([35])
torch.Size([35, 64, 3, 3, 3])
torch.Size([64, 3, 3, 3])
torch.Size([64, 3, 3, 3])
tensor([[-0.0031,  0.0121, -0.0709],
        [ 0.0624,  0.0210, -0.0120],
        [ 0.0653, -0.0405, -0.0496]], device='cuda:0',
       grad_fn=<SelectBackward0>)
torch.Size([64, 3, 3, 3])
Inner update time time 0.004975795745849609
<class 'dict'>


# Define inner_loop function

In [15]:
def contruct_target_list(N, n_way):  # N = n_support or n_query
    target_list = list()
    samples_per_model = int(N)
    for c in range(n_way):
        target = torch.ones(N * n_way, dtype=torch.float32) * -1.0
        start_index = c * samples_per_model
        stop_index = start_index+samples_per_model
        target[start_index:stop_index] = 1.0
        target_list.append(target.cuda())
    return target_list


def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

def inner_loop(x_support, target_list_support):
    
    # Create a param dict
    params = {k: v for k, v in net.named_parameters()}

    for inner_epoch in range(n_inner_upd):
        # Forward pass
        phi = functional_call(net, params, (x_support,))

        # Compute J(x1)
        jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x_support)
        s_jac = {k : s[k]*j for (k, j) in jac1.items()}   # Useful for later
        ntk_jac1 = [s_j.flatten(2) for s_j in s_jac.values()]   # Useful for the NTK computation

        # Compute J(x1) @ J(x2).T
        ntk = torch.stack([torch.einsum('Naf,Maf->aNM', j1, j2) for j1, j2 in zip(ntk_jac1, ntk_jac1)])
        ntk = ntk.sum(0)

        # Compute solutions to (NTK_c + eps I)^-1 (Y_c - phi_c)
        sols = []
        for c in range(n_way):
            inverse_term = ntk[c] + eps * torch.eye(n_way * n_support, device=device)
            residual = target_list_support[c] - phi[:, c]  # phi is of shape [n_way*n_support, n_way]

            # Solve the system (NTK_c + epsilon I_k) * result = residual
            sols.append(torch.linalg.solve(inverse_term, residual))

        # Update parameters 
        tensor_update = {k : s[k] * sum(torch.tensordot(s_jac[k][:,c], sols[c], dims=([0], [0])) for c in range(n_way)) for k in params.keys()}
        params = {k: param + lr_in * tensor_update[k] for k, param in params.items()}
                    
    return params

# OUTER LOOP

## Outer tools

In [16]:
# Those two are the same for every task and iteration, we compute them in advance

target_list_support = contruct_target_list(n_support, n_way)
target_list_query = contruct_target_list(n_query, n_way)

loss_fn = nn.CrossEntropyLoss()

In [17]:
print_freq = 10
avg_loss = 0
task_count = 0
loss_all = []
optimizer = torch.optim.Adam([{'params': net.parameters(), 'lr': lr_out},
                              {'params': s.values(), 'lr': lr_out}])

start_time = time.time()

for i_task in range(n_task):
    # New batch corresponding to a task
    x_support = torch.randn(n_way * n_support, 3, 84, 84, device=device)
    x_query = torch.randn(n_way * n_query, 3, 84, 84, device=device)
    
    # Inner updates
    inner_params = inner_loop(x_support, target_list_support)
    
    # outer optimization
    scores = functional_call(net, inner_params, (x_query,))
    loss = loss_fn( scores, y_query )
    loss_all.append(loss)
    
loss_q = torch.stack(loss_all).sum(0)
loss_q.backward()
optimizer.step
optimizer.zero_grad()

print(f"Total time for one iter over 4 tasks : {time.time() - start_time}")

Total time for one iter over 4 tasks : 0.9371540546417236


# TEST LOOP

## Test parameters

The test loop is directly copied from differentialDKTIX

/!\ We will not redifine n_support, n_query or the optimizer lr.
These quantities will not be the same as for the training procedure, but will be redefine in the test script. This works because we reconstruct a new neural net in the meta-test script, on which we paste the new values of the above hyperparameters.

# Choleski safe decomposition

In [18]:
def psd_safe_cholesky(A, upper=False, out=None, jitter=None):
    """Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal.
    Args:
        :attr:`A` (Tensor):
            The tensor to compute the Cholesky decomposition of
        :attr:`upper` (bool, optional):
            See torch.cholesky
        :attr:`out` (Tensor, optional):
            See torch.cholesky
        :attr:`jitter` (float, optional):
            The jitter to add to the diagonal of A in case A is only p.s.d. If omitted, chosen
            as 1e-6 (float) or 1e-8 (double)
    """
    try:
        if A.dim() == 2:
            L = torch.linalg.cholesky(A, upper=upper, out=out)
            return L
        else:
            L_list = []
            for idx in range(A.shape[0]):
                L = torch.linalg.cholesky(A[idx], upper=upper, out=out)
                L_list.append(L)
            return torch.stack(L_list, dim=0)
    except:
        isnan = torch.isnan(A)
        if isnan.any():
            raise NanError(
                f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
            )

        if jitter is None:
            jitter = 1e-6 if A.dtype == torch.float32 else 1e-8
        Aprime = A.clone()
        jitter_prev = 0
        for i in range(8):
            jitter_new = jitter * (10 ** i)
            Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
            jitter_prev = jitter_new
            try:
                if Aprime.dim() == 2:
                    L = torch.linalg.cholesky(Aprime, upper=upper, out=out)
                    warnings.warn(
                        f"A not p.d., added jitter of {jitter_new} to the diagonal",
                        RuntimeWarning,
                    )
                    return L
                else:
                    L_list = []
                    for idx in range(Aprime.shape[0]):
                        L = torch.linalg.cholesky(Aprime[idx], upper=upper, out=out)
                        L_list.append(L)
                    warnings.warn(
                        f"A not p.d., added jitter of {jitter_new} to the diagonal",
                        RuntimeWarning,
                    )
                    return torch.stack(L_list, dim=0)
            except:
                continue


In [29]:
def solve_using_cholesky(A, X):
    # Step 1: Perform Cholesky decomposition to get L
    L = psd_safe_cholesky(A)
    
    print(type(L))
    print(L.shape)
    print(L)
    
    X = X.unsqueeze(1)
    
    # Step 2: Solve L * Z = X for Z using forward substitution
    Z = torch.linalg.solve_triangular(L, X, upper=False)  # 'upper=False' indicates L is lower triangular
    
    # Step 3: Solve L^T * Y = Z for Y using backward substitution
    Y = torch.linalg.solve_triangular(L.T, Z, upper=True)  # 'upper=True' indicates L^T is upper triangular

    return Y.squeeze()

# Example usage
A = torch.tensor([[4.0, 2.0], [2.0, 3.0]], dtype=torch.float32)
X = torch.tensor([1.0, 2.0], dtype=torch.float32)

# Solve A * Y = X using Cholesky decomposition
Y = solve_using_cholesky(A, X)

print("Solution Y:", Y)

<class 'torch.Tensor'>
torch.Size([2, 2])
tensor([[2.0000, 0.0000],
        [1.0000, 1.4142]])
Solution Y: tensor([-0.1250,  0.7500])


In [30]:
# Solve AY = X using torch.linalg.solve
Y = torch.linalg.solve(A, X)

print("Solution Y:", Y)

Solution Y: tensor([-0.1250,  0.7500])


In [33]:
# Define a 3x3 positive semi-definite matrix A
A = torch.tensor([[4.0, 2.0, 2.0], 
                  [2.0, 3.0, 1.0], 
                  [2.0, 1.0, 2.0]], dtype=torch.float32)

# Ensure A is symmetric and positive semi-definite
A = A @ A.T  # This operation makes A symmetric and PSD

# Define a 3-dimensional vector X
X = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)

# Solve A * Y = X using Cholesky decomposition
Y = solve_using_cholesky(A, X)
print("Solution Cholesky Y:", Y)

# Solve AY = X using torch.linalg.solve
Y = torch.linalg.solve(A, X)
print("Solution torch.linalg.solve Y:", Y)

<class 'torch.Tensor'>
torch.Size([3, 3])
tensor([[ 4.8990,  0.0000,  0.0000],
        [ 3.2660,  1.8257,  0.0000],
        [ 2.8577, -0.1826,  0.8944]])
Solution Cholesky Y: tensor([-2.2969,  0.7187,  3.1875])
Solution torch.linalg.solve Y: tensor([-2.2969,  0.7188,  3.1875])


In [35]:
# Define a 7x7 positive semi-definite matrix A
A = torch.tensor([
    [4.0, 2.0, 1.0, 3.0, 2.0, 1.0, 1.0], 
    [2.0, 5.0, 2.0, 2.0, 1.0, 0.5, 1.0], 
    [1.0, 2.0, 6.0, 1.0, 2.0, 1.0, 0.5], 
    [3.0, 2.0, 1.0, 7.0, 3.0, 2.0, 1.0], 
    [2.0, 1.0, 2.0, 3.0, 8.0, 2.0, 1.0], 
    [1.0, 0.5, 1.0, 2.0, 2.0, 9.0, 2.0],
    [1.0, 1.0, 0.5, 1.0, 1.0, 2.0, 10.0]
], dtype=torch.float32)

# Ensure A is symmetric and positive semi-definite
A = A @ A.T  # This operation makes A symmetric and PSD

# Define a 7-dimensional vector X
X = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=torch.float32)

# Solve A * Y = X using Cholesky decomposition
Y = solve_using_cholesky(A, X)
print("Solution Cholesky Y:", Y)

# Solve AY = X using torch.linalg.solve
Y = torch.linalg.solve(A, X)
print("Solution linalg.solve Y:", Y)

<class 'torch.Tensor'>
torch.Size([7, 7])
tensor([[ 6.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 4.9167,  3.8828,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 3.7500,  2.7203,  5.0781,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 7.8333, -0.3899,  0.0366,  3.9352,  0.0000,  0.0000,  0.0000],
        [ 6.6667, -0.9729,  2.9827,  1.8524,  5.4112,  0.0000,  0.0000],
        [ 4.5000, -0.8048,  1.8342,  2.3809,  1.0785,  8.0098,  0.0000],
        [ 3.9167,  0.7065, -0.1200, -0.2289,  0.8054,  3.3509,  8.9703]])
Solution Cholesky Y: tensor([-0.3612,  0.0845,  0.0217,  0.1474,  0.0540,  0.0192,  0.0603])
Solution linalg.solve Y: tensor([-0.3612,  0.0845,  0.0217,  0.1474,  0.0540,  0.0192,  0.0603])
