# Preparation

In [1]:
seed_value= 42

import random
random.seed(seed_value)

import numpy as np
np.random.seed(seed_value)

import tensorflow as tf
tf.config.run_functions_eagerly(True)
tf.random.set_seed(seed_value)
tf.keras.utils.set_random_seed(seed_value)

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))

2.9.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
from tensorflow import keras
from tensorflow.keras import layers
from gudhi.tensorflow import CubicalLayer

In [4]:
num_epochs  = 20
batch_size  = 32
num_classes = 10
shape       = (28, 28, 1)
lr          = 0.0003
opt         = keras.optimizers.Adam(learning_rate=lr)
los         = keras.losses.BinaryCrossentropy(from_logits=True)
latent_dim  = 28

# Dataset

In [5]:
def prepare_data(main_path):
    with np.load(main_path) as data:
        x_train, y_train = data['x_train'], data['y_train']
    
    x_train = x_train.astype("float32") / 255.0
    x_train = np.reshape(x_train, (-1, 28, 28, 1))
    y_train = keras.utils.to_categorical(y_train, 10)
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
    
    return train_dataset

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = shape[2] + num_classes

train_dataset = prepare_data(main_path="../Dataset/mnist.npz")

# Model

In [6]:
def get_models():
    discriminator = keras.Sequential([keras.layers.InputLayer((shape[0], shape[1], discriminator_in_channels)),
                                      layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
                                      layers.LeakyReLU(alpha=0.2),
                                      layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
                                      layers.LeakyReLU(alpha=0.2),
                                      layers.GlobalMaxPooling2D(),
                                      layers.Dense(1)],
                                     name="discriminator")
    
    generator = keras.Sequential([keras.layers.InputLayer((generator_in_channels,)),
                                  layers.Dense(7 * 7 * generator_in_channels),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Reshape((7, 7, generator_in_channels)),
                                  layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid")],
                                 name="generator")
    
    return discriminator, generator

In [7]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim, accumulation=False):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.accumulation = accumulation
        self.seed_generator = tf.random.Generator.from_seed(seed_value)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
        self.cubical_layer = CubicalLayer(homology_dimensions=[0, 1, 2])
        
        self.digits = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        
        self.ph_real = {str(i): tf.zeros((0, 2), dtype=tf.float32) for i in range(10)}
        self.ph_fake = {str(i): tf.zeros((0, 2), dtype=tf.float32) for i in range(10)}


    def SWD(self, PD1, PD2, num_projections=50):
        angles = tf.random.uniform([num_projections, 2], minval=-1, maxval=1)
        angles /= tf.norm(angles, axis=-1, keepdims=True)
        
        proj1 = tf.linalg.matmul(PD1, tf.transpose(angles))
        proj2 = tf.linalg.matmul(PD2, tf.transpose(angles))
        
        proj1 = tf.sort(proj1, axis=0)
        proj2 = tf.sort(proj2, axis=0)
        
        target_size = tf.maximum(tf.shape(proj1)[0], tf.shape(proj2)[0])
        proj1 = tf.image.resize(proj1[None, :, :], [target_size, num_projections])[0]
        proj2 = tf.image.resize(proj2[None, :, :], [target_size, num_projections])[0]
    
        return tf.reduce_mean(tf.abs(proj1 - proj2))
    
    
    def process_digit(self, indices, X_real, X_fake):
        real_subset = tf.gather(X_real, indices)
        fake_subset = tf.gather(X_fake, indices)
    
        real_dgms = self.cubical_layer.call(real_subset)
        fake_dgms = self.cubical_layer.call(fake_subset)
        
        real_concat = tf.concat([real_dgms[0][0], real_dgms[1][0], real_dgms[2][0]], axis=0)
        fake_concat = tf.concat([fake_dgms[0][0], fake_dgms[1][0], fake_dgms[2][0]], axis=0)
    
        return real_concat, fake_concat
    
    
    def skip_digit(self):
        return tf.zeros((0, 2), dtype=tf.float32), tf.zeros((0, 2), dtype=tf.float32)


    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]


    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn


    def train_step(self, data):
        real_images, one_hot_labels = data
        
        # Reshape and repeat one-hot labels to match the image dimensions
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=[real_images.shape[1] * real_images.shape[2]], axis=-1)
        image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1, real_images.shape[1], real_images.shape[2], one_hot_labels.shape[-1]))
    
        # Batch size
        batch_size = tf.shape(real_images)[0]
        
        # Generate random latent vectors
        random_latent_vectors = self.seed_generator.normal(shape=(batch_size, self.latent_dim), dtype=tf.float32)
        random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)
    
        # Generate fake images
        generated_images = self.generator(random_vector_labels)
    
        # Combine fake and real images with their labels
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], axis=-1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], axis=-1)
        combined_images = tf.concat([fake_image_and_labels, real_image_and_labels], axis=0)
    
        # Labels for discriminator
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
    
        # Train discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
    
        # Train generator
        misleading_labels = tf.zeros((batch_size, 1))
    
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], axis=-1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        
            # Calculate Persistent Homology
            y_train = tf.argmax(one_hot_labels, axis=1)
            X_real = (real_images - tf.reduce_min(real_images)) / (tf.reduce_max(real_images) - tf.reduce_min(real_images))
            X_fake = (fake_images - tf.reduce_min(fake_images)) / (tf.reduce_max(fake_images) - tf.reduce_min(fake_images))
            
            ph_losses = 0
            valid_digits = 0
            for digit in self.digits:
                indices = tf.where(y_train == digit)[:, 0]
                
                if tf.size(indices) == 0:
                    continue
                else:
                    real_concat, fake_concat = self.process_digit(indices, X_real, X_fake)
                
                if self.accumulation:
                    self.ph_real[str(digit)] = tf.concat([self.ph_real[str(digit)], real_concat], axis=0)
                    self.ph_fake[str(digit)] = tf.concat([self.ph_fake[str(digit)], fake_concat], axis=0)
                    ph_losses += self.SWD(self.ph_real[str(digit)], self.ph_fake[str(digit)])
                else:
                    ph_losses += self.SWD(real_concat, fake_concat)
                    
                valid_digits += 1
                
            ph_losses = ph_losses / valid_digits
            
            g_total_loss = g_loss + ph_losses
            
        grads = tape.gradient(g_total_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
    
        # Update loss trackers
        self.gen_loss_tracker.update_state(g_total_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "loss": self.gen_loss_tracker.result() + self.disc_loss_tracker.result(),
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

In [8]:
discriminator, generator = get_models()
cond_gan = ConditionalGAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
cond_gan.compile(d_optimizer=opt, g_optimizer=opt, loss_fn=los)

# Train

In [None]:
cond_gan.fit(train_dataset, epochs=num_epochs, batch_size=batch_size, verbose=1)

Epoch 1/20
   5/1875 [..............................] - ETA: 1:27:38 - loss: 1.4477 - g_loss: 0.7674 - d_loss: 0.6803