In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from models import *
from utils import *
from test import test
from train import train
import torch.utils.data as data
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import copy
from torch.utils.data import TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import matplotlib.pyplot as plt
import time

In [None]:
#Create new examples by pairing all the possible combinations of n images taken 2
def samples_pair(images, labels):
    paired = []
    local_labels = []
    for i in range(len(images)):
        paired.append(images[i].numpy())
        local_labels.append(labels[i])
        for j in range(i, len(images)):
            r = np.random.random()
            if r >0.5:
                paired.append((images[i].numpy() + images[j].numpy())/2)
                local_labels.append(labels[i])
    return paired, local_labels

In [None]:
# Seed Everything
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True

In [None]:
#Load the provider's pretrained MNIST model
path_victim_mnist = 'pretrained_models/victim_mnist_l5.pt'

victim_mnist_model = MNIST_L5().cuda()

victim_mnist_model = load_state(victim_mnist_model, path_victim_mnist)
criterion = nn.CrossEntropyLoss()

In [None]:
#Original datasets
mnist_trainset, mnist_testset = get_mnist_dataset()

In [None]:
# calculating the 0.5% of the original data set, and randomly selecting the images. 
n = len(mnist_trainset)
d = int(len(mnist_trainset)*0.5/100)
print(d)
sampled_indices = np.random.choice(n, d, replace=False)
counter = 0
images = []
labels = []
for image, label in mnist_trainset:
#     print(label)
    if counter in sampled_indices:
        images.append(image)
        labels.append(label)
#         print(label)
    counter +=1
#     labels.append(label)

In [None]:
# creating the augmented data set
paired, labels = samples_pair(images, labels)

In [None]:
len(paired)

In [None]:
#Create data set of two lists
tensor_x = torch.Tensor(paired) # transform to torch tensor
tensor_y = torch.Tensor(labels)

emnist_trainset = TensorDataset(tensor_x,tensor_y) # create the user data set

In [None]:
#Load MNIST original train and test set to data loaders
mnist_test_loader = data.DataLoader(mnist_testset, batch_size=100, shuffle = False)
mnist_train_loader = data.DataLoader(mnist_trainset, batch_size=100, shuffle = False)

In [None]:
# The provider model's accuracy on the test set
test(victim_mnist_model, mnist_test_loader, criterion)

In [None]:
# Create the user's query loader, his model and optimizer
# the surrogate model can be changed here to any MNIST model availble from models.py
attacker_query_loader = data.DataLoader(emnist_trainset, batch_size=100, shuffle = False)
attacker_mnist_model = exp_MNIST_L5_2().cuda()
attacker_optimizer = optim.SGD(attacker_mnist_model.parameters(), 
    lr=0.001, momentum = 0.9)

In [None]:
# Query the provider model using the user unlabeled data. The user labels his data usig the provider model
start = int(round(time.time()*1000)) 
labels = query_labels(victim_mnist_model, attacker_query_loader)
time_elapsed = int(round(time.time()*1000)) -start
# print(start.elapsed_time(end))
print ('test time elapsed {}ms'.format(time_elapsed))
print ('test time elapsed {}s'.format(time_elapsed/1000))

In [None]:
#Creating the user data set by merging the images with the labels obtained from the provider model.
attacker_labled_data = get_attacker_dataset(emnist_trainset, labels)

In [None]:
#Loading the training data and the watermark data to data loaders
# Creating data indices for training and validation splits:
validation_split = .05
random_seed = 42

dataset_size = len(attacker_labled_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))


np.random.seed(random_seed)
np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

attacker_train_loader = torch.utils.data.DataLoader(attacker_labled_data, batch_size=100,
                                           sampler=train_sampler)
attacker_val_loader = torch.utils.data.DataLoader(attacker_labled_data, batch_size=100,
                                                sampler=valid_sampler)

In [None]:
# Test the user model accuracy on the test set before training. It should be random (something close to 10%)
test(attacker_mnist_model, mnist_test_loader, criterion)

