In [None]:
import sys 
import os
import copy 
import numpy as np 
import matplotlib.pyplot as plt
from skimage import io
import seaborn as sns

import torch
import torchvision 
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataset import Dataset

sys.path.insert(0, '../Utils')

import models
from train import *
from metrics import *  
from data_downloaders import *

print("Python: %s" % sys.version)
print("Pytorch: %s" % torch.__version__)

# determine device to run network on (runs on gpu if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
n_epochs = 2000
batch_size = 128
lr_classification = 0.0001
lr_inference = 0.001
lr_attack = 0.001

In [None]:
# define series of transforms to pre process images 
transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
num_classes = 10

# load training set 
trainset = torchvision.datasets.CIFAR10('../Datasets/', train=True, transform=transform, download=True)
testset = torchvision.datasets.CIFAR10('../Datasets/', train=False, transform=transform, download=True)


total_size = len(trainset)
split = int(total_size * 0.8)
indices = list(range(total_size))


D_idx = indices[:40000]
D_A_idx = indices[:20000]
D_prime_idx = indices[40000:45000]
D_prime_A_idx = indices[45000:]

eval_train = indices[20000:30000]
#eval_out = #testset

D_sampler = SubsetRandomSampler(D_idx)
D_A_sampler = SubsetRandomSampler(D_A_idx)
D_prime_sampler = SubsetRandomSampler(D_prime_idx)
D_prime_A_sampler = SubsetRandomSampler(D_prime_A_idx)
eval_train_sampler = SubsetRandomSampler(eval_train)

'''
train_idx = indices[:split]
out_idx = indices[split:]

train_sampler = SubsetRandomSampler(train_idx)
out_sampler = SubsetRandomSampler(out_idx)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
outloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=out_sampler, num_workers=2)
'''
D_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=D_sampler, num_workers=1)
D_A_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=D_A_sampler, num_workers=1)
D_prime_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=D_prime_sampler, num_workers=1)
D_prime_A_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=D_prime_A_sampler, num_workers=1)
eval_train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=eval_train_sampler, num_workers=1)
eval_out_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=1)


# load test set 
#testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

# helper function to unnormalize and plot image 
def imshow(img):
    img = np.array(img)
    img = img / 2 + 0.5
    img = np.moveaxis(img, 0, -1)
    plt.imshow(img)
    
# display sample from dataset 
imgs,labels = iter(D_loader).next()
imshow(torchvision.utils.make_grid(imgs))


In [None]:
def label_to_onehot(labels, num_classes=10): 
    one_hot = torch.eye(num_classes)
    return one_hot[labels]


class inference_attack(nn.Module): 
    def __init__(self, n_classes): 
        super(inference_attack, self).__init__()
        
        self.n_classes = n_classes
        
        self.prediction_vector_block = nn.Sequential(
            nn.Linear(n_classes, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512), 
            nn.ReLU(), 
            nn.Linear(512, 64), 
            nn.ReLU()
        )
        
        self.label_block = nn.Sequential(
            nn.Linear(n_classes, 512), 
            nn.ReLU(), 
            nn.Linear(512, 64), 
            nn.ReLU()
        )
        
        self.common_block = nn.Sequential(
            nn.Linear(128, 256), 
            nn.ReLU(), 
            nn.Linear(256, 64), 
            nn.ReLU(), 
            nn.Linear(64, 1)
            
        )
        
    def forward(self, prediction_vector, one_hot_label): 
        prediction_block_out = self.prediction_vector_block(prediction_vector)
        label_block_out = self.label_block(one_hot_label)
        #print(prediction_block_out)
        #print(label_block_out)

        
        out = F.sigmoid(self.common_block(torch.cat((prediction_block_out, label_block_out), dim=1)))
        return out
    

    

In [None]:
# determine device to run network on (runs on gpu if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vgg16 = torchvision.models.vgg16(num_classes=10)
# vgg16 fix for cifar10 image size 
vgg16.classifier = nn.Sequential(
            nn.Linear(512, 64),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(64, 64),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(64, 10),
        )




#net = resnet18.to(device)
net = vgg16.to(device)


