In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
import gc
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.hub import load_state_dict_from_url
from torchvision import datasets, models, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from tqdm import tqdm

from sklearn.metrics import accuracy_score, f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

In [2]:
N_ATTACKS = 3
N_SUBNET_LABELS = N_ATTACKS # + 1 # add label for non-adversarial

In [33]:
def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
      super(LambdaLayer, self).__init__()
      self.lambd = lambd

    def forward(self, x):
      return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.in_channels = 1 # 3 for CIFAR OR 1 for MNIST
        self.conv1 = nn.Conv2d(self.in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(self._weights_init)

    def _weights_init(self, m):
        classname = m.__class__.__name__
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
          nn.init.kaiming_normal_(m.weight)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, return_interm_layer=None):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        if return_interm_layer == 1:
            return out
        out = self.layer2(out)
        if return_interm_layer == 2:
            return out
        out = self.layer3(out)
        if return_interm_layer == 3:
            return out
        out = F.avg_pool2d(out, out.size()[3])
        print('embed = ', out.shape)
        out = out.view(out.size(0), -1)
        if return_interm_layer == -1:
            # print('embed = ', out.shape)
            return out
        out = self.linear(out)
        
        return out

def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

def resnet32():
    return ResNet(BasicBlock, [5, 5, 5])

def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])

def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])

def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])

def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])

In [34]:
class MultiClassSubNet(nn.Module):
    def __init__(self, in_channels):
        super(MultiClassSubNet, self).__init__()
        conv1 = nn.Conv2d(in_channels, 64, kernel_size=2, stride=2, padding=0, bias=False)
        bn1 = nn.BatchNorm2d(64)
        conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=False)
        bn2 = nn.BatchNorm2d(128)
        conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=0, bias=False)
        bn3 = nn.BatchNorm2d(256)
        relu = nn.ReLU(inplace=True)
        flatten = nn.Flatten()
        linear = nn.Linear(256, N_SUBNET_LABELS)

        self.layers = nn.Sequential(
            conv1,
            bn1,
            relu,
            conv2,
            bn2,
            relu,
            conv3,
            bn3,
            relu,
            flatten,
            linear,
        )

        self.layers.apply(self.init_param)

    def forward(self, x, return_embedding=False):
        embedding = None
        for itr, layer in enumerate(self.layers):
          x = layer(x)
          if return_embedding and itr == len(self.layers)-1:
            embedding = x

        return x

    def init_param(self, param):
        if type(param) in [nn.Linear, nn.Conv2d]:
            nn.init.kaiming_uniform_(param.weight)

