In [1]:
from dataset import get_user_loader, get_remaining_forget_loader, split_user_train_dataset_to_remaining_forget
from model import get_trained_linear, get_core_model_params
from train import train_accuracy_mixed_linear
from utils import params_to_device

import torch

device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

In [2]:
def add_random_noise(model, device, sigma=1e-5):
    with torch.no_grad():
        state_dict = model.state_dict()
        for name in model.tangents.keys():
            state_dict['tangent_model.{}'.format(name)] = state_dict['tangent_model.{}'.format(name)] + ((sigma**2) * torch.normal(0, 1, size=state_dict['tangent_model.{}'.format(name)].shape, device=device))
        model.load_state_dict(state_dict)

In [2]:
whole_loader, test_loader = get_user_loader('cifar10', 'resnet50', 64, shuffle=False)
remain_dataset, forget_dataset = split_user_train_dataset_to_remaining_forget('cifar10', 'resnet50', 0.1, seed=13)
remain_loader, forget_loader = get_remaining_forget_loader(remain_dataset, forget_dataset, 64, shuffle=False)

In [3]:
core_model_state_dict = get_core_model_params('checkpoint/05042024-213334-train-user-data-resnet50-cifar10-last5/05042024_213334_train_user_data_resnet50_cifar10_last5_core_model.pth', 'cpu')
core_model_state_dict = params_to_device(core_model_state_dict, device)

core_model_state_dict_remain = get_core_model_params('checkpoint/05092024-141054-train-user-data-resnet50-cifar10-last5-split0.1/05092024_141054_train_user_data_resnet50_cifar10_last5_split0.1_core_model.pth', 'cpu')
core_model_state_dict_remain = params_to_device(core_model_state_dict_remain, device)

feature_backbone, ml_whole_data = get_trained_linear('checkpoint/05042024-213334-train-user-data-resnet50-cifar10-last5/05042024_213334_train_user_data_resnet50_cifar10_last5.pth', 'resnet50', 'cifar10', 5)
ml_whole_data = ml_whole_data.to(device)

# feature_backbone, ml_scrubbing = get_trained_linear('checkpoint/05092024-132921-forgetting-resnet50-cifar10-last5-split0.1/05092024_132921_forgetting_resnet50_cifar10_last5_split0.1.pth', 'resnet50', 'cifar10', 5)
# ml_scrubbing = ml_scrubbing.to(device)
# add_random_noise(ml_scrubbing, device, sigma=0.17)

feature_backbone, ml_scrubbing = get_trained_linear('hess_diag_model_100_iter_fixed.pth', 'resnet50', 'cifar10', 5)
ml_scrubbing = ml_scrubbing.to(device)

feature_backbone, ml_remain = get_trained_linear('checkpoint/05092024-141054-train-user-data-resnet50-cifar10-last5-split0.1/05092024_141054_train_user_data_resnet50_cifar10_last5_split0.1.pth', 'resnet50', 'cifar10', 5)
ml_remain = ml_remain.to(device)

feature_backbone = feature_backbone.to(device)

In [5]:
train_accuracy_mixed_linear(ml_scrubbing, remain_loader, feature_backbone, core_model_state_dict, None, 0, device, None, save_param=False)

In [4]:
train_accuracy_mixed_linear(ml_scrubbing, forget_loader, feature_backbone, core_model_state_dict, None, 0, device, None, save_param=False)

In [7]:
torch.load('hess_diag_model_1_iter_fixed.pth')

In [8]:
torch.load('checkpoint/05042024-213334-train-user-data-resnet50-cifar10-last5/05042024_213334_train_user_data_resnet50_cifar10_last5.pth')['model_state_dict']