In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
! pip install torchattacks
! pip install adversarial-robustness-toolbox==1.8.1

In [4]:
import gc
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import torchattacks

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset
from torch.hub import load_state_dict_from_url
from torchvision import datasets, models, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from torch.autograd import Variable

from PIL import Image
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm.notebook import tqdm

from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescentPyTorch, CarliniLInfMethod, SaliencyMapMethod, DeepFool
from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist, load_cifar10
from art.utils import load_dataset

from torchattacks import PGD, FGSM, DeepFool, CW

### Models

In [5]:
IMG_CHANNELS = 3

In [6]:
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
      super(LambdaLayer, self).__init__()
      self.lambd = lambd

    def forward(self, x):
      return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.in_channels = IMG_CHANNELS
        self.conv1 = nn.Conv2d(self.in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(self._weights_init)

    def _weights_init(self, m):
        classname = m.__class__.__name__
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
          nn.init.kaiming_normal_(m.weight)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        
        return out

class Detector(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(Detector, self).__init__()
        self.in_planes = 16

        self.in_channels = IMG_CHANNELS
        self.freeze_till = 4
        self.conv1 = nn.Conv2d(self.in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        conv2 = nn.Conv2d(32, 96, kernel_size=3, stride=1, padding=0, bias=False)
        bn2 = nn.BatchNorm2d(96)
        conv3 = nn.Conv2d(96, 192, kernel_size=3, stride=1, padding=0, bias=False)
        bn3 = nn.BatchNorm2d(192)
        conv4 = nn.Conv2d(192, 2, kernel_size=1, stride=1, padding=0, bias=False)
        bn4 = nn.BatchNorm2d(2)
        relu = nn.ReLU(inplace=True)
        avgpool = nn.AdaptiveAvgPool2d((1, 1))
        flatten = nn.Flatten()
        linear = nn.Linear(2, 1)
        sigmoid = nn.Sigmoid()

        self.layers = nn.Sequential(
            conv2,
            bn2,
            relu,
            conv3,
            bn3,
            relu,
            conv4,
            bn4,
            relu,
            avgpool,
            flatten,
            linear,
            sigmoid,
        )

        self.apply(self._weights_init)

    def _weights_init(self, m):
        classname = m.__class__.__name__
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
          nn.init.kaiming_normal_(m.weight)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layers(out)
        
        return out

def freeze_base_classifier(model, freeze_till=4):
    for itr, param in enumerate(model.parameters()):
        param.requires_grad = False
        if itr == freeze_till-1:
            break

def copy_base_classifier(detector, resnet_state_dict):
    with torch.no_grad():
        detector.conv1.weight.copy_(resnet_state_dict["conv1.weight"])
        detector.bn1.weight.copy_(resnet_state_dict["bn1.weight"])
        detector.layer1.weight.copy_(resnet_state_dict["layer1.weight"])
        detector.layer2.weight.copy_(resnet_state_dict["layer2.weight"])

    return detector

def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

def resnet32():
    return ResNet(BasicBlock, [5, 5, 5])

def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])

def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])

def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])

def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])

In [7]:
def statically_train_detector(
    detector,
    detector_optimizer,
    detector_criterion,
    detector_scheduler,
    unattacked_train_data,
    unattacked_test_data,
    attacked_train_data,
    attacked_test_data,
    device,
    epochs=100,
    batch_size=64,
):
    detector.train()
    freeze_base_classifier(detector, freeze_till=4)
    batches = []

    for epoch in range(epochs):

      avg_loss = 0.0
      for batch_itr in tqdm(range(0, len(unattacked_train_data), batch_size)):
          attacked_input = attacked_train_data[batch_itr:batch_itr+batch_size]
          attacked_labels = torch.ones((attacked_input.shape[0], 1), dtype=torch.float32)
          unattacked_input = unattacked_train_data[batch_itr:batch_itr+batch_size]
          unattacked_labels = torch.zeros((unattacked_input.shape[0], 1), dtype=torch.float32)

          input = torch.cat((unattacked_input, attacked_input), axis=0)
          labels = torch.cat((unattacked_labels, attacked_labels), axis=0)

          assert input.shape[0] == labels.shape[0]
          shuffle_indices = np.arange(input.shape[0])
          np.random.shuffle(shuffle_indices)
          input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

          input, labels = input.to(device), labels.to(device)

          output = detector(input.float())

          loss = detector_criterion(output, labels)
          loss.backward()
          avg_loss += loss.item()

          detector_optimizer.step()

          del input
          del labels
          del loss
          torch.cuda.empty_cache()

      val_loss, val_acc, val_roc = statically_test_detector(
          detector,
          detector_criterion,
          unattacked_test_data,
          attacked_test_data,
          device,
      )

      print('Val Loss: {:.4f} | Val Accuracy: {:.4f} | Val ROC: {:.4f}'.format(val_loss, val_acc, val_roc))

      detector_scheduler.step(val_loss)

def statically_test_detector(
    detector,
    criterion,
    unattacked_test_data,
    attacked_test_data,
    device,
    batch_size=64,
):
    detector.eval()
    test_loss = []
    accuracies = []
    roc_scores = []

    for batch_itr in tqdm(range(0, len(unattacked_test_data), batch_size)):
        attacked_input = attacked_test_data[batch_itr:batch_itr+batch_size]
        attacked_labels = torch.ones((attacked_input.shape[0], 1), dtype=torch.float32)
        unattacked_input = unattacked_test_data[batch_itr:batch_itr+batch_size]
        unattacked_labels = torch.zeros((unattacked_input.shape[0], 1), dtype=torch.float32)

        input = torch.cat((unattacked_input, attacked_input), axis=0)
        labels = torch.cat((unattacked_labels, attacked_labels), axis=0)

        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        input, labels = input.to(device), labels.to(device)

        with torch.no_grad():
            output = detector(input.float())

        pred_labels = (output > 0.5).float()
        loss = criterion(output, labels)

        accuracy = accuracy_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        roc_score = roc_auc_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        accuracies.append(accuracy)
        roc_scores.append(roc_score)
        test_loss.extend([loss.item()]*input.size()[0])
        
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

    detector.train()
    freeze_base_classifier(detector, freeze_till=4)

    return np.mean(test_loss), np.mean(accuracies), np.mean(roc_scores)