In [35]:
def train_subnet(
    resnet_model,
    interm_layer,
    subnet_model,
    subnet_optimizer,
    subnet_criterion,
    subnet_scheduler,
    # unattacked_train_data,
    # unattacked_test_data,
    attacked_train_data_list,
    attacked_test_data_list,
    device,
    epochs=100,
    batch_size=64,
):
    subnet_model.train()
    resnet_model.eval()
    batches = []
    embeddings = []

    for epoch in range(epochs):

      avg_loss = 0.0
      for batch_itr in tqdm(range(0, len(attacked_train_data_list[0]), batch_size)):
          
          all_attacked_input, all_attacked_labels = [], []
          for label_id, attacked_train_data in enumerate(attacked_train_data_list):
            attacked_input = attacked_train_data[batch_itr:batch_itr+batch_size]
            all_attacked_input.append(attacked_input)
            attacked_labels = torch.full((attacked_input.shape[0], 1), label_id, dtype=torch.float32)
            all_attacked_labels.append(attacked_labels)

          input = torch.cat(all_attacked_input, axis=0)
          labels = torch.cat(all_attacked_labels, axis=0)

          assert input.shape[0] == labels.shape[0]
          shuffle_indices = np.arange(input.shape[0])
          np.random.shuffle(shuffle_indices)
          input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)

          # unattacked_input = unattacked_train_data[batch_itr:batch_itr+batch_size]
          # unattacked_labels = torch.zeros((unattacked_input.shape[0], 1), dtype=torch.float32)

          # input = torch.cat((unattacked_input, all_attacked_input), axis=0)
          # labels = torch.cat((unattacked_labels, all_attacked_labels), axis=0)

          input, labels = input.to(device), labels.to(device)

          with torch.no_grad():
            embeddings.append(resnet_model(input.float(), return_interm_layer=-1))
          input = resnet_model(input.float(), return_interm_layer=interm_layer)
          output = subnet_model(input)

          loss = subnet_criterion(output.float(), labels.squeeze(1).long())
          # print("loss", loss)
          loss.backward()
          avg_loss += loss.item()

          subnet_optimizer.step()

          del input
          del labels
          del loss
          torch.cuda.empty_cache()

      val_loss, val_acc, val_f1, val_embeddings = test_subnet(
          resnet_model,
          interm_layer,
          subnet_model,
          subnet_criterion,
          # unattacked_test_data,
          attacked_test_data_list,
          device,
      )

      print('Val Loss: {:.4f} | Val Acc {:.4f} | Val F1: {:.4f}'.format(val_loss, val_acc, val_f1))
      torch.save({
          'epoch': epoch,
          'model_state_dict': resnet_model.state_dict(),
          'optimizer_state_dict': subnet_optimizer.state_dict(),
      }, './' + str(epoch) + 'model.pt')

      embeddings.extend(val_embeddings)

      # subnet_scheduler.step(val_loss)

    return torch.cat(embeddings, axis=0).detach().cpu().numpy()


def test_subnet(
    resnet_model,
    interm_layer,
    subnet_model,
    criterion,
    # unattacked_test_data,
    attacked_test_data_list,
    device,
    batch_size=64,
):
    resnet_model.eval()
    subnet_model.eval()
    test_loss = []
    accuracies, f1_scores = [], []
    embeddings = []

    for batch_itr in tqdm(range(0, len(attacked_test_data_list[0]), batch_size)):
        all_attacked_input, all_attacked_labels = [], []
        for label_id, attacked_test_data in enumerate(attacked_test_data_list):
          attacked_input = attacked_test_data[batch_itr:batch_itr+batch_size]
          all_attacked_input.append(attacked_input)

          attacked_labels = torch.full((attacked_input.shape[0], 1), label_id, dtype=torch.float32)
          all_attacked_labels.append(attacked_labels)

        input = torch.cat(all_attacked_input, axis=0)
        labels = torch.cat(all_attacked_labels, axis=0)

        assert input.shape[0] == labels.shape[0]
        shuffle_indices = np.arange(input.shape[0])
        np.random.shuffle(shuffle_indices)
        input, labels = input[shuffle_indices].squeeze(0), labels[shuffle_indices].squeeze(0)
        
        # unattacked_input = unattacked_test_data[batch_itr:batch_itr+batch_size]
        # unattacked_labels = torch.zeros((unattacked_input.shape[0], 1), dtype=torch.float32)

        # input = torch.cat((unattacked_input, all_attacked_input), axis=0)
        # labels = torch.cat((unattacked_labels, all_attacked_labels), axis=0)

        input, labels = input.to(device), labels.to(device)

        with torch.no_grad():
            embeddings.append(resnet_model(input.float(), return_interm_layer=-1))
            input = resnet_model(input.float(), return_interm_layer=interm_layer)
            output = subnet_model(input)
            # output = torch.argmax(output, axis=1).unsqueeze(1)

        pred_labels = torch.argmax(output, axis=1).float()
        loss = criterion(output.float(), labels.squeeze(1).long())

        # print(pred_labels.flatten().cpu().shape)
        # print(labels.flatten().cpu().shape)
        accuracy = accuracy_score(pred_labels.flatten().cpu(), labels.flatten().cpu())
        accuracies.append(accuracy)
        f1 = f1_score(pred_labels.flatten().cpu(), labels.flatten().cpu(), average=None)
        f1_scores.append(f1)
        test_loss.extend([loss.item()]*input.size()[0])
        
        del input
        del labels
        del loss
        torch.cuda.empty_cache()

    subnet_model.train()

    return np.mean(test_loss), np.mean(accuracies), np.mean(f1_scores), embeddings

