# Preparation

In [3]:
seed_value= 42

import random
random.seed(seed_value)

import numpy as np
np.random.seed(seed_value)

import tensorflow as tf
tf.random.set_seed(seed_value)
tf.keras.utils.set_random_seed(seed_value)

In [4]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))

2.9.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [5]:
from tensorflow import keras
from tensorflow.keras import layers
from gudhi.tensorflow import CubicalLayer

In [6]:
num_epochs  = 20
batch_size  = 32
num_classes = 10
shape       = (28, 28, 1)
lr          = 0.0003
opt         = keras.optimizers.Adam(learning_rate=lr)
los         = keras.losses.BinaryCrossentropy(from_logits=True)
latent_dim  = 28

# Dataset

In [7]:
def prepare_data(main_path):
    with np.load(main_path) as data:
        x_train, y_train = data['x_train'], data['y_train']
    
    x_train = x_train.astype("float32") / 255.0
    x_train = np.reshape(x_train, (-1, 28, 28, 1))
    y_train = keras.utils.to_categorical(y_train, 10)
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
    
    return train_dataset

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = shape[2] + num_classes

train_dataset = prepare_data(main_path="../Dataset/mnist.npz")

# Model

In [8]:
def get_models():
    discriminator = keras.Sequential([keras.layers.InputLayer((shape[0], shape[1], discriminator_in_channels)),
                                      layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
                                      layers.LeakyReLU(alpha=0.2),
                                      layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
                                      layers.LeakyReLU(alpha=0.2),
                                      layers.GlobalMaxPooling2D(),
                                      layers.Dense(1)],
                                     name="discriminator")
    
    generator = keras.Sequential([keras.layers.InputLayer((generator_in_channels,)),
                                  layers.Dense(7 * 7 * generator_in_channels),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Reshape((7, 7, generator_in_channels)),
                                  layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid")],
                                 name="generator")
    
    return discriminator, generator

In [13]:
class PersistenceSetLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=1.0, mu=0.1, max_value=2000000, scale=1e-1, **kwargs):
        """
        Custom loss function for comparing persistence diagrams of different sizes.
        
        Parameters:
        - gamma: Controls how much long-lived features are weighted.
        - mu: Controls how much we penalize differences in diagram sizes.
        - max_value: Maximum value for the final loss function.
        - scale: The scale that multiplies on the final loss value.
        """
        super().__init__(**kwargs)
        self.gamma = gamma
        self.mu = mu
        self.max_value = max_value
        self.scale = scale
        
        self.cubical_layer = CubicalLayer(homology_dimensions=[0, 1, 2])
        
    def call_cubical_layer(self, X):
        output = self.cubical_layer.call(X)
        
        flattened_output = []
        for tuple_ in output:
            flattened_output.extend(tuple_[0])
        
        return flattened_output
        
    @tf.custom_gradient
    def call(self, real_images, fake_images):
        """
        Computes the loss between two persistence diagrams P and Q.
        
        P: Tensor -> Persistence diagram with m points (birth, death)
        Q: Tensor -> Persistence diagram with n points (birth, death)
        """
        X_real = (real_images - tf.reduce_min(real_images)) / (tf.reduce_max(real_images) - tf.reduce_min(real_images))
        X_fake = (fake_images - tf.reduce_min(fake_images)) / (tf.reduce_max(fake_images) - tf.reduce_min(fake_images))
        
        real_dgms = tf.py_function(self.call_cubical_layer, [X_real], Tout=[tf.float32] * 6)
        fake_dgms = tf.py_function(self.call_cubical_layer, [X_fake], Tout=[tf.float32] * 6)
        
        real_dgms = tf.concat(real_dgms, axis=0)
        P = tf.reshape(real_dgms, (-1, 2))
        
        fake_dgms = tf.concat(fake_dgms, axis=0)
        Q = tf.reshape(fake_dgms, (-1, 2))
        
        P_persistence = P[:, 1] - P[:, 0]
        Q_persistence = Q[:, 1] - Q[:, 0]

        def feature_transform(persistence):
            exp_term = tf.exp(-self.gamma * persistence)
            return tf.stack([exp_term, persistence * exp_term], axis=-1)

        P_features = feature_transform(P_persistence)
        Q_features = feature_transform(Q_persistence)

        P_sum = tf.reduce_sum(P_features, axis=0)
        Q_sum = tf.reduce_sum(Q_features, axis=0)

        feature_distance = tf.reduce_sum(tf.square(P_sum - Q_sum))

        size_difference = tf.cast(tf.shape(P)[0] - tf.shape(Q)[0], tf.float32)
        size_penalty = self.mu * tf.square(size_difference)        

        ph_loss = tf.clip_by_value(feature_distance + size_penalty, 0.0, self.max_value) * self.scale
        
        def grad(dy):
            grad_fake_images = tf.gradients(ph_loss, fake_images, grad_ys=dy)[0]
            return None, grad_fake_images
    
        return ph_loss, grad

In [14]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = tf.random.Generator.from_seed(seed_value)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")   
        self.pd_loss = PersistenceSetLoss(scale=100)


    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]


    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn


    def train_step(self, data):
        real_images, one_hot_labels = data
        
        # Reshape and repeat one-hot labels to match the image dimensions
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=[real_images.shape[1] * real_images.shape[2]], axis=-1)
        image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1, real_images.shape[1], real_images.shape[2], one_hot_labels.shape[-1]))
    
        # Batch size
        batch_size = tf.shape(real_images)[0]
        
        # Generate random latent vectors
        random_latent_vectors = self.seed_generator.normal(shape=(batch_size, self.latent_dim), dtype=tf.float32)
        random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)
    
        # Generate fake images
        generated_images = self.generator(random_vector_labels)
    
        # Combine fake and real images with their labels
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], axis=-1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], axis=-1)
        combined_images = tf.concat([fake_image_and_labels, real_image_and_labels], axis=0)
    
        # Labels for discriminator
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
    
        # Train discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
    
        # Train generator
        misleading_labels = tf.zeros((batch_size, 1))
    
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], axis=-1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        
            ph_losses = self.pd_loss(real_images, fake_images)
            g_total_loss = g_loss + ph_losses
            
        grads = tape.gradient(g_total_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
    
        # Update loss trackers
        self.gen_loss_tracker.update_state(g_total_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "loss": self.gen_loss_tracker.result() + self.disc_loss_tracker.result(),
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

In [15]:
discriminator, generator = get_models()
cond_gan = ConditionalGAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
cond_gan.compile(d_optimizer=opt, g_optimizer=opt, loss_fn=los)

# Train

In [None]:
cond_gan.fit(train_dataset, epochs=num_epochs, batch_size=batch_size, verbose=1)

Epoch 1/20
   2/1875 [..............................] - ETA: 2:40:40 - loss: 2.3849 - g_loss: 1.7010 - d_loss: 0.6838