Instructions for Colab:

1. Navigate to files.
2. Press Mount drive.
3. Select an account where kaggle.json file was placed.

In [None]:
"""
# Run the following cells only if using Colab
if 'google.colab' in str(get_ipython()):
    # Setup Kaggle (to download dataset)
    !mkdir -p ~/.kaggle
    !cp "/content/drive/My Drive/kaggle/kaggle.json" ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json
    # Setup project repository
    !git clone https://github.com/alexpod1000/FaceGen-GAN.git
    %cd FaceGen-GAN/
    !pwd
    # Download and unpack dataset
    !kaggle datasets download -d jessicali9530/celeba-dataset
    !unzip celeba-dataset.zip -d data
    # move all the images one folder upside (credits to: https://stackoverflow.com/a/11942775)
    !find data/img_align_celeba/img_align_celeba/ -name '*.*' -exec mv --target-directory=data/img_align_celeba '{}' +
    !rm -rf data/img_align_celeba/img_align_celeba/
"""

In [None]:
%cd FaceGen-GAN/

# Actual code

In [None]:
import os

# DATASET PATHS
dataset_base_directory = "./data"
dataset_attr_file = os.path.join(dataset_base_directory, "list_attr_celeba.csv")
dataset_images_path = os.path.join(dataset_base_directory, "img_align_celeba")
# SPLIT RANGES
train_split_range = (0, 162770)
valid_split_range = (162770, 182637)
test_split_range = (182637, -1)

In [None]:
USE_AMP = False

In [None]:
if USE_AMP:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')

In [None]:
"""
EXPERIMENTS FILE
"""
import tensorflow as tf

from architectures.base_dcgan import Generator, Discriminator, GAN_Wrapper

# TRAINING PARAMETERS
model_name = "base_dcgan"
batch_size = 32
n_epochs = 100
# GENERATOR PARAMETERS
conditional_dim = 40
latent_dim = 128
filters_gen = 64
kernel_size_gen = 4
# DISCRIMINATOR PARAMETERS
filters_disc = 64
kernel_size_disc = 5
# TRAINING OPTIMIZER PARAMETERS
init_learning_rate_gen = 0.0002
init_learning_rate_disc = 0.0002
beta_1 = 0.5
# RESUME PARAMS
load_model = False
load_epoch = 0

# create models
generator_model = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
discriminator_model = Discriminator(filters=filters_disc, kernel_size=kernel_size_disc)
# create gan wrapper model
gan_model = GAN_Wrapper(discriminator_model, generator_model)#, use_amp=USE_AMP)

# optimizers
d_optimizer = tf.keras.optimizers.Adam(learning_rate=init_learning_rate_disc, beta_1=beta_1)
g_optimizer = tf.keras.optimizers.Adam(learning_rate=init_learning_rate_gen, beta_1=beta_1)
if USE_AMP:
    d_optimizer = mixed_precision.LossScaleOptimizer(d_optimizer)
    g_optimizer = mixed_precision.LossScaleOptimizer(g_optimizer)


# compile model
gan_model.compile(
    d_optimizer=d_optimizer,
    g_optimizer=g_optimizer,
    loss_fn=tf.keras.losses.BinaryCrossentropy(),
)
# callbacks
train_callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='saved_models/'+model_name+'/model_{epoch}.h5', 
        save_weights_only=True
    )
]

In [None]:
"""
MAIN FILE
"""
from utils.file_utils import makedir_if_not_exists
from utils.visualization_utils import save_plot_batch

from callbacks import ImagesLoggingCallback

import os
import pickle

import numpy as np
import pandas as pd

import tensorflow as tf

# FILE PARAMETERS
model_save_dir = "saved_models/{}/".format(model_name)
model_images_save_base_dir = "gen/{}".format(model_name)
model_gen_sample_dir = "gen/{}/sample/".format(model_name)
model_gen_real_dir = "gen/{}/real_cond/".format(model_name)

# make model directories if they no exist
makedir_if_not_exists(model_save_dir)
makedir_if_not_exists(model_gen_sample_dir)
makedir_if_not_exists(model_gen_real_dir)

In [None]:
import math
import os
import random

import numpy as np
from tensorflow.keras.utils import Sequence

from utils.image_utils import load_image

