# DV2607 Project Notebook
### Authors:
### Oliver Ljung (ollj19@student.bth.se)
### Phoebe Waters (phaa19@student.bth.se)

## Importing modules and dataset

In [None]:
# imports
from matplotlib import pyplot as plt
import numpy as np

import tensorflow as tf

from keras.layers import Conv2D, Conv2DTranspose, MaxPooling2D, Flatten, Dense, Input, Activation, BatchNormalization, LeakyReLU, Reshape, UpSampling2D, Dropout, ReLU
from keras import Sequential, Model
from keras.datasets import mnist
from tensorflow.keras.datasets import cifar10
from keras.utils import to_categorical
import keras.backend as KB

from keras.losses import BinaryCrossentropy, CategoricalCrossentropy, Hinge, SquaredHinge, MeanSquaredError, Loss, SparseCategoricalCrossentropy

import random
import time

import art
from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescent, CarliniL2Method
from art.estimators.classification import KerasClassifier

print(tf.config.list_physical_devices('GPU'))

## Defining functions

In [None]:
# Defining functions

def display_attack(x_test, x_adv_test, model):
    x_real = x_test[:9]

    x_fake = x_adv_test[:9]
    x_fake_labels = model.predict(x_fake)
    x_real_labels = model.predict(x_real)

    for i in range(9):
        fig = plt.subplot(3, 3, i+1)
        fig.imshow(x_fake[i], cmap=plt.get_cmap('gray'))
        print(f'p_fake = {np.argmax(x_fake_labels[i])}, p_real = {np.argmax(x_real_labels[i])}')
    plt.show()

In [None]:
# Defining models

def create_generator(img_shape, seed_shape, channels):
    # Create a CNN model
    model = Sequential(name="generator")

    model.add(Dense(img_shape[0]*img_shape[1]*img_shape[2], activation="relu", input_dim=seed_shape))
    model.add(Reshape(img_shape))

    model.add(UpSampling2D())
    model.add(Conv2D(32, (3,3), activation='relu', padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU())

    model.add(MaxPooling2D())
    model.add(Conv2D(64, (3,3), activation='relu', padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU())

    model.add(UpSampling2D())
    model.add(Conv2D(128, (3,3), activation='relu', padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU())

    model.add(MaxPooling2D())
    model.add(Conv2D(256, (3,3), activation='relu', padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU())

    model.add(Conv2D(channels, (3,3), activation='relu', padding="same"))
    model.add(Activation('tanh', name="generated_img"))

    print(model.summary())

    input = Input(shape = (seed_shape,))
    
    img = model(input)
    
    return Model(input, img, name="generator")

def create_discriminator(img_shape, verbose=True):
    # Create a CNN model
    model = Sequential(name="discriminator")

    # Add Convolution layers
    model.add(Conv2D(32, (3,3), activation='relu', input_shape=img_shape))
    model.add(MaxPooling2D((2, 2)))

    model.add(Conv2D(64, (3,3), activation='relu'))
    model.add(MaxPooling2D((2, 2)))

    model.add(Conv2D(64, (3,3), activation='relu'))

    model.add(Flatten())

    # Add predictive layers
    model.add(Dense(64, activation='relu'))
    model.add(Dense(1, activation='sigmoid', name="validity_output"))
    
    print(model.summary())

    # Add input
    input = Input(shape = img_shape)

    # Get ouput from model
    validity = model(input)

    # return model
    return Model(input, validity, name="discriminator")


## Loading dataset

In [None]:
# load dataset
(train_X, train_y), (test_X, test_y) = mnist.load_data()

train_X = train_X.astype("float32") / 255
test_X = test_X.astype("float32") / 255

train_X = np.expand_dims(train_X, -1)
test_X = np.expand_dims(test_X, -1)

train_y = to_categorical(train_y)
train_y = np.array([np.argmax(y) for y in train_y])
test_y  = to_categorical(test_y)
test_y = np.array([np.argmax(y) for y in test_y])

IMG_SHAPE = (28,28,1)
CHANNELS = 1
SEED_SHAPE = 100
NUM_CLASSES = 10
IMG_LOW_LIMIT = 0
IMG_HIGH_LIMIT = 1

In [None]:
# load dataset
# import ssl
# ssl._create_default_https_context = ssl._create_unverified_context

# (train_X, train_y), (test_X, test_y) = cifar10.load_data()

# train_X = train_X.astype("float32")/255
# test_X = test_X.astype("float32")/255

# train_y = to_categorical(train_y)
# train_y = np.array([np.argmax(y) for y in train_y])
# test_y  = to_categorical(test_y)
# test_y = np.array([np.argmax(y) for y in test_y])

# IMG_SHAPE = (32,32,3)
# CHANNELS = 3
# SEED_SHAPE = 100
# NUM_CLASSES = 10
# IMG_LOW_LIMIT = 0
# IMG_HIGH_LIMIT = 1

## Defining models and input tensors

In [None]:
optimizer = tf.optimizers.Adam(learning_rate=0.0001)

input = Input(shape=(SEED_SHAPE,), name="seed")

generator = create_generator(IMG_SHAPE, SEED_SHAPE, CHANNELS)
fake_image = generator(input)

fake_image = tf.clip_by_value(fake_image, IMG_LOW_LIMIT, IMG_HIGH_LIMIT) # Values in image is [0,1]

# We want the generator and discriminator to be trained in a combined model but as seperate entities
discriminator = create_discriminator(IMG_SHAPE)
discriminator.compile(optimizer=optimizer, loss=BinaryCrossentropy(), metrics=["accuracy"])
discriminator.trainable = False
validity_fake = discriminator(fake_image)

GAN_model = Model(inputs=input, outputs=validity_fake, name="GAN-net")

GAN_model.compile(optimizer=optimizer, loss=BinaryCrossentropy(), metrics=["accuracy"])
GAN_model.summary()

## GAN Training

In [None]:
BATCH_SIZE = 32 
EPOCHS = 1_000_000

y_real = np.ones((BATCH_SIZE, 1))
y_fake = np.zeros((BATCH_SIZE, 1))

start_time = time.time()
for epoch in range(EPOCHS):
    if time.time() - start_time > 60*60*3:
        break   # time limit of 3 hours

    x_real = np.array(random.choices(train_X, k=BATCH_SIZE))
    seed = np.random.normal(0,1,(BATCH_SIZE, SEED_SHAPE))
    
    x_fake = generator.predict(seed, verbose=0)
    x_fake = np.clip(x_fake, IMG_LOW_LIMIT, IMG_HIGH_LIMIT)
    
    discriminator_loss_real = discriminator.train_on_batch(x=x_real, y=y_real)
    discriminator_loss_fake = discriminator.train_on_batch(x=x_fake, y=y_fake)
    discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

    generator_loss = GAN_model.train_on_batch(x=seed, y=y_real)

    if epoch % 500 == 0:
        print(f"EPOCH: {epoch}")
        print(f"    Generator Loss:     {round(generator_loss[0], 3)},  Discriminator Loss:     {round(discriminator_loss[0], 3)}")
        print(f"    Generator Accuracy: {round(generator_loss[1], 3)},  Discriminator Accuracy: {round(discriminator_loss[1], 3)}")

        for i in range(3):
            fig = plt.subplot(1, 3, i+1)
            fig.imshow(x_fake[i], cmap=plt.get_cmap('gray'))
        plt.show()