In [1]:
import torch
from model import get_core_model_params, get_trained_linear, init_pretrained_model, split_model_to_feature_linear, freeze, thaw
import torch.nn as nn
from train import test_mixed_linear
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
from dataset import get_user_loader

from torch.func import functional_call
import torch.autograd.forward_ad as fwAD

from loss import L2Regularization, LossWrapper
from utils import params_to_device

from torch.optim import SGD
from dataset import split_user_train_dataset_to_remaining_forget, get_remaining_forget_loader

In [2]:
# loading core model and linearized model -- init in cpu
pretrained_model = init_pretrained_model('resnet50', 'cifar10')
_, linearized_head_core, __ = split_model_to_feature_linear(pretrained_model, 5, None, send_params_to_device=False)
core_model_state_dict = get_core_model_params('checkpoint/05042024-184850-train-user-data-resnet50-cifar10-last5/05042024_184850_train_user_data_resnet50_cifar10_last5_core_model.pth', 'cpu')
feature_backbone, mixed_linear = get_trained_linear('checkpoint/05042024-184850-train-user-data-resnet50-cifar10-last5/05042024_184850_train_user_data_resnet50_cifar10_last5.pth', 'resnet50', 'cifar10', 5)
del _
del __
# _, test_loader = get_user_loader('cifar10', 'resnet50', 256)
# test_mixed_linear(mixed_linear, test_loader, feature_backbone, core_model_state_dict, None, None, 0, device, None, None, None, save_param=False)

In [3]:
v_param = {key: torch.randn_like(value, device='cpu') for key, value in core_model_state_dict.items()} ## init in cpu

class JVPNormLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, feature_backbone, arch, primals, tangents, inp):
        with torch.no_grad():
            inp = feature_backbone(inp)

        dual_params = {}
        with fwAD.dual_level():
            for name, p in primals.items():
                dual_params[name] = fwAD.make_dual(p, tangents[name])
            out = functional_call(arch, dual_params, inp)
            jvp = fwAD.unpack_dual(out).tangent
        return 2 * (torch.norm(jvp) ** 2 / inp.shape[0])

def calculate_gradient(feature_backbone, core_model_state_dict, model, loss_fnc, data_loader, device):
    grads = [torch.zeros_like(param) for param in model.parameters()]
    sample_count = 0
    thaw(model)
    for iter_idx, (inp, target) in enumerate(data_loader):
        model.zero_grad()
        inp = inp.to(device)
        target = target.to(device)
        curr_loss = loss_fnc(model(feature_backbone, core_model_state_dict, inp), target, model.parameters())
        curr_loss.backward()
        for idx, param in enumerate(model.parameters()):
            grads[idx] += (param.grad * inp.shape[0])
        sample_count += inp.shape[0]
        if iter_idx == 0 or (iter_idx + 1) % 50 == 0 or (iter_idx + 1) == len(data_loader):
            print('iter: {}/{}'.format(iter_idx + 1, len(data_loader)))
    freeze(model)
    
    last = []
    for grad in grads:
        tmp = grad / sample_count
        tmp.requires_grad = False
        last.append(tmp)
    return last   
    
class GradientVectorInnerProduct(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, grads, vector_values):
        grad_vector_inner_product_sum = None
        for param, vector_value in zip(grads, vector_values):
            if grad_vector_inner_product_sum is None:
                grad_vector_inner_product_sum = torch.sum(param * vector_value)
            else:
                grad_vector_inner_product_sum += torch.sum(param * vector_value)       
        return grad_vector_inner_product_sum

In [4]:
feature_backbone = feature_backbone.to(device)
freeze(feature_backbone)

mixed_linear = mixed_linear.to(device)
freeze(mixed_linear)

linearized_head_core = linearized_head_core.to(device)
freeze(linearized_head_core)

core_model_state_dict = params_to_device(core_model_state_dict, device)
v_param = params_to_device(v_param, device)
for param in v_param.values():
    param.requires_grad = True

In [5]:
remaining_dataset, forget_dataset = split_user_train_dataset_to_remaining_forget('cifar10', 'resnet50', 0.001)
remain_loader, forget_loader = get_remaining_forget_loader(remaining_dataset, forget_dataset, 256)

Files already downloaded and verified


In [6]:
main_criterion = LossWrapper([nn.MSELoss(), L2Regularization()], [1, 0.0005])
grads = calculate_gradient(feature_backbone, core_model_state_dict, mixed_linear, main_criterion, remain_loader, device)

iter: 1/196


In [None]:
jvp_norm_criterion = JVPNormLoss()
gradient_vector_inner_product_criterion = GradientVectorInnerProduct()
regularizor_criterion = L2Regularization()


