# Generative Adversarial Networks With Gaussian Penalty

In [None]:
import time
import functools as ft
from pathlib import Path
from tqdm.auto import tqdm

import mlx.optimizers as optim
import mlx.nn as nn
import mlx.core as mx

from wgan_gp.dataset import load_celeba
from wgan_gp.utils import grid_image_from_batch, ensure_exists

# Constants

In [None]:
SEED = 42
mx.random.seed(SEED)

BATCH_SIZE = 128
NUM_EPOCHS = 100
CRITIC_STEPS = 5
GP_WEIGHT = 10.0
Z_DIM = 100

IMAGE_SHAPE = (64, 64, 3)

EXP_TIME = time.strftime("%I-%M%p_%B-%d-%Y")
SAVE_DIR = Path("./artifacts") / EXP_TIME
SAVE_EVERY_EPOCH = 10

# Loading CelebA Dataset

In [None]:
data_buf = load_celeba(split="train")
total_samples = len(data_buf)

data = (
    data_buf
    .shuffle()
    .to_stream()
    .image_resize("image", h=IMAGE_SHAPE[0], w=IMAGE_SHAPE[1])
    .key_transform("image", lambda x: (x.astype("float32") - 127.5) / 127.5)
    .batch(BATCH_SIZE)
    .prefetch(prefetch_size=8, num_threads=8)
)

batch = next(data_iter)
grid_image_from_batch(batch["image"], num_rows=8)

# Losses

In [None]:
def c_loss_fn(c_model, g_model, real_data):
    """Loss function for discriminator (i.e. critic)"""

    batch_size = real_data.shape[0]

    # generate latent variable
    z = mx.random.normal(shape=(batch_size, Z_DIM))
    fake_data = g_model(z)

    # get discriminator predictions for real/fake data
    real_pred = c_model(real_data)
    fake_pred = c_model(fake_data)

    # Compute losses
    real_loss = real_pred.mean()  # Gradient ascent for real loss
    fake_loss = fake_pred.mean()   # Gradient descent for fake loss
    c_wass_loss = fake_loss - real_loss  # Wasserstein loss
    c_gp = gradient_penalty(c_model, real_data, fake_data)
    c_loss = c_wass_loss + GP_WEIGHT * c_gp

    return c_loss, c_wass_loss, c_gp

def gradient_penalty(c_model, real_data, fake_data):
    """Gradient penalty term"""
    batch_size = real_data.shape[0]
    # interpolate data
    alpha = mx.random.normal(shape=(batch_size, 1, 1, 1))
    interpo_data = (1 - alpha) * real_data + alpha * fake_data
    grads = mx.grad(lambda x: c_model(x).sum())(interpo_data)
    grad_norm = mx.linalg.norm(grads.reshape(batch_size, -1), axis=-1)
    c_gp = mx.mean((grad_norm - 1.0) ** 2)
    return c_gp

In [None]:
def g_loss_fn(c_model, g_model, batch_size):
    """Loss function for generator"""
    # generate fake data
    z = mx.random.normal(shape=(batch_size, Z_DIM))
    # generate and classify fake data
    fake_preds = c_model(g_model(z))
    # obtain loss
    g_loss = -fake_preds.mean()
    return g_loss

# Models

In [None]:
from wgan_gp.models import Critic, Generator

c_model = Critic(input_dim=3, output_dim=1)
g_model = Generator(input_dim=Z_DIM, output_dim=3)

mx.eval(c_model.parameters())
mx.eval(g_model.parameters())

# Optimizers

