In [5]:
## Original packages
import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import math
import torch.nn.functional as F
from torch.func import functional_call, vmap, vjp, jvp, jacrev

## Our packages
import gpytorch
from time import gmtime, strftime
import random
from statistics import mean
from data.qmul_loader import get_batch, train_people, test_people


class UnLiMiTDI(nn.Module):
    def __init__(self, conv_net, diff_net):
        super(UnLiMiTDI, self).__init__()
        ## GP parameters
        self.feature_extractor = conv_net
        self.diff_net = diff_net  #Differentiable network
        self.get_model_likelihood_mll() #Init model, likelihood, and mll

    def get_model_likelihood_mll(self, train_x=None, train_y=None):
        if(train_x is None): train_x=torch.ones(19, 2916).cuda()
        if(train_y is None): train_y=torch.ones(19).cuda()

        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, diff_net = self.diff_net, kernel='NTKcossim')

        self.model      = model.cuda()
        self.likelihood = likelihood.cuda()
        self.mll        = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model).cuda()
        self.mse        = nn.MSELoss()

        return self.model, self.likelihood, self.mll

    def set_forward(self, x, is_feature=False):
        pass

    def set_forward_loss(self, x):
        pass

    def train_loop(self, epoch, optimizer):
        batch, batch_labels = get_batch(train_people)
        batch, batch_labels = batch.cuda(), batch_labels.cuda()
        for inputs, labels in zip(batch, batch_labels):
            optimizer.zero_grad()

            inputs_conv = self.feature_extractor(inputs)
            self.model.set_train_data(inputs=inputs_conv, targets=labels - self.diff_net(inputs_conv).reshape(-1))  
            predictions = self.model(inputs_conv)
            loss = -self.mll(predictions, self.model.train_targets)

            loss.backward()
            optimizer.step()
            mse = self.mse(predictions.mean, labels)

            if (epoch%10==0):
                print('[%d] - Loss: %.3f  MSE: %.3f noise: %.3f' % (
                    epoch, loss.item(), mse.item(),
                    self.model.likelihood.noise.item()
                ))

    def test_loop(self, n_support, optimizer=None): # no optimizer needed for GP
        inputs, targets = get_batch(test_people)

        support_ind = list(np.random.choice(list(range(19)), replace=False, size=n_support))
        query_ind   = [i for i in range(19) if i not in support_ind]

        x_all = inputs.cuda()
        y_all = targets.cuda()

        x_support = inputs[:,support_ind,:,:,:].cuda()
        y_support = targets[:,support_ind].cuda()

        # choose a random test person
        n = np.random.randint(0, len(test_people)-1)
    
        x_conv_support = self.feature_extractor(x_support[n]).detach()
        self.model.set_train_data(inputs=x_conv_support, targets=y_support[n] - self.diff_net(x_conv_support).reshape(-1), strict=False)

        self.model.eval()
        self.feature_extractor.eval()
        self.likelihood.eval()

        with torch.no_grad():
            x_conv_query = self.feature_extractor(x_all[n]).detach()
            pred    = self.likelihood(self.model(x_conv_query))
            lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
            lower += self.diff_net(x_conv_query).reshape(-1)
            upper += self.diff_net(x_conv_query).reshape(-1)
        mse = self.mse(pred.mean + self.diff_net(self.feature_extractor(x_all[n])).reshape(-1), y_all[n])

        return mse

    def save_checkpoint(self, checkpoint):
        # save state
        gp_state_dict         = self.model.state_dict()
        likelihood_state_dict = self.likelihood.state_dict()
        conv_net_state_dict   = self.feature_extractor.state_dict()
        diff_net_state_dict   = self.diff_net.state_dict()
        torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'conv_net':conv_net_state_dict, 'diff_net':diff_net_state_dict}, checkpoint)

    def load_checkpoint(self, checkpoint):
        ckpt = torch.load(checkpoint)
        self.model.load_state_dict(ckpt['gp'])
        self.likelihood.load_state_dict(ckpt['likelihood'])
        self.feature_extractor.load_state_dict(ckpt['conv_net'])
        self.diff_net.load_state_dict(ckpt['diff_net'])

        
###################
#NTKernel
###################
        