net.apply(models.weights_init)

undefended_net = copy.deepcopy(net)

undefended_loss = nn.CrossEntropyLoss()
undefended_optim = optim.Adam(undefended_net.parameters(), lr=lr_classification)

class_loss = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
class_optim = optim.Adam(net.parameters(), lr=lr_classification)


infer_net = inference_attack(num_classes).to(device)
infer_net.apply(models.weights_init)

infer_loss = nn.BCELoss()
infer_optim = optim.Adam(infer_net.parameters(), lr=lr_inference)


attack_net = inference_attack(num_classes).to(device)
attack_net.apply(models.weights_init)

attack_net2 = copy.deepcopy(attack_net)
attack2_loss = nn.BCELoss()
attack2_optim = optim.Adam(attack_net2.parameters(), lr=lr_attack)

attack_loss = nn.BCELoss()
attack_optim = optim.Adam(attack_net.parameters(), lr=lr_attack)

In [None]:
def adversarial_train(inference_net, classification_net, train_set, out_set, test_set, 
                      infer_optim, infer_loss, class_optim, class_loss, n_epochs, k, privacy_theta):
    losses = []

    inference_net.train()
    classification_net.train()

    
    for epoch in range(n_epochs):

        train_top = np.array([])
        out_top = np.array([])
        
        train_p = np.array([])
        out_p = np.array([])
        
        total_inference = 0
        total_correct_inference = 0
        
        inference_losses = np.array([])
        classification_losses = np.array([])
        
        for k_count in range(k): 
            # train inference network 
            train_imgs, train_lbls = iter(train_set).next()
            train_imgs, train_lbls = train_imgs.to(device), train_lbls.to(device)
            out_imgs, out_lbls = iter(out_set).next()
            out_imgs, out_lbls = out_imgs.to(device), out_lbls.to(device)
            
            mini_batch_size = train_imgs.shape[0]
            
            train_lbl = torch.ones(mini_batch_size).to(device)
            out_lbl = torch.zeros(mini_batch_size).to(device)
            
            train_posteriors = F.softmax(classification_net(train_imgs), dim=1)
            out_posteriors = F.softmax(classification_net(out_imgs), dim=1)
            
            '''
            t_p = train_posteriors.cpu().detach().numpy().flatten()
            o_p = out_posteriors.cpu().detach().numpy().flatten()
            
            train_p = np.concatenate((train_p, t_p))
            out_p = np.concatenate((out_p, o_p))
            '''
            
            train_sort, _ = torch.sort(train_posteriors, descending=True)
            out_sort, _ = torch.sort(out_posteriors, descending=True)

            t_p = train_sort[:,:4].cpu().detach().numpy().flatten()
            o_p = out_sort[:,:4].cpu().detach().numpy().flatten()
            
            train_p = np.concatenate((train_p, t_p))
            out_p = np.concatenate((out_p, o_p))
                    
            train_top = np.concatenate((train_top, train_sort[:,0].cpu().detach().numpy()))
            out_top = np.concatenate((out_top, out_sort[:,0].cpu().detach().numpy()))
            
            infer_optim.zero_grad()

            train_inference = torch.squeeze(inference_net(train_posteriors, label_to_onehot(train_lbls).to(device)))
            out_inference = torch.squeeze(inference_net(out_posteriors, label_to_onehot(out_lbls).to(device)))
            
            total_inference += 2*mini_batch_size
            total_correct_inference += torch.sum(train_inference > 0.5).item() + torch.sum(out_inference < 0.5).item()
            
            
            loss_train = infer_loss(train_inference, train_lbl)
            loss_out = infer_loss(out_inference, out_lbl)
            
            loss = privacy_theta * (loss_train + loss_out) / 2 
            loss.backward()
            
            infer_optim.step()
            
        # train classifiction network 
        train_imgs, train_lbls = iter(train_set).next()
        train_imgs, train_lbls = train_imgs.to(device), train_lbls.to(device)
        
        class_optim.zero_grad()

        outputs = classification_net(train_imgs)
        train_posteriors = F.softmax(outputs, dim=1)


        loss_classification = class_loss(outputs, train_lbls)
        train_lbl = torch.ones(mini_batch_size).to(device)
        
        train_inference = torch.squeeze(inference_net(train_posteriors, label_to_onehot(train_lbls).to(device)))
        loss_infer = infer_loss(train_inference, train_lbl)
        loss = loss_classification - privacy_theta * loss_infer
        
        loss.backward()
        class_optim.step()
        
        '''
        correct += (train_predictions>=0.5).sum().item()
        correct += (out_predictions<0.5).sum().item()
        total += train_predictions.size(0) + out_predictions.size(0)
        print("[%d/%d][%d/%d] loss = %.2f, accuracy = %.2f" % (epoch, n_epochs, i, len(shadow_train), loss.item(), 100 * correct / total))
        '''        
        
        if epoch % 20 == 0 and epoch != 0: 

            
            plt.figure()
            sns.distplot(train_p,label='maximum train posterior')
            sns.distplot(out_p,label='maximum out posterior')
            #sns.distplot(train_top,label='maximum train posterior')
            #sns.distplot(out_top,label='maximum out posterior')
            plt.legend()
            plt.show()

            inference_accuracy = 100 * (total_correct_inference / total_inference)
            classification_accuracy = eval_target_net(classification_net, test_set, classes=classes)
            print("[%d/%d] Inference accuracy = %.2f%%, Classification accuracy = %.2f%%" % (epoch, n_epochs, inference_accuracy, classification_accuracy))
                  
        