In [26]:
def dynamically_train_detector(
    detector,
    detector_optimizer,
    detector_criterion,
    detector_scheduler,
    benign_train_imgs,
    benign_test_imgs,
    device,
    epochs=100,
    batch_size=64,
):
    detector.train()
    best_loss = float('inf')
    for epoch in range(epochs):
      
      avg_loss = 0.0
      for batch_itr in tqdm(range(0, len(benign_train_imgs), batch_size)):

        # get resnet outputs of benign imgs and labels as '0'
        benign_inputs = benign_train_imgs[batch_itr:batch_itr+batch_size].float().to(device)
        benign_train_labels = torch.zeros((benign_inputs.shape[0], 1), dtype=torch.float32)

        # create a attack instance using current state of the subnet
        train_attack = torchattacks.PGD(detector, eps=8/255, alpha=8/(255*40), steps=40)
        adv_inputs = train_attack(benign_inputs, benign_train_labels)
        adv_train_labels = torch.ones((adv_inputs.shape[0], 1), dtype=torch.float32)

        # create a 2x batch with adv and benign images
        input = torch.cat((benign_inputs, adv_inputs), axis=0)
        labels = torch.cat((benign_train_labels, adv_train_labels), axis=0)

        # shuffle the combined inputs and labels
        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        # feed to subnet
        input, labels = input.to(device), labels.to(device)
        output = detector(input)

        # calculate loss
        loss = detector_criterion(output, labels)
        loss.backward()
        avg_loss += loss.item()

        # param update
        detector_optimizer.step()

        # cleanup
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

      avg_loss /= len(benign_train_imgs)

      val_loss, val_acc, val_roc = dynamically_test_detector(
          detector,
          detector_criterion,
          benign_test_imgs,
          device,
      )

      print('Train Loss: {:.4f} | Val Loss: {:.4f} | Val Accuracy: {:.4f} | Val ROC: {:.4f}'.format(avg_loss, val_loss, val_acc, val_roc))
      detector_scheduler.step(val_loss)

def dynamically_test_detector(
    detector,
    detector_criterion,
    benign_test_imgs,
    device,
    batch_size=64,
    threshold=0.5,
):
    detector.eval()
    test_loss = []
    accuracies = []
    roc_scores = []

    val_attack = torchattacks.PGD(detector, eps=8/255, alpha=8/(255*40), steps=40)
    for batch_itr in tqdm(range(0, len(benign_test_imgs), batch_size)):

        # get resnet outputs of benign imgs and labels as '0'
        benign_inputs = benign_test_imgs[batch_itr:batch_itr+batch_size].float().to(device)
        benign_test_labels = torch.zeros((benign_inputs.shape[0], 1), dtype=torch.float32)
        
        # create a attack instance using current state of the subnet
        adv_inputs = val_attack(benign_inputs, benign_test_labels)
        adv_test_labels = torch.ones((adv_inputs.shape[0], 1), dtype=torch.float32)

        # create a 2x batch with adv and benign images
        input = torch.cat((benign_inputs, adv_inputs), axis=0)
        labels = torch.cat((benign_test_labels, adv_test_labels), axis=0)

        # shuffle the combined inputs and labels
        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        # feed to subnet
        input, labels = input.to(device), labels.to(device)
        with torch.no_grad():
            output = detector(input)

        # get labels based on threshold and calculate loss
        pred_labels = (output > threshold).float()
        loss = detector_criterion(output, labels)

        # calculate metrics
        accuracy = accuracy_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        accuracies.append(accuracy)
        roc_score = roc_auc_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        roc_scores.append(roc_score)
        test_loss.extend([loss.item()]*input.size()[0])
        
        # cleanup
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

    detector.train()

    return np.mean(test_loss), np.mean(accuracies), np.mean(roc_scores)

In [9]:
!ls '/content/drive/MyDrive/11785 - Project/data'

benign_cifar.npy	      fgsm_cifar_eps0.01.npy
benign_cifar_train.npy	      fgsm_mnist_eps0.01.npy
benign_mnist.npy	      fgsm_mnist_eps0.1.npy
benign_mnist_train.npy	      fgsm_mnist_eps0.5.npy
cwlinf_cifar_default_art.npy  pgd_cifar_eps4255_alpha4255_steps40.npy
cwlinf_default_art.npy	      pgd_mnist_default_art.npy
cwlinf_eps0.01.npy	      pgd_mnist_eps0.01.npy
cwlinf_eps0.1.npy	      pgd_mnist_eps0.1.npy
cwlinf_mnist_default_art.npy  pgd_mnist_eps0.3.npy
fgsm_cifar_eps0.007.npy       pgd_mnist_eps4255_alpha4255_steps40.npy


In [10]:
!ls '/content/drive/MyDrive/11785 - Project/'

AdversarialDetection.pdf
cifar10_model.pth
cifar-resnet-dynamic-adv-trained-model.pt
classifier-heatmap-2000.png
classifier-heatmap.png
data
detector-classifier-heatmap.png
dynamic-trained-detector.pt
Experiments.gsheet
Flowchart.drawio
imgs
Inference.drawio
InferenceHorizaontal.drawio
InferenceHorizaontal.drawio.png
mnist_model.pth
mnist-resnet-dynamic-adv-trained-model.pt
Presentation.gslides
static-detector.png
static-trained-detector.pt
static-vs-dynamic-detector.png
Traininng.drawio
Traininng.drawio.png
why-detector.png


In [11]:
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

resnet = ResNet(BasicBlock, [5, 5, 5])
# checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/cifar-resnet-dynamic-adv-trained-model.pt")
checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/cifar10_model.pth")
# resnet.load_state_dict(checkpoint["model_state_dict"])
resnet.load_state_dict(checkpoint)
resnet.to(device)

dynamic_detector = Detector(BasicBlock, [5, 5, 5])
unmatched_keys = dynamic_detector.load_state_dict(resnet.state_dict(), strict=False)
freeze_base_classifier(dynamic_detector, freeze_till=4)
# checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/dynamic-trained-detector.pt")
# detector.load_state_dict(checkpoint)
dynamic_detector.to(device)

