In [37]:
# @title Simple Example: Conditional GAN
'''
We'll use a simplified version of a Conditional GAN that generates handwritten digits based on labels (0–9).
'''

"\nWe'll use a simplified version of a Conditional GAN that generates handwritten digits based on labels (0–9).\n"

In [38]:
# @title Step 1: Import required Libraries
import numpy as np
import matplotlib.pyplot as plt

# Import TensorFlow and Keras modules
import tensorflow as tf
from tensorflow.keras.datasets import mnist # Preloaded dataset of handwritten digits (0-9)
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape, Concatenate
from tensorflow.keras.models import Model

In [39]:
# @title Step 2: Load and Normalize the Dataset

# Load MNIST dataset (images and labels)
(X_train, y_train), (_, _) = mnist.load_data()

# Normalize pixel values to range [-1,1] for better GAN performance
X_train = (X_train.astype(np.float32) - 127.5) / 127.5

# Flatten images from 28x28 to 784 vectors
X_train = X_train.reshape(X_train.shape[0], -1)

print("Training samples", X_train.shape[0])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Training samples 60000


In [40]:
# @title Step 3: Define Dimensions and Hyperparamters
latent_dim = 100 # Size of random noise vector
num_classes = 10 # Digits 0-9
img_shape = (784,) # Flattened image shap

In [41]:
# @title Step 4: Build the Generator

# Generator: Takes noise + label as input, ouput fake image

# Noise input
noise_input = Input(shape = (latent_dim,))
# Label Input
label_input = Input(shape=(1,), dtype = 'int32')

# Convert label to one-hot vector
label_embedding = tf.keras.layers.Embedding(num_classes, 50)(label_input)
label_embedding = Flatten()(label_embedding)

# Concatenate noise and label vector
merged_input = Concatenate()([noise_input,label_embedding])

# Dense layers to generate image
x = Dense(128, activation='relu')(merged_input)
x = Dense(256, activation='relu')(x)
x = Dense(np.prod(img_shape), activation='tanh')(x)

# Reshape back to image
generated_image = Reshape(img_shape)(x)

# Define generator model
generator = Model([noise_input, label_input], generated_image)
generator.summary()

In [42]:
# @title Step 5: Build the Discriminator

# Discriminator: Takes image + label and outputs real/fake image

# Image input
img_input = Input(shape = img_shape)
# Label Input
label_input = Input(shape=(1,), dtype = 'int32') # a single integer per sample

# Convert label to one-hot embedding
label_embedding = tf.keras.layers.Embedding(num_classes, 50)(label_input)
label_embedding = Flatten()(label_embedding)

# Concatenate Image and Label
merged_input = Concatenate()([img_input,label_embedding])

# Desnse Layers for Classification
x = Dense(256, activation = 'relu')(merged_input)
x = Dense(128, activation = 'relu')(x)
x = Dense(1, activation = 'sigmoid')(x) # Output probablity (real or fake)

# Define discriminator model
discriminator = Model([img_input, label_input], x)

# Compile Discriminator
discriminator.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
discriminator.summary()

In [43]:
# @title Step 6: Discriminator during generator training
discriminator.trainable = False

# Connect generator + discriminator
noise = Input(shape =(latent_dim,))
label = Input(shape=(1,), dtype = 'int32')

# Generate Image
generated_img = generator([noise,label])

# Discriminate Image
validity = discriminator([generated_img, label])

# Define combined CGAN model
# CGAN model is used to train the generator using discriminator feedback:
cgan = Model([noise, label], validity)
cgan.compile(loss = 'binary_crossentropy', optimizer = 'adam')
cgan.summary()

In [44]:
# @title Step 7 - Training Loop
epochs = 5000
batch_size = 128

for epoch in range(epochs):

  # Training Discriminiator

  # Select the random batch of real images
  idx = np.random.randint(0, X_train.shape[0], batch_size)
  real_imgs = X_train[idx]
  labels = y_train[idx]

  # Generate fake images
  noise = np.random.normal(0, 1, (batch_size, latent_dim))
  gen_imgs = generator.predict([noise, labels])


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8m

KeyboardInterrupt: 

In [None]:
# Generate sample images for a specific label (e.g., '7')
label = np.array([[7]])
noise = np.random.normal(0, 1, (1, latent_dim))

# Generate fake image for label '7'
gen_img = generator.predict([noise, label])[0]

# Reshape and display
plt.imshow(gen_img.reshape(28, 28), cmap='gray')
plt.title("Generated Digit: 7")
plt.axis('off')
plt.show()