In [None]:
# -*- coding: utf-8 -*-

"""

@ author: Taehyeong Kim, Fusion Data Analytics and Artificial Intelligence Lab

"""


import numpy as np

import tensorflow as tf
from tensorflow.keras import layers, initializers

import matplotlib.pyplot as plt
%matplotlib inline

from IPython import display

import os
import time
import random

SEED=1011

def set_seeds(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    
set_seeds()

tf.__version__

#### 1. Data Preprocessing

In [None]:
%%time

from watermark import watermarking
from skimage import color

dataset = 'cifar10'
alpha = 0

if dataset=='cifar10':
    (train_images, train_labels), (_, _) = tf.keras.datasets.cifar10.load_data()
elif dataset=='cifar100':
    (train_images, train_labels), (_, _) = tf.keras.datasets.cifar100.load_data(label_mode='coarse')

if alpha != 0:
    train_images=watermarking(train_images, alpha)
    train_images=(train_images - 0.5) / 0.5
else:
    train_images=(train_images - 127.5) / 127.5

train_images=train_images.astype("float32")
print(train_images.shape, train_labels.max())

BUFFER_SIZE = train_images.shape[0]
BATCH_SIZE = 64
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_dataset

#### 2. Modelling

In [None]:
n_classes=train_labels.max()+1
latent_dim=100

def make_generator_model():

    in_label = tf.keras.Input(shape=(1,))
    li = layers.Embedding(n_classes, 50)(in_label)
    n_nodes = 4 * 4
    li = layers.Dense(n_nodes,
                      kernel_initializer=initializers.RandomNormal(stddev=0.02))(li)
    li = layers.Reshape((4, 4, 1))(li)

    in_lat = tf.keras.Input(shape=(latent_dim,))

    gen = layers.Dense(4*4*512,
                       use_bias=False, kernel_initializer=initializers.RandomNormal(stddev=0.02))(in_lat)
    gen = layers.BatchNormalization()(gen)
    gen = layers.LeakyReLU(0.2)(gen)

    gen = layers.Reshape((4, 4, 512))(gen)
    merge = layers.Concatenate()([gen, li])

    gen = layers.Conv2DTranspose(256, 5, strides=(2, 2), padding='same',
                                 use_bias=False, kernel_initializer=initializers.RandomNormal(stddev=0.02))(merge)
    gen = layers.BatchNormalization()(gen)
    gen = layers.LeakyReLU(0.2)(gen)

    gen = layers.Conv2DTranspose(128, 5, strides=(2, 2), padding='same',
                                 use_bias=False, kernel_initializer=initializers.RandomNormal(stddev=0.02))(gen)
    gen = layers.BatchNormalization()(gen)
    gen = layers.LeakyReLU(0.2)(gen)

    out_layer = layers.Conv2DTranspose(3, 5, strides=(2, 2), padding='same', activation='tanh',
                                       use_bias=False, kernel_initializer=initializers.RandomNormal(stddev=0.02))(gen)

    model = tf.keras.Model([in_lat, in_label], out_layer)

    return model

def make_discriminator_model():

    in_label = tf.keras.Input(shape=(1,))
    li = layers.Embedding(n_classes, 50)(in_label)
    n_nodes = 32 * 32
    li = layers.Dense(n_nodes,
                      kernel_initializer=initializers.RandomNormal(stddev=0.02))(li)
    li = layers.Reshape((32, 32, 1))(li)

    in_image = tf.keras.Input(shape=(32, 32, 3))
    merge = layers.Concatenate()([in_image, li])

    fe = layers.Conv2D(64, 5, strides=(2, 2), padding='same',
                       kernel_initializer=initializers.RandomNormal(stddev=0.02))(merge)
    fe = layers.LeakyReLU(0.2)(fe)
    fe = layers.Dropout(0.3)(fe)

    fe = layers.Conv2D(128, 5, strides=(2, 2), padding='same',
                       kernel_initializer=initializers.RandomNormal(stddev=0.02))(fe)
    fe = layers.LayerNormalization()(fe)
    fe = layers.LeakyReLU(0.2)(fe)
    fe = layers.Dropout(0.3)(fe)

    fe = layers.Conv2D(256, 5, strides=(2, 2), padding='same',
                       kernel_initializer=initializers.RandomNormal(stddev=0.02))(fe)
    fe = layers.LayerNormalization()(fe)
    fe = layers.LeakyReLU(0.2)(fe)
    fe = layers.Dropout(0.3)(fe)

    fe = layers.Flatten()(fe)
    out_layer = layers.Dense(1, activation='linear',
                             kernel_initializer=initializers.RandomNormal(stddev=0.02))(fe)

    model = tf.keras.Model([in_image, in_label], out_layer)

    return model


generator = make_generator_model()
generator.summary()

discriminator = make_discriminator_model()
discriminator.summary()

In [None]:
def gradient_penalty(real, fake, epsilon, label):

    mixed_images = fake + epsilon * (real - fake)

    with tf.GradientTape() as tape:
        tape.watch(mixed_images) 
        mixed_scores = discriminator([mixed_images, label])

    gradient = tape.gradient(mixed_scores, mixed_images)[0]
    gradient_norm = tf.norm(gradient)
    penalty = tf.math.reduce_mean((gradient_norm - 1)**2)

    return penalty

def generator_loss(fake_output):
    gen_loss = -1. * tf.math.reduce_mean(fake_output)
    return gen_loss

def discriminator_loss(real_output, fake_output, gradient_penalty):
    c_lambda = 10
    loss = tf.math.reduce_mean(fake_output) - tf.math.reduce_mean(real_output) + c_lambda * gradient_penalty
    return loss

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0, beta_2=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0, beta_2=0.9)

