In [None]:
#
# Most of code taken from -- https://machinelearningmastery.com/how-to-develop-a-conditional-generative-adversarial-network-from-scratch/
#
# example of training an conditional gan on the fashion mnist dataset
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
import sys
# sys.path.append('..')
# from utils import foldername2class

from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint

# example of loading the generator model and generating images
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.models import load_model
from matplotlib import pyplot


import tensorflow as tf


from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import BatchNormalization

from tensorflow.keras import initializers

from skimage import io
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
from sklearn.utils import shuffle

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
H, W, C = 336, 336, 3
batch_size = 10
PATH_DATA = '../../expand_double_modes'

train_images_path = []

iterator = tqdm(glob.glob(PATH_DATA + "/*"))
for single_folder in iterator:
    img_folder = shuffle(glob.glob(single_folder + '/*'))
    for indx, single_img_path in enumerate(img_folder):
        train_images_path.append(single_img_path)
iterator.close()

train_images_path = shuffle(train_images_path)

In [None]:
def preprocess_images(images):
  images = (images - 127.5) / 127.5
  return images.astype('float32')

def generator_img(path_list: list):
    counter = 0
    max_counter = len(path_list)
    while True:
        single_path = path_list[counter]
        image_s = preprocess_images(np.asarray(io.imread(single_path), dtype=np.float32))
        yield image_s
        # yield np.ones((336, 336, 3))
        counter += 1

        if counter == max_counter:
            counter = 0
            path_list = shuffle(path_list)

def train_gen():
    return generator_img(train_images_path)

In [None]:
dataset = (
    tf.data.Dataset.from_generator(
        train_gen, 
        output_signature=(
            tf.TensorSpec(shape=(H, W, C), dtype=np.float32)
        )
    )
    .shuffle(batch_size * 10).batch(batch_size)
)


In [None]:
train_size = len(train_images_path)

print(f'train: {train_size}')

In [None]:
#def init_weights(shape, dtype=tf.float32):
#    return initializers.Orthogonal(1 / np.sqrt(np.prod(shape[1:])))(shape=shape, dtype=dtype)

def init_weights():
    return initializers.RandomNormal(stddev=0.02)

def generate_latent_points(latent_dim, n_samples):
    return randn(n_samples, latent_dim)

# create and save a plot of generated images
def save_figure(examples, n, prefix=0):
    # plot images
    fig = pyplot.figure(figsize=(12,12))
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i])
    #pyplot.show()
    fig.savefig(f'{prefix}_image.png')
    plt.close('all')

In [None]:
# define the standalone discriminator model
def define_discriminator(in_shape):
    # image input
    in_image = Input(shape=in_shape)                                 # 336
    # downsample
    fe = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(in_image) # 168
    #fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 84  
    #fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 42 
    #fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 21 
    #fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    
    fe = Conv2D(64, (3,3), padding='same', kernel_initializer=init_weights())(fe)                   # 21 
    #fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    out_layer = Conv2D(1, (1,1), padding='same', activation='sigmoid', kernel_initializer=init_weights())(fe)
    # flatten feature maps
    #fe = Flatten()(fe)
    #fe = Dropout(0.25)(fe)
    #out_layer = Dense(1, activation='sigmoid')(fe)
    # define model
    model = Model([in_image], out_layer)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5) # Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def define_generator(latent_dim, h_low=21, w_low=21):
    # image generator input
    in_lat = Input(shape=(latent_dim,))
    # foundation for h_low x w_low image
    n_nodes = 64 * h_low * w_low
    gen = Dense(n_nodes)(in_lat)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = Reshape((h_low, w_low, 64))(gen)
    # upsample to 42
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)  
    gen = BatchNormalization(trainable=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 84
    gen = Conv2DTranspose(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)
    gen = BatchNormalization(trainable=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 168
    gen = Conv2DTranspose(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)
    gen = BatchNormalization(trainable=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 336
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)
    gen = BatchNormalization(trainable=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # output
    out_layer = Conv2D(C, (4,4), activation='tanh', padding='same', kernel_initializer=init_weights())(gen)
    # define model
    model = Model([in_lat], out_layer)
    return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # get noise and label inputs from generator model
    gen_noise = g_model.input
    # get image output from the generator model
    gen_output = g_model.output
    # connect image output and label input from generator as inputs to discriminator
    gan_output = d_model([gen_output])
    # define gan model as taking noise and label and outputting a classification
    model = Model([gen_noise], gan_output)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# load fashion mnist images
def load_real_samples():
    # load dataset
    (trainX, trainy), (_, _) = load_data()
    # expand to 3d, e.g. add channels
    X = expand_dims(trainX, axis=-1)
    # convert from ints to floats
    X = X.astype('float32')
    # scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    return [X, trainy]

# # select real samples
def generate_real_samples(dataset, n_samples):
    # split into images and labels
    X = list(dataset.take(1))[0]
    # generate class labels
    y = ones((n_samples, 21, 21, 1)) * 0.9
    return X, y

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    z_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    images = generator.predict([z_input])
    # create class labels
    y = zeros((n_samples, 21, 21, 1))
    return images, y

# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=1, n_batch=128, dataset_size=30000):
    bat_per_epo = int(dataset_size / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = generate_real_samples(dataset, half_batch)
            # update discriminator model weights
            d_loss1, _ = d_model.train_on_batch(X_real, y_real)
            # generate 'fake' examples
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update discriminator model weights
            d_loss2, _ = d_model.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            z_input = generate_latent_points(latent_dim, n_batch)
            # create inverted labels for the fake samples
            y_gan = ones((n_batch, 21, 21, 1)) # * 0.9 do not use label smoothing
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(z_input, y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %
                (i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
            
            if j % 50 == 0:
                g_model.save('cgan_generator.h5')
                # load model
                model = load_model('cgan_generator.h5')
                # generate images
                latent_points = generate_latent_points(latent_dim, 100)
                # generate images
                X  = model.predict([latent_points])
                # scale from [-1,1] to [0,1]
                X = (X + 1) / 2.0
                # plot the result
                save_figure(X, 10, prefix=f'e_{i}_be_{j}')
    # save the generator model
    g_model.save('cgan_generator.h5')

In [None]:
# size of the latent space
latent_dim = 100

In [None]:
# create the discriminator
d_model = define_discriminator((H, W, C))

In [None]:
# create the generator
g_model = define_generator(latent_dim)

In [None]:
# create the gan
gan_model = define_gan(g_model, d_model)

In [None]:
# load image data
#dataset = load_real_samples()
#dataset = (train_images, train_labels)

In [None]:
# train model
train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=20, n_batch=batch_size * 2)

In [None]:
# example of loading the generator model and generating images
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.models import load_model
from matplotlib import pyplot

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples, n_classes=21):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    z_input = x_input.reshape(n_samples, latent_dim)
    # generate labels
    labels = randint(0, n_classes, n_samples)
    return [z_input, labels]

# create and save a plot of generated images
def save_plot(examples, n):
    # plot images
    pyplot.figure(figsize=(12,12))
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i])
    pyplot.show()

# load model
model = load_model('cgan_generator.h5')
# generate images
latent_points, labels = generate_latent_points(latent_dim, 100)
# specify labels
labels = asarray([x for _ in range(10) for x in range(10)])
#labels = ones((100)) * 0 # generate one thing
# generate images
X  = model.predict([latent_points, labels])
# scale from [-1,1] to [0,1]
X = (X + 1) / 2.0
# plot the result
save_plot(X, 10)