In [None]:
#Training the attacker model on the paired data set
attacker_main_task_acc = []
watermark_acc = []
best_acc = 0
start = int(round(time.time()*1000))
for epoch in tqdm_notebook(range(100)):
    attacker_mnist_model, loss = train(model=attacker_mnist_model, train_loader=attacker_train_loader, 
      criterion=criterion, optimizer=attacker_optimizer, local_epochs=1)
    print('Epoch: ', epoch+1)
    print('Attacker\'s model acc on the validation set')
    _, acc = test(attacker_mnist_model, attacker_val_loader, criterion)
    if acc > best_acc:
        best_acc = acc
        torch.save(attacker_mnist_model.state_dict(), 'best_mnist_attacker_model_second_different_with_MNIST.pth')
#     print('Attacker\'s model acc on the watermark triggers')
#     _, wm_acc = test(attacker_mnist_model, watermark_loader, criterion)
#     attacker_main_task_acc.append(acc)
#     watermark_acc.append(wm_acc)
time_elapsed = int(round(time.time()*1000)) -start
# print(start.elapsed_time(end))
print ('test time elapsed {}ms'.format(time_elapsed))
print ('test time elapsed {}s'.format(time_elapsed/1000))

In [None]:
# saving the user model
attacker_mnist_model.load_state_dict(torch.load('Surrogate model 3.pth'))

In [None]:
# Test the user model accuracy on the test set after training. It should be accurate (something close to the accuracy of the provider model)
start = int(round(time.time()*1000))
test(attacker_mnist_model, mnist_test_loader, criterion)
time_elapsed = int(round(time.time()*1000)) -start
# print(start.elapsed_time(end))
print ('test time elapsed {}ms'.format(time_elapsed))
print ('test time elapsed {}s'.format(time_elapsed/1000))

In [None]:
# This function checks the provider model for any wrong classified image and returns a
# list with all the images that were wrongly classified.

def our_test(local_model, device, local_test_loader):

    # Accuracy counter
    correct = 0
    wrong_examples = []
    logits = []
    labels = []
    counter = 0

    # Loop over all examples in test set
    for data, target in local_test_loader:
        counter += 1

        # Send the data and label to the device
        data, target = data.to(device), target.to(device)

        # Forward pass the data through the model

        with torch.no_grad():
            output = local_model(data)

        # Check for success
    
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            correct += 1
        else:

            wrong_examples.append(data)
            labels.append(target)
            logits.append(output)
        if len(labels) > 300:
            break
    
    # Calculate final accuracy for this epsilon
    print(len(local_test_loader))
    print(correct)
    final_acc = correct/float(len(local_test_loader))

    # Return the accuracy and an adversarial example
    return final_acc, wrong_examples, labels, logits

In [None]:
#the data loader with batch size = 1, to check the images one by one. 
new_batch_size = 1
our_test_loader = torch.utils.data.DataLoader(mnist_testset, new_batch_size, num_workers=0, pin_memory=True)

In [None]:
# Get wrong predictions
accuracy, wrong_examples, labels, logits = our_test(victim_mnist_model, "cuda", our_test_loader)

In [None]:
wrong_labels = [logits[i].max(1, keepdim=True)[1].item() for i in range(len(labels))]

In [None]:
len(wrong_labels)

In [None]:
# adding the perturbations to create the counterfactual example.
corrected_examples = np.zeros([len(wrong_examples)])
perturbed_examples = []
new_labels = []
eps_all = []
start = int(round(time.time()*1000))
for i in range(len(wrong_examples)):
    x, y, prediction = wrong_examples[i], labels[i], logits[i]
    eps = 0.0
