# StyleGAN

Dataset: [https://ai.stanford.edu/~jkrause/cars/car_dataset.html](https://ai.stanford.edu/~jkrause/cars/car_dataset.html). 

## Dependencies

In [None]:
!pip install -q git+https://github.com/tensorflow/docs &> /dev/null
!pip install imageio wandb &> /dev/null

In [None]:
import glob
import imageio
import time
import os
import re
import tarfile
import datetime

from IPython import display
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from zipfile import ZipFile
from PIL import Image

# %tensorflow_version 2.x
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras.utils.vis_utils import plot_model
from keras.initializers import RandomNormal
from keras import backend

import wandb

In [None]:
wandb.login()

## Data preparation

In [None]:
config = {
    'IMAGE_HEIGHT': 128,
    'IMAGE_WIDTH': 128,
}

In [None]:
data_path = 'data'
data_compressed_filename_train = 'cars_train.tgz'
data_extracted_foldername = 'cars'

In [None]:
def extract(tar_url, dir):
    if not (os.path.exists(f"{data_path}/{data_extracted_foldername}/{dir}")):
        tar = tarfile.open(f"{data_path}/{tar_url}", 'r')
        for item in tar:
            tar.extract(item, f"{data_path}/{data_extracted_foldername}/{dir}")
            if item.name.find(".tgz") != -1 or item.name.find(".tar") != -1:
                extract(item.name, "./" + item.name[:item.name.rfind('/')])

extract(data_compressed_filename_train,'train')

In [None]:
def load_images():
    image_list = []
    
    for filename in glob.glob(f'{data_path}/{data_extracted_foldername}/train/cars_train/*.jpg'): 
        image = Image.open(filename).resize((config['IMAGE_HEIGHT'], config['IMAGE_WIDTH']))
        image = np.asarray(image)

        if len(image.shape) == 3: # take only rgb images
            image_list.append(image)
    
    return np.asarray(image_list)

In [None]:
train_images = load_images()
train_images.shape

In [None]:
for i in range(9):
    plt.subplot(3, 3, 1 + i)
    plt.axis('off')
    plt.imshow(train_images[i])

plt.show()

Optimize by converting from unsigned ints to floats and scale from [0,255] to [-1,1]

In [None]:
train_images = train_images.astype('float32')
train_images = (train_images - 127.5) / 127.5

## Model

In [None]:
# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

# clip model weights to a given hypercube
class ClipConstraint(keras.constraints.Constraint):

    def __init__(self, clip_value):
        self.clip_value = clip_value

    def __call__(self, weights):
        return backend.clip(weights, -self.clip_value, self.clip_value)

    def get_config(self):
        return {'clip_value': self.clip_value}

In [None]:
config = {
    **config,
    "EPOCHS": 200,
    "BATCH_SIZE": 64,
    "LEARNING_RATE": 0.00005, # wgan
    # "LEARNING_RATE": 0.0002,
    "BETA": 0.5,
    "LOSS": wasserstein_loss,
    # "LOSS": 'binary_crossentropy',
    "LATENT_DIM":100,
    'D_DROPOUT': 0.4,
    'D_OUTPUT_ACTIVATION': 'linear', # wgan
    # 'D_OUTPUT_ACTIVATION': 'sigmoid', # wgan
    'N_CRITIC': 5,
}

### Discriminator

In [None]:
# def define_discriminator(in_shape=(config['IMAGE_HEIGHT'],config['IMAGE_WIDTH'],3)):
#     # weight initialization
#     init = RandomNormal(stddev=0.02)
    
#     # define model
#     model = tf.keras.Sequential()

#     # normal
#     model.add(layers.Conv2D(64, (3,3), padding='same', kernel_initializer=init, input_shape=in_shape))
#     model.add(layers.LeakyReLU(alpha=0.2))
    
#     # downsample to 64x64
#     model.add(layers.Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init))
#     model.add(layers.LeakyReLU(alpha=0.2))
    
#     # downsample to 32x32
#     model.add(layers.Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init))
#     model.add(layers.LeakyReLU(alpha=0.2))
    
#     # downsample to 16x16
#     model.add(layers.Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init))
#     model.add(layers.LeakyReLU(alpha=0.2))

#     # classifier
#     model.add(layers.Flatten())
#     model.add(layers.Dropout(config["D_DROPOUT"]))
#     model.add(layers.Dense(1, activation=config['D_OUTPUT_ACTIVATION']))

#     # compile model
#     opt = keras.optimizers.Adam(lr=config['LEARNING_RATE'], beta_1=config['BETA'])
#     model.compile(loss=config['LOSS'], optimizer=opt, metrics=['accuracy'])
    
#     return model

# discriminator = define_discriminator()
# plot_model(discriminator, to_file='discriminator_plot.png', show_shapes=True, show_layer_names=True)

### Critic

In [None]:
def define_critic(in_shape=(config['IMAGE_HEIGHT'], config['IMAGE_WIDTH'], 3)):
    # weight initialization
    init = RandomNormal(stddev=0.02)

    # weight constraint
    const = ClipConstraint(0.01)

    # define model
    model = keras.Sequential()

    # downsample to 64x64
    model.add(layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # downsample to 32x32
    model.add(layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # downsample to 16x16
    model.add(layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # downsample to 8x8
    model.add(layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # scoring, linear activation
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation=config['D_OUTPUT_ACTIVATION']))

    # optimizer
    opt = keras.optimizers.RMSprop(lr=config['LEARNING_RATE'])  # wgan

    # compile model
    model.compile(loss=config['LOSS'], optimizer=opt)

    return model

critic = define_critic()
plot_model(critic, to_file='critic_plot.png', show_shapes=True, show_layer_names=True)

### Generator

In [None]:
def define_generator(input_dim=config['LATENT_DIM']):
    # weight initialization
    init = RandomNormal(stddev=0.02)

    # define model
    model = keras.Sequential()

    # foundation for 4x4 image
    n_nodes = 256 * 4 * 4

    model.add(layers.Dense(n_nodes, kernel_initializer=init, input_dim=input_dim))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((4, 4, 256)))

    # upsample to 8x8
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # upsample to 16x16
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # upsample to 32x32
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # upsample to 64x64
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # upsample to 128x128
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))

    # output 128x128x3
    model.add(layers.Conv2D(3, (4, 4), activation='tanh', padding='same', kernel_initializer=init))

    return model