In [36]:
def Normalize0to1(AA):
    AA[:, :, :, 0] -= AA[:, :, :, 0].min().item()
    AA[:, :, :, 0] /= AA[:, :, :, 0].max().item()

    AA[:, :, :, 1] -= AA[:, :, :, 1].min().item()
    AA[:, :, :, 1] /= AA[:, :, :, 1].max().item()

    AA[:, :, :, 2] -= AA[:, :, :, 2].min().item()
    AA[:, :, :, 2] /= AA[:, :, :, 2].max().item()

    return AA / 225.

In [37]:
!ls '/content/drive/MyDrive/11785 - Project/data'

benign_cifar.npy	      pgd_cifar_default_art.npy
benign_mnist.npy	      pgd_cifar_default_torchattacks_new.npy
cwlinf_cifar_default_art.npy  pgd_cifar_eps0.1_torchattacks.npy
cwlinf_mnist_default_art.npy  pgd_cifar_eps0.3_alpha0.1_steps7.npy
fgsm_cifar_default_art.npy    pgd_mnist_default_art.npy
fgsm_mnist_eps0.5.npy	      pgd_mnist_eps0.3_alpha0.1_steps7.npy


In [38]:
batch_size = 64

# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# unattacked_data_path = "/content/drive/MyDrive/11785 - Project/data/benign_cifar.npy"
unattacked_data_path = "/content/drive/MyDrive/11785 - Project/data/benign_mnist.npy"
unattacked_data = np.load(unattacked_data_path, allow_pickle=True).astype(float)
unattacked_data = torch.from_numpy(unattacked_data)
# unattacked_data = Normalize0to1(unattacked_data)
# unattacked_data = normalize(unattacked_data)
unattacked_train_data = unattacked_data[:9000]
unattacked_test_data = unattacked_data[9000:]
print(unattacked_data.shape)

# attacked_data_path1 = "/content/drive/MyDrive/11785 - Project/data/cwlinf_cifar_default_art.npy"
attacked_data_path1 = "/content/drive/MyDrive/11785 - Project/data/cwlinf_mnist_default_art.npy"
attacked_data1 = np.load(attacked_data_path1, allow_pickle=True).astype(float)
attacked_data1 = torch.from_numpy(attacked_data1)
# attacked_data1 = Normalize0to1(attacked_data1)
# attacked_data1 = normalize(attacked_data1)
attacked_train_data1 = attacked_data1[:9000]
attacked_test_data1 = attacked_data1[9000:]
print(attacked_data1.shape)

# attacked_data_path2 = "/content/drive/MyDrive/11785 - Project/data/fgsm_cifar_default_art.npy"
attacked_data_path2 = "/content/drive/MyDrive/11785 - Project/data/fgsm_mnist_eps0.5.npy"
attacked_data2 = np.load(attacked_data_path2, allow_pickle=True).astype(float)
attacked_data2 = torch.from_numpy(attacked_data2.transpose(1, 0, 2, 3))
# attacked_data2 = Normalize0to1(attacked_data2)
# attacked_data2 = normalize(attacked_data2)
attacked_train_data2 = attacked_data2[:9000]
attacked_test_data2 = attacked_data2[9000:]
print(attacked_data2.shape)