class NTKernel(gpytorch.kernels.Kernel):
    def __init__(self, net, **kwargs):
        super(NTKernel, self).__init__(**kwargs)
        self.net = net

    def forward(self, x1, x2, diag=False, **params):
        jac1 = self.compute_jacobian(x1)
        jac2 = self.compute_jacobian(x2) if x1 is not x2 else jac1
        
        result = jac1@jac2.T
        
        if diag:
            return result.diag()
        return result
    
    def compute_jacobian(self, inputs):
        """
        Return the jacobian of a batch of inputs, thanks to the vmap functionality
        """
        self.zero_grad()
        params = {k: v for k, v in self.net.named_parameters()}
        def fnet_single(params, x):
            return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)
        
        jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
        jac = jac.values()
        # jac1 of dimensions [Nb Layers, Nb input / Batch, dim(y), Nb param/layer left, Nb param/layer right]
        reshaped_tensors = [
            j.flatten(2)                # Flatten starting from the 3rd dimension to acount for weights and biases layers
                .permute(2, 0, 1)         # Permute to align dimensions correctly for reshaping
                .reshape(-1, j.shape[0] * j.shape[1])  # Reshape to (c, a*b) using dynamic sizing
            for j in jac
        ]
        return torch.cat(reshaped_tensors, dim=0).T
    
    
    
###################
#NTKernel CosSim
###################
class CosSimNTKernel(gpytorch.kernels.Kernel):
    def __init__(self, net, **kwargs):
        super(CosSimNTKernel, self).__init__(**kwargs)
        self.net = net
        
        self.alpha = nn.Parameter(torch.ones(1))

    def forward(self, x1, x2, diag=False, **params):
        jac1T = self.compute_jacobian(x1).T
        jac1T_norm = jac1T.norm(dim=0, keepdim=True)
        jac1T_normalized = jac1T/jac1T_norm
        #print(jac1.shape)
        #print(jac1.norm(dim=0, keepdim=True).shape)
        jac2T = self.compute_jacobian(x2).T if x1 is not x2 else jac1T
        jac2T_norm = jac2T.norm(dim=0, keepdim=True)
        jac2T_normalized = jac2T/jac2T_norm
        
        result = self.alpha * jac1T_normalized.T@jac2T_normalized
        
        if diag:
            return result.diag()
        return result
    
    def compute_jacobian(self, inputs):
        """
        Return the jacobian of a batch of inputs, thanks to the vmap functionality
        """
        self.zero_grad()
        params = {k: v for k, v in self.net.named_parameters()}
        def fnet_single(params, x):
            return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)
        
        jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
        jac = jac.values()
        # jac1 of dimensions [Nb Layers, Nb input / Batch, dim(y), Nb param/layer left, Nb param/layer right]
        reshaped_tensors = [
            j.flatten(2)                # Flatten starting from the 3rd dimension to acount for weights and biases layers
                .permute(2, 0, 1)         # Permute to align dimensions correctly for reshaping
                .reshape(-1, j.shape[0] * j.shape[1])  # Reshape to (c, a*b) using dynamic sizing
            for j in jac
        ]
        return torch.cat(reshaped_tensors, dim=0).T
    
###################
#GP
###################    
class ExactGPLayer(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, diff_net, kernel='NTK'):
        super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
        self.mean_module  = gpytorch.means.ConstantMean()

        ## NTKernel
        if(kernel=='NTK'):
            self.covar_module = NTKernel(diff_net)
        elif(kernel=='NTKcossim'):
            self.covar_module = CosSimNTKernel(diff_net)        
        else:
            raise ValueError("[ERROR] the kernel '" + str(kernel) + "' is not supported for regression, use 'rbf' or 'spectral'.")

    def forward(self, x):
        mean_x  = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import configs
from data.qmul_loader import get_batch, train_people, test_people
from io_utils import parse_args_regression, get_resume_file
from methods.DKT_regression import DKT
from methods.feature_transfer_regression import FeatureTransfer
import backbone
import os
import numpy as np

np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

bb           = backbone.Conv3().cuda()
simple_net   = backbone.simple_net().cuda()

model = UnLiMiTDI(bb, simple_net).cuda()
optimizer = torch.optim.Adam([{'params': model.model.parameters(), 'lr': 0.001},
                                {'params': model.feature_extractor.parameters(), 'lr': 0.001}])