def train_attacker(attack_net, target_net, attack_train, attack_out, optimizer, criterion, n_epochs):
    losses = []

    target_net.eval()
    attack_net.train()
    for epoch in range(n_epochs):
       
        total = 0
        correct = 0

        #train_top = np.array([])
        #train_top = []
        train_top = np.empty((0,2))
        out_top = np.empty((0,2))
        for i, ((train_imgs, train_lbls), (out_imgs, out_lbls)) in enumerate(zip(attack_train, attack_out)):

            #######out_imgs = torch.randn(out_imgs.shape)
            if train_imgs.shape[0] != out_imgs.shape[0]: 
                continue
            mini_batch_size = train_imgs.shape[0]
            train_imgs, train_lbls = train_imgs.to(device), train_lbls.to(device)
            out_imgs, out_lbls = out_imgs.to(device), out_lbls.to(device)


            train_posteriors = F.softmax(target_net(train_imgs.detach()), dim=1)

            out_posteriors = F.softmax(target_net(out_imgs.detach()), dim=1)

            optimizer.zero_grad()

            train_sort, _ = torch.sort(train_posteriors, descending=True)
            train_top_k = train_sort.clone().to(device)

            out_sort, _ = torch.sort(out_posteriors, descending=True)
            out_top_k = out_sort.clone().to(device)

            train_top = np.vstack((train_top,train_top_k[:,:2].cpu().detach().numpy()))
            out_top = np.vstack((out_top, out_top_k[:,:2].cpu().detach().numpy()))



            train_lbl = torch.ones(mini_batch_size).to(device)
            out_lbl = torch.zeros(mini_batch_size).to(device)
            
            
            train_inference = torch.squeeze(attack_net(train_posteriors, label_to_onehot(train_lbls).to(device)))
            out_inference = torch.squeeze(attack_net(out_posteriors, label_to_onehot(out_lbls).to(device)))


            
            loss_train = criterion(train_inference, train_lbl)
            loss_out = criterion(out_inference, out_lbl)
            loss = (loss_train + loss_out) / 2
            loss.backward()
            optimizer.step()
            


            correct += (train_inference>=0.5).sum().item()
            correct += (out_inference<0.5).sum().item()
            total += train_inference.size(0) + out_inference.size(0)


            print("[%d/%d][%d/%d] loss = %.2f, accuracy = %.2f" % (epoch, n_epochs, i, len(attack_train), loss.item(), 100 * correct / total))


        '''
        plt.scatter(out_top.T[0,:], out_top.T[1,:], c='b')
        plt.scatter(train_top.T[0,:], train_top.T[1,:], c='r')
        plt.show()
        '''
        
        
