In [None]:
import torch
print(torch.cuda.is_available())
from train import prepare_cifar
from model import ResNet18
import os

In [None]:
_, testloader = prepare_cifar()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cudnn.benchmark = True

    

net_set = []
for i in range(5):
    net_set.append(ResNet18())
    net_set[i] = net_set[i].to(device)
    net_set[i] = torch.nn.DataParallel(net_set[i])
    checkpoint = torch.load("./checkpoint/ckpt_{}.pth".format(i))
    net_set[i].load_state_dict(checkpoint["net"])
    print("{} {}".format(checkpoint["acc"], checkpoint["epoch"]))

In [None]:
import copy
import numpy as np
def bad_case(net_set):
    pred_results = []
    gt = np.array([], dtype=np.int32)
    for i in range(5):
        net_set[i].eval()
        results = np.array([], dtype=np.int32)
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                if i == 0:
                    gt = np.concatenate((gt, targets), axis=0)
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net_set[i](inputs)
                _, pred = outputs.max(1)
                pred = copy.deepcopy(pred.to("cpu").numpy())
                results = np.concatenate((results, pred), axis=0)
                print(batch_idx)
        print(results.shape)
        pred_results.append(results)
    
    def iou(pred1, pred2, gt):
        inter_sz = 0
        union_sz = 0
        for i in range(pred1.shape[0]):
            if gt[i] != pred1[i] or gt[i] != pred2[i]:
                union_sz += 1
                if pred1[i] == pred2[i]:
                    inter_sz += 1
        return inter_sz / union_sz
    
    def iou2(pred1, pred2, gt):
        inter_sz = 0
        union_sz = 0
        for i in range(pred1.shape[0]):
            if gt[i] != pred1[i] or gt[i] != pred2[i]:
                union_sz += 1
                if gt[i] != pred1[i] and gt[i] != pred2[i]:
                    inter_sz += 1
        return inter_sz / union_sz
    
    for i in range(5):
        for j in range(i + 1, 5):
            print(iou(pred_results[i], pred_results[j], gt), iou2(pred_results[i], pred_results[j], gt))
            
bad_case(net_set)

In [None]:
mean_channel = [0.4914, 0.4822, 0.4465]
std_channel = [0.2023, 0.1994, 0.2010]
def restore_fig(img):
    img = copy.deepcopy(img)
    for i in range(3):
        img[i, :, :] = 255 * (img[i, :, :] * std_channel[i] + mean_channel[i] + 1e-5)
    
    img = img.astype(np.uint8)
    img = np.swapaxes(img, 0, 1)
    img = np.swapaxes(img, 1, 2)
    return img

def plot_img(img):
    plt.figure(figsize=(2, 2))
    plt.imshow(img)

def pertube(net, inps, targets, eps):
    length = 32. * 32. * 3
    iter_time = 10
    pert = torch.zeros(inps.shape, dtype=torch.float32).to(device)
    
    criterion = nn.CrossEntropyLoss()

    for i in range(iter_time):
        net.zero_grad()
        inputs = torch.autograd.Variable(inps + pert, requires_grad=True)
        outputs = net(inputs)
        import pdb
        pdb.set_trace()
        loss = criterion(outputs, targets)
        loss.backward()
        grad = inputs.grad.data
        grad = eps / 3 * length * grad / torch.norm(grad.reshape(targets.shape[0], -1), dim=1).reshape(targets.shape[0], 1, 1, 1)
        pert = pert + grad
        pert = eps * length * pert / torch.norm(pert.reshape(targets.shape[0], -1), dim=1).reshape(targets.shape[0], 1, 1, 1)

    return pert

    
def experiment_attack(net_set):
    for i in range(5):
        net_set[i].eval()
    for eps in [5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 3e-2, 5e-2]:
        for i in range(2):
            for j in range(i + 1, 2):
                acc1 = 0
                acc2 = 0
                for batch_idx, (inputs, targets) in enumerate(testloader):
                    inps = inputs.to(device)
                    tgts = targets.to(device)
                    fake = torch.tensor(np.random.randint(0, 10, targets.shape[0]), dtype=torch.int64).to(device)
                    pert = pertube(net_set[i], inps, fake, eps)
                    inps += pert
                    tmp = inps.to("cpu").numpy()
                    if batch_idx == 0:
                        plot_img(restore_fig(inputs[0]))
                        plot_img(restore_fig(tmp[0]))

                    oup1 = net_set[i](inps)
                    oup2 = net_set[j](inps)
                    _, pred1 = oup1.max(1)
                    _, pred2 = oup2.max(1)
                    acc1 += pred1.eq(tgts).sum().item()
                    acc2 += pred2.eq(tgts).sum().item()
                
                acc1 /= 10000
                acc2 /= 10000
                print(acc1, acc2)

                    
                
experiment_attack(net_set)