for epoch in range(100):
    model.train_loop(epoch, optimizer)

[0] - Loss: 0.828  MSE: 0.000 noise: 0.693
[0] - Loss: 0.824  MSE: 0.000 noise: 0.692
[0] - Loss: 0.824  MSE: 0.000 noise: 0.692
[0] - Loss: 0.823  MSE: 0.000 noise: 0.691
[0] - Loss: 0.823  MSE: 0.000 noise: 0.691
[0] - Loss: 0.822  MSE: 0.000 noise: 0.690
[0] - Loss: 0.822  MSE: 0.000 noise: 0.690
[0] - Loss: 0.821  MSE: 0.000 noise: 0.689
[0] - Loss: 0.821  MSE: 0.000 noise: 0.689
[0] - Loss: 0.821  MSE: 0.000 noise: 0.688
[0] - Loss: 0.820  MSE: 0.000 noise: 0.688
[0] - Loss: 0.820  MSE: 0.000 noise: 0.687
[0] - Loss: 0.819  MSE: 0.000 noise: 0.687
[0] - Loss: 0.819  MSE: 0.000 noise: 0.686
[0] - Loss: 0.819  MSE: 0.000 noise: 0.686
[0] - Loss: 0.818  MSE: 0.000 noise: 0.685
[0] - Loss: 0.818  MSE: 0.000 noise: 0.685
[0] - Loss: 0.818  MSE: 0.000 noise: 0.684
[0] - Loss: 0.817  MSE: 0.000 noise: 0.684
[0] - Loss: 0.817  MSE: 0.000 noise: 0.683
[0] - Loss: 0.816  MSE: 0.000 noise: 0.683
[0] - Loss: 0.816  MSE: 0.000 noise: 0.682
[0] - Loss: 0.816  MSE: 0.000 noise: 0.682
[0] - Loss:

In [7]:
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

mse_list = []
for epoch in range(10):
    mse = float(model.test_loop(5, optimizer).cpu().detach().numpy())
    mse_list.append(mse)

print("-------------------")
print("Average MSE: " + str(np.mean(mse_list)) + " +- " + str(np.std(mse_list)))
print("-------------------")


-------------------
Average MSE: 0.05402122889645398 +- 0.03329951535135017
-------------------


In [14]:
## Original packages
import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import math
import torch.nn.functional as F
from torch.func import functional_call, vmap, vjp, jvp, jacrev

## Our packages
import gpytorch
from time import gmtime, strftime
import random
from statistics import mean
from data.qmul_loader import get_batch, train_people, test_people

