In [43]:
#coding:utf-8
import sys
sys.path.append("..")
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from torch.autograd import Variable
from models.resnet import *
from models.vggnet import *
from models.mynet import *
import torch.optim as optim
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [44]:
def get_adv_path(attack, dataset, model_name):
    print(dataset + "_" + model_name)
    dir_path = "adv_image"
    dic = {
           "cifar_resnet18": _get_new_adv_path(dir_path, attack, dataset, model_name),
           "cifar_vgg16": _get_new_adv_path(dir_path, attack, dataset, model_name),
           }
    print(dic[dataset + "_" + model_name])
    return dic[dataset + "_" + model_name]

def _get_new_adv_path(dir_path, attack, dataset, model_name):
    # ./adv_image/fashionLeNet1_cw_image.npy
    i = './{}/{}_{}_{}.npy'.format(dir_path, dataset + model_name, attack, 'image')
    l = './{}/{}_{}_{}.npy'.format(dir_path, dataset + model_name, attack, 'label')
    return i, l

In [45]:
def get_testcase(model_name,dataset_name,model,val_dataloader):
    all_test_data=[]
    all_test_label=[]
    for data, label in val_dataloader:
        all_test_data.append(data)#clean
        all_test_label.append(label)
        break
    all_test_data = np.concatenate(all_test_data, axis=0)
    all_test_label = np.concatenate(all_test_label, axis=0)
    
    attack_lst = ['cw', 'fgsm', 'jsma', 'bim']
    adv_image_all = []
    adv_label_all = []
    for attack in attack_lst:
        im, lab = get_adv_path(attack, "cifar", "vgg16")
        adv_image_all.append(np.load(im))
        adv_label_all.append(np.load(lab))
    adv_image_all = np.concatenate(adv_image_all, axis=0)
    adv_label_all = np.concatenate(adv_label_all, axis=0)
    adv_image_all = np.transpose(adv_image_all, (0,3,1,2))
    print("adv: ", adv_image_all.shape)
    print("clean: ", all_test_data.shape)
    all_test_data = torch.from_numpy(np.concatenate([all_test_data, adv_image_all], axis=0))
    all_test_label = torch.from_numpy(np.concatenate([all_test_label, adv_label_all], axis=0))
        
    print(all_test_data.shape)
    print(all_test_label.shape)
    torch.save(all_test_data,'images_of_TestCaseSet_{}_{}.pt'.format(model_name,dataset_name))
    torch.save(all_test_label,'labels_of_TestCaseSet_{}_{}.pt'.format(model_name,dataset_name))

In [46]:
batch_size=128

#vgg16+cifar10
val_dataset = datasets.CIFAR10(root='../dataset/data', train=False,download=False, transform=transforms.ToTensor())
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
model = vgg16_bn().to(device)
model.load_state_dict((torch.load('../adv_train/model-vgg16-cifar10/Standard-cifar10-model-vgg16-epoch300.pt')))#评估普通模型-干净样本准确率
model = model.to(device).eval()
get_testcase("vgg16","cifar10",model,val_dataloader)
print("vgg16+cifar10 testcaseset ok!!!")

cifar_vgg16
('./adv_image/cifarvgg16_cw_image.npy', './adv_image/cifarvgg16_cw_label.npy')
cifar_vgg16
('./adv_image/cifarvgg16_fgsm_image.npy', './adv_image/cifarvgg16_fgsm_label.npy')
cifar_vgg16
('./adv_image/cifarvgg16_jsma_image.npy', './adv_image/cifarvgg16_jsma_label.npy')
cifar_vgg16
('./adv_image/cifarvgg16_bim_image.npy', './adv_image/cifarvgg16_bim_label.npy')
adv:  (8000, 3, 32, 32)
clean:  (128, 3, 32, 32)
torch.Size([8128, 3, 32, 32])
torch.Size([8128])
vgg16+cifar10 testcaseset ok!!!


In [66]:
images=torch.load('images_of_TestCaseSet_vgg16_cifar10.pt')
labels=torch.load('labels_of_TestCaseSet_vgg16_cifar10.pt')

In [75]:
datalist=torch.split(images, batch_size, dim=0)
labellist=torch.split(labels, batch_size, dim=0)
test_loss = 0
correct = 0
bcount=0
with torch.no_grad():
    for data in datalist:
        data = data.to(device)
        target=labellist[bcount].to(device)
        output = model(data)
        pred = output.max(1, keepdim=True)[1]
        correct+=pred.eq(target.view_as(pred)).sum().item()
        bcount+=1

test_accuracy = correct/len(images)
print('准确率:',test_accuracy)

准确率: 0.8987450787401575


In [59]:
datalist=torch.split(images, batch_size, dim=0)
labellist=torch.split(labels, batch_size, dim=0)

pred_test_prob=[]
for data_batch in datalist:
    output=model(data_batch.to(device))
    prob = F.softmax(output)
    pred_one=prob.cpu().detach()
    pred_test_prob.append(pred_one)
pred_test_prob=torch.cat(pred_test_prob,dim=0)
pred_test_prob=pred_test_prob.numpy()

pred_test=np.argmax(pred_test_prob, axis=1)

  import sys


In [60]:
pred_test_prob.shape

(8128, 10)

In [65]:
for i in range(pred_test_prob.shape[0]):
    i = 'label:{}:{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f}'.format(labels[i],pred_test_prob[i][0],pred_test_prob[i][1],pred_test_prob[i][2],pred_test_prob[i][3],pred_test_prob[i][4],pred_test_prob[i][5],pred_test_prob[i][6],pred_test_prob[i][7],pred_test_prob[i][8],pred_test_prob[i][9])
    print(i)

label:3.0:0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:6.0:0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000
label:6.0:0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000
label:1.0:0.000,0.999,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:6.0:0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000
label:3.0:0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000
label:1.0:0.000,0.999,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.001
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:9.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:9.0:0.000,0.00

label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.999,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.006,0.000,0.002,0.000,0.991,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.001,0.000,0.999,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000
label:7.0:0.000,0.000,0.000,0.000,0.000,0.001,0.000,0.998,0.000,0.000
label:7.0:0.000,0.000,0.000,0.001,0.000,0.954,0.000,0.045,0.000,0.000
label:7.0:0.000,0.00

label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,0.999,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.001,0.000,0.040,0.908,0.006,0.031,0.010,0.002,0.001,0.001
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.001,0.000,0.999,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000
label:5.0:0.002,0.000,0.054,0.923,0.001,0.014,0.002,0.002,0.001,0.001
label:5.0:0.000,0.00

label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.002,0.003,0.114,0.157,0.007,0.682,0.006,0.027,0.002,0.002
label:2.0:0.000,0.000,0.003,0.000,0.001,0.000,0.000,0.995,0.000,0.001
label:2.0:0.007,0.000,0.980,0.004,0.002,0.002,0.002,0.001,0.002,0.000
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.000,0.038,0.000,0.000,0.000,0.000,0.959,0.000,0.001
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.242,0.001,0.409,0.003,0.001,0.000,0.339,0.001,0.002,0.001
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.000,0.999,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:2.0:0.000,0.00

label:0.0:0.031,0.003,0.005,0.002,0.003,0.001,0.000,0.001,0.947,0.006
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:0.997,0.000,0.000,0.000,0.001,0.000,0.000,0.000,0.001,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:0.381,0.001,0.001,0.000,0.004,0.000,0.000,0.000,0.604,0.008
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
label:0.0:1.000,0.00

label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.003,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.996,0.001
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.011,0.018,0.000,0.001,0.001,0.001,0.000,0.000,0.940,0.028
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000
label:8.0:0.000,0.00