In [None]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from matplotlib import pyplot as plt

if len(tf.config.list_physical_devices('GPU')):
    print("Using GPU")
    gpu_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_visible_devices(gpu_devices[0], 'GPU')

### Loading in data

In [None]:
image_dir = '../data'
image_size = (32, 32)
batch_size = 256

def custom_preprocess(x):
    x = (x - 127.5) / 127.5
    return x

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=custom_preprocess, # rescale [0.0 - 255.0] to [-1.0 to 1.0] range
    validation_split=0.2,
    horizontal_flip=True,
)

train_data = image_generator.flow_from_directory(
    image_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode=None,  # No labels needed for GAN training
    subset='training'
)

train_dataset = tf.data.Dataset.from_generator(
    lambda: train_data,
    output_signature=tf.TensorSpec(shape=(None, image_size[0], image_size[1], 3), dtype=tf.float32) #RGB images so 3 channels, float32 for GPU acceleration
)

train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
class Generator(tf.keras.Model):
    def __init__(self, latent_size):
        super(Generator, self).__init__()
        input_layer = tf.keras.layers.Input(shape=(latent_size, ), dtype=tf.float32)

        dense_layer = tf.keras.layers.Dense(256 * 4 * 4)(input_layer)
        leaky_relu_1 = tf.keras.layers.LeakyReLU(0.2)(dense_layer)
        reshape_1 = tf.keras.layers.Reshape((4, 4, 256))(leaky_relu_1)
        upsample_1 = tf.keras.layers.UpSampling2D(size=(2, 2))(reshape_1)

        conv_1 = tf.keras.layers.Conv2DTranspose(256, kernel_size=5, padding="same")(upsample_1)
        leaky_relu_2 = tf.keras.layers.LeakyReLU(0.2)(conv_1)
        upsample_2 = tf.keras.layers.UpSampling2D(size=(2, 2))(leaky_relu_2)

        conv_2 = tf.keras.layers.Conv2DTranspose(128, kernel_size=5, padding="same")(upsample_2)
        leaky_relu_3 = tf.keras.layers.LeakyReLU(0.2)(conv_2)
        upsample_3 = tf.keras.layers.UpSampling2D(size=(2, 2))(leaky_relu_3)

        conv_3 = tf.keras.layers.Conv2DTranspose(32, kernel_size=5, padding="same")(upsample_3)
        leaky_relu_4 = tf.keras.layers.LeakyReLU(0.2)(conv_3)

        output_layer = tf.keras.layers.Conv2DTranspose(3, kernel_size=1, padding='same')(leaky_relu_4)
        
        self.generator = tf.keras.Model(input_layer, output_layer, name='generator')

In [None]:
class Discriminator(tf.keras.Model):
    def __init__(self, image_size): 
        super(Discriminator, self).__init__()
        input_layer = tf.keras.layers.Input(shape=(image_size[0], image_size[1], 3), dtype=tf.float32)

        conv_1 = tf.keras.layers.Conv2D(32, (5, 5), strides=(2, 2), padding='same')(input_layer)
        leaky_relu_1 = tf.keras.layers.LeakyReLU(0.2)(conv_1)
        conv_2 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(leaky_relu_1)
        leaky_relu_2 = tf.keras.layers.LeakyReLU(0.2)(conv_2)
        conv_3 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(leaky_relu_2)
        leaky_relu_3 = tf.keras.layers.LeakyReLU(0.2)(conv_3)
        flatten = tf.keras.layers.Flatten()(leaky_relu_3)
        output_layer = tf.keras.layers.Dense(1)(flatten)

        self.discriminator = tf.keras.Model(input_layer, output_layer, name='discriminator')

In [None]:
def cross_entropy_loss(logits, labels):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

def get_noise(n_batch, noise_var):
    return np.random.randn(n_batch, noise_var).astype(np.float32)

def visualize_images(images, n_rows, n_cols, title=None):
    images = images / 2 + 0.5 # Rescale to [0, 1] range
    images = np.clip(images, 0, 1)
    _, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows))
    for i in range(n_rows):
        for j in range(n_cols):
            ax = axes[i, j]
            ax.imshow(images[i * n_cols + j])
            ax.axis("off")
    plt.subplots_adjust(wspace=0, hspace=0)
    if title:
        plt.suptitle(title, fontsize=16)
    plt.show()