class UnLiMiTDproj(nn.Module):
    def __init__(self, conv_net, diff_net, P):
        super(UnLiMiTDproj, self).__init__()
        ## GP parameters
        self.feature_extractor = conv_net
        self.diff_net = diff_net  #Differentiable network
        
        input_dimension = sum(p.numel() for p in diff_net.parameters())
        self.P = P
        self.get_model_likelihood_mll() #Init model, likelihood, and mll

    def get_model_likelihood_mll(self, train_x=None, train_y=None):
        if(train_x is None): train_x=torch.ones(19, 2916).cuda()
        if(train_y is None): train_y=torch.ones(19).cuda()

        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, diff_net = self.diff_net, P = self.P, kernel='NTKcossim')

        self.model      = model.cuda()
        self.likelihood = likelihood.cuda()
        self.mll        = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model).cuda()
        self.mse        = nn.MSELoss()

        return self.model, self.likelihood, self.mll

    def set_forward(self, x, is_feature=False):
        pass

    def set_forward_loss(self, x):
        pass

    def train_loop(self, epoch, optimizer):
        batch, batch_labels = get_batch(train_people)
        batch, batch_labels = batch.cuda(), batch_labels.cuda()
        for inputs, labels in zip(batch, batch_labels):
            optimizer.zero_grad()

            inputs_conv = self.feature_extractor(inputs)
            self.model.set_train_data(inputs=inputs_conv, targets=labels - self.diff_net(inputs_conv).reshape(-1))  
            predictions = self.model(inputs_conv)
            loss = -self.mll(predictions, self.model.train_targets)

            loss.backward()
            optimizer.step()
            mse = self.mse(predictions.mean, labels)

            if (epoch%10==0):
                print('[%d] - Loss: %.3f  MSE: %.3f noise: %.3f' % (
                    epoch, loss.item(), mse.item(),
                    self.model.likelihood.noise.item()
                ))

    def test_loop(self, n_support, optimizer=None): # no optimizer needed for GP
        inputs, targets = get_batch(test_people)

        support_ind = list(np.random.choice(list(range(19)), replace=False, size=n_support))
        query_ind   = [i for i in range(19) if i not in support_ind]

        x_all = inputs.cuda()
        y_all = targets.cuda()

        x_support = inputs[:,support_ind,:,:,:].cuda()
        y_support = targets[:,support_ind].cuda()

        # choose a random test person
        n = np.random.randint(0, len(test_people)-1)
    
        x_conv_support = self.feature_extractor(x_support[n]).detach()
        self.model.set_train_data(inputs=x_conv_support, targets=y_support[n] - self.diff_net(x_conv_support).reshape(-1), strict=False)

        self.model.eval()
        self.feature_extractor.eval()
        self.likelihood.eval()

        with torch.no_grad():
            x_conv_query = self.feature_extractor(x_all[n]).detach()
            pred    = self.likelihood(self.model(x_conv_query))
            lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
            lower += self.diff_net(x_conv_query).reshape(-1)
            upper += self.diff_net(x_conv_query).reshape(-1)
        mse = self.mse(pred.mean + self.diff_net(self.feature_extractor(x_all[n])).reshape(-1), y_all[n])

        return mse

    def save_checkpoint(self, checkpoint):
        # save state
        gp_state_dict         = self.model.state_dict()
        likelihood_state_dict = self.likelihood.state_dict()
        conv_net_state_dict   = self.feature_extractor.state_dict()
        diff_net_state_dict   = self.diff_net.state_dict()
        torch.save({
            'gp': gp_state_dict,
            'likelihood': likelihood_state_dict,
            'conv_net': conv_net_state_dict,
            'diff_net': diff_net_state_dict,
            'proj_matrix': self.P  # Save the tensor directly
        }, checkpoint)

    def load_checkpoint(self, checkpoint):
        ckpt = torch.load(checkpoint)
        if 'covar_module.scaling_param' not in ckpt['gp'].keys():
            ckpt['gp']['covar_module.scaling_param'] = torch.ones(self.P.shape[0]).cuda()
        self.model.load_state_dict(ckpt['gp'])
        self.likelihood.load_state_dict(ckpt['likelihood'])
        self.feature_extractor.load_state_dict(ckpt['conv_net'])
        self.diff_net.load_state_dict(ckpt['diff_net'])
        if 'proj_matrix' in ckpt.keys():
            self.P = ckpt['proj_matrix']
        
        print(f"Total number of param that requires grad : {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")


# ##################
# NTKernel
# ##################

class NTKernel_proj(gpytorch.kernels.Kernel):
    def __init__(self, net, P, **kwargs):
        super(NTKernel_proj, self).__init__(**kwargs)
        self.net = net
        
        self.P = P # Projection matrix
        
        # Add subspace_dimension scaling parameters, initializing them as one
        self.scaling_param = nn.Parameter(torch.ones(P.shape[0]))
        
    def forward(self, x1, x2, diag=False, **params):
        jac1 = self.compute_jacobian(x1)
        jac2 = self.compute_jacobian(x2) if x1 is not x2 else jac1
        D = torch.diag(torch.pow(self.scaling_param, 2))
        
        result = torch.chain_matmul(jac1, self.P.T, D, self.P, jac2.T)
        
        if diag:
            return result.diag()
        return result
    
    def compute_jacobian(self, inputs):
        """
        Return the jacobian of a batch of inputs, thanks to the vmap functionality
        """
        self.zero_grad()
        params = {k: v for k, v in self.net.named_parameters()}
        def fnet_single(params, x):
            return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)
        
        jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
        jac = jac.values()
        # jac1 of dimensions [Nb Layers, Nb input / Batch, dim(y), Nb param/layer left, Nb param/layer right]
        reshaped_tensors = [
            j.flatten(2)                # Flatten starting from the 3rd dimension to acount for weights and biases layers
                .permute(2, 0, 1)         # Permute to align dimensions correctly for reshaping
                .reshape(-1, j.shape[0] * j.shape[1])  # Reshape to (c, a*b) using dynamic sizing
            for j in jac
        ]
        return torch.cat(reshaped_tensors, dim=0).T


    