generator = define_generator()
plot_model(generator, to_file='generator_plot.png', show_shapes=True, show_layer_names=True)

### GAN

In [None]:
def define_gan(generator, critic):

    # make weights in the critic not trainable
    for layer in critic.layers:
        if not isinstance(layer, layers.BatchNormalization):
            critic.trainable = False
    
    # define model
    model = keras.Sequential()

    # add generator
    model.add(generator)

    # add the criticz
    model.add(critic)

    # optimizer
    opt = keras.optimizers.RMSprop(lr=config['LEARNING_RATE']) #wgab
    # opt = keras.optimizers.Adam(lr=config['LEARNING_RATE'], beta_1=config['BETA'])
    
    # compile model
    model.compile(loss=config['LOSS'], optimizer=opt)

    return model

gan = define_gan(generator,critic)
plot_model(gan, to_file='gan_plot.png', show_shapes=True, show_layer_names=True)

## Traing the model

In [None]:
def check_if_dir_exists(filepath):
    directory = os.path.dirname(filepath)
    Path(directory).mkdir(parents=True, exist_ok=True)
    return filepath

def generate_real_samples(dataset, n_samples):
    """
    select real samples
    """
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]

    # wgan
    y = -np.ones((n_samples, 1))
    # y = np.ones((n_samples, 1))

    return X, y


def generate_latent_points(latent_dim, n_samples):
    """
    generate points in latent space as input for the generator
    """
    x_input = np.random.randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples, latent_dim)

    return x_input


def generate_fake_samples(generator, latent_dim, n_samples):
    """
    use the generator to generate n fake examples, with class labels
    """
    x_input = generate_latent_points(latent_dim, n_samples)
    X = generator.predict(x_input)

    y = np.ones((n_samples, 1))  # wgan
    # y = np.zeros((n_samples, 1))

    return X, y


def save_model(epoch, g_model, model_path):
    # save locally
    filename = check_if_dir_exists(f'{model_path}/generator_model_{(epoch + 1):04d}.h5')
    g_model.save(filename)

    # save to wandb
    wandb_run.save(filename)


def save_plot(epoch, images_path, g_model, latent_dim, n_samples=150, n=7):
    """
    create and save a plot of generated images
    """

    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    X = (X + 1) / 2.0

    for i in range(n * n):
        plt.subplot(n, n, 1 + i)
        plt.axis('off')
        plt.imshow(X[i])

    # save plot to file
    image_name = f"generated_plot_e{(epoch + 1):04d}.png"
    filename = check_if_dir_exists(f'{images_path}/{image_name}')
    plt.savefig(filename)

    # save to wandb
    wandb_run.log({"images": wandb.Image(plt, caption=image_name)})

    plt.close()