static_detector = Detector(BasicBlock, [5, 5, 5])
unmatched_keys = static_detector.load_state_dict(resnet.state_dict(), strict=False)
freeze_base_classifier(static_detector, freeze_till=4)
# checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/dynamic-trained-detector.pt")
# detector.load_state_dict(checkpoint)
static_detector.to(device)

print(f"unmatched keys: {unmatched_keys}")

unmatched keys: _IncompatibleKeys(missing_keys=['layers.0.weight', 'layers.1.weight', 'layers.1.bias', 'layers.1.running_mean', 'layers.1.running_var', 'layers.3.weight', 'layers.4.weight', 'layers.4.bias', 'layers.4.running_mean', 'layers.4.running_var', 'layers.6.weight', 'layers.7.weight', 'layers.7.bias', 'layers.7.running_mean', 'layers.7.running_var', 'layers.11.weight', 'layers.11.bias'], unexpected_keys=['layer3.0.conv1.weight', 'layer3.0.bn1.weight', 'layer3.0.bn1.bias', 'layer3.0.bn1.running_mean', 'layer3.0.bn1.running_var', 'layer3.0.bn1.num_batches_tracked', 'layer3.0.conv2.weight', 'layer3.0.bn2.weight', 'layer3.0.bn2.bias', 'layer3.0.bn2.running_mean', 'layer3.0.bn2.running_var', 'layer3.0.bn2.num_batches_tracked', 'layer3.1.conv1.weight', 'layer3.1.bn1.weight', 'layer3.1.bn1.bias', 'layer3.1.bn1.running_mean', 'layer3.1.bn1.running_var', 'layer3.1.bn1.num_batches_tracked', 'layer3.1.conv2.weight', 'layer3.1.bn2.weight', 'layer3.1.bn2.bias', 'layer3.1.bn2.running_mean', 

In [12]:
batch_size = 64

unattacked_data_path = "/content/drive/MyDrive/11785 - Project/data/benign_cifar.npy"
unattacked_data = np.load(unattacked_data_path, allow_pickle=True).astype(float)

# unattacked_data = torch.from_numpy(unattacked_data.transpose(0, 3, 1, 2))
unattacked_data = torch.from_numpy(unattacked_data)
print(f"unattacked data shape: {unattacked_data.shape}")

# train-test split
unattacked_train_data = unattacked_data[:9000]
unattacked_test_data = unattacked_data[9000:]

attacked_data_path = "/content/drive/MyDrive/11785 - Project/data/pgd_cifar_eps4255_alpha4255_steps40.npy"
# attacked_data_path = "/content/drive/MyDrive/11785 - Project/data/cwlinf_cifar_default_art.npy"
# attacked_data_path = "/content/drive/MyDrive/11785 - Project/data/fgsm_cifar_default_art.npy"
# attacked_data_path = "/content/drive/MyDrive/11785 - Project/data/pgd_cifar_default_torchattacks_new.npy"
attacked_data = np.load(attacked_data_path, allow_pickle=True).astype(float)
# attacked_data = torch.from_numpy(attacked_data.transpose(1, 0, 2, 3))
attacked_data = torch.from_numpy(attacked_data)
print(f"attacked data shape: {attacked_data.shape}")

# train-test split
attacked_train_data = attacked_data[:9000]
attacked_test_data = attacked_data[9000:]

unattacked data shape: torch.Size([10000, 3, 32, 32])
attacked data shape: torch.Size([10000, 3, 32, 32])


In [13]:
resnet_optimizer = torch.optim.SGD(resnet.parameters(), lr=0.1, weight_decay=5e-5, momentum=0.9)
resnet_criterion = nn.CrossEntropyLoss()
resnet_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(resnet_optimizer, T_0=10, T_mult=2, eta_min=0.01, last_epoch=-1)

dynamic_detector_optimizer = torch.optim.Adam(dynamic_detector.parameters(), lr=0.0001, betas=(0.99, 0.999))
dynamic_detector_criterion = nn.BCELoss()
dynamic_detector_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(dynamic_detector_optimizer, 'min', factor=0.5, patience=1)

static_detector_optimizer = torch.optim.Adam(static_detector.parameters(), lr=0.0001, betas=(0.99, 0.999))
static_detector_criterion = nn.BCELoss()
static_detector_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(static_detector_optimizer, 'min', factor=0.5, patience=1)

In [None]:
statically_train_detector(
    static_detector,
    static_detector_optimizer,
    static_detector_criterion,
    static_detector_scheduler,
    unattacked_train_data,
    unattacked_test_data,
    attacked_train_data,
    attacked_test_data,
    device,
    epochs=10,
    batch_size=64
)

In [22]:
def dynamically_test_static_detector(
    detector,
    detector_criterion,
    benign_test_imgs,
    device,
    batch_size=64,
    threshold=0.5,
):
    detector.eval()
    test_loss = []
    accuracies = []
    roc_scores = []

    val_attack = torchattacks.PGD(detector, eps=8/255, alpha=8/(255*40), steps=40)
    for batch_itr in tqdm(range(0, len(benign_test_imgs), batch_size)):

        # get resnet outputs of benign imgs and labels as '0'
        benign_inputs = benign_test_imgs[batch_itr:batch_itr+batch_size].float().to(device)
        benign_test_labels = torch.zeros((benign_inputs.shape[0], 1), dtype=torch.float32)
        
        # create a attack instance using current state of the subnet
        adv_inputs = val_attack(benign_inputs, benign_test_labels)
        adv_test_labels = torch.ones((adv_inputs.shape[0], 1), dtype=torch.float32)

        # create a 2x batch with adv and benign images
        input = torch.cat((benign_inputs, adv_inputs), axis=0)
        labels = torch.cat((benign_test_labels, adv_test_labels), axis=0)

        # shuffle the combined inputs and labels
        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        # feed to subnet
        input, labels = input.to(device), labels.to(device)
        with torch.no_grad():
            output = detector(input)

        # get labels based on threshold and calculate loss
        pred_labels = (output > threshold).float()
        loss = detector_criterion(output, labels)

        # calculate metrics
        accuracy = accuracy_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        accuracies.append(accuracy)
        roc_score = roc_auc_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        roc_scores.append(roc_score)
        test_loss.extend([loss.item()]*input.size()[0])
        
        # cleanup
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

    detector.train()

    return np.mean(test_loss), np.mean(accuracies), np.mean(roc_scores)

In [23]:
dynamically_test_static_detector(
    static_detector,
    static_detector_criterion,
    unattacked_test_data,
    device,
    batch_size=64,
    threshold=0.5,
)

  0%|          | 0/16 [00:00<?, ?it/s]

(0.6694466099739075, 0.534375, 0.534375)

In [29]:
dynamically_train_detector(
    dynamic_detector,
    dynamic_detector_optimizer,
    dynamic_detector_criterion,
    dynamic_detector_scheduler,
    unattacked_train_data,
    unattacked_test_data,
    device,
    epochs=10,
    batch_size=128,
)

  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5729 | Val Accuracy: 0.7628 | Val ROC: 0.7628


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5710 | Val Accuracy: 0.7581 | Val ROC: 0.7581


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5705 | Val Accuracy: 0.7683 | Val ROC: 0.7683


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5710 | Val Accuracy: 0.7697 | Val ROC: 0.7697


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5697 | Val Accuracy: 0.7656 | Val ROC: 0.7656


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5703 | Val Accuracy: 0.7670 | Val ROC: 0.7670


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5715 | Val Accuracy: 0.7666 | Val ROC: 0.7666


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5729 | Val Accuracy: 0.7694 | Val ROC: 0.7694


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5703 | Val Accuracy: 0.7677 | Val ROC: 0.7677


  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Train Loss: 0.0045 | Val Loss: 0.5722 | Val Accuracy: 0.7610 | Val ROC: 0.7610


### Attacks

In [None]:
class CombinedAttack(object):
    r"""
    Base class for all attacks.
    .. note::
        It automatically set device to the device where given model is.
        It basically changes training mode to eval during attack process.
        To change this, please see `set_training_mode`.
    """
    def __init__(self, name, model, detector_model):
        r"""
        Initializes internal attack state.
        Arguments:
            name (str): name of attack.
            model (torch.nn.Module): model to attack.
        """

        self.attack = name
        self.model = model
        self.detector = detector_model
        self.model_name = str(model).split("(")[0]
        self.device = next(model.parameters()).device

        self._attack_mode = 'default'
        self._targeted = False
        self._return_type = 'float'
        self._supported_mode = ['default']

        self._model_training = False
        self._batchnorm_training = False
        self._dropout_training = False

    def forward(self, *input):
        r"""
        It defines the computation performed at every call.
        Should be overridden by all subclasses.
        """
        raise NotImplementedError

    def get_mode(self):
        r"""
        Get attack mode.
        """
        return self._attack_mode

    def set_mode_default(self):
        r"""
        Set attack mode as default mode.
        """
        self._attack_mode = 'default'
        self._targeted = False
        print("Attack mode is changed to 'default.'")

    def set_mode_targeted_by_function(self, target_map_function=None):
        r"""
        Set attack mode as targeted.
        Arguments:
            target_map_function (function): Label mapping function.
                e.g. lambda images, labels:(labels+1)%10.
                None for using input labels as targeted labels. (Default)
        """
        if "targeted" not in self._supported_mode:
            raise ValueError("Targeted mode is not supported.")

        self._attack_mode = 'targeted'
        self._targeted = True
        self._target_map_function = target_map_function
        print("Attack mode is changed to 'targeted.'")

    def set_mode_targeted_least_likely(self, kth_min=1):
        r"""
        Set attack mode as targeted with least likely labels.
        Arguments:
            kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)
        """
        if "targeted" not in self._supported_mode:
            raise ValueError("Targeted mode is not supported.")

        self._attack_mode = "targeted(least-likely)"
        self._targeted = True
        self._kth_min = kth_min
        self._target_map_function = self._get_least_likely_label
        print("Attack mode is changed to 'targeted(least-likely).'")

    def set_mode_targeted_random(self, n_classses=None):
        r"""
        Set attack mode as targeted with random labels.
        Arguments:
            num_classses (str): number of classes.
        """
        if "targeted" not in self._supported_mode:
            raise ValueError("Targeted mode is not supported.")

        self._attack_mode = "targeted(random)"
        self._targeted = True
        self._n_classses = n_classses
        self._target_map_function = self._get_random_target_label
        print("Attack mode is changed to 'targeted(random).'")

    def set_return_type(self, type):
        r"""
        Set the return type of adversarial images: `int` or `float`.
        Arguments:
            type (str): 'float' or 'int'. (Default: 'float')
        .. note::
            If 'int' is used for the return type, the file size of 
            adversarial images can be reduced (about 1/4 for CIFAR10).
            However, if the attack originally outputs float adversarial images
            (e.g. using small step-size than 1/255), it might reduce the attack
            success rate of the attack.
        """
        if type == 'float':
            self._return_type = 'float'
        elif type == 'int':
            self._return_type = 'int'
        else:
            raise ValueError(type + " is not a valid type. [Options: float, int]")

    def set_training_mode(self, model_training=False, batchnorm_training=False, dropout_training=False):
        r"""
        Set training mode during attack process.
        Arguments:
            model_training (bool): True for using training mode for the entire model during attack process.
            batchnorm_training (bool): True for using training mode for batchnorms during attack process.
            dropout_training (bool): True for using training mode for dropouts during attack process.
        .. note::
            For RNN-based models, we cannot calculate gradients with eval mode.
            Thus, it should be changed to the training mode during the attack.
        """
        self._model_training = model_training
        self._batchnorm_training = batchnorm_training
        self._dropout_training = dropout_training

    def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
        r"""
        Save adversarial images as torch.tensor from given torch.utils.data.DataLoader.
        Arguments:
            save_path (str): save_path.
            data_loader (torch.utils.data.DataLoader): data loader.
            verbose (bool): True for displaying detailed information. (Default: True)
            return_verbose (bool): True for returning detailed information. (Default: False)
        """
        if (verbose==False) and (return_verbose==True):
            raise ValueError("Verobse should be True if return_verbose==True.")
            
        if save_path is not None:
            image_list = []
            label_list = []

        correct = 0
        total = 0
        l2_distance = []

        total_batch = len(data_loader)

        given_training = self.model.training

        for step, (images, labels) in enumerate(data_loader):
            start = time.time()
            adv_images = self.__call__(images, labels)

            batch_size = len(images)

            if save_path is not None:
                image_list.append(adv_images.cpu())
                label_list.append(labels.cpu())

            if self._return_type == 'int':
                adv_images = adv_images.float()/255

            if verbose:
                with torch.no_grad():
                    if given_training:
                        self.model.eval()
                    outputs = self.model(adv_images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    right_idx = (predicted == labels.to(self.device))
                    correct += right_idx.sum()
                    end = time.time()
                    delta = (adv_images - images.to(self.device)).view(batch_size, -1)
                    l2_distance.append(torch.norm(delta[~right_idx], p=2, dim=1))

                    rob_acc = 100 * float(correct) / total
                    l2 = torch.cat(l2_distance).mean().item()
                    progress = (step+1)/total_batch*100
                    elapsed_time = end-start
                    self._save_print(progress, rob_acc, l2, elapsed_time, end='\r')

        # To avoid erasing the printed information.
        if verbose:
            self._save_print(progress, rob_acc, l2, elapsed_time, end='\n')

        if save_path is not None:
            x = torch.cat(image_list, 0)
            y = torch.cat(label_list, 0)
            torch.save((x, y), save_path)
            print('- Save complete!')

        if given_training:
            self.model.train()

        if return_verbose:
            return rob_acc, l2, elapsed_time

    def _save_print(self, progress, rob_acc, l2, elapsed_time, end):
        print('- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t' \
              % (progress, rob_acc, l2, elapsed_time), end=end)

    def _get_target_label(self, images, labels=None):
        r"""
        Function for changing the attack mode.
        Return input labels.
        """
        if self._target_map_function:
            return self._target_map_function(images, labels)
        raise ValueError('Please define target_map_function.')

    def _get_least_likely_label(self, images, labels=None):
        r"""
        Function for changing the attack mode.
        Return least likely labels.
        """
        outputs = self.model(images)
        if self._kth_min < 0:
            pos = outputs.shape[1] + self._kth_min + 1
        else:
            pos = self._kth_min
        _, target_labels = torch.kthvalue(outputs.data, pos)
        target_labels = target_labels.detach()
        return target_labels.long().to(self.device)

    def _get_random_target_label(self, images, labels=None):
        if self._n_classses is None:
            outputs = self.model(images)
            if labels is None:
                _, labels = torch.max(outputs, dim=1)
            n_classses = outputs.shape[-1]
        else:
            n_classses = self._n_classses

        target_labels = torch.zeros_like(labels)
        for counter in range(labels.shape[0]):
            l = list(range(n_classses))
            l.remove(labels[counter])
            t = self.random_int(0, len(l))
            target_labels[counter] = l[t]

        return target_labels.long().to(self.device)
    
    def random_int(self, low=0, high=1, shape=[1]):
        t = low + (high - low) * torch.rand(shape).to(self.device)
        return t.long()

    def _to_uint(self, images):
        r"""
        Function for changing the return type.
        Return images as int.
        """
        return (images*255).type(torch.uint8)

    def __str__(self):
        info = self.__dict__.copy()

        del_keys = ['model', 'attack']

        for key in info.keys():
            if key[0] == "_":
                del_keys.append(key)

        for key in del_keys:
            del info[key]

        info['attack_mode'] = self._attack_mode
        info['return_type'] = self._return_type

        return self.attack + "(" + ', '.join('{}={}'.format(key, val) for key, val in info.items()) + ")"

    def __call__(self, *input, **kwargs):
        given_training = self.model.training

        if self._model_training:
            self.model.train()
            for _, m in self.model.named_modules():
                if not self._batchnorm_training:
                    if 'BatchNorm' in m.__class__.__name__:
                        m = m.eval()
                if not self._dropout_training:
                    if 'Dropout' in m.__class__.__name__:
                        m = m.eval()

        else:
            self.model.eval()

        images = self.forward(*input, **kwargs)

        if given_training:
            self.model.train()

        if self._return_type == 'int':
            images = self._to_uint(images)

        return images

class CombinedPGD(CombinedAttack):
    r"""
    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
    [https://arxiv.org/abs/1706.06083]

    Distance Measure : Linf

    Arguments:
        model (nn.Module): model to attack.
        eps (float): maximum perturbation. (Default: 0.3)
        alpha (float): step size. (Default: 2/255)
        steps (int): number of steps. (Default: 40)
        random_start (bool): using random initialization of delta. (Default: True)

    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.

    Examples::
        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=40, random_start=True)
        >>> adv_images = attack(images, labels)

    """
    def __init__(
        self,
        base_model,
        detector_model,
        eps=0.3,
        alpha=2/255,
        steps=40,
        random_start=True,
    ):
        super().__init__("PGD", base_model, detector_model)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self._supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self._targeted:
            target_labels = self._get_target_label(images, labels)

        loss1 = nn.CrossEntropyLoss()
        loss2 = nn.BCEWithLogitsLoss()

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs1 = self.model(adv_images)
            outputs2 = self.detector(adv_images)

            resnet_loss = loss1(outputs1, labels)
            subnet_loss = loss2(outputs2, torch.ones((labels.shape[0], 1)))

            # Calculate loss
            if self._targeted:
                cost = -resnet_loss + -subnet_loss
            else:
                cost = resnet_loss + subnet_loss

            print(f"ResNet Loss: {resnet_loss} | SubNet Loss: {subnet_loss} | Cost: {cost}")

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() + self.alpha*grad.sign()
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        return adv_images

class CombinedPGDL2(CombinedAttack):
    r"""
    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
    [https://arxiv.org/abs/1706.06083]
    Distance Measure : L2
    Arguments:
        model (nn.Module): model to attack.
        eps (float): maximum perturbation. (Default: 1.0)
        alpha (float): step size. (Default: 0.2)
        steps (int): number of steps. (Default: 40)
        random_start (bool): using random initialization of delta. (Default: True)
    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.
    Examples::
        >>> attack = torchattacks.PGDL2(model, eps=1.0, alpha=0.2, steps=40, random_start=True)
        >>> adv_images = attack(images, labels)
    """
    def __init__(
        self,
        base_model,
        detector_model,
        eps=0.3,
        alpha=2/255,
        steps=40,
        random_start=True,
    ):
        super().__init__("PGDL2", base_model, detector_model)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.eps_for_division = eps_for_division
        self._supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self._targeted:
            target_labels = self._get_target_label(images, labels)

        loss = nn.CrossEntropyLoss()

        adv_images = images.clone().detach()
        batch_size = len(images)

        if self.random_start:
            # Starting at a uniformly random point
            delta = torch.empty_like(adv_images).normal_()
            d_flat = delta.view(adv_images.size(0),-1)
            n = d_flat.norm(p=2,dim=1).view(adv_images.size(0),1,1,1)
            r = torch.zeros_like(n).uniform_(0, 1)
            delta *= r/n*self.eps

        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs1 = self.model(adv_images)
            outputs2 = self.detector(adv_images)

            resnet_loss = loss1(outputs1, labels)
            subnet_loss = loss2(outputs2, torch.ones((labels.shape[0], 1)))

            # Calculate loss
            if self._targeted:
                cost = -resnet_loss + -subnet_loss
            else:
                cost = resnet_loss + subnet_loss

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]
            grad_norms = torch.norm(grad.view(batch_size, -1), p=2, dim=1) + self.eps_for_division
            grad = grad / grad_norms.view(batch_size, 1, 1, 1)
            adv_images = adv_images.detach() + self.alpha * grad

            delta = adv_images - images
            delta_norms = torch.norm(delta.view(batch_size, -1), p=2, dim=1)
            factor = self.eps / delta_norms
            factor = torch.min(factor, torch.ones_like(delta_norms))
            delta = delta * factor.view(-1, 1, 1, 1)

            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        return adv_images

class CombinedCW(CombinedAttack):
    r"""
    CW in the paper 'Towards Evaluating the Robustness of Neural Networks'
    [https://arxiv.org/abs/1608.04644]
    Distance Measure : L2
    Arguments:
        model (nn.Module): model to attack.
        c (float): c in the paper. parameter for box-constraint. (Default: 1e-4)    
            :math:`minimize \Vert\frac{1}{2}(tanh(w)+1)-x\Vert^2_2+c\cdot f(\frac{1}{2}(tanh(w)+1))`
        kappa (float): kappa (also written as 'confidence') in the paper. (Default: 0)
            :math:`f(x')=max(max\{Z(x')_i:i\neq t\} -Z(x')_t, - \kappa)`
        steps (int): number of steps. (Default: 1000)
        lr (float): learning rate of the Adam optimizer. (Default: 0.01)
    .. warning:: With default c, you can't easily get adversarial images. Set higher c like 1.
    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.
    Examples::
        >>> attack = torchattacks.CW(model, c=1e-4, kappa=0, steps=1000, lr=0.01)
        >>> adv_images = attack(images, labels)
    .. note:: Binary search for c is NOT IMPLEMENTED methods in the paper due to time consuming.
    """
    def __init__(
        self,
        base_model,
        detector_model,
        c=1e-4,
        kappa=0,
        steps=1000,
        lr=0.01,
    ):
        super().__init__("CW", base_model, detector_model)
        self.c = c
        self.kappa = kappa
        self.steps = steps
        self.lr = lr
        self._supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self._targeted:
            target_labels = self._get_target_label(images, labels)

        # w = torch.zeros_like(images).detach() # Requires 2x times
        w = self.inverse_tanh_space(images).detach()
        w.requires_grad = True

        best_adv_images = images.clone().detach()
        best_L2 = 1e10*torch.ones((len(images))).to(self.device)
        prev_cost = 1e10
        dim = len(images.shape)

        MSELoss = nn.MSELoss(reduction='none')
        Flatten = nn.Flatten()

        optimizer = optim.Adam([w], lr=self.lr)

        for step in range(self.steps):
            # Get adversarial images
            adv_images = self.tanh_space(w)

            # Calculate loss
            current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(images)).sum(dim=1)
            L2_loss = current_L2.sum()

            outputs1 = self.model(adv_images.float())
            outputs2 = self.detector(adv_images.float())
            if self._targeted:
                f_loss = self.f(outputs1, target_labels).sum()
            else:
                f_loss = self.f(outputs1, labels).sum() #+ self.f(outputs2, torch.ones((labels.shape[0], 1))).sum()

            cost = L2_loss + self.c * f_loss

            optimizer.zero_grad()
            cost.backward()
            optimizer.step()

            # Update adversarial images
            _, pre = torch.max(outputs.detach(), 1)
            correct = (pre == labels).float()

            # filter out images that get either correct predictions or non-decreasing loss, 
            # i.e., only images that are both misclassified and loss-decreasing are left 
            mask = (1-correct)*(best_L2 > current_L2.detach())
            best_L2 = mask*current_L2.detach() + (1-mask)*best_L2

            mask = mask.view([-1]+[1]*(dim-1))
            best_adv_images = mask*adv_images.detach() + (1-mask)*best_adv_images

            # Early stop when loss does not converge.
            if step % (self.steps//10) == 0:
                if cost.item() > prev_cost:
                    return best_adv_images
                prev_cost = cost.item()

        return best_adv_images

    def tanh_space(self, x):
        return 1/2*(torch.tanh(x) + 1)

    def inverse_tanh_space(self, x):
        # torch.atanh is only for torch >= 1.7.0
        return self.atanh(x*2-1)

    def atanh(self, x):
        return 0.5*torch.log((1+x)/(1-x))

    # f-function in the paper
    def f(self, outputs, labels):
        one_hot_labels = torch.eye(len(outputs[0]))[labels].to(self.device)

        i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) # get the second largest logit
        j = torch.masked_select(outputs, one_hot_labels.bool()) # get the largest logit

        if self._targeted:
            return torch.clamp((i-j), min=-self.kappa)
        else:
            return torch.clamp((j-i), min=-self.kappa)

In [None]:
class SingleAttack(object):
    r"""
    Base class for all attacks.
    .. note::
        It automatically set device to the device where given model is.
        It basically changes training mode to eval during attack process.
        To change this, please see `set_training_mode`.
    """
    def __init__(self, name, model):
        r"""
        Initializes internal attack state.
        Arguments:
            name (str): name of attack.
            model (torch.nn.Module): model to attack.
        """

        self.attack = name
        self.model = model
        self.model_name = str(model).split("(")[0]
        self.device = next(model.parameters()).device

        self._attack_mode = 'default'
        self._targeted = False
        self._return_type = 'float'
        self._supported_mode = ['default']

        self._model_training = False
        self._batchnorm_training = False
        self._dropout_training = False

    def forward(self, *input):
        r"""
        It defines the computation performed at every call.
        Should be overridden by all subclasses.
        """
        raise NotImplementedError

    def get_mode(self):
        r"""
        Get attack mode.
        """
        return self._attack_mode

    def set_mode_default(self):
        r"""
        Set attack mode as default mode.
        """
        self._attack_mode = 'default'
        self._targeted = False
        print("Attack mode is changed to 'default.'")

    def set_mode_targeted_by_function(self, target_map_function=None):
        r"""
        Set attack mode as targeted.
        Arguments:
            target_map_function (function): Label mapping function.
                e.g. lambda images, labels:(labels+1)%10.
                None for using input labels as targeted labels. (Default)
        """
        if "targeted" not in self._supported_mode:
            raise ValueError("Targeted mode is not supported.")

        self._attack_mode = 'targeted'
        self._targeted = True
        self._target_map_function = target_map_function
        print("Attack mode is changed to 'targeted.'")

    def set_mode_targeted_least_likely(self, kth_min=1):
        r"""
        Set attack mode as targeted with least likely labels.
        Arguments:
            kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)
        """
        if "targeted" not in self._supported_mode:
            raise ValueError("Targeted mode is not supported.")

        self._attack_mode = "targeted(least-likely)"
        self._targeted = True
        self._kth_min = kth_min
        self._target_map_function = self._get_least_likely_label
        print("Attack mode is changed to 'targeted(least-likely).'")

    def set_mode_targeted_random(self, n_classses=None):
        r"""
        Set attack mode as targeted with random labels.
        Arguments:
            num_classses (str): number of classes.
        """
        if "targeted" not in self._supported_mode:
            raise ValueError("Targeted mode is not supported.")

        self._attack_mode = "targeted(random)"
        self._targeted = True
        self._n_classses = n_classses
        self._target_map_function = self._get_random_target_label
        print("Attack mode is changed to 'targeted(random).'")

    def set_return_type(self, type):
        r"""
        Set the return type of adversarial images: `int` or `float`.
        Arguments:
            type (str): 'float' or 'int'. (Default: 'float')
        .. note::
            If 'int' is used for the return type, the file size of 
            adversarial images can be reduced (about 1/4 for CIFAR10).
            However, if the attack originally outputs float adversarial images
            (e.g. using small step-size than 1/255), it might reduce the attack
            success rate of the attack.
        """
        if type == 'float':
            self._return_type = 'float'
        elif type == 'int':
            self._return_type = 'int'
        else:
            raise ValueError(type + " is not a valid type. [Options: float, int]")

    def set_training_mode(self, model_training=False, batchnorm_training=False, dropout_training=False):
        r"""
        Set training mode during attack process.
        Arguments:
            model_training (bool): True for using training mode for the entire model during attack process.
            batchnorm_training (bool): True for using training mode for batchnorms during attack process.
            dropout_training (bool): True for using training mode for dropouts during attack process.
        .. note::
            For RNN-based models, we cannot calculate gradients with eval mode.
            Thus, it should be changed to the training mode during the attack.
        """
        self._model_training = model_training
        self._batchnorm_training = batchnorm_training
        self._dropout_training = dropout_training

    def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
        r"""
        Save adversarial images as torch.tensor from given torch.utils.data.DataLoader.
        Arguments:
            save_path (str): save_path.
            data_loader (torch.utils.data.DataLoader): data loader.
            verbose (bool): True for displaying detailed information. (Default: True)
            return_verbose (bool): True for returning detailed information. (Default: False)
        """
        if (verbose==False) and (return_verbose==True):
            raise ValueError("Verobse should be True if return_verbose==True.")
            
        if save_path is not None:
            image_list = []
            label_list = []

        correct = 0
        total = 0
        l2_distance = []

        total_batch = len(data_loader)

        given_training = self.model.training

        for step, (images, labels) in enumerate(data_loader):
            start = time.time()
            adv_images = self.__call__(images, labels)

            batch_size = len(images)

            if save_path is not None:
                image_list.append(adv_images.cpu())
                label_list.append(labels.cpu())

            if self._return_type == 'int':
                adv_images = adv_images.float()/255

            if verbose:
                with torch.no_grad():
                    if given_training:
                        self.model.eval()
                    outputs = self.model(adv_images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    right_idx = (predicted == labels.to(self.device))
                    correct += right_idx.sum()
                    end = time.time()
                    delta = (adv_images - images.to(self.device)).view(batch_size, -1)
                    l2_distance.append(torch.norm(delta[~right_idx], p=2, dim=1))

                    rob_acc = 100 * float(correct) / total
                    l2 = torch.cat(l2_distance).mean().item()
                    progress = (step+1)/total_batch*100
                    elapsed_time = end-start
                    self._save_print(progress, rob_acc, l2, elapsed_time, end='\r')

        # To avoid erasing the printed information.
        if verbose:
            self._save_print(progress, rob_acc, l2, elapsed_time, end='\n')

        if save_path is not None:
            x = torch.cat(image_list, 0)
            y = torch.cat(label_list, 0)
            torch.save((x, y), save_path)
            print('- Save complete!')

        if given_training:
            self.model.train()

        if return_verbose:
            return rob_acc, l2, elapsed_time

    def _save_print(self, progress, rob_acc, l2, elapsed_time, end):
        print('- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t' \
              % (progress, rob_acc, l2, elapsed_time), end=end)

    def _get_target_label(self, images, labels=None):
        r"""
        Function for changing the attack mode.
        Return input labels.
        """
        if self._target_map_function:
            return self._target_map_function(images, labels)
        raise ValueError('Please define target_map_function.')

    def _get_least_likely_label(self, images, labels=None):
        r"""
        Function for changing the attack mode.
        Return least likely labels.
        """
        outputs = self.model(images)
        if self._kth_min < 0:
            pos = outputs.shape[1] + self._kth_min + 1
        else:
            pos = self._kth_min
        _, target_labels = torch.kthvalue(outputs.data, pos)
        target_labels = target_labels.detach()
        return target_labels.long().to(self.device)

    def _get_random_target_label(self, images, labels=None):
        if self._n_classses is None:
            outputs = self.model(images)
            if labels is None:
                _, labels = torch.max(outputs, dim=1)
            n_classses = outputs.shape[-1]
        else:
            n_classses = self._n_classses

        target_labels = torch.zeros_like(labels)
        for counter in range(labels.shape[0]):
            l = list(range(n_classses))
            l.remove(labels[counter])
            t = self.random_int(0, len(l))
            target_labels[counter] = l[t]

        return target_labels.long().to(self.device)
    
    def random_int(self, low=0, high=1, shape=[1]):
        t = low + (high - low) * torch.rand(shape).to(self.device)
        return t.long()

    def _to_uint(self, images):
        r"""
        Function for changing the return type.
        Return images as int.
        """
        return (images*255).type(torch.uint8)

    def __str__(self):
        info = self.__dict__.copy()

        del_keys = ['model', 'attack']

        for key in info.keys():
            if key[0] == "_":
                del_keys.append(key)

        for key in del_keys:
            del info[key]

        info['attack_mode'] = self._attack_mode
        info['return_type'] = self._return_type

        return self.attack + "(" + ', '.join('{}={}'.format(key, val) for key, val in info.items()) + ")"

    def __call__(self, *input, **kwargs):
        given_training = self.model.training

        if self._model_training:
            self.model.train()
            for _, m in self.model.named_modules():
                if not self._batchnorm_training:
                    if 'BatchNorm' in m.__class__.__name__:
                        m = m.eval()
                if not self._dropout_training:
                    if 'Dropout' in m.__class__.__name__:
                        m = m.eval()

        else:
            self.model.eval()

        images = self.forward(*input, **kwargs)

        if given_training:
            self.model.train()

        if self._return_type == 'int':
            images = self._to_uint(images)

        return images

class SinglePGD(SingleAttack):
    r"""
    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
    [https://arxiv.org/abs/1706.06083]

    Distance Measure : Linf

    Arguments:
        model (nn.Module): model to attack.
        eps (float): maximum perturbation. (Default: 0.3)
        alpha (float): step size. (Default: 2/255)
        steps (int): number of steps. (Default: 40)
        random_start (bool): using random initialization of delta. (Default: True)

    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.

    Examples::
        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=40, random_start=True)
        >>> adv_images = attack(images, labels)

    """
    def __init__(
        self,
        model,
        eps=0.3,
        alpha=2/255,
        steps=40,
        random_start=True,
    ):
        super().__init__("PGD", model)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self._supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self._targeted:
            target_labels = self._get_target_label(images, labels)

        loss1 = nn.CrossEntropyLoss()
        loss2 = nn.BCEWithLogitsLoss()

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs = self.model(adv_images)

            # Calculate loss
            if self._targeted:
                cost = -loss(outputs, target_labels)
            else:
                cost = loss(outputs, labels)

            print(f"Cost: {cost}")

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() + self.alpha*grad.sign()
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        return adv_images

### Combined Attack Test

In [None]:
benign = load_cifar10()
(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_cifar10()

In [None]:
combined_pgd_attack = CombinedPGD(resnet_model.to('cpu'), subnet_model.to('cpu'), eps=4/255, alpha=4/(255*40), steps=40)
# combined_attack = SinglePGD(subnet_model.to('cpu'), eps=4/255, alpha=4/(255*40), steps=100)

In [None]:
pgd_adv_inputs = combined_pgd_attack(
    torch.from_numpy(x_test).permute(0, 3, 1, 2),
    torch.from_numpy(y_test)
)

In [None]:
def detector_test_with_attack(
    detector,
    attacked_test_imgs,
    device,
    batch_size=64,
    threshold=0.5,
):
    detector.eval()
    test_loss = []
    accuracies = []
    roc_scores = []

    for batch_itr in tqdm(range(0, len(attacked_test_imgs), batch_size)):
        inputs = attacked_test_imgs[batch_itr:batch_itr+batch_size].float().to(device)
        labels = torch.ones((inputs.shape[0], 1), dtype=torch.float32)

        # shuffle the combined inputs and labels
        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        # feed to subnet
        input, labels = input.to(device), labels.to(device)
        with torch.no_grad():
            output = detector(input)

        # get labels based on threshold and calculate loss
        pred_labels = (output > threshold).float()
        loss = detector_criterion(output, labels)

        # calculate metrics
        accuracy = accuracy_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        accuracies.append(accuracy)
        roc_score = roc_auc_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        roc_scores.append(roc_score)
        test_loss.extend([loss.item()]*input.size()[0])
        
        # cleanup
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

    detector.train()

    return np.mean(test_loss), np.mean(accuracies), np.mean(roc_scores)

In [None]:
detector_test_with_attack(
    detector,
    cw_adv_inputs,
    device,
    batch_size=64,
    threshold=0.5,
)

In [None]:
def clean_dataset(
    detector,
    test_data,
    device,
    batch_size=64,
):
    detector.eval()
    clean_idxs = []

    print("shape of data to be cleaned: "test_data.shape)
    for batch_itr in range(0, len(test_data), batch_size):
        input = test_data[batch_itr:batch_itr+batch_size]
        input = input.to(device)

        with torch.no_grad():
            output = detector(test_data)

        pred_labels = (output > 0.5).float()
        pred_labels = pred_labels.flatten().cpu()
        clean_idxs.extend([
            batch_itr+idx for idx in range(len(pred_labels)) \
              if pred_labels[idx] == 0
        ])
        
        del input
        torch.cuda.empty_cache()

    return clean_idxs