# ##################
# NTKernel CosSim
# ##################

class CosSimNTKernel_proj(gpytorch.kernels.Kernel):
    def __init__(self, net, P, **kwargs):
        super(CosSimNTKernel_proj, self).__init__(**kwargs)
        self.net = net
        self.alpha = nn.Parameter(torch.ones(1))
        
        self.P = P # Projection matrix
        
        # Add subspace_dimension scaling parameters, initializing them as one
        self.scaling_param = nn.Parameter(torch.ones(P.shape[0]))
        
    def forward(self, x1, x2, diag=False, **params):
        jac1 = self.compute_jacobian(x1)
        jac2 = self.compute_jacobian(x2) if x1 is not x2 else jac1
        
        D = torch.diag(self.scaling_param)
        
        result_1 = torch.chain_matmul(D, self.P, jac1.T)
        result_2 = torch.chain_matmul(D, self.P, jac2.T)
        
        result_1_norm = result_1.norm(dim=0, keepdim=True)
        result_1_normalized = result_1/result_1_norm
        #print(result_1.shape)
        #print(result_1.norm(dim=0, keepdim=True).shape)
        result_2_norm = result_2.norm(dim=0, keepdim=True)
        result_2_normalized = result_2/result_2_norm
        
        result = self.alpha * result_1_normalized.T@result_2_normalized
        
        if diag:
            return result.diag()
        return result
    
    def compute_jacobian(self, inputs):
        """
        Return the jacobian of a batch of inputs, thanks to the vmap functionality
        """
        self.zero_grad()
        params = {k: v for k, v in self.net.named_parameters()}
        def fnet_single(params, x):
            return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)
        
        jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
        jac = jac.values()
        # jac1 of dimensions [Nb Layers, Nb input / Batch, dim(y), Nb param/layer left, Nb param/layer right]
        reshaped_tensors = [
            j.flatten(2)                # Flatten starting from the 3rd dimension to acount for weights and biases layers
                .permute(2, 0, 1)         # Permute to align dimensions correctly for reshaping
                .reshape(-1, j.shape[0] * j.shape[1])  # Reshape to (c, a*b) using dynamic sizing
            for j in jac
        ]
        return torch.cat(reshaped_tensors, dim=0).T

# ##################
# GP layer
# ##################
class ExactGPLayer(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, diff_net, P, kernel='NTK'):
        super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
        self.mean_module  = gpytorch.means.ConstantMean()

        ## NTKernel
        if(kernel=='NTK'):
            self.covar_module = NTKernel_proj(diff_net, P)
        elif(kernel=='NTKcossim'):
            self.covar_module = CosSimNTKernel_proj(diff_net, P)  
        else:
            raise ValueError("[ERROR] the kernel '" + str(kernel) + "' is not supported for regression, use 'rbf' or 'spectral'.")

    def forward(self, x):
        mean_x  = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)



In [15]:
from projection import create_random_projection_matrix, proj_sketch

# FIM proj search needs no gradient
for param in model.model.parameters():
    param.requires_grad_(False)
for param in model.feature_extractor.parameters():
    param.requires_grad_(False)
optimizer = None
# Batch preparation
nb_batch_proj = 10
    
batches = []
for _ in range(nb_batch_proj):
    batch, batch_labels = get_batch(train_people)
    for person_task in batch :
        person_conv = model.feature_extractor(person_task.cuda()).detach()
        batches.append(person_conv)  
batches = torch.stack(batches)
# FIM projection computation
input_dimension = sum(p.numel() for p in simple_net.parameters())
P = proj_sketch(model.diff_net, batches, 100).cuda()
# Gradients back to training mode
for param in model.model.parameters():
    param.requires_grad_(True)
for param in model.feature_extractor.parameters():
    param.requires_grad_(True)
    
# Unlimitd-F training
model = UnLiMiTDproj(bb, simple_net, P).cuda()
optimizer = torch.optim.Adam([{'params': model.model.parameters(), 'lr': 0.001},
                            {'params': model.feature_extractor.parameters(), 'lr': 0.001}])
