Designing a Conditional GAN for CIFAR-10 Image Generation

In [15]:
import numpy as np
import matplotlib.pyplot as plt
import time
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, Embedding, Activation
from tensorflow.keras.layers import BatchNormalization, LeakyReLU, Conv2D, Conv2DTranspose
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize images to [-1, 1] range
x_train = (x_train.astype(np.float32) - 127.5) / 127.5

# Convert class vectors to binary class matrices
num_classes = 10
y_train = to_categorical(y_train, num_classes)

# Define parameters
img_rows = 32
img_cols = 32
channels = 3
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

# Class names for CIFAR-10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

def build_generator():
    model = Sequential()

    # Foundation for 8x8 feature maps
    model.add(Dense(128 * 8 * 8, activation="relu", input_dim=latent_dim))
    model.add(Reshape((8, 8, 128)))
    model.add(BatchNormalization(momentum=0.8))
    
    # Upsampling to 16x16
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    # Upsampling to 32x32
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    # Final Conv layer to get 3 channels
    model.add(Conv2D(channels, kernel_size=3, padding="same"))
    model.add(Activation("tanh"))

    model.summary()

    noise = Input(shape=(latent_dim,))
    label = Input(shape=(num_classes,), dtype='float32')
    
    # Embed label and concatenate with noise
    label_embedding = Dense(latent_dim)(label)
    model_input = multiply([noise, label_embedding])
    
    img = model(model_input)

    return Model([noise, label], img)

def build_discriminator():
    model = Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.summary()

    img = Input(shape=img_shape)
    label = Input(shape=(num_classes,), dtype='float32')

    # Embed label and flatten
    label_embedding = Dense(img_rows * img_cols)(label)
    label_embedding = Reshape((img_rows, img_cols, 1))(label_embedding)
    
    # Concatenate image and label
    model_input = multiply([img, label_embedding])
    
    validity = model(model_input)
    validity = Dense(1, activation='sigmoid')(validity)

    return Model([img, label], validity)

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss=['binary_crossentropy'],
                      optimizer=Adam(0.0002, 0.5),
                      metrics=['accuracy'])

# Build the generator
generator = build_generator()

# The generator takes noise and the target label as input
# and generates the corresponding digit
noise = Input(shape=(latent_dim,))
label = Input(shape=(num_classes,))
img = generator([noise, label])

# For the combined model we will only train the generator
discriminator.trainable = False

# The discriminator takes generated image as input and determines validity
valid = discriminator([img, label])

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model([noise, label], valid)
combined.compile(loss=['binary_crossentropy'],
                optimizer=Adam(0.0002, 0.5))

def sample_images(generator, epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    sampled_labels = np.array([num for _ in range(r) for num in range(c)])
    sampled_labels = to_categorical(sampled_labels, num_classes)
    
    gen_imgs = generator.predict([noise, sampled_labels])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i,j].set_title(f"{class_names[np.argmax(sampled_labels[cnt])]}")
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig(f"cifar10_cgan_epoch_{epoch}.png")
    plt.close()

def train(generator, discriminator, combined, epochs, batch_size=128, sample_interval=50):
    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        start_time = time.time()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        imgs = x_train[idx]
        labels = y_train[idx]

        # Sample noise and generate a batch of new images
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        gen_imgs = generator.predict([noise, labels])

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch([imgs, labels], valid)
        d_loss_fake = discriminator.train_on_batch([gen_imgs, labels], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        # Sample noise and labels
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        sampled_labels = to_

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
