In [1]:
from datasets import load_cifar10
from torchvision import transforms as T
from finetune.models.resnet import resnet50
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = resnet50()
state_dict = torch.load('weights/cifar10_resnet50_weights.pth')
for key in list(state_dict.keys()):
    if 'module' in key:
        state_dict[key.replace('module.', '')] = state_dict[key]
        del state_dict[key]
model.load_state_dict(state_dict)
model = model.eval().to(device)

In [2]:
import copy
net = copy.deepcopy(model)

In [3]:
train_dataloaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def forward_backward(net, train_loader, loss_fn):
    total_loss = 0
    num = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        outputs = net(x)
        loss = loss_fn(outputs, y)
        total_loss = loss.item() + total_loss
        loss.backward()
        num += x.shape[0]
    return total_loss, num

In [5]:
def update_all_params(net,alpha,parameter_names,operation):
    all_grads = dict()
    for name, param in net.named_parameters():
        if name in parameter_names:
            grads = param.grad.data
            weights = param.data
            if operation == 'add':
                mask = ((weights < 0) & (grads < 0)) ^ ((weights > 0) & (grads > 0))
            elif operation == 'minus':
                mask = ~(((weights < 0) & (grads < 0)) ^ ((weights > 0) & (grads > 0)))
            param.data = weights + alpha * grads.sign() * mask.float()
            all_grads[name] = grads
    return all_grads


In [6]:
def update_attribution(parameter_names,weight_attribution,num,grads_before,grads,alpha,operation):
    if operation == 'minus':
        alpha = -alpha
    for name in parameter_names:
        if weight_attribution[name] is None:
            weight_attribution[name] = alpha * grads_before[name].sign() * (grads_before[name] + grads[name]) / num / 2
        else:
            weight_attribution[name] += alpha * grads_before[name].sign() * (grads_before[name] + grads[name]) / num / 2
    return weight_attribution

In [8]:
loss_fn = torch.nn.CrossEntropyLoss()
parameter_names = [name for name, param in net.named_parameters() if param.requires_grad]
weight_attribution = dict()
for name in parameter_names:
    weight_attribution[name] = None
start_loss = None
end_loss = None
grads_before = None
grads = None
num_steps = 2
alpha = 1e-8
operation = 'minus'
for i in range(num_steps + 1):
    total_loss, num = forward_backward(net, train_dataloader_all, loss_fn)
    if i == 0:
        start_loss = total_loss / num
        grads_before = update_all_params(net,alpha,parameter_names,operation)
    else:
        grads = update_all_params(net,alpha,parameter_names,operation)
        weight_attribution = update_attribution(parameter_names,weight_attribution,num,grads_before,grads,alpha,operation)
        grads_before = grads
        end_loss = total_loss / num
    net.zero_grad()
        

In [12]:
import numpy as np
model_state_dict = model.state_dict()
combine = dict()
combine_flatten = list()
for name in parameter_names:
    combine[name] = (weight_attribution[name] * model_state_dict[name]).abs()
    combine_flatten.append(combine[name].flatten())
combine_flatten = torch.cat(combine_flatten)
combine_flatten = combine_flatten.cpu().detach().numpy()
threshold = np.percentile(combine_flatten, 60)
mask = dict()
for name in parameter_names:
    mask[name] = combine[name] > threshold
    model_state_dict[name] = mask[name] * model_state_dict[name]

In [13]:
model_copy = copy.deepcopy(model)
model_copy.load_state_dict(model_state_dict)

<All keys matched successfully>

In [14]:
test_preds = []
test_labels = []
with torch.no_grad():
    for x, y in test_dataloader_all:
        x, y = x.to(device), y.to(device)
        outputs = model_copy(x)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(y.cpu().numpy())
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)
test_preds = np.argmax(test_preds, axis=1)
acc = np.mean(test_preds == test_labels)
print(acc)

0.9347