#     print(i)
    while True:
        perturbed_image = x.clone()
        perturbed_image.requires_grad = True
        output = attacker_mnist_model(perturbed_image)
        loss = F.nll_loss(output, y)
        attacker_mnist_model.zero_grad()
            # Calculate gradients of model in backward pass
        loss.backward()
        img_grad = perturbed_image.grad.data
        perturbed_image = perturbed_image - eps*img_grad
        output = victim_mnist_model(perturbed_image)
        new_label = output.max(1, keepdim=True)[1]
        if(new_label.item() == y.item()):
            perturbed_examples.append(perturbed_image.squeeze().data.cpu().numpy())
            new_labels.append(new_label)
            eps_all.append(eps)
            corrected_examples[i] = 1
            print("Image {} has been modified with epsilon {}".format(i, eps))
            break
        eps += 0.05
        if eps > 0.99:
            break
time_elapsed = int(round(time.time()*1000)) -start
# print(start.elapsed_time(end))
print ('test time elapsed {}ms'.format(time_elapsed))
print ('test time elapsed {}s'.format(time_elapsed/1000))

In [None]:
# lists of the images that have been explained
real_examples = []
real_labels = []
wrong_predictions = []
corrected_idx = np.where(corrected_examples == 1)
for idx in corrected_idx[0]:
    real_examples.append(wrong_examples[idx].squeeze().data.cpu().numpy())
    real_labels.append(labels[idx].item())
    wrong_predictions.append(wrong_labels[idx])

In [None]:
def get_quartlies(samples):
    q1, med, q3 = np.percentile(samples, [25, 50, 75])
    return q3, q3-q1

In [None]:
# boosting the perturbation added to the images
diff = []
tau = 8
for i in range(len(eps_all)):
    diff.append((real_examples[i] - perturbed_examples[i])**2)
    q3, iqr = get_quartlies(diff[i])
    idx = np.where(diff[i] < q3+iqr*tau)
    diff[i][idx]*=0

In [None]:
# neglecting the non improtant perturbation and boosting the important once
diff2 = diff
for im in range(len(diff)):
        for color in range(28):
            for pixel in range(28):
                if diff[im][color][pixel] > 0.0:
                    diff2[im][color][pixel] = 1
                                


In [None]:
# real_examples[0].shape

In [None]:
# len(real_examples)

In [None]:
# path to save the explanations
path_wrong = 'surrogate_model_3/wrong'

In [None]:
# normalizing the images
for i in range(len(perturbed_examples)):
    for j in range(len(perturbed_examples[i])):
        for k in range(len(perturbed_examples[i][j])):
            perturbed_examples[i][j][k] = (perturbed_examples[i][j][k]-perturbed_examples[i].min())/(perturbed_examples[i].max()-perturbed_examples[i].min())


In [None]:
# saving the explanation
for i in range(len(real_examples)):
#    the original image
    image1 = torch.from_numpy(real_examples[i])
#     saving the image with the explanation
    plt.imsave(path_wrong+'/edited/'+str(i)+".jpg",z)
#     saveing the perturbtation added to the image
    image3 = torch.from_numpy(diff[i])
    z = image1 + image3
    for color in range(28):
        for pixel in range(28):
            if z[color][pixel] > 0.0:
                z[color][pixel] = 1
    plt.imsave(path_wrong+'/image_with_exp/'+str(i)+".jpg",z)

In [None]:
# This function checks the provider model for any correctly classified image and returns a
# list with all the images that were correctly classified.
def our_test_true_classified(local_model, device, local_test_loader):

    # Accuracy counter
    correct = 0
    wrong_examples = []
    logits = []
    labels = []
    counter = 0
    second_label = []

    # Loop over all examples in test set
    for data, target in local_test_loader:
        counter += 1

        # Send the data and label to the device
        data, target = data.to(device), target.to(device)

        # Forward pass the data through the model

        with torch.no_grad():
            output = local_model(data)
#             print(output)
#             print(counter)
#             print(target)

        # Check for success
    
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        check = torch.topk(output, 2)
#         print(final_pred)
#         print(check.indices[0][1])
#         second_label.append(check.indices[0][1])
#         final_pred = output.argmax()
#         print(final_pred)
        if final_pred.item() == target.item():
            wrong_examples.append(data)
            labels.append(target)
            logits.append(output)
            second_label.append(check.indices[0][1])