# attacked_data_path3 = "/content/drive/MyDrive/11785 - Project/data/pgd_cifar_default_art.npy"
attacked_data_path3 = "/content/drive/MyDrive/11785 - Project/data/pgd_mnist_default_art.npy"
attacked_data3 = np.load(attacked_data_path3, allow_pickle=True).astype(float)
attacked_data3 = torch.from_numpy(attacked_data3)
# attacked_data3 = Normalize0to1(attacked_data3)
# attacked_data3 = normalize(attacked_data3)
attacked_train_data3 = attacked_data3[:9000]
attacked_test_data3 = attacked_data3[9000:]
print(attacked_data3.shape)

attacked_train_data_list = [attacked_train_data1, attacked_train_data2, attacked_train_data3]
attacked_test_data_list = [attacked_test_data1, attacked_test_data2, attacked_test_data3]

torch.Size([10000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])


In [39]:
! ls "/content/drive/MyDrive/11785 - Project/"

AdversarialDetection.pdf  data		      mnist_model.pth
cifar10_model.pth	  Experiments.gsheet  Presentation.gslides


In [40]:
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

resnet_model = resnet32()
resnet_model.to(device)
# checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/cifar10_model.pth")
checkpoint = torch.load("/content/drive/MyDrive/11785 - Project/mnist_model.pth")
# mod_checkpoint = {k.replace("module.", ""): v for k, v in checkpoint['state_dict'].items()}
resnet_model.load_state_dict(checkpoint)
resnet_optimizer = torch.optim.SGD(resnet_model.parameters(), lr=0.1, weight_decay=5e-5, momentum=0.9)
resnet_criterion = nn.CrossEntropyLoss()
resnet_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(resnet_optimizer, T_0=10, T_mult=2, eta_min=0.01, last_epoch=-1)

# output size at diff intermediate layers of resnet
interm_layer2dim = {1: 16, 2: 32, 3: 64}
interm_layer = 2

subnet_model = MultiClassSubNet(interm_layer2dim[interm_layer])
subnet_model.to(device)
subnet_optimizer = torch.optim.Adam(subnet_model.parameters(), lr=0.0001, betas=(0.99, 0.999))
subnet_criterion = nn.CrossEntropyLoss()
subnet_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(subnet_optimizer, T_0=10, T_mult=2, eta_min=0.01, last_epoch=-1)

In [41]:
embeddings = train_subnet(
  resnet_model,
  interm_layer,
  subnet_model,
  subnet_optimizer,
  subnet_criterion,
  subnet_scheduler,
  # unattacked_train_data,
  # unattacked_test_data,
  attacked_train_data_list,
  attacked_test_data_list,
  device,
  epochs=3,
)

 91%|█████████ | 128/141 [00:09<00:00, 14.11it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


 94%|█████████▎| 132/141 [00:09<00:00, 14.19it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


 95%|█████████▌| 134/141 [00:09<00:00, 14.07it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


 98%|█████████▊| 138/141 [00:09<00:00, 14.12it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


100%|██████████| 141/141 [00:09<00:00, 14.13it/s]


embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([120, 64, 1, 1])


 31%|███▏      | 5/16 [00:00<00:00, 46.17it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


100%|██████████| 16/16 [00:00<00:00, 46.61it/s]


embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([120, 64, 1, 1])
Val Loss: 0.1015 | Val Acc 0.9585 | Val F1: 0.9582


  0%|          | 0/141 [00:00<?, ?it/s]

embed =  torch.Size([192, 64, 1, 1])


  1%|▏         | 2/141 [00:00<00:09, 13.99it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


  3%|▎         | 4/141 [00:00<00:09, 13.92it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


  4%|▍         | 6/141 [00:00<00:09, 13.90it/s]

embed =  torch.Size([192, 64, 1, 1])


  6%|▌         | 8/141 [00:00<00:09, 13.89it/s]

embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])
embed =  torch.Size([192, 64, 1, 1])


  8%|▊         | 11/141 [00:00<00:09, 13.78it/s]

embed =  torch.Size([192, 64, 1, 1])





KeyboardInterrupt: ignored

In [14]:
embeddings.shape

(150000, 64)

In [None]:
from sklearn.manifold import TSNE

X_embed