for epoch in range(100):
    model.train_loop(epoch, optimizer)

118361
U shape: torch.Size([118361, 402])
Index tensor: tensor([401, 400, 399, 398, 397, 396, 395, 394, 393, 392, 391, 390, 389, 388,
        387, 386, 385, 384, 383, 382, 381, 380, 379, 378, 377, 376, 375, 374,
        373, 372, 371, 370, 369, 368, 367, 366, 365, 364, 363, 362, 361, 360,
        359, 358, 357, 356, 355, 354, 353, 352, 351, 350, 349, 348, 347, 346,
        345, 344, 343, 342, 341, 340, 339, 338, 337, 336, 335, 334, 333, 332,
        331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318,
        317, 316, 315, 314, 313, 312, 311, 310, 309, 308, 307, 306, 305, 304,
        303, 302, 301, 300, 299, 298, 297, 296, 295, 294, 293, 292, 291, 290,
        289, 288, 287, 286, 285, 284, 283, 282, 281, 280, 279, 278, 277, 276,
        275, 274, 273, 272, 271, 270, 269, 268, 267, 266, 265, 264, 263, 262,
        261, 260, 259, 258, 257, 256, 255, 254, 253, 252, 251, 250, 249, 248,
        247, 246, 245, 244, 243, 242, 241, 240, 239, 238, 237, 236, 235, 234,
        

  return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]


[0] - Loss: 0.830  MSE: 0.000 noise: 0.693
[0] - Loss: 0.828  MSE: 0.000 noise: 0.692
[0] - Loss: 0.830  MSE: 0.000 noise: 0.692
[0] - Loss: 0.825  MSE: 0.000 noise: 0.691
[0] - Loss: 0.827  MSE: 0.000 noise: 0.691
[0] - Loss: 0.825  MSE: 0.000 noise: 0.690
[0] - Loss: 0.824  MSE: 0.000 noise: 0.690
[0] - Loss: 0.823  MSE: 0.000 noise: 0.689
[0] - Loss: 0.823  MSE: 0.000 noise: 0.689
[0] - Loss: 0.824  MSE: 0.000 noise: 0.688
[0] - Loss: 0.823  MSE: 0.000 noise: 0.688
[0] - Loss: 0.822  MSE: 0.000 noise: 0.687
[0] - Loss: 0.824  MSE: 0.000 noise: 0.687
[0] - Loss: 0.821  MSE: 0.000 noise: 0.686
[0] - Loss: 0.820  MSE: 0.000 noise: 0.686
[0] - Loss: 0.819  MSE: 0.000 noise: 0.685
[0] - Loss: 0.819  MSE: 0.000 noise: 0.685
[0] - Loss: 0.820  MSE: 0.000 noise: 0.684
[0] - Loss: 0.820  MSE: 0.000 noise: 0.684
[0] - Loss: 0.818  MSE: 0.000 noise: 0.683
[0] - Loss: 0.819  MSE: 0.000 noise: 0.683
[0] - Loss: 0.817  MSE: 0.000 noise: 0.682
[0] - Loss: 0.816  MSE: 0.000 noise: 0.682
[0] - Loss:

In [16]:
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

mse_list = []
for epoch in range(10):
    mse = float(model.test_loop(5, optimizer).cpu().detach().numpy())
    mse_list.append(mse)

print("-------------------")
print("Average MSE: " + str(np.mean(mse_list)) + " +- " + str(np.std(mse_list)))
print("-------------------")


-------------------
Average MSE: 0.0853831883519888 +- 0.06335430772426338
-------------------


In [17]:
## Original packages
import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import math
import torch.nn.functional as F
from torch.func import functional_call, vmap, vjp, jvp, jacrev

## Our packages
import gpytorch
from time import gmtime, strftime
import random
from statistics import mean
from data.qmul_loader import get_batch, train_people, test_people


class UnLiMiTDcov(nn.Module):
    def __init__(self, conv_net, diff_net):
        super(UnLiMiTDI, self).__init__()
        ## GP parameters
        self.feature_extractor = conv_net
        self.diff_net = diff_net  #Differentiable network
        self.get_model_likelihood_mll() #Init model, likelihood, and mll

    def get_model_likelihood_mll(self, train_x=None, train_y=None):
        if(train_x is None): train_x=torch.ones(19, 2916).cuda()
        if(train_y is None): train_y=torch.ones(19).cuda()

        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, diff_net = self.diff_net, kernel='NTK')

        self.model      = model.cuda()
        self.likelihood = likelihood.cuda()
        self.mll        = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model).cuda()
        self.mse        = nn.MSELoss()

        return self.model, self.likelihood, self.mll

    def set_forward(self, x, is_feature=False):
        pass

    def set_forward_loss(self, x):
        pass

    def train_loop(self, epoch, optimizer):
        batch, batch_labels = get_batch(train_people)
        batch, batch_labels = batch.cuda(), batch_labels.cuda()
        for inputs, labels in zip(batch, batch_labels):
            optimizer.zero_grad()

            inputs_conv = self.feature_extractor(inputs)
            self.model.set_train_data(inputs=inputs_conv, targets=labels - self.diff_net(inputs_conv).reshape(-1))  
            predictions = self.model(inputs_conv)
            loss = -self.mll(predictions, self.model.train_targets)

            loss.backward()
            optimizer.step()
            mse = self.mse(predictions.mean, labels)

            if (epoch%10==0):
                print('[%d] - Loss: %.3f  MSE: %.3f noise: %.3f' % (
                    epoch, loss.item(), mse.item(),
                    self.model.likelihood.noise.item()
                ))

    def test_loop(self, n_support, optimizer=None): # no optimizer needed for GP
        inputs, targets = get_batch(test_people)

        support_ind = list(np.random.choice(list(range(19)), replace=False, size=n_support))
        query_ind   = [i for i in range(19) if i not in support_ind]

        x_all = inputs.cuda()
        y_all = targets.cuda()

        x_support = inputs[:,support_ind,:,:,:].cuda()
        y_support = targets[:,support_ind].cuda()

        # choose a random test person
        n = np.random.randint(0, len(test_people)-1)
    
        x_conv_support = self.feature_extractor(x_support[n]).detach()
        self.model.set_train_data(inputs=x_conv_support, targets=y_support[n] - self.diff_net(x_conv_support).reshape(-1), strict=False)

        self.model.eval()
        self.feature_extractor.eval()
        self.likelihood.eval()

        with torch.no_grad():
            x_conv_query = self.feature_extractor(x_all[n]).detach()
            pred    = self.likelihood(self.model(x_conv_query))
            lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
            lower += self.diff_net(x_conv_query).reshape(-1)
            upper += self.diff_net(x_conv_query).reshape(-1)
        mse = self.mse(pred.mean + self.diff_net(self.feature_extractor(x_all[n])).reshape(-1), y_all[n])

        return mse

    def save_checkpoint(self, checkpoint):
        # save state
        gp_state_dict         = self.model.state_dict()
        likelihood_state_dict = self.likelihood.state_dict()
        conv_net_state_dict   = self.feature_extractor.state_dict()
        diff_net_state_dict   = self.diff_net.state_dict()
        torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'conv_net':conv_net_state_dict, 'diff_net':diff_net_state_dict}, checkpoint)

    def load_checkpoint(self, checkpoint):
        ckpt = torch.load(checkpoint)
        self.model.load_state_dict(ckpt['gp'])
        self.likelihood.load_state_dict(ckpt['likelihood'])
        self.feature_extractor.load_state_dict(ckpt['conv_net'])
        self.diff_net.load_state_dict(ckpt['diff_net'])

        
