In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# Parameters
NUM_CLASSES = 10  # Number of digits (0-9)
NOISE_DIM = 100  # Size of the random noise vector
IMAGE_SIZE = 28  # Image size (28x28)
IMAGE_CHANNELS = 1  # MNIST images are grayscale

# Generator Model
def build_generator(label_dim, noise_dim, image_size, image_channels):
    model = tf.keras.Sequential(name="Generator")
    model.add(layers.Input(shape=(label_dim + noise_dim,)))
    model.add(layers.Dense(7 * 7 * 256, activation="relu"))
    model.add(layers.Reshape((7, 7, 256)))
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same", activation="relu"))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding="same", activation="relu"))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv2DTranspose(image_channels, (4, 4), strides=(1, 1), padding="same", activation="tanh"))
    return model

# Discriminator Model
def build_discriminator(image_size, image_channels, label_dim):
    img_input = layers.Input(shape=(image_size, image_size, image_channels))
    label_input = layers.Input(shape=(label_dim,))
    x = layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu")(img_input)
    x = layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same", activation="relu")(x)
    x = layers.Flatten()(x)
    label_embedding = layers.Dense(128, activation="relu")(label_input)
    combined = layers.Concatenate()([x, label_embedding])
    combined = layers.Dense(128, activation="relu")(combined)
    combined = layers.Dense(1, activation="sigmoid")(combined)
    return tf.keras.Model(inputs=[img_input, label_input], outputs=combined, name="Discriminator")

# Loss Functions
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

# Optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# Training Step
@tf.function
def train_step(generator, discriminator, class_labels, real_images):
    # Dynamically determine the batch size from the class_labels tensor
    batch_size = tf.shape(class_labels)[0]

    # Generate random noise with the correct batch size
    noise = tf.random.normal([batch_size, NOISE_DIM], dtype=tf.float32)
    class_labels = tf.cast(class_labels, tf.float32)  # Ensure class labels are float32

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Concatenate noise and class labels
        fake_images = generator(tf.concat([class_labels, noise], axis=1), training=True)

        # Discriminator outputs
        real_output = discriminator([real_images, class_labels], training=True)
        fake_output = discriminator([fake_images, class_labels], training=True)

        # Compute losses
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # Compute gradients and update weights
    gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

    return gen_loss, disc_loss


# Training Loop
def train(generator, discriminator, dataset, epochs):
    for epoch in range(epochs):
        for class_labels, real_images in dataset:
            gen_loss, disc_loss = train_step(generator, discriminator, class_labels, real_images)
        print(f"Epoch {epoch+1}, Generator Loss: {gen_loss.numpy()}, Discriminator Loss: {disc_loss.numpy()}")


# Generate and Save Image for a Given Digit
def generate_image(generator, digit):
    noise = tf.random.normal([1, NOISE_DIM])
    label = tf.one_hot([digit], depth=NUM_CLASSES)
    generated_image = generator(tf.concat([label, noise], axis=1), training=False)
    generated_image = (generated_image + 1) / 2.0  # Rescale to [0, 1]
    plt.imshow(generated_image[0, :, :, 0], cmap="gray")
    plt.axis("off")
    plt.title(f"Generated Image for Digit {digit}")
    plt.show()

# Load MNIST Dataset
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
(train_images, train_labels), (_, _) = mnist.load_data()
train_images = (train_images - 127.5) / 127.5
train_images = np.expand_dims(train_images, axis=-1)
train_labels = to_categorical(train_labels, num_classes=NUM_CLASSES)

# Create Dataset
batch_size = 64
dataset = tf.data.Dataset.from_tensor_slices((train_labels, train_images)).shuffle(60000).batch(batch_size)

# Instantiate Models
generator = build_generator(NUM_CLASSES, NOISE_DIM, IMAGE_SIZE, IMAGE_CHANNELS)
discriminator = build_discriminator(IMAGE_SIZE, IMAGE_CHANNELS, NUM_CLASSES)

# Train Models
train(generator, discriminator, dataset, epochs=10)#, batch_size=batch_size)

# Generate Images for User Input
while True:
    user_input = input("Enter a digit (0-9) to generate its image, or 'exit' to quit: ")
    if user_input.lower() == "exit":
        break
    elif user_input.isdigit() and 0 <= int(user_input) <= 9:
        generate_image(generator, int(user_input))
    else:
        print("Please enter a valid digit between 0 and 9.")
