In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# Load MNIST dataset
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = (train_images.astype('float32') - 127.5) / 127.5
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)

# Generator Model
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(256, input_shape=(110,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Dense(512),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Dense(1024),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Dense(28 * 28, activation="tanh"),
        layers.Reshape((28, 28, 1))
    ])
    return model

# Discriminator Model
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Flatten(input_shape=(28, 28, 11)),
        layers.Dense(512),
        layers.LeakyReLU(),
        layers.Dense(256),
        layers.LeakyReLU(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# Concatenate noise with label embeddings
def concatenate_label_and_noise(labels, noise):
    label_embeddings = tf.one_hot(labels, depth=10)
    noise_with_labels = tf.concat([noise, label_embeddings], axis=1)
    return noise_with_labels

# Train cGAN
def train_cgan(generator, discriminator, epochs=50, batch_size=64, noise_dim=100):
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

    for epoch in range(epochs):
        for i in range(int(train_images.shape[0] / batch_size)):
            # Get real images and labels
            real_images = train_images[i*batch_size:(i+1)*batch_size]
            labels = train_labels[i*batch_size:(i+1)*batch_size]
            noise = tf.random.normal([batch_size, noise_dim])
            noise_with_labels = concatenate_label_and_noise(labels, noise)

            # Train Discriminator
            with tf.GradientTape() as tape:
                generated_images = generator(noise_with_labels)
                real_input = tf.concat([real_images, tf.one_hot(labels, 10)], axis=-1)
                fake_input = tf.concat([generated_images, tf.one_hot(labels, 10)], axis=-1)
                real_output = discriminator(real_input)
                fake_output = discriminator(fake_input)
                d_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output) + \
                         tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
            gradients = tape.gradient(d_loss, discriminator.trainable_variables)
            discriminator_optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables))

            # Train Generator
            noise_with_labels = concatenate_label_and_noise(labels, noise)
            with tf.GradientTape() as tape:
                generated_images = generator(noise_with_labels)
                fake_input = tf.concat([generated_images, tf.one_hot(labels, 10)], axis=-1)
                fake_output = discriminator(fake_input)
                g_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)
            gradients = tape.gradient(g_loss, generator.trainable_variables)
            generator_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))

        print(f'Epoch: {epoch+1}, D Loss: {d_loss}, G Loss: {g_loss}')

    # Generate conditional images after training
    noise = tf.random.normal([16, noise_dim])
    labels = tf.constant([i for i in range(10)] * 2)
    noise_with_labels = concatenate_label_and_noise(labels, noise)
    generated_images = generator(noise_with_labels)

    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated_images[i, :, :, 0] * 127.5 + 127.5, cmap="gray")
        plt.axis('off')
    plt.show()

# Initialize and Train
generator = build_generator()
discriminator = build_discriminator()
train_cgan(generator, discriminator)


InvalidArgumentError: {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [64,28,28,1] vs. shape[1] = [64,10] [Op:ConcatV2] name: concat