# ##################
# NTKernel
# ##################

class NTKernelcov(gpytorch.kernels.Kernel):
    def __init__(self, net, **kwargs):
        super(NTKernelcov, self).__init__(**kwargs)
        self.net = net
        
        # Add number of params scaling parameters, initializing them as one
        self.scaling_param = nn.Parameter(torch.ones(sum(p.numel() for p in simple_net.parameters())))
        
    def forward(self, x1, x2, diag=False, **params):
        jac1 = self.compute_jacobian(x1)
        jac2 = self.compute_jacobian(x2) if x1 is not x2 else jac1
        D = torch.diag(torch.pow(self.scaling_param, 2))
        
        result = torch.chain_matmul(jac1, D, jac2.T)
        
        if diag:
            return result.diag()
        return result
    
    def compute_jacobian(self, inputs):
        """
        Return the jacobian of a batch of inputs, thanks to the vmap functionality
        """
        self.zero_grad()
        params = {k: v for k, v in self.net.named_parameters()}
        def fnet_single(params, x):
            return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)
        
        jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
        jac = jac.values()
        # jac1 of dimensions [Nb Layers, Nb input / Batch, dim(y), Nb param/layer left, Nb param/layer right]
        reshaped_tensors = [
            j.flatten(2)                # Flatten starting from the 3rd dimension to acount for weights and biases layers
                .permute(2, 0, 1)         # Permute to align dimensions correctly for reshaping
                .reshape(-1, j.shape[0] * j.shape[1])  # Reshape to (c, a*b) using dynamic sizing
            for j in jac
        ]
        return torch.cat(reshaped_tensors, dim=0).T


    
