In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from datetime import datetime
import keras
from keras import layers, ops
import numpy as np

# Hyperparameters

dataset_repetitions = 5
num_epochs = 50
image_size = 128
batch_size = 64
latent_dim = 100

# WGAN specific
critic_iterations = 5
gp_lambda = 10.0
learning_rate = 1e-4


def preprocess_image(data):
    height = ops.shape(data["image"])[0]
    width = ops.shape(data["image"])[1]
    crop_size = ops.minimum(height, width)

    image = tf.image.crop_to_bounding_box(
        data["image"],
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )
    image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
    return ops.clip(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
    return (
        tfds.load(dataset_name, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .repeat(dataset_repetitions)
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

def load_custom_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=1)  # 1 = grayscale
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]

    # Ensure shape is fully defined
    img = tf.image.resize(img, [128, 128], antialias=True)

    # Convert grayscale to 3-channel if needed
    img = tf.image.grayscale_to_rgb(img)  # shape becomes (128,128,3)

    return img

def make_image_dataset(folder, batch_size=64):
    files = tf.data.Dataset.list_files(folder + "/*", shuffle=True)

    ds = files.map(load_custom_image, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.shuffle(1000)
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(tf.data.AUTOTUNE)

    return ds

image_folder = "/Images/Processed_128x128_Grayscale"
train_dataset = make_image_dataset(image_folder, batch_size=batch_size)

for img_batch in train_dataset.take(1):
    print(img_batch.shape)  # Expected: (batch_size, 128, 128, 3)
    plt.imshow(img_batch[0])
    plt.show()


# Visualize sample
for image in train_dataset.take(1):
    plt.imshow(image[0].numpy())
    plt.show()
    break


def build_generator(latent_dim, image_size):
    return keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(4 * 4 * 256, use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Reshape((4, 4, 256)),

        layers.Conv2DTranspose(128, 4, 2, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),

        layers.Conv2DTranspose(64, 4, 2, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),

        layers.Conv2DTranspose(32, 4, 2, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),

        layers.Conv2DTranspose(3, 4, 2, padding="same", use_bias=False),
        layers.Activation("tanh"),
    ])


def build_critic(image_size):
    L = layers
    inp = L.Input(shape=(image_size, image_size, 3))
    nf = 64

    x = L.Conv2D(nf, 4, 2, padding="same")(inp)
    x = L.LeakyReLU(0.2)(x)

    x = L.Conv2D(nf * 2, 4, 2, padding="same")(x)
    x = L.LeakyReLU(0.2)(x)

    x = L.Conv2D(nf * 4, 4, 2, padding="same")(x)
    x = L.LeakyReLU(0.2)(x)

    x = L.Conv2D(nf * 8, 4, 2, padding="same")(x)
    x = L.LeakyReLU(0.2)(x)

    x = L.Flatten()(x)
    out = L.Dense(1)(x)

    return keras.Model(inp, out)


def gradient_penalty(critic, real, fake):
    batch = tf.shape(real)[0]
    eps = tf.random.uniform([batch, 1, 1, 1], 0.0, 1.0)
    x_hat = eps * real + (1 - eps) * fake

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(x_hat)
        pred = critic(x_hat, training=True)

    grads = gp_tape.gradient(pred, x_hat)
    grads = tf.reshape(grads, [batch, -1])
    grad_norm = tf.norm(grads, axis=1)
    return tf.reduce_mean((grad_norm - 1.0) ** 2)


def update_ema(G_src, G_tgt, decay=0.999):
    src_w = G_src.get_weights()
    tgt_w = G_tgt.get_weights()
    new_w = [
        decay * w_ema + (1.0 - decay) * w
        for w_ema, w in zip(tgt_w, src_w)
    ]
    G_tgt.set_weights(new_w)


generator = build_generator(latent_dim, image_size)
critic = build_critic(image_size)

G_ema = keras.models.clone_model(generator)
G_ema.set_weights(generator.get_weights())

generator_optimizer = keras.optimizers.Adam(learning_rate, 0.0, 0.9)
critic_optimizer = keras.optimizers.Adam(learning_rate, 0.0, 0.9)

print("Generator:")
generator.summary()
print("\nCritic:")
critic.summary()


@tf.function
def train_step(real_images):
    batch = tf.shape(real_images)[0]
    real_images = real_images * 2.0 - 1.0

    # ----- Critic -----
    for _ in range(critic_iterations):
        noise = tf.random.normal((batch, latent_dim))
        fake = generator(noise, training=True)

        with tf.GradientTape() as tape:
            real_scores = critic(real_images, training=True)
            fake_scores = critic(fake, training=True)

            wasserstein = tf.reduce_mean(real_scores) - tf.reduce_mean(fake_scores)
            gp = gradient_penalty(critic, real_images, fake)
            critic_loss = -(wasserstein) + gp_lambda * gp

        grads = tape.gradient(critic_loss, critic.trainable_weights)
        critic_optimizer.apply_gradients(zip(grads, critic.trainable_weights))

    # ----- Generator -----
    noise = tf.random.normal((batch, latent_dim))
    with tf.GradientTape() as tape:
        fake = generator(noise, training=True)
        fake_scores = critic(fake, training=True)
        generator_loss = -tf.reduce_mean(fake_scores)

    g_grads = tape.gradient(generator_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(g_grads, generator.trainable_weights))

    return critic_loss, generator_loss, wasserstein, gp


def generate_images(generator, epoch, num=16):
    noise = tf.random.normal((num, latent_dim))
    imgs = generator(noise, training=False)
    imgs = (imgs + 1) / 2.0
    imgs = tf.clip_by_value(imgs, 0.0, 1.0)

    os.makedirs("generated_images", exist_ok=True)

    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i].numpy())
        ax.axis("off")

    fname = f"generated_images/epoch_{epoch:03d}.png"
    plt.savefig(fname)
    plt.close()
    print(f"Saved {fname}")


# ------------ TRAINING LOOP (NO KID) ------------
print("\nStarting training...")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    d_losses, g_losses = [], []

    for step, batch_images in enumerate(train_dataset):
        d_loss, g_loss, W, gp_val = train_step(batch_images)
        update_ema(generator, G_ema, 0.999)

        d_losses.append(float(d_loss))
        g_losses.append(float(g_loss))

        if step % 100 == 0:
            print(f"  Step {step}: D={d_loss:.4f} G={g_loss:.4f} W={W:.4f} GP={gp_val:.4f}")

    print(f"Epoch {epoch+1} â€” Avg D={np.mean(d_losses):.4f}, Avg G={np.mean(g_losses):.4f}")
    generate_images(G_ema, epoch+1)

print("\nSaving models...")
os.makedirs("checkpoints", exist_ok=True)
generator.save_weights("checkpoints/wgan_generator.h5")
critic.save_weights("checkpoints/wgan_critic.h5")
print("Done!")
