In [1]:
from dataset import get_user_loader, get_remaining_forget_loader, split_user_train_dataset_to_remaining_forget
from utils import params_to_device
from model import get_core_model_params, get_trained_linear, init_pretrained_model, split_model_to_feature_linear, freeze, thaw
from loss import L2Regularization, LossWrapper, MSELossDiv2
from torch.autograd.functional import hvp, jvp
from torch.func import functional_call

import torch
import torch.nn as nn
import random
import numpy as np

device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
def set_deterministic(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
seed = 13 # any number 
set_deterministic(seed=seed)

In [2]:
# loading core model and linearized model -- init in cpu
pretrained_model = init_pretrained_model('resnet50', 'cifar10')
_, linearized_head_core, __ = split_model_to_feature_linear(pretrained_model, 5, None, send_params_to_device=False)
core_model_state_dict = get_core_model_params('checkpoint/05042024-213334-train-user-data-resnet50-cifar10-last5/05042024_213334_train_user_data_resnet50_cifar10_last5_core_model.pth', 'cpu')
feature_backbone, mixed_linear = get_trained_linear('checkpoint/05042024-213334-train-user-data-resnet50-cifar10-last5/05042024_213334_train_user_data_resnet50_cifar10_last5.pth', 'resnet50', 'cifar10', 5)
del _
del __

feature_backbone = feature_backbone.to(device)
freeze(feature_backbone)

mixed_linear = mixed_linear.to(device)
freeze(mixed_linear)

linearized_head_core = linearized_head_core.to(device)
freeze(linearized_head_core)

core_model_state_dict = params_to_device(core_model_state_dict, device)

def calculate_gradient(feature_backbone, core_model_state_dict, model, loss_fnc, regularizor_hyperparameter, data_loader, device):
    grads = [torch.zeros_like(param) for param in model.parameters()]
    sample_count = 0
    thaw(model)
    for iter_idx, (inp, target) in enumerate(data_loader):
        model.zero_grad()
        inp = inp.to(device)
        target = 5 * target.to(device)
        curr_loss = loss_fnc(model(feature_backbone, core_model_state_dict, inp), target, model.parameters())
        curr_loss.backward()
        for idx, param in enumerate(model.parameters()):
            grads[idx] += ((param.grad - regularizor_hyperparameter*param.clone().detach()) * inp.shape[0])
        sample_count += inp.shape[0]

        if iter_idx == 0 or (iter_idx + 1) % 50 == 0 or (iter_idx + 1) == len(data_loader):
            print('iter: {}/{}'.format(iter_idx + 1, len(data_loader)))
    freeze(model)
    
    last = []
    for grad in grads:
        tmp = grad / sample_count
        tmp.requires_grad = False
        last.append(tmp)
    return last

whole_loader, test_loader = get_user_loader('cifar10', 'resnet50', 64, shuffle=False)
remain_dataset, forget_dataset = split_user_train_dataset_to_remaining_forget('cifar10', 'resnet50', 0.1, seed=13)
remain_loader, forget_loader = get_remaining_forget_loader(remain_dataset, forget_dataset, 64, shuffle=False)

main_criterion = LossWrapper([MSELossDiv2(), L2Regularization()], [1, 0.0005])

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
whole_grad = calculate_gradient(feature_backbone, core_model_state_dict, mixed_linear, main_criterion, 0.0005, whole_loader, device)

iter: 1/782
iter: 50/782
iter: 100/782
iter: 150/782
iter: 200/782
iter: 250/782
iter: 300/782
iter: 350/782
iter: 400/782
iter: 450/782
iter: 500/782
iter: 550/782
iter: 600/782
iter: 650/782
iter: 700/782
iter: 750/782
iter: 782/782


In [4]:
forget_grad = calculate_gradient(feature_backbone, core_model_state_dict, mixed_linear, main_criterion, 0.0005, forget_loader, device)

iter: 1/79
iter: 50/79
iter: 79/79


In [5]:
remain_grad = calculate_gradient(feature_backbone, core_model_state_dict, mixed_linear, main_criterion, 0.0005, remain_loader, device)

iter: 1/704
iter: 50/704
iter: 100/704
iter: 150/704
iter: 200/704
iter: 250/704
iter: 300/704
iter: 350/704
iter: 400/704
iter: 450/704
iter: 500/704
iter: 550/704
iter: 600/704
iter: 650/704
iter: 700/704
iter: 704/704


In [8]:
# remain_forget_grad = [(remain_grad_elem * len(remain_dataset) + forget_grad_elem * len(forget_dataset)) / (len(remain_dataset) + len(forget_dataset)) for remain_grad_elem, forget_grad_elem in zip(remain_grad, forget_grad)]
# torch.sum(torch.tensor([torch.norm(remain_forget_grad[i] - whole_grad[i]) for i in range(len(remain_forget_grad))]))

In [3]:
# APPROX. DIRECTLY WITH FORGET - DEBUG PURPOSES

def calculate_hess_diag(feature_backbone, core_model_state_dict, model, loss_fnc, regularizor_hyperparameter, data_loader, device):
    # grads = None
    hess_diags = [torch.zeros_like(p) for p in model.parameters()]
    sample_count = 0
    v = [np.random.uniform(0, 1, size = p.shape) for p in model.parameters()]
    for vi in v:
        vi[ vi < 0.5] =  -1 
        vi[ vi >= 0.5] =  1 
    v = [torch.tensor(vi) for vi in v]
    v = {key: param for (key, _), param in zip(model.named_parameters(), v)}
    v = params_to_device(v, device)
    
    thaw(model)
    for iter_idx, (inp, target) in enumerate(data_loader):
        # model.zero_grad()
        inp = inp.to(device)
        target = 5 * target.to(device)
        # curr_loss = loss_fnc(model(feature_backbone, core_model_state_dict, inp), target, model.parameters())
        curr_loss = loss_fnc(model(feature_backbone, core_model_state_dict, inp), target)
        curr_grad = torch.autograd.grad(curr_loss, model.parameters(), create_graph=True)
        
        vprod = None
        for vi, grad in zip(v.values(), curr_grad):
            if vprod is None:
                vprod = torch.sum(vi * grad)
            else:
                vprod += torch.sum(vi * grad)

        hvp_val = torch.autograd.grad(vprod, model.parameters())

        # Hd = [torch.abs(Hvi * vi) for Hvi, vi in zip(Hv, v)]
        
        # def func(params):
        #     model_dict = {key: param for (key, _), param in zip(model.named_parameters(), params)}
        #     out = functional_call(model, model_dict, (feature_backbone, core_model_state_dict, inp))
        #     return out
        # 
        # def loss(*params):
        #     out = func(params)
        #     return loss_fnc(out, target)
        # 
        # _, hvp_val = hvp(loss, tuple(model.parameters()), tuple(v.values()))

        for idx, (vi, hvp_val_i) in enumerate(zip(v.values(), hvp_val)):
            hess_diags[idx] = hess_diags[idx] + (torch.abs(vi * hvp_val_i) * inp.shape[0])
        
        sample_count += inp.shape[0]
        # vprod = None
        # for vi, grad in zip(v, curr_grad):
        #     if vprod is None:
        #         vprod = torch.sum(vi * grad)
        #     else:
        #         vprod += torch.sum(vi * grad)

        # Hv = torch.autograd.grad(vprod, model.parameters())

        # Hd = [torch.abs(Hvi * vi) for Hvi, vi in zip(Hv, v)]

        
    #     curr_loss.backward()
    #     for idx, param in enumerate(model.parameters()):
    #         grads[idx] += ((param.grad - regularizor_hyperparameter*param.clone().detach()) * inp.shape[0])
    #     sample_count += inp.shape[0]

        if iter_idx == 0 or (iter_idx + 1) % 50 == 0 or (iter_idx + 1) == len(data_loader):
            print('iter: {}/{}'.format(iter_idx + 1, len(data_loader)))
    freeze(model)
    
    hess_diags = [(diags / sample_count) + (regularizor_hyperparameter * torch.norm(param) ** 2) for diags, param in zip(hess_diags, v.values())]
    # hess_diags = [(diags / sample_count) for diags in hess_diags]
    return hess_diags


In [4]:
hess_diags = calculate_hess_diag(feature_backbone, core_model_state_dict, mixed_linear, MSELossDiv2(), 0.0005, forget_loader, device)

iter: 1/79
iter: 50/79
iter: 79/79


In [5]:
hess_diags

[tensor([[[[262.1453]],
 
          [[262.1459]],
 
          [[262.1447]],
 
          ...,
 
          [[262.1452]],
 
          [[262.1453]],
 
          [[262.1453]]],
 
 
         [[[262.1452]],
 
          [[262.1457]],
 
          [[262.1448]],
 
          ...,
 
          [[262.1448]],
 
          [[262.1454]],
 
          [[262.1452]]],
 
 
         [[[262.1453]],
 
          [[262.1467]],
 
          [[262.1453]],
 
          ...,
 
          [[262.1450]],
 
          [[262.1454]],
 
          [[262.1455]]],
 
 
         ...,
 
 
         [[[262.1453]],
 
          [[262.1457]],
 
          [[262.1451]],
 
          ...,
 
          [[262.1449]],
 
          [[262.1452]],
 
          [[262.1455]]],
 
 
         [[[262.1457]],
 
          [[262.1459]],
 
          [[262.1448]],
 
          ...,
 
          [[262.1450]],
 
          [[262.1450]],
 
          [[262.1457]]],
 
 
         [[[262.1451]],
 
          [[262.1459]],
 
          [[262.1446]],
 
          ...,
 
       

In [6]:
# APPROX. DIRECTLY WITH REMAINING
remain_grad

[tensor([[[[-6.5634e-09]],
 
          [[ 9.1132e-07]],
 
          [[-1.9040e-06]],
 
          ...,
 
          [[ 2.2687e-06]],
 
          [[-2.3203e-06]],
 
          [[-6.1204e-06]]],
 
 
         [[[-2.7520e-06]],
 
          [[ 3.0717e-06]],
 
          [[ 1.7497e-06]],
 
          ...,
 
          [[-8.7384e-07]],
 
          [[ 9.3652e-07]],
 
          [[ 8.2978e-06]]],
 
 
         [[[-1.0194e-06]],
 
          [[-1.2976e-05]],
 
          [[ 2.7846e-06]],
 
          ...,
 
          [[-7.1667e-07]],
 
          [[-1.0624e-06]],
 
          [[-1.8895e-06]]],
 
 
         ...,
 
 
         [[[-2.4403e-06]],
 
          [[ 1.9693e-06]],
 
          [[-4.1486e-06]],
 
          ...,
 
          [[-1.4912e-06]],
 
          [[-2.3654e-06]],
 
          [[ 1.8859e-06]]],
 
 
         [[[ 2.1935e-06]],
 
          [[ 7.9350e-06]],
 
          [[-5.3859e-07]],
 
          ...,
 
          [[ 6.1263e-07]],
 
          [[-2.1926e-07]],
 
          [[ 6.2182e-06]]],
 
 
         [[[