In [None]:
checkpoint_dir = './load_model/generator'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

EPOCHS = 300

num_examples_to_generate = n_classes
noise_dim = latent_dim

seed = tf.random.normal([num_examples_to_generate, noise_dim])
label = np.arange(0, n_classes, 1)
len(label)

In [None]:
@tf.function
def train_step(images):
    noise = tf.random.normal([images[0].shape[0], noise_dim])
    
    for i in range(5):
        
        with tf.GradientTape() as disc_tape:

            generated_images = generator([noise, images[1]], training=True)
            
            real_output = discriminator([images[0], images[1]], training=True)
            fake_output = discriminator([generated_images, images[1]], training=True)
            
            epsilon = tf.random.normal([images[0].shape[0], 1, 1, 1], 0.0, 1.0)
            gp = gradient_penalty(images[0], generated_images, epsilon, images[1])
        
            disc_loss = discriminator_loss(real_output, fake_output, gp)

        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

        
    with tf.GradientTape() as gen_tape:
        
        generated_images = generator([noise, images[1]], training=True)
        
        fake_output = discriminator([generated_images, images[1]], training=True)
        gen_loss = generator_loss(fake_output)
        
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

In [None]:
def generate_and_save_images(model, epoch, test_input, test_label):

    if len(test_label)==10:
        width=5
        height=2
    elif len(test_label)==20:
        width=5
        height=4

    predictions = model([test_input, test_label], training=False)

    fig = plt.figure(figsize=(width, height))

    for i in range(predictions.shape[0]):
        plt.subplot(height, width, i+1)

        R=predictions[i, :, :, 0] * 127.5 + 127.5
        G=predictions[i, :, :, 1] * 127.5 + 127.5
        B=predictions[i, :, :, 2] * 127.5 + 127.5

        sample=np.stack([R,G,B], axis=2).round().astype("int")

        plt.imshow(sample)
        plt.axis('off')

    if epoch % 10 == 0:
        plt.savefig('./figure/forgery/image_at_epoch_{:04d}.png'.format(epoch))

    plt.show()

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch)

        display.clear_output(wait=True)
        generate_and_save_images(generator,
                                 epoch + 1,
                                 seed, label)

        if (epoch + 1) % 50 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epochs,
                             seed, label)

#### 3. training

In [None]:
%%time
train(train_dataset, EPOCHS)

#### 4. save

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

set_seeds()

def generate_latent_points(latent_dim, n_samples, classes):
    x_input = np.random.randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    labels = np.full(n_samples, classes)
    return [z_input, labels]
 
def save_plot(examples, n):
    for i in range(n):
        plt.axis('off')
        plt.imshow(examples[i, :, :, :])
        plt.imsave('./data/forgery/fig{:d}.png'.format(i+1), examples[i])
    plt.show()

latent_points, labels = generate_latent_points(latent_dim=latent_dim,
                                               n_samples=1000,
                                               classes=0)
X = generator.predict([latent_points, labels])
X = (X + 1) / 2.0
save_plot(X, 500)