# 5-Layers Specialisation

## 1st way - 1 shot

In [1]:
## Original packages
import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from torch.func import functional_call, vmap, vjp, jvp, jacrev
from methods.meta_template import MetaTemplate
import math
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR
import warnings
from torch.distributions import MultivariateNormal
import warnings

In [2]:
class simple_netC_0hl(nn.Module):
    def __init__(self):
        super(simple_netC_0hl, self).__init__()
        self.layer1 = nn.Linear(1600, 5)
        
    def forward(self, x):
        out = self.layer1(x)
        return out

net = simple_netC_0hl()

In [3]:
c=0

def compute_jacobian(inputs):   # i is the class label, and corresponds to the output targeted
    """
    Return the jacobian of a batch of inputs, thanks to the vmap functionality
    """
    net.zero_grad()
    params = {k: v for k, v in net.named_parameters()}

    def fnet_single(params, x):
        # Make sure output has the right dimensions
        return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)[c]

    jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
    jac_values = jac.values()

    reshaped_tensors = []
    for j in jac_values:
        if len(j.shape) == 3:  # For layers with weights
            # Flatten parameters dimensions and then reshape
            flattened = j.flatten(start_dim=1)  # Flattens to [batch, params]
            reshaped = flattened.T  # Transpose to align dimensions as [params, batch]
            reshaped_tensors.append(reshaped)
        elif len(j.shape) == 2:  # For biases or single parameter components
            reshaped_tensors.append(j.T)  # Simply transpose

    # Concatenate all the reshaped tensors into one large matrix
    return torch.cat(reshaped_tensors, dim=0).T

In [4]:
x = torch.empty(3, 1600)

# Fill each row of the tensor with the row index
for i in range(3):
    x[i] = i

print(x)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [2., 2., 2.,  ..., 2., 2., 2.]])


In [5]:
print(compute_jacobian(x))

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [2., 2., 2.,  ..., 0., 0., 0.]])


In [6]:
# Constructing the specialisation matrix
with torch.no_grad():
    spe = net(x)
print(spe.shape)
print(spe)

torch.Size([3, 5])
tensor([[ 0.0219,  0.0210, -0.0219, -0.0019, -0.0085],
        [ 0.4146, -0.1063,  0.2000,  0.7586,  0.7553],
        [ 0.8073, -0.2337,  0.4219,  1.5190,  1.5191]])


In [7]:
# Generate a single random number between 0 and n_classes
random_class = torch.randint(low=0, high=5, size=(1,))

print(random_class)

tensor([4])


In [8]:
col = spe[:, random_class].flatten()
print(col)

tensor([-0.0085,  0.7553,  1.5191])


In [9]:
with torch.no_grad():
    softmax_col = F.softmax(col, dim=0)

print(softmax_col)

tensor([0.1290, 0.2768, 0.5942])


In [10]:
# Use multinomial to pick an index based on the weights
# The second argument 'num_samples' is the number of indices to sample
# 'replacement=True' allows picking the same index more than once if num_samples > 1
random_index = torch.multinomial(softmax_col, num_samples=1, replacement=True)

print(random_index)


tensor([2])


In [11]:
print(f"So input number {random_index[0]} will have class number {random_class[0]}")

So input number 2 will have class number 4


In [12]:
#Now updating the matrix :
i = random_index[0]
j = random_class[0]
# Remove the ith row
new_spe = torch.cat((spe[:i], spe[i+1:]))

# Remove the jth column
new_spe = torch.cat((new_spe[:, :j], new_spe[:, j+1:]), dim=1)

print("Modified Matrix:\n", new_spe)

Modified Matrix:
 tensor([[ 0.0219,  0.0210, -0.0219, -0.0019],
        [ 0.4146, -0.1063,  0.2000,  0.7586]])


In [13]:
#repeat the process

x = torch.empty(5, 1600)
for i in range(5):
    x[i] = i
    
# Constructing the specialisation matrix
with torch.no_grad():
    spe = net(x)

classes = torch.tensor([0, 1, 2, 3, 4])
for _ in range(5):
    # Pick a class randomly with equal probability
    random_class = classes[torch.randint(low=0, high=len(classes), size=(1,))]
    col = spe[:, random_class].flatten()
    with torch.no_grad():
        softmax_col = F.softmax(col, dim=0)
    random_index = torch.multinomial(softmax_col, num_samples=1, replacement=True)
    
    print(f"Input number {random_index[0]} will have class number {random_class[0]}")
    
    i = random_index[0]
    j = random_class[0]
    
    # Remove the ith row
    spe[i] = float('-inf')
    
    # can't pick the jth class anymore
    mask = classes != random_class
    classes = classes[mask]
    