def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=200, n_batch=128, run=0, n_critic=5):
    """
    train the generator and critic
    """
    
    model_path = f'runs/{run}/models'
    images_path = f'runs/{run}/images'

    time_train_start = time.time()
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)

    # manually enumerate epochs
    for i in range(n_epochs):
        time_epoch_start = time.time()

        c1_epoch, c2_epoch, gan_epoch = list(), list(), list()
        for j in range(bat_per_epo):

            # update the critic more than the generator
            # c1_batch, c2_batch = list(), list()
            # for _ in range(n_critic):
 
            # real
            X_real, y_real = generate_real_samples(dataset, half_batch)
            c_loss1 = c_model.train_on_batch(X_real, y_real)
            c1_batch.append(c_loss1)

            # fake
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            c_loss2 = c_model.train_on_batch(X_fake, y_fake)
            c2_batch.append(c_loss2)

            # store critic loss
            c1_epoch.append(c_loss1)
            c2_epoch.append(c_loss2)

            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = -np.ones((n_batch, 1))  # wgan
            # y_gan = np.ones((n_batch, 1))
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            gan_epoch.append(g_loss)

        c1_loss = np.mean(c1_epoch)
        c2_loss = np.mean(c2_epoch)
        g_loss = np.mean(gan_epoch)

        time_since_start = str(datetime.timedelta(seconds=(time.time() - time_train_start)))
        print(f'({time_since_start}) [{i + 1}/{config["EPOCHS"]}]: c1={c1_loss:.3f}, c2={c2_loss:.3f}, g={g_loss:.3f}, took {time.time() - time_epoch_start} seconds')
        
        wandb_run.log({
            'c_real_loss': c1_loss,
            'c_fake_loss': c2_loss,
            'gan_loss': g_loss,
            'epoch_time': time.time() - time_epoch_start
        }, step=i + 1)

        # save image
        save_plot(i, images_path, g_model, latent_dim)

        # save model every sometimes
        if (i + 1) % 10 == 0:
            save_model(i, g_model, model_path)


In [None]:
run = 9
notes ='WGAN training 2'

wandb_run = wandb.init(project="styleGAN", entity="nn2021",name=f'gcp_run_{run}', notes=notes)
wandb_run.config.update(config)

for image in ['generator_plot.png', 'critic_plot.png', 'gan_plot.png']:
    wandb_run.save(image)

train(
    g_model=generator,
    c_model=critic,
    gan_model=gan, 
    dataset=train_images, 
    latent_dim=config['LATENT_DIM'],
    n_epochs=config['EPOCHS'],
    n_batch=config['BATCH_SIZE'],
    run=run,
    n_critic=config['N_CRITIC']
)

wandb_run.finish()

###### Visualize the results

In [None]:
def display_image_for_epoch(epoch_no, run):
    return Image.open(f'runs/{run}/images/generated_plot_e{epoch_no:04d}.png')

In [None]:
display_image_for_epoch(1, run)

In [None]:
display_image_for_epoch(100,run)

In [None]:
display_image_for_epoch(200,run)

Display images after epochs in gif

In [None]:
def generate_gif(run):
    anim_file = f'runs/{run}/dcgan.gif'

    with imageio.get_writer(anim_file, mode='I') as writer:
        filenames = glob.glob(f'runs/{run}/images/generated_plot_e*.png')
        filenames = sorted(filenames)
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)
            image = imageio.imread(filename)
            writer.append_data(image)
            
    return anim_file

In [None]:
import tensorflow_docs.vis.embed as embed
embed.embed_file(generate_gif(run))

In [None]:
api = wandb.Api()
last_run = api.run(f"{wandb_run.entity}/{wandb_run.project}/{wandb_run.id}")
last_run.upload_file(f'runs/{run}/dcgan.gif')
last_run.save()

## Using the model

In [None]:
def create_plot(examples, n):
    for i in range(n * n):
        plt.subplot(n, n, 1 + i)
        plt.axis('off')
        plt.imshow(examples[i, :, :])
    plt.show()


def generate(model,latent_points,amount=4):
    X = model.predict(latent_points)
    X = (X + 1) / 2.0
    create_plot(X, amount)
    
def load_model(run, epoch):
    return keras.models.load_model(f'runs/{run}/models/generator_model_{epoch:04d}.h5')

In [None]:
model = load_model(run,200)
latent_points = generate_latent_points(100,100)
generate(model,latent_points,amount=7)

In [None]:
vector = np.asarray([[0.75 for _ in range(100)]])
X = model.predict(vector)
X = (X + 1) / 2.0
plt.imshow(X[0, :, :])
plt.show()