In [None]:
'''
reference from https://github.com/sky4689524/Pytorch_AdversarialAttacks
paper : https://arxiv.org/abs/1703.08603
'''

In [32]:
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

import pickle
import random
import sys
import os
from typing import Any, Tuple
from PIL import Image
from torchvision.datasets import VOCSegmentation
from dag import DAG
from dag_utils import generate_target, generate_target_swap
from util import make_one_hot

from optparse import OptionParser

n_classes = 21
BATCH_SIZE = 8
dag = 'DAG_A'

In [33]:
class VOCSegDataset(VOCSegmentation):
        def __getitem__(self, index: int) -> Tuple[Any, Any]:
            image = Image.open(self.images[index]).convert('RGB')
            label = Image.open(self.targets[index])
            image = self.transform(image)
            label = self.target_transform(label)
            label = (label*255)
            label_no_255 = torch.where(label >= 255, torch.zeros_like(label), label)
            # binary_tensors = torch.zeros((num_classes, *label.shape), dtype=torch.) # 개별인자 연산자.
            # for class_idx in range(num_classes):
            #     binary_tensors[class_idx] = (label == class_idx).float()

            return image ,label_no_255.long()
        
image_transforms = transforms.Compose([
                            transforms.Resize((256,256)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5) ),
                            ])
target_transform = transforms.Compose([transforms.Resize((256,256)),
                            transforms.ToTensor(),
                            ])
test_dataset = VOCSegDataset('./data',
                                year='2012',download=False ,image_set='trainval', transform=image_transforms, target_transform=target_transform)
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=0)
model = (torch.load('./UNet_with_dice.pth'))
print(model)

UNet(
  (inc): double_conv(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): contraction_path(
    (contract): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0

In [34]:

def DAG_Attack(model, test_dataset):
    
    # Hyperparamter for DAG 
    num_iterations=15
    gamma=0.2
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else :
        print('NOCUDAAVAILABLE')
        
    model = model.to(device)
    
    adversarial_examples = []
    
    for batch_idx in range(len(test_dataset)):
        image, label = test_dataset.__getitem__(batch_idx)

        image = image.unsqueeze(0)
        pure_label = label.squeeze(0).numpy()

        image , label = image.clone().detach().requires_grad_(True).float(), label.clone().detach().float()
        image , label = image.to(device), label.to(device)

        # Change labels from [batch_size, height, width] to [batch_size, num_classes, height, width]
        label_oh=make_one_hot(label.long(),n_classes,device)

        if dag == 'DAG_A':

            adv_target = torch.zeros_like(label_oh)

        elif dag == 'DAG_B':

            adv_target=generate_target_swap(label_oh.cpu().numpy())
            adv_target=torch.from_numpy(adv_target).float()

        elif dag == 'DAG_C':
            
            # choice one randome particular class except background class(0)
            unique_label = torch.unique(label)
            target_class = int(random.choice(unique_label[1:]).item())

            adv_target=generate_target(label_oh.cpu().numpy(), target_class = target_class)
            adv_target=torch.from_numpy(adv_target).float()

        else :
            print("wrong adversarial attack types : must be DAG_A, DAG_B, or DAG_C")
            raise SystemExit
        
        adv_target=adv_target.to(device)

        _, _, _, _, _, image_iteration=DAG(model=model,
                  image=image,
                  ground_truth=label_oh,
                  adv_target=adv_target,
                  num_iterations=num_iterations,
                  gamma=gamma,
                  no_background=True,
                  background_class=0,
                  device=device,
                  verbose=False)

        if len(image_iteration) >= 1:

            adversarial_examples.append([image_iteration[-1],
                                         pure_label])

        del image_iteration
    
    print('total {} images are generated'.format(len(adversarial_examples)))
    
    return adversarial_examples

In [35]:
adversarial_examples = DAG_Attack(model, test_dataset)
with open('./Adversarial_img/'+dag+'.pickle', 'wb') as fp:
    pickle.dump(adversarial_examples, fp)


Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no gradient
Condition Reached, no