# ##################
# NTKernel CosSim
# ##################

class CosSimNTKernelcov(gpytorch.kernels.Kernel):
    def __init__(self, net, **kwargs):
        super(CosSimNTKernelcov, self).__init__(**kwargs)
        self.net = net
        self.alpha = nn.Parameter(torch.ones(1))
        
        # Add subspace_dimension scaling parameters, initializing them as one
        self.scaling_param = nn.Parameter(torch.ones(sum(p.numel() for p in simple_net.parameters())))
        
    def forward(self, x1, x2, diag=False, **params):
        jac1 = self.compute_jacobian(x1)
        jac2 = self.compute_jacobian(x2) if x1 is not x2 else jac1
        
        D = torch.diag(self.scaling_param)
        
        result_1 = torch.chain_matmul(D, jac1.T)
        result_2 = torch.chain_matmul(D, jac2.T)
        
        result_1_norm = result_1.norm(dim=0, keepdim=True)
        result_1_normalized = result_1/result_1_norm
        #print(result_1.shape)
        #print(result_1.norm(dim=0, keepdim=True).shape)
        result_2_norm = result_2.norm(dim=0, keepdim=True)
        result_2_normalized = result_2/result_2_norm
        
        result = self.alpha * result_1_normalized.T@result_2_normalized
        
        if diag:
            return result.diag()
        return result
    
    def compute_jacobian(self, inputs):
        """
        Return the jacobian of a batch of inputs, thanks to the vmap functionality
        """
        self.zero_grad()
        params = {k: v for k, v in self.net.named_parameters()}
        def fnet_single(params, x):
            return functional_call(self.net, params, (x.unsqueeze(0),)).squeeze(0)
        
        jac = vmap(jacrev(fnet_single), (None, 0))(params, inputs)
        jac = jac.values()
        # jac1 of dimensions [Nb Layers, Nb input / Batch, dim(y), Nb param/layer left, Nb param/layer right]
        reshaped_tensors = [
            j.flatten(2)                # Flatten starting from the 3rd dimension to acount for weights and biases layers
                .permute(2, 0, 1)         # Permute to align dimensions correctly for reshaping
                .reshape(-1, j.shape[0] * j.shape[1])  # Reshape to (c, a*b) using dynamic sizing
            for j in jac
        ]
        return torch.cat(reshaped_tensors, dim=0).T

    
###################
#GP
###################    
class ExactGPLayer(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, diff_net, kernel='NTK'):
        super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
        self.mean_module  = gpytorch.means.ConstantMean()

        ## NTKernel
        if(kernel=='NTK'):
            self.covar_module = NTKernelcov(diff_net)
        elif(kernel=='NTKcossim'):
            self.covar_module = CosSimNTKernelcov(diff_net)        
        else:
            raise ValueError("[ERROR] the kernel '" + str(kernel) + "' is not supported for regression, use 'rbf' or 'spectral'.")

    def forward(self, x):
        mean_x  = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


In [None]:
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

bb           = backbone.Conv3().cuda()
simple_net   = backbone.simple_net().cuda()

model = UnLiMiTDI(bb, simple_net).cuda()
optimizer = torch.optim.Adam([{'params': model.model.parameters(), 'lr': 0.001},
                                {'params': model.feature_extractor.parameters(), 'lr': 0.001}])

for epoch in range(100):
    model.train_loop(epoch, optimizer)