In [None]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm


if len(tf.config.list_physical_devices('GPU')):
    print("Using GPU")
    gpu_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_visible_devices(gpu_devices[0], 'GPU')

In [None]:
image_size = (32, 32)
batch_size = 32
num_of_stages = 3

def custom_preprocess(x):
    x = (x - 127.5) / 127.5
    return x

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=custom_preprocess, # rescale [0.0 - 255.0] to [-1.0 to 1.0] range
    validation_split=0.2,
    horizontal_flip=True,
)

data_for_stage = []
for stage in range(num_of_stages):
    image_dir = f'../data/cats_stage_{stage}/'

    train_data = image_generator.flow_from_directory(
        image_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode=None,
        subset='training'
    )

    train_dataset = tf.data.Dataset.from_generator(
        lambda: train_data,
        output_signature=tf.TensorSpec(shape=(None, image_size[0], image_size[1], 3), dtype=tf.float32) #RGB images so 3 channels, float32 for GPU acceleration
    )

    train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
    train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    data_for_stage.append(train_data)


In [None]:
def get_latent_distribiution(latent_size=128, batch_size=64):
    initial_distribution_mean = np.zeros(latent_size)
    initial_distribution_stddev = 1.0
    initial_latent_distribution = np.random.normal(
        initial_distribution_mean, initial_distribution_stddev, (batch_size, latent_size)
    ).astype(np.float32)

    return initial_latent_distribution

def mse_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

In [None]:
class Generator(tf.keras.layers.Layer):
    def __init__(self, n_c):
        super(Generator, self).__init__()
        self.n_c = n_c

        self.dense1 = tf.keras.layers.Dense(256 * 4 * 4)
        self.reshape = tf.keras.layers.Reshape((4, 4, 256))
        self.upsample1 = tf.keras.layers.UpSampling2D(size=(2, 2))
        self.conv1 = tf.keras.layers.Conv2DTranspose(128, kernel_size=5, padding='same')
        self.upsample2 = tf.keras.layers.UpSampling2D(size=(2, 2))
        self.conv2 = tf.keras.layers.Conv2DTranspose(64, kernel_size=5, padding='same')
        self.upsample3 = tf.keras.layers.UpSampling2D(size=(2, 2))
        self.conv3 = tf.keras.layers.Conv2DTranspose(32, kernel_size=5, padding='same')
        self.conv4 = tf.keras.layers.Conv2DTranspose(n_c, kernel_size=1, padding='same')
    
    def generate_images_at_stage(self, Z, stage):
        h = tf.nn.leaky_relu(self.dense1(Z), 0.2)
        h = self.reshape(h)
        h = self.upsample1(h)
        h = tf.nn.leaky_relu(self.conv1(h), 0.2)
        h = self.upsample2(h)
        h = tf.nn.leaky_relu(self.conv2(h), 0.2)
        h = self.upsample3(h)
        h = tf.nn.leaky_relu(self.conv3(h), 0.2)
        
        if stage == 1:
            x = self.conv4(h)
        elif stage == 2:
            h = self.upsample3(h)
            h = tf.nn.leaky_relu(self.conv3(h), 0.2)
        elif stage == 3:
            h = self.upsample3(h)
            h = tf.nn.leaky_relu(self.conv3(h), 0.2)

        x = self.conv4(h)
        return x

In [None]:
n_updates_total = 100
lr = 0.0002

noise_samples = get_latent_distribiution()
generator = Generator(n_c=3)  
g_optimizer = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5)
g_vars = generator.trainable_variables

g_loss_history = []

for n_updates in tqdm(range(n_updates_total), ncols=80, leave=False):
    true_images_stage_1 = data_for_stage[0].next()
    true_images_stage_2 = data_for_stage[1].next()
    true_images_stage_3 = data_for_stage[2].next()

    generated_images_stage_1 = generator.generate_images_at_stage(noise_samples, stage=1)
    generated_images_stage_2 = generator.generate_images_at_stage(noise_samples, stage=2)
    generated_images_stage_3 = generator.generate_images_at_stage(noise_samples, stage=3)
    
    g_loss_stage_1 = mse_loss(generated_images_stage_1, true_images_stage_1)
    g_loss_stage_2 = mse_loss(generated_images_stage_2, true_images_stage_2)
    g_loss_stage_3 = mse_loss(generated_images_stage_3, true_images_stage_3)
    
    with tf.GradientTape() as tape:
        g_loss = g_loss_stage_1 + g_loss_stage_2 + g_loss_stage_3

    g_gradients = tape.gradient(g_loss, g_vars)
    g_optimizer.apply_gradients(zip(g_gradients, g_vars))
    
    if n_updates % 100 == 0:
        g_loss_history.append(g_loss)
    
    if n_updates % 1000 == 0:
        pass