class DataSequence(Sequence):
    """
    Keras Sequence object to train a model on larger-than-memory data.
    """
    def __init__(self, df, data_root, batch_size, resize_size=(64, 64), flip_augment=True, mode='train', use_amp=False):
        self.df = df
        self.batch_size = batch_size
        self.mode = mode
        self.resize_size = resize_size
        self.crop_pt_1 = (45, 25)
        self.crop_pt_2 = (173, 153)
        self.flip_augment = flip_augment
        # extract columns from df columns
        self.label_columns = self.df.columns[1:].tolist() 
        self.use_amp = use_amp

        # Take labels and a list of image locations in memory
        self.labels = self.df[self.label_columns].values
        self.im_list = self.df['Image_Name'].apply(lambda x: os.path.join(data_root, x)).tolist()
        self.img_cache = {}
        # Trigger a shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(math.floor(len(self.df) / float(self.batch_size)))

    def on_epoch_end(self):
        # Shuffles indexes after each epoch if in training mode
        self.indexes = range(len(self.im_list))
        if self.mode == 'train':
            self.indexes = random.sample(self.indexes, k=len(self.indexes))

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return self.labels[idx]

    def get_batch_features(self, idx):
        images = []
        for im_idx in idx:
            im = self.im_list[im_idx]
            if im not in self.img_cache:
                loaded_image = load_image(im, self.resize_size, self.crop_pt_1, self.crop_pt_2)
            else:
                loaded_image = np.copy(self.img_cache[im])
            if self.flip_augment and random.random() < 0.5:
                loaded_image = np.flip(loaded_image, 1)
            images.append(loaded_image)
        # Fetch a batch of inputs
        #if self.use_amp:
        #    return np.array(images, dtype=np.float16)
        #else:
        return np.array(images)

    def __getitem__(self, index):
        idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # get the actual data
        batch_x = self.get_batch_features(idx)
        #if self.use_amp:
        #    batch_y = np.clip(self.get_batch_labels(idx).astype(np.float16), 0, 1)
        #else:
        batch_y = np.clip(self.get_batch_labels(idx).astype(np.float32), 0, 1)
        return (batch_x, batch_y), batch_y

In [None]:
#from datasets.celeba.dataloader import DataSequence

import pandas as pd

data_df = pd.read_csv(dataset_attr_file).rename(
    columns={"image_id": "Image_Name"}
)
train_df = data_df[train_split_range[0]:train_split_range[1]]
valid_df = data_df[valid_split_range[0]:valid_split_range[1]]
test_df = data_df[test_split_range[0]:test_split_range[1]]



""" 
DATASEQ
"""

batch_size = 64 * 2

training_generator = DataSequence(train_df, dataset_images_path,  batch_size=batch_size, use_amp=USE_AMP)

# take first batch of validation dataset for visual results report 
# (i.e. conditioned generation based on first batch conditions)
valid_cond_batch = DataSequence(valid_df, dataset_images_path,  batch_size=batch_size, mode="valid", use_amp=USE_AMP)

_, real_view_conditions = next(iter(valid_cond_batch))
real_view_conditions = real_view_conditions[:25]

# take apart a batch for reconstruction
view_cond = np.zeros((25, conditional_dim), dtype=np.float32)
view_cond[:, 31] = 1.0 # all smile
view_cond = view_cond.astype(np.float32)

In [None]:
bbb = next(iter(training_generator))

In [None]:
bbb[0][0].dtype

In [None]:
class ImagesLoggingCallback(tf.keras.callbacks.Callback):

    def __init__(self, n_images, latent_dim, view_cond, real_view_conditions, images_dir):
        super(ImagesLoggingCallback, self).__init__()
        self.n_images = n_images
        self.latent_dim = latent_dim
        self.images_dir = images_dir
        self.view_cond = view_cond
        self.real_view_conditions = real_view_conditions

    def on_epoch_begin(self, epoch, logs=None):

        random_latent_vectors = tf.random.normal(shape=(self.n_images, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors, self.view_cond, training=False)
        generated_images_real_cond = self.model.generator(random_latent_vectors, self.real_view_conditions, training=False)
        # convert to float32 if needed
        generated_images = tf.cast(generated_images, tf.float32)
        generated_images_real_cond = tf.cast(generated_images_real_cond, tf.float32)
        # rescale
        generated_images = (generated_images + 1) / 2.0
        generated_images_real_cond = (generated_images_real_cond + 1) / 2.0
        #generated_images *= 255
        generated_images.numpy()
        generated_images_real_cond.numpy()

        save_plot_batch(generated_images, self.images_dir+"/sample/sample_{}.png".format(epoch))
        save_plot_batch(generated_images_real_cond, self.images_dir+"/real_cond/sample_{}.png".format(epoch))

In [None]:
train_callbacks.append(ImagesLoggingCallback(25, latent_dim, view_cond, real_view_conditions, model_images_save_base_dir))

In [None]:
history = gan_model.fit(training_generator,
    use_multiprocessing=True,
    workers=8,
    epochs=n_epochs,
    callbacks=train_callbacks,
    initial_epoch=load_epoch
)