# Imports

In [1]:
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import wandb

# Set path

In [2]:
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

# Hyperparameters

In [4]:
config_defaults = {
    "epochs": 50,
    "batch_size": 128,
    "latent_dim": 100,
    "learning_rate": 0.0002,
    "beta_1": 0.5,
    "patience": 10
}

wandb.init(project="dcgan", config=config_defaults)
config = wandb.config

print(config)

wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Joshua\_netrc
wandb: Currently logged in as: joshuacox924007 (joshuacox924007-atlas-school) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


{'epochs': 50, 'batch_size': 128, 'latent_dim': 100, 'learning_rate': 0.0002, 'beta_1': 0.5, 'patience': 10}


# Load Model

In [5]:
from models.dcgan_baseline import build_generator, build_discriminator

data_path = os.path.abspath(os.path.join('..', 'data', 'mnist_preprocessed.npz'))
data = np.load(data_path)
x_train = data['x_train']
img_shape = x_train.shape[1:]

# Generate the Combined Model

In [8]:
generator = build_generator(config.latent_dim)
discriminator = build_discriminator(img_shape, k_size=10)

optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate, beta_1=config.beta_1)
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

discriminator.trainable = False
noise_input = tf.keras.Input(shape=(config.latent_dim,))
generated_image = generator(noise_input)
validity = discriminator(generated_image)
combined = tf.keras.Model(noise_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

# Function to generate image log

In [9]:
def log_generated_images(generator, epoch, examples=16, dim=(4, 4), figsize=(4, 4)):
    noise = np.random.normal(0, 1, (examples, config.latent_dim))
    generated_images = generator.predict(noise)
    generated_images = (generated_images + 1) / 2.0

    fig, axs = plt.subplots(dim[0], dim[1], figsize=figsize, sharex=True, sharey=True)
    cnt = 0
    for i in range(dim[0]):
        for j in range(dim[1]):
            axs[i, j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.suptitle(f"Epoch {epoch}")
    wandb.log({"generated_images": wandb.Image(fig, caption=f"Epoch {epoch}")})
    plt.close(fig)

# Calculate number of epochs per batch

In [10]:
num_batches = x_train.shape[0] // config.batch_size

best_g_loss = float('inf')
patience_counter = 0

# Training loop

In [11]:
for epoch in range(1, config.epochs + 1):
    d_loss_epoch = 0.0
    g_loss_epoch = 0.0

    for batch in range(num_batches):
        # ---------------------
        #  Train Discriminator
        # ---------------------
        idx = np.random.randint(0, x_train.shape[0], config.batch_size)
        real_imgs = x_train[idx]

        noise = np.random.normal(0, 1, (config.batch_size, config.latent_dim))
        fake_imgs = generator.predict(noise)

        valid = np.ones((config.batch_size, 1))
        fake = np.zeros((config.batch_size, 1))

        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------
        noise = np.random.normal(0, 1, (config.batch_size, config.latent_dim))
        g_loss = combined.train_on_batch(noise, valid)

        d_loss_epoch += d_loss[0]
        g_loss_epoch += g_loss

    d_loss_epoch /= num_batches
    g_loss_epoch /= num_batches

    print(f"Epoch {epoch}/{config.epochs} [D loss: {d_loss_epoch:.4f}] [G loss: {g_loss_epoch:.4f}]")
    wandb.log({
        "epoch": epoch,
        "d_loss": d_loss_epoch,
        "g_loss": g_loss_epoch
    })

    if epoch == 1 or epoch % 5 == 0:
        log_generated_images(generator, epoch)

    if g_loss_epoch < best_g_loss:
        best_g_loss = g_loss_epoch
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= config.patience:
            print(f"Early stopping triggered at epoch {epoch}")
            wandb.log({"early_stopping": True, "stopped_epoch": epoch})
            break

Epoch 1/50 [D loss: 0.5169] [G loss: 0.5940]
Epoch 2/50 [D loss: 0.7126] [G loss: 0.7786]
Epoch 3/50 [D loss: 0.6677] [G loss: 0.7958]
Epoch 4/50 [D loss: 0.6539] [G loss: 0.8311]
Epoch 5/50 [D loss: 0.6774] [G loss: 0.7987]
Epoch 6/50 [D loss: 0.6781] [G loss: 0.7820]
Epoch 7/50 [D loss: 0.6767] [G loss: 0.7773]
Epoch 8/50 [D loss: 0.6764] [G loss: 0.7762]
Epoch 9/50 [D loss: 0.6748] [G loss: 0.7788]
Epoch 10/50 [D loss: 0.6717] [G loss: 0.7852]
Epoch 11/50 [D loss: 0.6679] [G loss: 0.7941]
Early stopping triggered at epoch 11