def eval_attacker(attack_net, target_net, attack_train, attack_out):

    target_net.eval()
    attack_net.eval()
       
    total = 0
    correct = 0

    #train_top = np.empty((0,2))
    #out_top = np.empty((0,2))
    
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    
    for i, ((train_imgs, train_lbls), (out_imgs, out_lbls)) in enumerate(zip(attack_train, attack_out)):


        mini_batch_size = train_imgs.shape[0]
        train_imgs, train_lbls = train_imgs.to(device), train_lbls.to(device)
        out_imgs, out_lbls = out_imgs.to(device), out_lbls.to(device)


        train_posteriors = F.softmax(target_net(train_imgs.detach()), dim=1)

        out_posteriors = F.softmax(target_net(out_imgs.detach()), dim=1)
        
        '''
        train_sort, _ = torch.sort(train_posteriors, descending=True)
        train_top_k = train_sort.clone().to(device)

        out_sort, _ = torch.sort(out_posteriors, descending=True)
        out_top_k = out_sort.clone().to(device)

        train_top = np.vstack((train_top,train_top_k[:,:2].cpu().detach().numpy()))
        out_top = np.vstack((out_top, out_top_k[:,:2].cpu().detach().numpy()))
        '''
        train_inference = torch.squeeze(attack_net(train_posteriors, label_to_onehot(train_lbls).to(device)))
        out_inference = torch.squeeze(attack_net(out_posteriors, label_to_onehot(out_lbls).to(device)))


        true_positives += (train_inference >= 0.5).sum().item()
        false_positives += (out_inference >= 0.5).sum().item()
        false_negatives += (train_inference < 0.5).sum().item()
        
        correct += (train_inference>=0.5).sum().item()
        correct += (out_inference<0.5).sum().item()
        total += train_inference.size(0) + out_inference.size(0)

    accuracy = 100 * correct / total 
    precision = true_positives / (true_positives + false_positives) if true_positives + false_positives != 0 else 0
    recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives !=0 else 0
    print("accuracy = %.2f, precision = %.2f, recall = %.2f" % (accuracy, precision, recall))



                  

In [None]:
adversarial_train(infer_net, net, D_loader, D_prime_loader, eval_out_loader,
                      infer_optim, infer_loss, class_optim, class_loss, n_epochs, 7, 1)

In [None]:
train_attacker(attack_net, net, D_A_loader, D_prime_A_loader, attack_optim, attack_loss, 100)

In [None]:
train(undefended_net, D_loader, eval_out_loader, undefended_optim, undefended_loss, n_epochs=100, classes=None, verbose=True)

In [None]:
train_attacker(attack_net2, undefended_net, D_A_loader, D_prime_A_loader, attack2_optim, attack2_loss, 100)

In [None]:
print("\nAttack performance on Adversarial Regularization Defense Network: ")
eval_attacker(attack_net, net, eval_train_loader, eval_out_loader)

print("\nAttack performance on normal network: ")
eval_attacker(attack_net2, undefended_net, eval_train_loader, eval_out_loader)

In [None]:
print("\nAdversarial Regularization network classification accuracy on training set: ")
train_accuracy = eval_target_net(net, D_loader, classes=None)

print("\nAdversarial Regularization network classification accuracy on test set: ")
test_accuracy = eval_target_net(net, eval_out_loader, classes=None)

print("\nNormal network classification accuracy on training set: ")
train_accuracy = eval_target_net(undefended_net, D_loader, classes=None)

print("\nNormal network classification accuracy on test set: ")
test_accuracy = eval_target_net(undefended_net, eval_out_loader, classes=None)

## Results



Attack performance on Adversarial Regularization Defense Network: 
accuracy = 54.68, precision = 0.53, recall = 0.78

Attack performance on normal network: 
accuracy = 65.97, precision = 0.60, recall = 0.95


Adversarial Regularization network classification accuracy on training set: 

Accuracy = 81.12 %



Adversarial Regularization network classification accuracy on test set: 

Accuracy = 68.42 %



Normal network classification accuracy on training set: 

Accuracy = 98.91 %



Normal network classification accuracy on test set: 

Accuracy = 74.12 %


