In [1]:
import keras

from keras import layers
from keras import ops
import tensorflow as tf
import numpy as np
import imageio
from tensorflow.keras.optimizers.schedules import ExponentialDecay

from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [24]:
batch_size = 4
num_channels = 1
num_classes = 5
image_size = 512
latent_dim = 256 # TUNE

data_dir = "/kaggle/input/lung-ds/Full_slice/train" ###

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes

datagen = ImageDataGenerator(
    rescale=1./255
)

image_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='categorical',
    color_mode="grayscale",
    seed=42            
)

images, labels = next(image_generator)

Found 1890 images belonging to 5 classes.


In [37]:
output_dir = "/kaggle/working"

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, generator):
        super().__init__()
        self.generator = generator
       
    def on_epoch_end(self, epoch, logs=None):
        global latent_dim, num_classes
        sample_latent = keras.random.normal(shape=(1, latent_dim), seed=42)
        sample_label = tf.one_hot([2], depth=num_classes)  # assuming class 2
        sample_input = tf.concat([sample_latent, sample_label], axis=1)

        generated_image = self.generator(sample_input, training=False)
        generated_image = (generated_image + 1.0) * 127.5
        generated_image = tf.cast(generated_image, tf.uint8).numpy()[0, :, :, 0]

        imageio.imwrite(f"{output_dir}/output_epoch_{epoch}.png", generated_image)
    

In [None]:
# Create the discriminator.
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((512, 512, discriminator_in_channels)),
        #layers.Conv2D(16, (4, 4), strides=(2, 2), padding="same"), 
        #layers.LeakyReLU(negative_slope=0.2),
        #layers.Dropout(0.4),
        
        layers.Conv2D(64, kernel_size=4, strides=2, padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),

        layers.Conv2D(128, kernel_size=4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),

        layers.Conv2D(256, kernel_size=4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Conv2D(512, kernel_size=4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),

        layers.Conv2D(1024, kernel_size=4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),

        layers.Flatten(),
        #layers.Dense(128, activation='relu'), # NOW
        layers.Dense(1, activation='sigmoid')
    ],
    name="discriminator",
)

# Create the generator.
generator = keras.Sequential(
    [
        #keras.layers.InputLayer((generator_in_channels,)),
        
        # Increase feature map size
        layers.Dense(8 * 8 * 512, input_dim=generator_in_channels, use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Reshape((8, 8, 512)),

        # Upsample progressively to 512x512
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),

        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        
        layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        
        layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        
        layers.Conv2DTranspose(16, kernel_size=4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),

        layers.Conv2DTranspose(8, kernel_size=4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        
        layers.Conv2DTranspose(1, kernel_size=3, strides=1, padding='same', activation='tanh')
    ],
    name="generator",
)

class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(42)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        global n_epoch, output_dir
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = ops.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = ops.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = ops.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = ops.concatenate(
            [generated_images, image_one_hot_labels], -1
        )
        real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
        combined_images = ops.concatenate(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = ops.concatenate(
            [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = ops.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = ops.concatenate(
                [fake_images, image_one_hot_labels], -1
            )
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        
        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }


cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)

print_callback = CustomCallback(generator=generator)

initial_lr_generator = 0.0002
lr_schedule_generator = ExponentialDecay(
    initial_learning_rate=initial_lr_generator,
    decay_steps=750,    # Adjust based on your training iterations
    decay_rate=0.96,
    staircase=True
)
generator_optimizer = keras.optimizers.Adam(learning_rate=initial_lr_generator, beta_1=0.5)

initial_lr_discriminator = 0.0001
lr_schedule_discriminator = ExponentialDecay(
    initial_learning_rate=initial_lr_discriminator,
    decay_steps=750,    # Adjust as needed
    decay_rate=0.96,
    staircase=True
)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=initial_lr_discriminator, beta_1=0.5)


cond_gan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(image_generator, epochs=200, callbacks=[print_callback])

trained_gen = cond_gan.generator

# Choose the number of intermediate images that would be generated in
# between the interpolation + 2 (start and last images).
num_interpolation = 9  # @param {type:"integer"}

# Sample noise for the interpolation.
interpolation_noise = keras.random.normal(shape=(1, latent_dim))
interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)
interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))

Epoch 1/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 126ms/step - d_loss: 0.3491 - g_loss: 3.2003
Epoch 2/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.3492 - g_loss: 3.0416
Epoch 3/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.1704 - g_loss: 3.6422
Epoch 4/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.1722 - g_loss: 4.1231
Epoch 5/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.1297 - g_loss: 4.4670
Epoch 6/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.1469 - g_loss: 3.6616
Epoch 7/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.1818 - g_loss: 3.8173
Epoch 8/200
[1m473/473[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 117ms/step - d_loss: 0.6251 - g_loss: 4.1780
Epoch 9/

In [4]:
def interpolate_class(first_number, second_number):
    # Convert the start and end labels to one-hot encoded vectors.
    first_label = keras.utils.to_categorical([first_number], num_classes)
    second_label = keras.utils.to_categorical([second_number], num_classes)
    first_label = ops.cast(first_label, "float32")
    second_label = ops.cast(second_label, "float32")

    # Calculate the interpolation vector between the two labels.
    percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]
    percent_second_label = ops.cast(percent_second_label, "float32")
    interpolation_labels = (
        first_label * (1 - percent_second_label) + second_label * percent_second_label
    )

    # Combine the noise and the labels and run inference with the generator.
    noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)
    fake = trained_gen.predict(noise_and_labels)
    return fake


start_class = 0  # @param {type:"slider", min:0, max:9, step:1}
end_class = 4  # @param {type:"slider", min:0, max:9, step:1}

fake_images = interpolate_class(start_class, end_class)

fake_images *= 255.0
converted_images = fake_images.astype(np.uint8)
converted_images = ops.image.resize(converted_images, (512, 512)).numpy().astype(np.uint8)
imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 738ms/step