Input number 1 will have class number 3
Input number 4 will have class number 4
Input number 2 will have class number 1
Input number 0 will have class number 2
Input number 3 will have class number 0


## Multiple shot

In [14]:
n_shot = 3
# For this example, all the inputs of the same class are next to one another
x = torch.empty(n_shot*5, 1600)

# Fill each row of the tensor with the row index
for i in range(n_shot*5):
    x[i] = i

print(x)

tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
        ...,
        [12., 12., 12.,  ..., 12., 12., 12.],
        [13., 13., 13.,  ..., 13., 13., 13.],
        [14., 14., 14.,  ..., 14., 14., 14.]])


In [15]:
# Constructing the specialisation matrix
with torch.no_grad():
    spe = net(x)
print(spe.shape)
print(spe)

torch.Size([15, 5])
tensor([[ 2.1855e-02,  2.1019e-02, -2.1882e-02, -1.8952e-03, -8.4974e-03],
        [ 4.1456e-01, -1.0635e-01,  2.0000e-01,  7.5858e-01,  7.5529e-01],
        [ 8.0727e-01, -2.3371e-01,  4.2189e-01,  1.5190e+00,  1.5191e+00],
        [ 1.2000e+00, -3.6108e-01,  6.4377e-01,  2.2795e+00,  2.2829e+00],
        [ 1.5927e+00, -4.8845e-01,  8.6565e-01,  3.0400e+00,  3.0466e+00],
        [ 1.9854e+00, -6.1581e-01,  1.0875e+00,  3.8005e+00,  3.8104e+00],
        [ 2.3781e+00, -7.4318e-01,  1.3094e+00,  4.5609e+00,  4.5742e+00],
        [ 2.7708e+00, -8.7054e-01,  1.5313e+00,  5.3214e+00,  5.3380e+00],
        [ 3.1635e+00, -9.9791e-01,  1.7532e+00,  6.0819e+00,  6.1018e+00],
        [ 3.5562e+00, -1.1253e+00,  1.9751e+00,  6.8423e+00,  6.8656e+00],
        [ 3.9489e+00, -1.2526e+00,  2.1970e+00,  7.6028e+00,  7.6294e+00],
        [ 4.3416e+00, -1.3800e+00,  2.4188e+00,  8.3633e+00,  8.3931e+00],
        [ 4.7343e+00, -1.5074e+00,  2.6407e+00,  9.1237e+00,  9.1569e+00],
     

In [16]:
# Reshape the matrix to (K/n_shot, n_shot, C)
reshape_spe = spe.view(5, n_shot, 5)

# Compute the mean along the middle dimension
spe = reshape_spe.mean(dim=1)

print(spe)

tensor([[ 0.4146, -0.1063,  0.2000,  0.7586,  0.7553],
        [ 1.5927, -0.4884,  0.8657,  3.0400,  3.0466],
        [ 2.7708, -0.8705,  1.5313,  5.3214,  5.3380],
        [ 3.9489, -1.2526,  2.1970,  7.6028,  7.6293],
        [ 5.1270, -1.6347,  2.8626,  9.8842,  9.9207]])


In [17]:
classes = torch.tensor([0, 1, 2, 3, 4])
for _ in range(5):
    # Pick a class randomly with equal probability
    random_class = classes[torch.randint(low=0, high=len(classes), size=(1,))]
    col = spe[:, random_class].flatten()
    with torch.no_grad():
        softmax_col = F.softmax(col, dim=0)
    random_index = torch.multinomial(softmax_col, num_samples=1, replacement=True)
    
    print(f"Input number {random_index[0]} will have class number {random_class[0]}")
    
    i = random_index[0]
    j = random_class[0]
    
    # Remove the ith row
    spe[i] = float('-inf')
    
    # can't pick the jth class anymore
    mask = classes != random_class
    classes = classes[mask]

Input number 3 will have class number 4
Input number 4 will have class number 3
Input number 2 will have class number 0
Input number 1 will have class number 2
Input number 0 will have class number 1


## 3rd option

In [18]:
#repeat the process

class CDKT_dummy_net(nn.Module):
    def __init__(self):
        super(simple_netC_0hl, self).__init__()
        self.layer1 = nn.Linear(1600, 1600)
        
    def forward(self, x):
        out = self.layer1(x)
        return out

net = CDKT_dummy_net()

x = torch.empty(5, 1600)
for i in range(5):
    x[i] = i
    
# Constructing the specialisation matrix
with torch.no_grad():
    spe = net(x)

flattened_spe = spe.flatten()
for _ in range(5):
    #Take the softmax of all the elements in the matrix
    with torch.no_grad():
        softmax_matrix = F.softmax(flattened_spe, dim=0)

    rd_element_idx = torch.multinomial(softmax_matrix, num_samples=1, replacement=True)
    rd_elemt = rd_element_idx // 5 # Indice of the row
    rd_class = rd_element_idx % 5 # Indice of the column
    indices_1 = torch.tensor([5 * i + rd_class for i in range(5)])
    indices_2 = torch.tensor([i + rd_elemt * 5 for i in range(5)])

    # Combine indices from both calculations, ensuring uniqueness if necessary
    all_indices = torch.cat((indices_1, indices_2)).unique()
    flattened_spe[all_indices] = float('-inf')
    print(f"Input number {rd_elemt[0]} will have class number {rd_class[0]}")
    

TypeError: super(type, obj): obj must be an instance or subtype of type

# Specialization on kernel Matrix for CDKT

CDKT matrix of model is size $[5, 85, 85]$. We want to measure the compatibility of the class label to the input. To do so, let's average out all the compatibility information we have from inputs of one class. Then, we end up with a matrix of the form $\left(k_{c'}(X_c)\right) = \sum_{1 \leq i \leq j \leq n/c} k_{c'}(x_{ci}, x_{cj})$. We don't use the information we have about inputs of different classes.

In [95]:
z_batch = torch.randn(85, 1600)
sorted_z_batch = torch.randn(85, 1600)

# model matrix
model = torch.randn(5, 85, 85)

# specialization matrix
spe = torch.empty(5, 5)
for c in range(5):
    for cp in range(5):
        matrix = model[cp][c*17:(c+1)*17]
        
        upper_triangular = torch.triu(matrix, diagonal=0)
        # Sum all the elements above the diagonal
        sum_upper_triangular = upper_triangular.sum()
        spe[c][cp]

transformation = dict()
flattened_spe = spe.flatten()
for _ in range(5):
    #Take the softmax of all the elements in the matrix
    print(flattened_spe.reshape(5,5))
    with torch.no_grad():
        softmax_matrix = F.softmax(flattened_spe, dim=0)
    
    print(softmax_matrix.reshape(5,5))
    rd_element_idx = torch.multinomial(softmax_matrix, num_samples=1, replacement=True)
    rd_class = rd_element_idx // 5 # Indice of the row
    rd_elemt = rd_element_idx % 5 # Indice of the column
    indices_1 = torch.tensor([5 * i + rd_elemt for i in range(5)])
    indices_2 = torch.tensor([i + rd_class * 5 for i in range(5)])

    # Combine indices from both calculations, ensuring uniqueness if necessary
    all_indices = torch.cat((indices_1, indices_2)).unique()
    flattened_spe[all_indices] = float('-inf')
    print(f"Input number {rd_elemt[0]} will have class number {rd_class[0]}")
    
    tranformations[rd_class] = rd_elemt
    sorted_z_batch[n_shot*rd_class:n_shot*(rd_class+1)] = z_batch[n_shot*rd_elemt:n_shot*(rd_elemt+1)]

tensor([[-3.0554e-30,  3.0632e-41,  0.0000e+00,  4.0149e-02,  4.0600e-02],
        [ 4.6261e-02,  4.1081e-02,  4.1433e-02,  4.1500e-02,  4.0357e-02],
        [ 4.1334e-02,  4.2125e-02,  4.5181e-02,  4.0624e-02,  4.1709e-02],
        [ 4.3186e-02,  4.0793e-02,  4.1136e-02,  4.1149e-02,  4.6053e-02],
        [ 4.0559e-02,  4.1544e-02,  4.2045e-02,  4.0732e-02,  4.0149e-02]])
tensor([[0.0386, 0.0386, 0.0386, 0.0401, 0.0401],
        [0.0404, 0.0402, 0.0402, 0.0402, 0.0401],
        [0.0402, 0.0402, 0.0403, 0.0401, 0.0402],
        [0.0403, 0.0402, 0.0402, 0.0402, 0.0404],
        [0.0401, 0.0402, 0.0402, 0.0402, 0.0401]])
Input number 4 will have class number 1


NameError: name 'tranformations' is not defined

In [26]:
matrix = torch.randn(85, 85)  # Example tensor, replace with your actual tensor

# Create an upper triangular matrix with diagonal elements set to zero
upper_triangular = torch.triu(matrix, diagonal=0)

# Sum all the elements above the diagonal
sum_upper_triangular = upper_triangular.sum()

print(upper_triangular)

tensor([[-0.6797, -0.4558, -0.5233,  ...,  0.5295, -0.3898,  0.1230],
        [ 0.0000, -0.1272, -1.7410,  ..., -0.2704,  0.5932, -1.1409],
        [ 0.0000,  0.0000, -1.9735,  ...,  0.5691,  2.3677,  0.3890],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -1.3393,  0.2762,  0.0635],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.6384,  0.5418],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.8604]])
