In [11]:
# libraries
import torch
from model import get_core_model_params, get_trained_linear, freeze
from dataset import split_user_train_dataset_to_remaining_forget, get_remaining_forget_loader
from utils import params_to_device
from loss import MSELossDiv2
import os

device = 'cpu:3' if torch.cuda.is_available() else 'cpu'

In [2]:
# load pretrained model
exp_path = 'checkpoint/05152024-011132-train-user-data-resnet18-cifar10-last1/'
core_model_state_dict = get_core_model_params(os.path.join(exp_path, '05152024_011132_train_user_data_resnet18_cifar10_last1_core_model.pth'), 'cpu')
feature_backbone, mixed_linear = get_trained_linear(os.path.join(exp_path, '05152024_011132_train_user_data_resnet18_cifar10_last1.pth'), 'resnet18', 'cifar10', 1)

feature_backbone = feature_backbone.to(device)
freeze(feature_backbone)

mixed_linear = mixed_linear.to(device)
freeze(mixed_linear)

core_model_state_dict = params_to_device(core_model_state_dict, device)

## split dataset into remaning and forget
remaining_dataset, forget_dataset = split_user_train_dataset_to_remaining_forget('cifar10', 'resnet18', 0.1)
remain_loader, forget_loader = get_remaining_forget_loader(remaining_dataset, forget_dataset, 256)

Files already downloaded and verified


In [3]:
# calculate the hessian on the last linear layer for both remaning and forget
def calculate_hessian(backbone, loader, exp_path, mode='forget'):
    print('{} hessian'.format(mode))
    hessian = None
    sample_count = 0
    with torch.no_grad():
        for iter, (data, _) in enumerate(loader):
            data = data.to(device)
            act = backbone(data).squeeze()
            act = act.unsqueeze(-1)
            batched_hessian = act @ act.permute(0, 2, 1)
            if hessian is None:
                hessian = torch.sum(batched_hessian, dim=0).clone().detach().to('cpu')
            else:
                hessian += torch.sum(batched_hessian, dim=0).clone().detach().to('cpu')
            sample_count += data.shape[0]
            if (iter + 1) % 50 == 0 or (iter + 1) == len(loader):
                print('iter: {}/{}'.format(iter + 1, len(loader))) 
    hessian = hessian / sample_count
    torch.save({'hessian': hessian}, os.path.join(exp_path, '05152024_011132_train_user_data_resnet18_cifar10_last1_{}_hessian.pth'.format(mode)))
    return hessian

forget_hessian = calculate_hessian(feature_backbone, forget_loader, exp_path, mode='forget')
remain_hessian = calculate_hessian(feature_backbone, remain_loader, exp_path, mode='remain')

forget hessian
iter: 20/20
remain hessian
iter: 50/176
iter: 100/176
iter: 150/176
iter: 176/176


In [8]:
# sample perturbed parameters
## perturb from gradient direction
## NOTE: we can analyze the effects of sampling different perturbations and its importance

trained_mixed_linear_weights = [key.clone().detach().to('cpu') for key in mixed_linear.tangents.values()]
num_of_perturbations = 500
scale_random = 0.01

# using default random perturbation
perturbations = []
perturbed_weights = []
for _ in range(num_of_perturbations):
    curr_perturb = [torch.randn(*weight.shape) * scale_random for weight in trained_mixed_linear_weights]
    curr_perturbed_weight = [weight + perturb for weight, perturb in zip(trained_mixed_linear_weights, curr_perturb)]
    perturbations.append(curr_perturb)
    perturbed_weights.append(curr_perturbed_weight)
torch.save({'perturbations': perturbations}, os.path.join(exp_path, '05152024_011132_train_user_data_resnet18_cifar10_last1_perturbations.pth'))
torch.save({'perturbed_weights': perturbed_weights}, os.path.join(exp_path, '05152024_011132_train_user_data_resnet18_cifar10_last1_perturbed_weights.pth'))

In [17]:
# find out loss differences (L_forget)
criterion = MSELossDiv2()
forget_loss_differences = torch.zeros(num_of_perturbations).to(device)
sample_count = 0
mixed_linear.eval()
with torch.no_grad():
    for iter, (data, label) in enumerate(forget_loader):
        data, label = data.to(device), label.to(device)
        label = label * 5
        preds = mixed_linear(feature_backbone, core_model_state_dict, data)
        actual_loss = criterion(preds, label)
        sample_count += data.shape[0]
        for perturb_idx, perturbed_weight in enumerate(perturbed_weights):
            mixed_linear.to('cpu')
            state_dict = mixed_linear.state_dict()
            for key, perturbed in zip(state_dict.keys(), perturbed_weight):
                state_dict[key] = perturbed
            mixed_linear.load_state_dict(state_dict)
            mixed_linear.to(device)
            perturbed_preds = mixed_linear(feature_backbone, core_model_state_dict, data)
            perturbed_loss = criterion(perturbed_preds, label)
            forget_loss_differences[perturb_idx] += (perturbed_loss - actual_loss) * 2 * data.shape[0]
            if (perturb_idx + 1) % 10 == 0 or (perturb_idx + 1) == len(forget_loader):
                print('iter: {}/{} perturb: {}/{}'.format(iter + 1, len(forget_loader), perturb_idx + 1, len(perturbed_weights)))
    forget_loss_differences = forget_loss_differences / sample_count
    forget_loss_differences = forget_loss_differences.to('cpu')

iter: 1/20 perturb: 10/500
iter: 1/20 perturb: 20/500
iter: 1/20 perturb: 30/500
iter: 1/20 perturb: 40/500
iter: 1/20 perturb: 50/500
iter: 1/20 perturb: 60/500
iter: 1/20 perturb: 70/500
iter: 1/20 perturb: 80/500
iter: 1/20 perturb: 90/500
iter: 1/20 perturb: 100/500
iter: 1/20 perturb: 110/500
iter: 1/20 perturb: 120/500
iter: 1/20 perturb: 130/500
iter: 1/20 perturb: 140/500
iter: 1/20 perturb: 150/500
iter: 1/20 perturb: 160/500
iter: 1/20 perturb: 170/500
iter: 1/20 perturb: 180/500
iter: 1/20 perturb: 190/500
iter: 1/20 perturb: 200/500
iter: 1/20 perturb: 210/500
iter: 1/20 perturb: 220/500
iter: 1/20 perturb: 230/500
iter: 1/20 perturb: 240/500
iter: 1/20 perturb: 250/500
iter: 1/20 perturb: 260/500
iter: 1/20 perturb: 270/500
iter: 1/20 perturb: 280/500
iter: 1/20 perturb: 290/500
iter: 1/20 perturb: 300/500
iter: 1/20 perturb: 310/500
iter: 1/20 perturb: 320/500
iter: 1/20 perturb: 330/500
iter: 1/20 perturb: 340/500
iter: 1/20 perturb: 350/500
iter: 1/20 perturb: 360/500
i

In [None]:
# set optimization problem