In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import datasets, layers, models, ops

import matplotlib.pyplot as plt

import tensorflow as tf


In [None]:
class ResNetBlock(layers.Layer):
    def __init__(self, filters, kernel_size, activation="linear", **kwargs):
        super(ResNetBlock, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = activation

    def build(self, input_shape):
        self.conv1 = layers.Conv2D(self.filters, self.kernel_size, padding='same')
        self.bn = layers.BatchNormalization()
        self.act = layers.Activation(activation=self.activation)
        self.conv2 = layers.Conv2D(self.filters, self.kernel_size, padding='same')
        self.bn2 = layers.BatchNormalization()


    def call(self, x):
        y = self.conv1(x)
        y = self.bn(y)
        y = self.act(y)
        x = self.conv2(x)
        x = self.bn2(x)

        return x + y

    def get_config(self):
        config = super(ResNetBlock, self).get_config()
        config.update(
            {
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "activation": self.activation
        }
        )
        
        return config


class UpBlock(layers.Layer):
    def __init__(self, filters, kernel_size, activation="linear", depth=2, **kwargs):
        super(UpBlock, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = activation
        self.depth = depth

    def build(self, input_shape):
        self.resnet_layers = []
        self.UpSampling = layers.UpSampling2D(size=(2, 2))

        self.first_resnet = ResNetBlock(self.filters, self.kernel_size, self.activation)

        if self.depth > 1:
            for _ in range(self.depth-1):
                layer_i = ResNetBlock(self.filters, self.kernel_size, self.activation)
                self.resnet_layers.append(layer_i)


    def call(self, x):
        x = self.first_resnet(x)
        x = self.UpSampling(x)
        for layer in self.resnet_layers:
            x = layer(x)
        
        return x

    def get_config(self):
        config = super(UpBlock, self).get_config()
        config.update(
            {
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "activation": self.activation,
            "depth": self.depth
        }
        )
        
        return config


class DownBlock(layers.Layer):
    def __init__(self, filters, kernel_size, activation="linear", depth=2, dropout=.1, **kwargs):
        super(DownBlock, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = activation
        self.depth = depth
        self.dropout = dropout

    def build(self, input_shape):
        self.resnet_layers = []
        self.pooling = layers.MaxPooling2D()


        for _ in range(self.depth):
            layer_i = ResNetBlock(self.filters, self.kernel_size, self.activation)
            self.resnet_layers.append(layer_i)

        self.dropout = layers.Dropout(self.dropout)


    def call(self, x):
        for layer in self.resnet_layers:
            x = layer(x)

        x = self.pooling(x)
        x = self.dropout(x)
        
        return x

    def get_config(self):
        config = super(DownBlock, self).get_config()
        
        config.update(
            {
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "activation": self.activation,
            "depth": self.depth,
            "dropout": self.dropout
        }
        )
        
        return config


In [None]:
generator_config = {
    "latent_dim": 256,
    "start_res": 4,
    "channels": [512, 512, 256, 128, 64, 32],
    "image_channels": 3,
    "kernel_size": 4,
    "activation": "swish",
    "depth": 2,
}

discriminator_config = {
    "image_size": 128,
    "image_channels": 3,
    "channels": [64, 128, 256, 512],
    "kernel_size": 3,
    "activation": "gelu",
    "depth": 1,
    "dropout": .1
}

In [None]:
@keras.saving.register_keras_serializable()
class Generator(keras.Model):
    def __init__(self, latent_dim, start_res, channels, image_channels, kernel_size, activation, depth, **kwargs):
        super(Generator, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.start_res = start_res
        self.channels = channels
        self.image_channels = image_channels
        self.kernel_size = kernel_size
        self.activation = activation
        self.depth = depth



    def build(self, input_shape):
        self.ff = layers.Dense(self.start_res*self.start_res*self.channels[0])
        self.bn = layers.BatchNormalization()
        self.activation = layers.Activation(activation=self.activation)
        self.reshape = layers.Reshape((self.start_res, self.start_res, self.channels[0]))

        self.upblocks = []

        for filters in self.channels[1:]:
            layer_i = UpBlock(filters, self.kernel_size, self.activation, self.depth)
            self.upblocks.append(layer_i)

        self.outputs = layers.Conv2D(self.image_channels, kernel_size=self.kernel_size, padding='same')

    
    def call(self, x):
        x = self.ff(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.reshape(x)
        for upblock in self.upblocks:
            x = upblock(x)

        x = self.outputs(x)

        return x

    def get_config(self):
        config = super(Generator, self).get_config()

        config.update(
            {
            "latent_dim": self.latent_dim,
            "start_res": self.start_res,
            "channels": self.channels,
            "image_channels": self.image_channels,
            "kernel_size": self.kernel_size,
            "activation": self.activation,
            "depth": self.depth
            }
        )

        return config

@keras.saving.register_keras_serializable()
class Discriminator(keras.Model):
    def __init__(self, image_size, image_channels, channels, kernel_size, activation, depth, dropout, **kwargs):
        super(Discriminator, self).__init__(**kwargs)
        self.image_size = image_size
        self.image_channels = image_channels
        self.channels = channels
        self.kernel_size = kernel_size
        self.activation = activation
        self.depth = depth
        self.dropout = dropout



    def build(self, input_shape):

        self.downblocks = []

        for filters in self.channels:
            layer_i = DownBlock(filters, self.kernel_size, self.activation, self.depth, self.dropout)
            self.downblocks.append(layer_i)

        self.flatten = layers.Flatten()
        self.outputs = layers.Dense(1, activation='sigmoid')
    
    def call(self, x):
        
        for downblock in self.downblocks:
            x = downblock(x)

        x = self.flatten(x)

        x = self.outputs(x)

        return x

    def get_config(self):
        # Include hyperparameters in the model configuration
        config = super(Discriminator, self).get_config()
        
        config.update(
            {
            "image_size": self.image_size,
            "image_channels": self.image_channels,
            "channels": self.channels,
            "kernel_size": self.kernel_size,
            "activation": self.activation,
            "depth": self.depth,
            "dropout": self.dropout
        }
        )
        
        return config

In [None]:
generator = Generator(**generator_config)
discriminator = Discriminator(**discriminator_config)

In [None]:
class GenerativeAdversarialNetwork(keras.Model):
    def __init__(self, generator, discriminator, latent_dim, **kwargs):
        super(GenerativeAdversarialNetwork, self).__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.train_prog_noise = tf.random.normal(shape=(25, latent_dim), seed=42)
        
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = ops.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        generated_images = self.generator(random_latent_vectors)

        combined_images = ops.concatenate([generated_images, real_images], axis=0)

        labels = ops.concatenate(
            [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
        )
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        misleading_labels = ops.zeros((batch_size, 1))

        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

    def test_step(self, real_images):
        

    def generate(self, num_images):
        eps = tf.random.normal(shape=(num_images, self.latent_dim))
        return self.generator(eps, training=False)

In [None]:
gan = GenerativeAdversarialNetwork(generator, discriminator, latent_dim=generator.latent_dim)

In [None]:
gan.compile(
    g_optimizer = keras.optimizers.Adam(learning_rate=2e-4, beta_1=.5, beta_2=.9),
    d_optimizer = keras.optimizers.Adam(learning_rate=2e-4, beta_1=.5, beta_2=.9),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.04)
)

In [None]:
def preprocessing(x):
    image_size = discriminator_config["image_size"]
    x = tf.image.resize(x, (image_size, image_size))
    return x / 255

dataset = keras.preprocessing.image_dataset_from_directory(
    "celeba_hq",
    batch_size=32,
    labels=None
).map(preprocessing)

In [None]:
gan.fit(dataset, epochs=10)