#             print('same')
            correct += 1
#         else:
#             print('hi')

#             wrong_examples.append(data)
#             labels.append(target)
#             logits.append(output)
#         if len(labels) > 300:
#             break
    
    # Calculate final accuracy for this epsilon
#     print(len(local_test_loader))
#     print(correct)
    final_acc = correct/float(len(local_test_loader))

    # Return the accuracy and an adversarial example
    return final_acc, wrong_examples, labels, logits, second_label

In [None]:
# Get correct predictions
accuracy, wrong_examples, labels, logits, second = our_test_true_classified(victim_mnist_model, "cuda", our_test_loader)

In [None]:
# adding the perturbations to create the counterfactual example.
corrected_examples = np.zeros([len(wrong_examples)])
perturbed_examples = []
new_labels = []
eps_all = []
new_counter = 0
start =int(round(time.time()*1000))
for i in range(len(wrong_examples)):
    x, y, prediction = wrong_examples[i], torch.tensor([int(second[i])]).to('cuda'), logits[i]
    eps = 0.0
#     print(i)
    while True:
        perturbed_image = x.clone()
        perturbed_image.requires_grad = True
        output = attacker_mnist_model(perturbed_image)
        loss = F.nll_loss(output, y)
        attacker_mnist_model.zero_grad()
            # Calculate gradients of model in backward pass
        loss.backward()
        img_grad = perturbed_image.grad.data
        perturbed_image = perturbed_image - eps*img_grad
        output = victim_mnist_model(perturbed_image)
        new_label = output.max(1, keepdim=True)[1]
        if(new_label.item() == y.item()):
            new_counter += 1
            perturbed_examples.append(perturbed_image.squeeze().data.cpu().numpy())
            new_labels.append(new_label)
            eps_all.append(eps)
            corrected_examples[i] = 1
            print("Image {} has been modified with epsilon {}".format(i, eps))
            break
        eps += 0.05
        if eps > 0.99:
            break
time_elapsed = int(round(time.time()*1000)) -start
# print(start.elapsed_time(end))
print ('test time elapsed {}ms'.format(time_elapsed))
print ('test time elapsed {}s'.format(time_elapsed/1000))

In [None]:
real_examples = []
real_labels = []
wrong_predictions = []
corrected_idx = np.where(corrected_examples == 1)
print(len(corrected_idx))
for idx in corrected_idx[0]:
    real_examples.append(wrong_examples[idx].squeeze().data.cpu().numpy())
    real_labels.append(labels[idx].item())
    wrong_predictions.append(second[idx])

In [None]:
# boosting the perturbation added to the images
diff = []
tau = 12
for i in range(len(eps_all)):
    diff.append((real_examples[i] - perturbed_examples[i])**2)
    q3, iqr = get_quartlies(diff[i])
    idx = np.where(diff[i] < q3+iqr*tau)
    diff[i][idx]*=0

In [None]:
# neglecting the non improtant perturbation and boosting the important once
diff2 = diff
for im in range(len(diff)):
        for color in range(28):
            for pixel in range(28):
                if diff[im][color][pixel] > 0.0:
                    diff2[im][color][pixel] = 1
                                


In [None]:
# path to save the explanation of the correctky classified images
path_right = 'surrogate_model_3/right'

In [None]:
# saving the explanation
for i in range(len(real_examples)):
#    the original image
    image1 = torch.from_numpy(real_examples[i])
#     saving the image with the explanation
    plt.imsave(path_wrong+'/edited/'+str(i)+".jpg",z)
#     saveing the perturbtation added to the image
    image3 = torch.from_numpy(diff[i])
    z = image1 + image3
    for color in range(28):
        for pixel in range(28):
            if z[color][pixel] > 0.0:
                z[color][pixel] = 1
    plt.imsave(path_right+'/image_with_exp/'+str(i)+".jpg",z)