In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import os

IMG_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 64
NOISE_DIM = 100
EPOCHS = 100
DATA_PATH = './data'


def load_custom_dataset(path):
    dataset = tf.keras.utils.image_dataset_from_directory(
        path,
        label_mode=None,
        image_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    dataset = dataset.map(lambda x: (tf.cast(x, tf.float32) / 127.5) - 1.0)
    return dataset

train_dataset = load_custom_dataset(DATA_PATH)

In [None]:
def make_generator():
    model = tf.keras.Sequential([
        layers.Dense(8*8*256, use_bias=False, input_shape=(100,)), 
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((8, 8, 256)),

        layers.Conv2DTranspose(128, 5, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Conv2DTranspose(64, 5, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
    

        layers.Conv2DTranspose(CHANNELS, 5, strides=2, padding='same', use_bias=False, activation='tanh')
    ])
    return model


In [None]:
def make_discriminator():
    model = tf.keras.Sequential([
        layers.Conv2D(64, 5, strides=2, padding='same', input_shape=(IMG_SIZE, IMG_SIZE, CHANNELS)),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Conv2D(128, 5, strides=2, padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Conv2D(256, 5, strides=2, padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

In [None]:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def discriminator_accuracy(real_output, fake_output):
    real_accuracy = tf.reduce_mean(
        tf.cast(tf.math.greater_equal(real_output, 0.5), tf.float32))
    fake_accuracy = tf.reduce_mean(
        tf.cast(tf.math.less(fake_output, 0.5), tf.float32))
    return (real_accuracy + fake_accuracy) / 2

In [None]:
generator = make_generator()
discriminator = make_discriminator()

gen_opt = tf.keras.optimizers.Adam(1e-4)
disc_opt = tf.keras.optimizers.Adam(1e-4)

seed = tf.random.normal([16, NOISE_DIM])

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        disc_acc = discriminator_accuracy(real_output, fake_output)

    gradients_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gen_opt.apply_gradients(zip(gradients_gen, generator.trainable_variables))
    disc_opt.apply_gradients(zip(gradients_disc, discriminator.trainable_variables))

    return gen_loss, disc_loss, disc_acc

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(2, 2))
    
    for i in range(min(4, predictions.shape[0])): 
        plt.subplot(2, 2, i+1)
        plt.imshow((predictions[i] + 1) / 2) 
        plt.axis('off')
    plt.savefig(f'gen_images/epoch_{epoch:03}.png')
    plt.close()

def train(dataset, epochs):
    os.makedirs('gen_images', exist_ok=True)
    for epoch in range(epochs):
        print(f'Starting epoch {epoch+1}/{epochs}')
        for image_batch in dataset:
            g_loss, d_loss, d_acc = train_step(image_batch)

        print(f"Gen Loss: {g_loss:.4f}, Disc Loss: {d_loss:.4f}, Disc Acc: {d_acc:.4f}")
        generate_and_save_images(generator, epoch+1, seed)

train(train_dataset, EPOCHS)