def plot_learning_curve(g_loss_history, d_g_z_loss_history, d_x_loss_history, d_loss_history):
    plt.figure(figsize=(15, 8))
    plt.plot(g_loss_history, label='Generator Loss', alpha=0.7)
    plt.plot(d_g_z_loss_history, label='Discriminator (Generated) Loss', alpha=0.7)
    plt.plot(d_x_loss_history, label='Discriminator (Real) Loss', alpha=0.7)
    plt.plot(d_loss_history, label='Discriminator Total Loss', alpha=0.7)
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Learning Curves')
    plt.grid(True)
    plt.show()

### Initialize parameters for GAN model training

In [None]:
latent_size = 128
n_batch = 100  # num of generated images per noise vector
lr = 1e-5
n_updates_total = 40000
noise_vector = get_noise(n_batch, latent_size)

In [None]:
generator = Generator(latent_size=latent_size)
discriminator = Discriminator(image_size=image_size)

g_z = generator.generator(noise_vector)
d_g_z = discriminator.discriminator(g_z)
d_x = discriminator.discriminator(train_data.next())

g_loss = cross_entropy_loss(logits=d_g_z, labels=tf.ones(tf.shape(d_g_z)))
d_g_z_loss = cross_entropy_loss(logits=d_g_z, labels=tf.zeros(tf.shape(d_g_z)))
d_x_loss = cross_entropy_loss(logits=d_x, labels=tf.ones(tf.shape(d_x)))
d_loss = (d_g_z_loss + d_x_loss) / 2

g_vars = generator.trainable_variables
d_vars = discriminator.trainable_variables

g_optimizer = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5)
d_optimizer = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5)

g_loss_history = []
d_g_z_loss_history = []
d_x_loss_history = []
d_loss_history = []

for n_updates in tqdm(range(n_updates_total), ncols=80, leave=False):
    xs = train_data.next()
    noise_samples = get_noise(n_batch, latent_size)
    
    with tf.GradientTape() as tape:
        d_g_z = discriminator.discriminator(generator.generator(noise_samples))
        d_x = discriminator.discriminator(xs)
        
        d_g_z_loss = cross_entropy_loss(logits=d_g_z, labels=tf.zeros_like(d_g_z))
        d_x_loss = cross_entropy_loss(logits=d_x, labels=tf.ones_like(d_x))
        d_loss = (d_g_z_loss + d_x_loss) / 2
    
    d_gradients = tape.gradient(d_loss, discriminator.discriminator.trainable_variables)
    d_optimizer.apply_gradients(zip(d_gradients, discriminator.discriminator.trainable_variables))
    
    with tf.GradientTape() as tape:
        g_z = generator.generator(noise_samples)
        d_g_z = discriminator.discriminator(g_z)
        g_loss = cross_entropy_loss(logits=d_g_z, labels=tf.ones_like(d_g_z))
    
    g_gradients = tape.gradient(g_loss, generator.generator.trainable_variables)
    g_optimizer.apply_gradients(zip(g_gradients, generator.generator.trainable_variables))

    if n_updates % 100 == 0:
        g_loss_history.append(g_loss)
        d_g_z_loss_history.append(d_g_z_loss)
        d_x_loss_history.append(d_x_loss)
        d_loss_history.append(d_loss)
    
    if n_updates % 1000 == 0:
        generated_images = generator.generator(noise_vector)
        n_rows = 10
        n_cols = 10
        visualize_images(generated_images, n_rows, n_cols, title=f"Generated images after {n_updates} iterations")
        plot_learning_curve(g_loss_history, d_g_z_loss_history, d_x_loss_history, d_loss_history)

### Save Generator and Discriminator models

In [None]:
# generator_weights = generator.get_weights()
# with open('generator_weights.pkl', 'wb') as f:
#     pickle.dump(generator_weights, f)

# discriminator_weights = discriminator.get_weights()
# with open('discriminator_weights.pkl', 'wb') as f:
#     pickle.dump(discriminator_weights, f)