optimizer = SGD(v_param.values(), lr=0.001, momentum=0.999)

for epoch in range(50):
    if (epoch + 1) in [10, 20, 30, 40]:
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.5
    for iter_idx, (data, label) in enumerate(remain_loader):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        jvp_norm_loss = 0.5 * jvp_norm_criterion(feature_backbone, linearized_head_core, core_model_state_dict, v_param, data)
        gradient_vector_inner_product_loss = gradient_vector_inner_product_criterion(grads, v_param.values())
        regularizor_loss = 0.0005 * regularizor_criterion(v_param.values())
        loss = jvp_norm_loss + regularizor_loss - gradient_vector_inner_product_loss
        loss.backward()
        optimizer.step()
        if iter_idx == 0 or (iter_idx + 1) % 50 == 0 or (iter_idx + 1) == len(remain_loader):
            print('epoch: {}/{}, iter: {}/{}, loss: {}'.format(epoch + 1, 3, iter_idx + 1, len(remain_loader), loss.item()))


epoch: 1/3, iter: 1/781, loss: 6417.1826171875
epoch: 1/3, iter: 50/781, loss: 5631.30322265625
epoch: 1/3, iter: 100/781, loss: 5201.34228515625
epoch: 1/3, iter: 150/781, loss: 5716.8359375
epoch: 1/3, iter: 200/781, loss: 5611.5625
epoch: 1/3, iter: 250/781, loss: 5057.02880859375
epoch: 1/3, iter: 300/781, loss: 4747.9970703125
epoch: 1/3, iter: 350/781, loss: 5443.9248046875
epoch: 1/3, iter: 400/781, loss: 5663.1572265625
epoch: 1/3, iter: 450/781, loss: 4738.43505859375
epoch: 1/3, iter: 500/781, loss: 4685.9921875
epoch: 1/3, iter: 550/781, loss: 5159.494140625
epoch: 1/3, iter: 600/781, loss: 5132.50146484375
epoch: 1/3, iter: 650/781, loss: 4724.9716796875
epoch: 1/3, iter: 700/781, loss: 4728.07763671875
epoch: 1/3, iter: 750/781, loss: 4944.23681640625
epoch: 1/3, iter: 781/781, loss: 4838.33837890625
epoch: 2/3, iter: 1/781, loss: 4797.712890625
epoch: 2/3, iter: 50/781, loss: 4491.76611328125
epoch: 2/3, iter: 100/781, loss: 4450.587890625
epoch: 2/3, iter: 150/781, loss:

KeyboardInterrupt: 

In [12]:
v_param

{'0.conv1.weight': tensor([[[[-0.7885]],
 
          [[-1.1451]],
 
          [[-2.0049]],
 
          ...,
 
          [[-0.1389]],
 
          [[ 0.3950]],
 
          [[-1.6776]]],
 
 
         [[[-0.3147]],
 
          [[ 0.5587]],
 
          [[-1.0021]],
 
          ...,
 
          [[-0.7139]],
 
          [[-0.8996]],
 
          [[-0.0274]]],
 
 
         [[[ 0.3187]],
 
          [[ 0.3746]],
 
          [[ 0.1337]],
 
          ...,
 
          [[ 0.2187]],
 
          [[-0.3412]],
 
          [[-1.1021]]],
 
 
         ...,
 
 
         [[[-0.7907]],
 
          [[ 0.1959]],
 
          [[-0.3407]],
 
          ...,
 
          [[ 1.3172]],
 
          [[ 0.0259]],
 
          [[-0.3442]]],
 
 
         [[[ 0.4694]],
 
          [[ 1.0574]],
 
          [[-0.8297]],
 
          ...,
 
          [[ 0.4758]],
 
          [[ 0.4577]],
 
          [[ 0.4569]]],
 
 
         [[[-0.9990]],
 
          [[ 0.3037]],
 
          [[ 1.5692]],
 
          ...,
 
          [[ 0.6154]],

In [13]:
forgetted = {name: first - second for name, first, second in zip(mixed_linear.tangents.keys(), mixed_linear.tangents.values(), v_param.values())}
mixed_linear.tangents = forgetted
_, test_loader = get_user_loader('cifar10', 'resnet50', 256)
test_mixed_linear(mixed_linear, test_loader, feature_backbone, core_model_state_dict, None, None, 0, device, None, None, None, save_param=False)

Files already downloaded and verified
Files already downloaded and verified
test iter - processing: 1/40
test iter - processing: 25/40
test iter - processing: 40/40
epoch: 1/50, test accuracy: 0.1583


(None, None)