In [7]:
import os
import math
import numpy as np
import tensorflow as tf


In [8]:
SEED = 1337
np.random.seed(SEED)
tf.random.set_seed(SEED)

OUTPUT_DIR = "gan_outputs"
SAMPLES_DIR = os.path.join(OUTPUT_DIR, "samples")
CKPT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

LATENT_DIM = 128
BATCH_SIZE = 128
EPOCHS = 10                  # keep small for lab run
LEARNING_RATE = 2e-4
BETA_1 = 0.5
IMAGE_SHAPE = (28, 28, 1)    # MNIST


In [9]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()

def preprocess_images(x):
    x = x.astype("float32")
    x = (x - 127.5) / 127.5  # scale to [-1, 1]
    x = np.expand_dims(x, axis=-1)
    return x

x_train = preprocess_images(x_train)
x_test = preprocess_images(x_test)

train_ds = (
    tf.data.Dataset.from_tensor_slices(x_train)
    .shuffle(1024, seed=SEED)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices(x_test)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

print(f"Train: {x_train.shape}, Test: {x_test.shape}")

Train: (60000, 28, 28, 1), Test: (10000, 28, 28, 1)


In [10]:
def build_generator(latent_dim=LATENT_DIM):
    inputs = tf.keras.Input(shape=(latent_dim,))
    x = tf.keras.layers.Dense(7 * 7 * 256, use_bias=False)(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Reshape((7, 7, 256))(x)

    x = tf.keras.layers.Conv2DTranspose(128, 5, strides=1, padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(64, 5, strides=2, padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(1, 5, strides=2, padding="same", use_bias=False, activation="tanh")(x)
    return tf.keras.Model(inputs, x, name="generator")

def build_discriminator(input_shape=IMAGE_SHAPE):
    inputs = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(64, 5, strides=2, padding="same")(inputs)
    x = tf.keras.layers.LeakyReLU(0.2)(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Conv2D(128, 5, strides=2, padding="same")(x)
    x = tf.keras.layers.LeakyReLU(0.2)(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Flatten()(x)
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
    return tf.keras.Model(inputs, outputs, name="discriminator")

generator = build_generator()
discriminator = build_discriminator()
generator.summary()
discriminator.summary()

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128)]             0         
                                                                 
 dense_2 (Dense)             (None, 12544)             1605632   
                                                                 
 batch_normalization_3 (Batc  (None, 12544)            50176     
 hNormalization)                                                 
                                                                 
 re_lu_3 (ReLU)              (None, 12544)             0         
                                                                 
 reshape_1 (Reshape)         (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose_3 (Conv2DT  (None, 7, 7, 128)        819200    
 ranspose)                                               

In [11]:
class GAN(tf.keras.Model):
    def __init__(self, generator, discriminator, latent_dim=LATENT_DIM):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim

        self.d_loss_tracker = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = tf.keras.metrics.Mean(name="g_loss")
        self.d_acc_tracker = tf.keras.metrics.BinaryAccuracy(name="d_acc", threshold=0.5)

    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker, self.d_acc_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        # Train discriminator
        z = tf.random.normal((batch_size, self.latent_dim))
        fake_images = self.generator(z, training=True)
        real_labels = tf.ones((batch_size, 1))
        fake_labels = tf.zeros((batch_size, 1))

        with tf.GradientTape() as tape:
            pred_real = self.discriminator(real_images, training=True)
            pred_fake = self.discriminator(fake_images, training=True)
            d_loss = self.loss_fn(real_labels, pred_real) + self.loss_fn(fake_labels, pred_fake)

        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Train generator
        z = tf.random.normal((batch_size, self.latent_dim))
        misleading_labels = tf.ones((batch_size, 1))
        with tf.GradientTape() as tape:
            generated = self.generator(z, training=True)
            pred = self.discriminator(generated, training=True)
            g_loss = self.loss_fn(misleading_labels, pred)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Metrics
        y_true = tf.concat([tf.ones_like(pred_real), tf.zeros_like(pred_fake)], axis=0)
        y_pred = tf.concat([pred_real, pred_fake], axis=0)
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        self.d_acc_tracker.update_state(y_true, y_pred)

        return {"d_loss": self.d_loss_tracker.result(),
                "g_loss": self.g_loss_tracker.result(),
                "d_acc": self.d_acc_tracker.result()}

In [13]:
bce = tf.keras.losses.BinaryCrossentropy()
opt_d = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA_1)
opt_g = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA_1)

gan = GAN(generator, discriminator, latent_dim=LATENT_DIM)
gan.compile(d_optimizer=opt_d, g_optimizer=opt_g, loss_fn=bce)

# -------------------------
# IMAGE GRID UTILS
# -------------------------
def save_image_grid(images, grid_rows, grid_cols, path):
    import cv2
    images = (images + 1.0) / 2.0  # [-1,1] → [0,1]
    images = np.clip(images, 0, 1)
    images = (images * 255).astype("uint8")

    H, W = images.shape[1], images.shape[2]
    C = images.shape[3]
    grid = np.zeros((grid_rows * H, grid_cols * W, C), dtype=np.uint8)

    idx = 0
    for r in range(grid_rows):
        for c in range(grid_cols):
            if idx >= images.shape[0]:
                break
            grid[r*H:(r+1)*H, c*W:(c+1)*W, :] = images[idx]
            idx += 1

    if C == 1:
        cv2.imwrite(path, grid.squeeze(-1))
    else:
        cv2.imwrite(path, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))

class SampleImagesCallback(tf.keras.callbacks.Callback):
    def __init__(self, num_samples=25, rows=5, cols=5, latent_dim=LATENT_DIM):
        self.num_samples, self.rows, self.cols, self.latent_dim = num_samples, rows, cols, latent_dim

    def on_epoch_end(self, epoch, logs=None):
        z = tf.random.normal((self.num_samples, self.latent_dim))
        gen = self.model.generator(z, training=False).numpy()
        path = os.path.join(SAMPLES_DIR, f"epoch_{epoch+1:03d}.png")
        save_image_grid(gen, self.rows, self.cols, path)
        print(f"[Saved] {path}")

In [None]:
history = gan.fit(
    train_ds,
    epochs=EPOCHS,
    callbacks=[SampleImagesCallback()],
    verbose=2
)

Epoch 1/10
[Saved] gan_outputs\samples\epoch_001.png
468/468 - 222s - d_loss: 1.3565 - g_loss: 0.8276 - d_acc: 0.6514 - 222s/epoch - 475ms/step
Epoch 2/10
[Saved] gan_outputs\samples\epoch_002.png
468/468 - 230s - d_loss: 1.3607 - g_loss: 0.8033 - d_acc: 0.6383 - 230s/epoch - 492ms/step
Epoch 3/10
[Saved] gan_outputs\samples\epoch_003.png
468/468 - 224s - d_loss: 1.3729 - g_loss: 0.7806 - d_acc: 0.6117 - 224s/epoch - 478ms/step
Epoch 4/10
[Saved] gan_outputs\samples\epoch_004.png
468/468 - 222s - d_loss: 1.3690 - g_loss: 0.7837 - d_acc: 0.6172 - 222s/epoch - 474ms/step
Epoch 5/10


In [None]:
z = tf.random.normal((25, LATENT_DIM))
gen_images = generator(z, training=False).numpy()
save_image_grid(gen_images, 5, 5, os.path.join(SAMPLES_DIR, "final.png"))

# Evaluate discriminator accuracy on test set
metric = tf.keras.metrics.BinaryAccuracy()
for real in test_ds.take(20):  # limit batches
    pred = discriminator(real, training=False)
    metric.update_state(tf.ones_like(pred), pred)
print(f"Discriminator accuracy on test real: {metric.result().numpy():.4f}")

# Fake accuracy
z = tf.random.normal((500, LATENT_DIM))
fakes = generator(z, training=False)
pred = discriminator(fakes, training=False)
metric2 = tf.keras.metrics.BinaryAccuracy()
metric2.update_state(tf.zeros_like(pred), pred)
print(f"Discriminator accuracy on generated fakes: {metric2.result().numpy():.4f}")

print(f"Final D Loss: {history.history['d_loss'][-1]:.4f}")
print(f"Final G Loss: {history.history['g_loss'][-1]:.4f}")
print(f"Final D Accuracy (train): {history.history['d_acc'][-1]:.4f}")