In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, Flatten, Dense, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [2]:
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = (X_train.astype(np.float32) / 127.5) - 1.0  # [-1, 1]
X_train = np.expand_dims(X_train, axis=-1)  # (28,28,1)

# Two domains: A (original), B (inverted)
X_A = X_train
X_B = -X_train

img_shape = (28,28,1)
latent_dim = 100
optimizer = Adam(0.0002, 0.5)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [3]:
def build_generator(img_shape):
    inputs = Input(shape=img_shape)

    x = Conv2D(64, kernel_size=3, strides=2, padding="same")(inputs)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2DTranspose(128, kernel_size=3, strides=2, padding="same")(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2DTranspose(64, kernel_size=3, strides=2, padding="same")(x)
    x = LeakyReLU(0.2)(x)

    outputs = Conv2D(1, kernel_size=3, padding="same", activation="tanh")(x)
    return Model(inputs, outputs)


In [4]:
def build_discriminator(img_shape):
    img = Input(shape=img_shape)

    x = Conv2D(64, kernel_size=3, strides=2, padding="same")(img)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)

    x = Flatten()(x)
    validity = Dense(1, activation="sigmoid")(x)
    return Model(img, validity)

In [5]:
G_AB = build_generator(img_shape)
G_BA = build_generator(img_shape)

D_A = build_discriminator(img_shape)
D_B = build_discriminator(img_shape)

D_A.compile(loss="mse", optimizer=optimizer, metrics=["accuracy"])
D_B.compile(loss="mse", optimizer=optimizer, metrics=["accuracy"])

In [6]:
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

fake_B = G_AB(img_A)
fake_A = G_BA(img_B)

reconstr_A = G_BA(fake_B)
reconstr_B = G_AB(fake_A)

img_A_id = G_BA(img_A)
img_B_id = G_AB(img_B)

D_A.trainable = False
D_B.trainable = False

valid_A = D_A(fake_A)
valid_B = D_B(fake_B)

combined = Model(inputs=[img_A, img_B],
                 outputs=[valid_A, valid_B,
                          reconstr_A, reconstr_B,
                          img_A_id, img_B_id])

combined.compile(loss=["mse", "mse", "mae", "mae", "mae", "mae"],
                 loss_weights=[1,1,10,10,1,1],
                 optimizer=optimizer)

In [7]:
def train(epochs=10000, batch_size=64, save_interval=1000):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    d_losses, g_losses = [], []

    for epoch in range(1, epochs+1):
        idx = np.random.randint(0, X_A.shape[0], batch_size)
        imgs_A = X_A[idx]
        imgs_B = X_B[idx]

        fake_B = G_AB.predict(imgs_A, verbose=0)
        fake_A = G_BA.predict(imgs_B, verbose=0)

        dA_loss_real = D_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = D_A.train_on_batch(fake_A, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = D_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = D_B.train_on_batch(fake_B, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        d_loss = 0.5 * np.add(dA_loss, dB_loss)

        g_loss = combined.train_on_batch([imgs_A, imgs_B],
                                         [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])

        d_losses.append(d_loss[0])
        g_losses.append(g_loss[0])

        if epoch % 100 == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc: {100*d_loss[1]:.2f}] [G loss: {g_loss[0]:.4f}]")

        if epoch % save_interval == 0:
            save_imgs(epoch)

    return d_losses, g_losses


In [8]:
def save_imgs(epoch, examples=5):
    idx = np.random.randint(0, X_A.shape[0], examples)
    imgs_A = X_A[idx]
    imgs_B = X_B[idx]

    fake_B = G_AB.predict(imgs_A, verbose=0)
    fake_A = G_BA.predict(imgs_B, verbose=0)

    gen_imgs = np.concatenate([imgs_A, fake_B, imgs_B, fake_A])

    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ["A", "A→B", "B", "B→A"]
    plt.figure(figsize=(10,4))
    for i in range(examples*4):
        plt.subplot(examples, 4, i+1)
        plt.imshow(gen_imgs[i,:,:,0], cmap="gray")
        plt.title(titles[i%4])
        plt.axis("off")
    plt.suptitle(f"Epoch {epoch}")
    plt.show()


In [None]:
d_losses, g_losses = train(epochs=2000, batch_size=64, save_interval=500)

plt.figure(figsize=(8,6))
plt.plot(d_losses, label="Discriminator Loss")
plt.plot(g_losses, label="Generator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("CycleGAN Training Losses")
plt.show()



100 [D loss: 0.2555, acc: 37.65] [G loss: 8.5332]
