In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
! pip install torchattacks

In [73]:
import gc
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import torchattacks

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset
from torch.hub import load_state_dict_from_url
from torchvision import datasets, models, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from torch.autograd import Variable

from PIL import Image
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm.notebook import tqdm

In [74]:
PGD_ATTACK_EPS = 2/255

### Base Classifier - ResNet

In [96]:
def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
      super(LambdaLayer, self).__init__()
      self.lambd = lambd

    def forward(self, x):
      return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.in_channels = 3 # 3 for CIFAR OR 1 for MNIST
        self.conv1 = nn.Conv2d(self.in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(self._weights_init)

    def _weights_init(self, m):
        classname = m.__class__.__name__
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
          nn.init.kaiming_normal_(m.weight)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, return_interm_layer=None):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        if return_interm_layer == 1:
            return out
        out = self.layer2(out)
        if return_interm_layer == 2:
            return out
        out = self.layer3(out)
        if return_interm_layer == 3:
            return out
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        if return_interm_layer == -1:
            return out
        out = self.linear(out)
        
        return out

def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

def resnet32():
    return ResNet(BasicBlock, [5, 5, 5])

def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])

def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])

def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])

def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])

### SubNetwork

In [97]:
class SubNet(nn.Module):
    def __init__(self, in_channels):
        super(SubNet, self).__init__()
        conv1 = nn.Conv2d(in_channels, 96, kernel_size=3, stride=1, padding=0, bias=False)
        bn1 = nn.BatchNorm2d(96)
        conv2 = nn.Conv2d(96, 192, kernel_size=3, stride=1, padding=0, bias=False)
        bn2 = nn.BatchNorm2d(192)
        conv3 = nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=0, bias=False)
        bn3 = nn.BatchNorm2d(192)
        conv4 = nn.Conv2d(192, 2, kernel_size=1, stride=1, padding=0, bias=False)
        bn4 = nn.BatchNorm2d(2)
        relu = nn.ReLU(inplace=True)
        avgpool = nn.AdaptiveAvgPool2d((1, 1))
        flatten = nn.Flatten()
        linear = nn.Linear(2, 1)
        sigmoid = nn.Sigmoid()

        self.layers = nn.Sequential(
            conv1,
            bn1,
            relu,
            conv2,
            bn2,
            relu,
            conv3,
            bn3,
            conv4,
            bn4,
            relu,
            avgpool,
            flatten,
            linear,
            sigmoid,
        )

        self.layers.apply(self.init_param)

    def forward(self, x):
        for itr, layer in enumerate(self.layers):
          x = layer(x)

        return x

    def init_param(self, param):
        if type(param) in [nn.Linear, nn.Conv2d]:
            nn.init.kaiming_uniform_(param.weight)

### Train and Test Loops

In [98]:
def dynamically_train_subnet(
    resnet_model,
    interm_layer,
    subnet_model,
    subnet_optimizer,
    subnet_criterion,
    subnet_scheduler,
    benign_train_imgs,
    benign_test_imgs,
    device,
    epochs=100,
    batch_size=64,
):
    subnet_model.train()
    resnet_model.eval()

    best_loss = float('inf')
    for epoch in range(epochs):
      
      avg_loss = 0.0
      for batch_itr in tqdm(range(0, len(benign_train_imgs), batch_size)):

        # get resnet outputs of benign imgs and labels as '0'
        benign_inputs = benign_train_imgs[batch_itr:batch_itr+batch_size].float().to(device)
        benign_inputs = resnet_model(benign_inputs, return_interm_layer=interm_layer)
        benign_train_labels = torch.zeros((benign_inputs.shape[0], 1), dtype=torch.float32)

        # create a attack instance using current state of the subnet
        train_attack = torchattacks.PGD(subnet_model, eps=4/255, alpha=1/255, steps=40)
        adv_inputs = train_attack(benign_inputs, benign_train_labels)
        adv_train_labels = torch.ones((adv_inputs.shape[0], 1), dtype=torch.float32)

        # create a 2x batch with adv and benign images
        input = torch.cat((benign_inputs, adv_inputs), axis=0)
        labels = torch.cat((benign_train_labels, adv_train_labels), axis=0)

        # shuffle the combined inputs and labels
        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        # feed to subnet
        input, labels = input.to(device), labels.to(device)
        output = subnet_model(input)

        # calculate loss
        loss = subnet_criterion(output, labels)
        loss.backward()
        avg_loss += loss.item()

        # param update
        subnet_optimizer.step()

        # cleanup
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

      avg_loss /= len(benign_train_imgs) // 64

      val_loss, val_acc, val_roc = dynamically_test_subnet(
          resnet_model,
          interm_layer,
          subnet_model,
          subnet_criterion,
          benign_test_imgs,
          device,
      )

      print('Train Loss: {:.4f} | Val Loss: {:.4f} | Val Accuracy: {:.4f} | Val ROC: {:.4f}'.format(avg_loss, val_loss, val_acc, val_roc))
      break
      if val_loss > best_loss:
        best_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': subnet_model.state_dict(),
            'optimizer_state_dict': subnet_optimizer.state_dict(),
        }, './' + str(epoch) + 'model.pt')

      subnet_scheduler.step(val_loss)

