In [None]:
import time
import random
import numpy as np
import torch.nn as nn
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from DeepFool.Python.deepfool import deepfool
from torch.autograd.gradcheck import zero_gradients
def get_model(device):
    net=models.vgg16(pretrained=True)
    net.eval()
    net=net.to(device)
    return net
def data_input_init(xi):
    mean = [ 0.485, 0.456, 0.406 ]
    std = [ 0.229, 0.224, 0.225 ]
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean = mean,
                         std = std)])
    return (mean,std,transform)
def proj_lp(v,xi,p):
    # 映射到p阶范式球形上，球的中心是0，半径为xi
    if p==np.inf:
        v=torch.clamp(v,-xi,xi) # 无穷范式
    else:
        v=v*min(1,xi/(torch.norm(v,p)+0.00001))
    return v
def get_fooling_rate(data_list,v,model,device):
    # data_list 输入图片的路经
    # v: Noise matrix
    # model: 目标网络
    # device: 计算设备
    tf = data_input_init(0)[2]
    num_images = len(data_list)
    fooled=0.0
    for name in tqdm(data_list):
        image = Image.open(name)
        image =tf(image)
        image=image.unsqueeze(0)
        image=image.to(device)
        _,pred = torch.max(model(image),1)
        _,adv_pred = torch.max(model(image+v),1)
        if(pred != adv_pred):
            fooled+=1
    # 计算愚弄成功率
    fooling_rate = fooled/num_images
    print('Fooling Rate = ',fooling_rate)
    for param in model.parameters():
        # 冻结梯度更新
        param.requires_grad = False
    return fooling_rate,model
def universal_adversarial_perturbation(data_list,model,device,xi=10,
                                       delta=0.2,max_iter_uni=10,p=np.inf,
                                       num_classes=10,overshoot=0.02,
                                       max_iter_df=10,t_p=0.2):
    '''
    data_list: 图像路经
    model：目标模型
    device：计算设备
    xi:控制着p阶范式的扰动程度
    delta: 控制愚弄率，默认80%
    max_iter_uni: 最大迭代次数
    p: 正则化阶数
    num_classes 用于deepfool，限制测试的类别，默认为10
    overshoot：用于deepfool，防止无效更新
    max_iter_df:最大的deepfool的迭代次数
    t_p:用于deepfool，真实的比例，默认为0.2
    函数返回通用扰动矩阵
    '''
    time_start=time.time()
    mean,std,tf=data_input_init(xi)
    v=torch.zeros(1,3,224,224).to(device)
    v.requires_grad_()
    fooling_rate=0.0
    num_images = len(data_list)
    itr = 0
    while fooling_rate<1-delta and itr<max_iter_uni:
        random.shuffle(data_list)
        # 迭代增量扰动
        pbar=tqdm(data_list)
        pbar.set_description('Starting pass number '+str(iter))
        for k,name in enumerate(pbar):
            img = Image.open(name)
            img = tf(img)
            img = img.to(device)
            img = img.unsqueeze(0)
            _,pred = torch.max(model(img),1)
            _,adv_pred=torch.max(model(img+v),1)
            if(pred==adv_pred):
                dr,iter,_,_,_=deepfool((img+v).detach()[0],model,device,
                                      num_classes=num_classes,overshoot=overshoot,
                                      max_iter=max_iter_df)
                if(iter<max_iter_df-1):
                    v=v+torch.from_numpy(dr).to(device)
                    v = proj_lp(v,xi,p)
            if(k%10==0):
                pbar.set_description('Norm of v:'+str(torch.norm(v).detach().cpu().numpy()))
            fooling_rate,model=get_fooling_rate(data_list,v,model,device)
            itr=itr+1
    return v