In [1]:
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import wandb

In [2]:
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

In [3]:
config_defaults = {
    "epochs": 50,
    "batch_size": 128,
    "latent_dim": 100,
    "learning_rate": 0.0002,
    "beta_1": 0.5,
    "patience": 20
}

wandb.init(project="dcgan", config=config_defaults)
config = wandb.config

print(config)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: joshuacox924007 (joshuacox924007-atlas-school) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


{'epochs': 50, 'batch_size': 128, 'latent_dim': 100, 'learning_rate': 0.0002, 'beta_1': 0.5, 'patience': 20}


In [4]:
from models.dcgan_baseline import build_generator, build_discriminator

data_path = os.path.abspath(os.path.join('..', 'data', 'mnist_preprocessed.npz'))
data = np.load(data_path)
x_train = data['x_train']
img_shape = x_train.shape[1:]

In [5]:
generator = build_generator(config.latent_dim, k_size=5, filter_size=32, s1=2, s2=2)
discriminator = build_discriminator(img_shape, k_size=5, alpha=0.1, s=1)

optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate, beta_1=config.beta_1)
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

discriminator.trainable = False
noise_input = tf.keras.Input(shape=(config.latent_dim,))
generated_image = generator(noise_input)
validity = discriminator(generated_image)
combined = tf.keras.Model(noise_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

In [6]:
def log_generated_images(generator, epoch, examples=16, dim=(4, 4), figsize=(4, 4)):
    noise = np.random.normal(0, 1, (examples, config.latent_dim))
    generated_images = generator.predict(noise)
    generated_images = (generated_images + 1) / 2.0

    fig, axs = plt.subplots(dim[0], dim[1], figsize=figsize, sharex=True, sharey=True)
    cnt = 0
    for i in range(dim[0]):
        for j in range(dim[1]):
            axs[i, j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.suptitle(f"Epoch {epoch}")
    wandb.log({"generated_images": wandb.Image(fig, caption=f"Epoch {epoch}")})
    plt.close(fig)

In [7]:
num_batches = x_train.shape[0] // config.batch_size

best_g_loss = float('inf')
patience_counter = 0

In [8]:
for epoch in range(1, config.epochs + 1):
    d_loss_epoch = 0.0
    g_loss_epoch = 0.0
    d_acc_epoch = 0.0

    for batch in range(num_batches):
        # ---------------------
        #  Train Discriminator
        # ---------------------
        idx = np.random.randint(0, x_train.shape[0], config.batch_size)
        real_imgs = x_train[idx]

        noise = np.random.normal(0, 1, (config.batch_size, config.latent_dim))
        fake_imgs = generator.predict(noise)

        valid = np.ones((config.batch_size, 1))
        fake = np.zeros((config.batch_size, 1))

        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------
        noise = np.random.normal(0, 1, (config.batch_size, config.latent_dim))
        g_loss = combined.train_on_batch(noise, valid)

        d_loss_epoch += d_loss[0]
        d_acc_epoch += d_loss[1]
        g_loss_epoch += g_loss

    d_loss_epoch /= num_batches
    d_acc_epoch /= num_batches
    g_loss_epoch /= num_batches

    print(f"Epoch {epoch}/{config.epochs} [D loss: {d_loss_epoch:.4f}, D acc: {d_acc_epoch*100:.2f}%] [G loss: {g_loss_epoch:.4f}]")
    wandb.log({
        "epoch": epoch,
        "d_loss": d_loss_epoch,
        "d_accuracy": d_acc_epoch,
        "g_loss": g_loss_epoch
    })

    if epoch == 1 or epoch % 5 == 0:
        log_generated_images(generator, epoch)

    if g_loss_epoch < best_g_loss:
        best_g_loss = g_loss_epoch
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= config.patience:
            print(f"Early stopping triggered at epoch {epoch}")
            wandb.log({"early_stopping": True, "stopped_epoch": epoch})
            break

Epoch 1/50 [D loss: 0.3580, D acc: 87.67%] [G loss: 0.7935]
Epoch 2/50 [D loss: 0.6226, D acc: 65.99%] [G loss: 1.2668]
Epoch 3/50 [D loss: 0.5952, D acc: 67.99%] [G loss: 1.2514]
Epoch 4/50 [D loss: 0.6289, D acc: 64.36%] [G loss: 1.0857]
Epoch 5/50 [D loss: 0.6133, D acc: 66.89%] [G loss: 1.0784]
Epoch 6/50 [D loss: 0.5908, D acc: 69.59%] [G loss: 1.1474]
Epoch 7/50 [D loss: 0.5740, D acc: 70.96%] [G loss: 1.2309]
Epoch 8/50 [D loss: 0.5571, D acc: 72.31%] [G loss: 1.3114]
Epoch 9/50 [D loss: 0.5444, D acc: 73.03%] [G loss: 1.3672]
Epoch 10/50 [D loss: 0.5356, D acc: 73.56%] [G loss: 1.4260]
Epoch 11/50 [D loss: 0.5287, D acc: 73.97%] [G loss: 1.4770]
Epoch 12/50 [D loss: 0.5233, D acc: 74.12%] [G loss: 1.5036]
Epoch 13/50 [D loss: 0.5131, D acc: 74.91%] [G loss: 1.5518]
Epoch 14/50 [D loss: 0.5069, D acc: 75.34%] [G loss: 1.5938]
Epoch 15/50 [D loss: 0.5021, D acc: 75.52%] [G loss: 1.6186]
Epoch 16/50 [D loss: 0.5010, D acc: 75.61%] [G loss: 1.6460]
Epoch 17/50 [D loss: 0.4949, D ac