In [1]:

import argparse
from models.resnet import resnet50
from datasets import load_cifar10, load_cifar100, load_caltech_101
import torch
import torch.nn as nn
import numpy as np
import os
import torchvision
import matplotlib.pyplot as plt
import pickle as pkl
from foolbox.attacks import LinfPGD, L2PGD, L2BasicIterativeAttack, LinfBasicIterativeAttack, L2CarliniWagnerAttack
from foolbox import PyTorchModel
from tqdm import tqdm
from sigmoid_method import exp as sigmoid_exp
from combine_method import exp as combine_exp
from utils import setup_seed,AverageMeter
setup_seed(3407)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
train_dataloader, test_dataloader, data_min, data_max = load_cifar10()

Files already downloaded and verified
Files already downloaded and verified


In [3]:
model = resnet50("cifar10")

In [4]:
attack = LinfPGD()

In [5]:
pt_model = PyTorchModel(model, bounds=(data_min-1e-8, data_max+1e-8))

In [6]:
def caculate_total(model, x, delta_x, y, steps=5, alpha=0.0025, lambda_r=0.01, op="add"):
    total = None
    delta_x = delta_x.clone()
    delta_x = torch.nn.Parameter(data=delta_x)
    loss_func = torch.nn.CrossEntropyLoss(reduction='sum')
    x.requires_grad_(False)
    delta_x.requires_grad_(True)
    model.zero_grad()
    for _ in range(steps):
        outputs = model(x+delta_x)
        model.zero_grad()
        loss = loss_func(outputs, y)
        regularization_loss = torch.norm(delta_x)
        if op == "minus":
            loss = loss + lambda_r * regularization_loss
            loss.backward(retain_graph=True)
            grad = delta_x.grad
            delta_x = delta_x - alpha * torch.sign(grad)
        elif op == "add":
            loss = loss - lambda_r * regularization_loss
            loss.backward(retain_graph=True)
            grad = delta_x.grad
            delta_x = delta_x + alpha * torch.sign(grad)
        delta_x = delta_x.detach().requires_grad_(True)
        if total == None:
            total = ((alpha * torch.sign(grad)) * grad).detach()
        else:
            total += ((alpha * torch.sign(grad)) * grad).detach()
        model.zero_grad()
    return total


def caculate_combine(total, delta, use_total=True, use_delta=True):
    if use_total and use_delta:
        combine = torch.abs(torch.abs(total) * delta)
    elif use_total:
        combine = torch.abs(total)
    elif use_delta:  # taylor
        combine = torch.abs(delta)
    combine = combine.unsqueeze(0).cpu().detach().numpy()
    combine_flatten = combine.flatten()
    return combine, combine_flatten


def get_result(model, x, pos, combine, combine_flatten, delta):
    threshold = np.sort(combine_flatten)[pos]
    delta_ = delta.clone()
    delta_[combine < threshold] = 0
    result = model(x+delta_).argmax(-1)
    return result, torch.norm(delta_).item(), delta_


def binary_search(model, x, delta, y, r, combine, combine_flatten, search_times=10):
    l = 0
    r = r
    pos = int((l + r) / 2)
    y_label = y.argmax(-1)
    for _ in range(search_times):
        if l == r:
            break
        result, norm_delta, delta_ = get_result(
            model, x, pos, combine, combine_flatten, delta)
        if result == y_label:
            l = pos
            pos = int((pos + r) / 2)
            flag = pos
            flag_norm = norm_delta
            flag_delta = delta_
        else:
            r = pos
            pos = int((pos + l) / 2)
    return flag, flag_norm, flag_delta


def exp(model, x, delta_x, y, add_steps=5, minus_steps=0, alpha=0.0025, lambda_r=0.01, method="total*delta"):
    if add_steps != 0:
        total_add = caculate_total(
            model, x.unsqueeze(0), delta_x.unsqueeze(0), y.argmax(-1).unsqueeze(0), steps=add_steps, op="add", alpha=alpha, lambda_r=lambda_r)
    if minus_steps != 0:
        total_minus = caculate_total(
            model, x.unsqueeze(0), delta_x.unsqueeze(0), y.argmax(-1).unsqueeze(0), steps=minus_steps, op="minus", alpha=alpha, lambda_r=lambda_r)
    if add_steps != 0 and minus_steps != 0:
        total = total_add + total_minus
    elif add_steps != 0:
        total = total_add
    elif minus_steps != 0:
        total = total_minus
    delta_x = delta_x.unsqueeze(0)
    if method == "total*delta":
        combine, combine_flatten = caculate_combine(
            total, delta_x, use_total=True, use_delta=True)
    elif method == "total":
        combine, combine_flatten = caculate_combine(
            total, delta_x, use_total=True, use_delta=False)
    elif method == "taylor":
        combine, combine_flatten = caculate_combine(
            total, delta_x, use_total=False, use_delta=True)
    _, flag_norm, _ = binary_search(model, x.unsqueeze(0), delta_x, y, len(
        combine_flatten) - 1, combine, combine_flatten, search_times=10)
    return flag_norm

def print_adv_pred_bar(adv_pred):
    plt.bar(list(range(10)), nn.Softmax()(adv_pred[0]).detach().numpy())

In [7]:
count = 0
total = 1
pbar = tqdm(total=total)
avg = AverageMeter()
for x, label in test_dataloader:
    x, label = x.to(device), label.to(device)
    pred = model(x)
    correct = pred.argmax(-1) == label
    x = x[correct]
    label = label[correct]
    _, adv_data, success = attack(pt_model, x, label, epsilons=0.5)
    adv_pred = model(adv_data)
    success_x = x[success]
    success_adv_data = adv_data[success]
    success_label = label[success]
    # print('success_adv_data', success_adv_data.shape)
    success_y = adv_pred[success]
    print_adv_pred_bar(pred[correct])
    print_adv_pred_bar(adv_pred)
    for x, label, delta_x, y in zip(success_x, success_label, success_adv_data-success_x, success_y):
        delta_x_norm = torch.norm(delta_x)
        try:
            norm = exp(x=x, delta_x=delta_x, y=y, model=model, add_steps=5, minus_steps=0, alpha=0.1, lambda_r=0.01, method="total*delta")
            avg.calculate(norm, delta_x_norm)
        except:
            continue
        pbar.update(1)
        count += 1
        if count == total:
            break
    if count == total:
        break
print(avg.avg)

100%|██████████| 10/10 [00:09<00:00,  2.26it/s]

tensor(0.5233, device='cuda:0')


: 