In [99]:
def dynamically_test_subnet(
    resnet_model,
    interm_layer,
    subnet_model,
    subnet_criterion,
    benign_test_imgs,
    device,
    batch_size=64,
    threshold=0.5,
):
    resnet_model.eval()
    subnet_model.eval()
    test_loss = []
    accuracies = []
    roc_scores = []

    for batch_itr in tqdm(range(0, len(benign_test_imgs), batch_size)):

        # get resnet outputs of benign imgs and labels as '0'
        benign_inputs = benign_test_imgs[batch_itr:batch_itr+batch_size].float().to(device)
        benign_inputs = resnet_model(benign_inputs, return_interm_layer=interm_layer)
        benign_test_labels = torch.zeros((benign_inputs.shape[0], 1), dtype=torch.float32)
        
        # create a attack instance using current state of the subnet
        val_attack = torchattacks.PGD(subnet_model, eps=2/255, alpha=1/255, steps=40)
        adv_inputs = val_attack(benign_inputs, benign_test_labels)
        adv_test_labels = torch.ones((adv_inputs.shape[0], 1), dtype=torch.float32)

        # create a 2x batch with adv and benign images
        input = torch.cat((benign_inputs, adv_inputs), axis=0)
        labels = torch.cat((benign_test_labels, adv_test_labels), axis=0)

        # shuffle the combined inputs and labels
        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

        # feed to subnet
        input, labels = input.to(device), labels.to(device)
        with torch.no_grad():
            output = subnet_model(input)

        # get labels based on threshold and calculate loss
        pred_labels = (output > threshold).float()
        loss = subnet_criterion(output, labels)

        print("output", output.flatten().cpu())
        print("labels", labels.flatten().cpu())
        print("1/0 outputs", pred_labels.flatten().cpu())

        # calculate metrics
        accuracy = accuracy_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        accuracies.append(accuracy)
        roc_score = roc_auc_score(labels.flatten().cpu(), pred_labels.flatten().cpu())
        roc_scores.append(roc_score)
        test_loss.extend([loss.item()]*input.size()[0])
        
        # cleanup
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

    subnet_model.train()

    return np.mean(test_loss), np.mean(accuracies), np.mean(roc_scores)

### Read Benign Data

In [100]:
!ls '/content/drive/MyDrive/11785 - Project/data'

benign_cifar.npy	      fgsm_mnist_eps0.5.npy
benign_cifar_train.npy	      pgd_cifar_default_art.npy
benign_mnist.npy	      pgd_cifar_default_torchattacks_new.npy
benign_mnist_train.npy	      pgd_cifar_eps0.1_torchattacks.npy
cwlinf_cifar_default_art.npy  pgd_cifar_eps0.3_alpha0.1_steps7.npy
cwlinf_mnist_default_art.npy  pgd_mnist_default_art.npy
fgsm_cifar_default_art.npy    pgd_mnist_eps0.3_alpha0.1_steps7.npy


In [101]:
batch_size = 64

# benign data
benign_imgs_path = "/content/drive/MyDrive/11785 - Project/data/benign_cifar_train.npy"
benign_imgs = np.load(benign_imgs_path, allow_pickle=True).astype(float)
benign_imgs = torch.from_numpy(benign_imgs) # .transpose(0, 3, 1, 2))
shuffle_indices = np.arange(benign_imgs.shape[0])
np.random.shuffle(shuffle_indices)
benign_imgs = benign_imgs[shuffle_indices]
print(f"benign data shape: {benign_imgs.shape}")

# train-test split
benign_train_imgs = benign_imgs[:45000]
benign_test_imgs = benign_imgs[45000:]

benign data shape: torch.Size([50000, 3, 32, 32])


### Load Base Model - ResNet

In [None]:
! ls "/content/drive/MyDrive/11785 - Project/"

In [None]:
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

resnet_model = resnet32()
resnet_model.to(device)
checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/cifar10_model.pth")
resnet_model.load_state_dict(checkpoint)

### Create SubNet

In [104]:
# output size at diff intermediate layers of resnet
interm_layer2dim = {1: 16, 2: 32, 3: 64}
interm_layer = 3

subnet_model = SubNet(interm_layer2dim[interm_layer])
subnet_model.to(device)
subnet_optimizer = torch.optim.Adam(subnet_model.parameters(), lr=0.0001, betas=(0.99, 0.999))
subnet_criterion = nn.BCELoss()
subnet_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(subnet_optimizer, 'min', factor=0.75, patience=1)

### Run Training

In [None]:
dynamically_train_subnet(
  resnet_model,
  interm_layer,
  subnet_model,
  subnet_optimizer,
  subnet_criterion,
  subnet_scheduler,
  benign_train_imgs,
  benign_test_imgs,
  device,
  epochs=50,
  batch_size=64,
)