In [None]:
c_optim=optim.AdamW(learning_rate=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
g_optim=optim.AdamW(learning_rate=1e-4, betas=[0.9, 0.999], weight_decay=0.01)

# Metrics

In [None]:
from dataclasses import dataclass, field, replace

@dataclass
class Metrics:
    c_loss: list[float] = field(default_factory=list)
    g_loss: list[float] = field(default_factory=list)
    c_wass_loss: list[float] = field(default_factory=list)
    c_gp: list[float] = field(default_factory=list)
    thrp: list[float] = field(default_factory=list)

metrics = Metrics()

# Training

In [None]:
def train_epoch(c_model, c_loss_fn, c_optim,
                g_model, g_loss_fn, g_optim,
                data):

    c_state = [
        c_model.state, c_optim.state,
        g_model.state, mx.random.state
    ]

    losses = {"c_loss":[], "g_loss":[], "c_wass_loss":[], "c_gp":[], "thrp":[]}

    @ft.partial(mx.compile, inputs=c_state, outputs=c_state)
    def train_critic(x):
        loss_and_grad_fn_c = nn.value_and_grad(c_model, c_loss_fn)
        (c_loss, c_wass_loss, c_gp), c_grads = loss_and_grad_fn_c(c_model, g_model, x)
        c_optim.update(c_model, c_grads)
        return c_loss, c_wass_loss, c_gp

    g_state = [
        c_model.state, c_optim.state,
        g_model.state, g_optim.state,
        mx.random.state
    ]

    @ft.partial(mx.compile, inputs=g_state, outputs=g_state)
    def train_generator(batch_size):
        loss_and_grad_fn_g = nn.value_and_grad(g_model, g_loss_fn)
        g_loss, g_grads = loss_and_grad_fn_g(c_model, g_model, batch_size)
        g_optim.update(g_model, g_grads)
        return g_loss

    with tqdm(total=total_samples) as pbar:

        for batch_counter, batch in enumerate(data):

            tic = time.perf_counter()
            x = mx.array(batch["image"])
            c_loss, c_wass_loss, c_gp = train_critic(x)
            mx.eval(c_state)
            # Only update generator after running `num_critic_steps`
            if batch_counter % CRITIC_STEPS == 0:
                batch_size = x.shape[0]
                g_loss = train_generator(batch_size)
                mx.eval(g_state)
            toc = time.perf_counter()
            thrp = x.shape[0] / (toc - tic)

            losses["c_loss"].append(c_loss)
            losses["g_loss"].append(g_loss)
            losses["c_wass_loss"].append(c_wass_loss)
            losses["c_gp"].append(c_gp)
            losses["thrp"].append(thrp)

            pbar.update(x.shape[0])

    losses = {k: mx.array(v).mean() for k,v in losses.items()}

    return losses

In [None]:
for epoch in range(NUM_EPOCHS):

    data.reset()

    losses = train_epoch(
        c_model, c_loss_fn, c_optim,
        g_model, g_loss_fn, g_optim,
        data, epoch)

    for k, v in losses.items():
        metrics.__dict__[k].append(v)

    print("-"*120)
    print(
        " | ".join([
            f"Epoch: {epoch:02d}",
            f"avg. Critic loss: {losses["c_loss"]:.3f}",
            f"avg. Generator loss: {losses["g_loss"]:.3f}",
            f"avg. Throughput: {losses["thrp"]:.2f} images/second",
            f"avg. Wasserstein loss: {losses["c_wass_loss"]:.3f}",
            f"avg. Gradient penalty: {losses["c_gp"]:.3f}",
        ])
    )
    print("-"*120)

    # plot some samples after save_interval epochs
    if epoch % SAVE_EVERY_EPOCH == 0:
        z = mx.random.normal(shape=(BATCH_SIZE, Z_DIM))  # 128 random vectors
        fake_images = mx.array(g_model(z))
        img = grid_image_from_batch(fake_images, num_rows=8)
        ensure_exists(SAVE_DIR / "images")
        img.save(SAVE_DIR / "gen_images" / f"image_{epoch}.png")

# Visualize the metrics

In [None]:
import matplotlib.pyplot as plt  # Visualization
from matplotlib.ticker import MaxNLocator

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15,9))

ax[0, 0].plot(range(1, NUM_EPOCHS+1), metrics.c_loss, 'r')
ax[0, 0].title.set_text("c_loss") #row=0, col=0
ax[0, 0].xaxis.set_major_locator(MaxNLocator(integer=True))  # integer xaxis

ax[0, 1].plot(range(1, NUM_EPOCHS+1), metrics.g_loss, 'b') #row=1, col=0
ax[0, 1].title.set_text("g_loss") #row=0, col=0
ax[0, 1].xaxis.set_major_locator(MaxNLocator(integer=True)) # integer xaxis

ax[1, 0].plot(range(1, NUM_EPOCHS+1), metrics.c_wass_loss, 'g') #row=0, col=1
ax[1, 0].title.set_text("c_wass_loss") #row=0, col=0
ax[1, 0].xaxis.set_major_locator(MaxNLocator(integer=True)) # integer xaxis

ax[1, 1].plot(range(1, NUM_EPOCHS+1), metrics.c_gp, 'm') #row=1, col=1
ax[1, 1].title.set_text("c_gp") #row=0, col=0
ax[1, 1].xaxis.set_major_locator(MaxNLocator(integer=True)) # integer xaxis

ensure_exists(SAVE_DIR / "metrics")
fig.savefig(SAVE_DIR / "metrics" / "training.png")