In [1]:
from matplotlib import pyplot as plt
import pickle
import numpy as np
import os
import random
import tensorflow as tf
from tqdm import tqdm
import random
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import cm
from matplotlib import pyplot as plt
import torch
import torch.nn as nn

In [2]:
device = torch.device("cpu")

In [3]:
#Path of the generated samples by beta poisoning attack
CIFAR_GEN_SAMPLES_PATH = ""
MNIST_GEN_SAMPLES_PATH = ""

### Load MNIST real/adversarial samples

In [4]:
mnist_gen_samples = []

for class_folder in os.scandir(MNIST_GEN_SAMPLES_PATH):
    for sample in os.scandir(class_folder):
        with open(sample, "rb") as f:
            sample = pickle.load(f)
            mnist_gen_samples.append(sample)

In [5]:
mnist_gen_samples = np.vstack(mnist_gen_samples)

mnist_gen_samples = mnist_gen_samples.reshape(mnist_gen_samples.shape[0], 28, 28, 1).astype('float32')
mnist_gen_samples = (mnist_gen_samples - 127.5) / 127.5  # Normalize the images to [-1, 1]

In [7]:
(_, _), (mnist_test_images, mnist_test_labels) = tf.keras.datasets.mnist.load_data()

In [8]:
mnist_test_images = mnist_test_images.reshape(mnist_test_images.shape[0], 28, 28, 1).astype('float32')
mnist_test_images = (mnist_test_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

In [9]:
mnist_real_samples = mnist_test_images[0:len(mnist_gen_samples), :, :, :]

### Load CIFAR-10 real/adversarial samples

In [58]:
cifar_gen_samples = []

for class_folder in os.scandir(CIFAR_GEN_SAMPLES_PATH):
    for sample in os.scandir(class_folder):
        with open(sample, "rb") as f:
            sample = pickle.load(f)
            cifar_gen_samples.append(transforms.ToPILImage()(sample.reshape(3, 32, 32)))

In [59]:
#It is important to normalize the images for the discriminator to work properly.

cifar_transform =  transforms.Compose(
                    [transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [71]:
cifar_gen_samples = [torch.unsqueeze(cifar_transform(x), dim=0) for x in cifar_gen_samples]
cifar_gen_samples = torch.cat(cifar_gen_samples)

In [60]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset
cifar_test_dataset = torchvision.datasets.CIFAR10("./cifar10", transform=cifar_transform, download=True)

Files already downloaded and verified


In [61]:
cifar_dataloader = torch.utils.data.DataLoader(cifar_test_dataset, batch_size=len(cifar_gen_samples))

In [62]:
cifar_real_samples = next(iter(cifar_dataloader))[0].to(device)

### Discriminator model for MNIST

In [23]:
mnist_discriminator = tf.keras.models.load_model('../discriminator_models/mnist-discriminator.h5')



In [30]:
threshold = 0.1 #Refer to the paper for the threshold value

with tf.device('/cpu:0'):
    mnist_real_outputs = mnist_discriminator(mnist_real_samples, training=False)
    mnist_fake_outputs = mnist_discriminator(mnist_gen_samples, training=False)
    
mnist_precision_hist = []
mnist_recall_hist = []
mnist_f1_hist = []
mnist_accuracy_hist = []


tp = (mnist_fake_outputs < threshold).numpy().sum()
fp = (mnist_real_outputs < threshold).numpy().sum()
tn = (mnist_real_outputs > threshold).numpy().sum()
fn = (mnist_fake_outputs > threshold).numpy().sum()


precision = (tp) / (tp + fp)
recall = (tp) / (tp + fn)
f1 = (2*precision*recall) / (precision + recall)
accuracy = (tp + tn) / (tp + tn + fp + fn)

print("Results for the MNIST discriminator")
print("Precision: ", precision)
print("Recall: ", recall)
print("F1: ", f1)
print("Accuracy: ", accuracy)

Results for the MNIST discriminator
Precision:  0.9982993197278912
Recall:  1.0
F1:  0.9991489361702128
Accuracy:  0.9991482112436116


### Discriminator model for CIFAR-10

In [85]:
image_size = 32
# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64


# Discriminator
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
    
# cifar_discriminator = torch.load("../discriminator_models/cifar_discriminator.pt")
# cifar_discriminator.to(device)
# cifar_discriminator.eval()

cifar_discriminator = Discriminator(0).to(device)
checkpoint = torch.load(f"../discriminator_models/cifar_discriminator.pt")
cifar_discriminator.load_state_dict(checkpoint['state_dict'])
cifar_discriminator.eval()

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)

In [88]:
cifar_real_outputs = cifar_discriminator(cifar_real_samples.reshape(-1, 3, 32, 32))
cifar_fake_outputs = cifar_discriminator(cifar_gen_samples.reshape(-1, 3, 32, 32))


In [89]:
threshold = 0.36 #Refer to the paper for the threshold value

tp = (cifar_fake_outputs < threshold).numpy().sum()
fp = (cifar_real_outputs < threshold).numpy().sum()
tn = (cifar_real_outputs > threshold).numpy().sum()
fn = (cifar_fake_outputs > threshold).numpy().sum()


precision = (tp) / (tp + fp)
recall = (tp) / (tp + fn)
f1 = (2*precision*recall) / (precision + recall)
accuracy = (tp + tn) / (tp + tn + fp + fn)

print("Results for the CIFAR-10 discriminator")
print("Precision: ", precision)
print("Recall: ", recall)
print("F1: ", f1)
print("Accuracy: ", accuracy)

Results for the CIFAR-10 discriminator
Precision:  0.9014522821576764
Recall:  0.9720357941834452
F1:  0.9354144241119484
Accuracy:  0.9328859060402684
