In [None]:
import tensorflow as tf
import os
from tensorflow.keras import layers
from matplotlib import pyplot as plt

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive/')

## Setup and Configuration

In [None]:
def setup_environment():
    """Check for GPU availability and set up directories."""
    if tf.config.list_physical_devices('GPU'):
        print("GPU is available and ready for use.")
    else:
        print("No GPU found. Please check your runtime settings.")

    # Set the directory for saving checkpoints
    checkpoint_dir = '/content/drive/../ImageGenerationWithGANs/training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

    return checkpoint_dir, checkpoint_prefix

In [None]:
checkpoint_dir, checkpoint_prefix = setup_environment()

No GPU found. Please check your runtime settings.


## Data Preparation

In [None]:
def prepare_data():
    """Load and preprocess the MNIST dataset."""
    (x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
    x_train = (x_train - 127.5) / 127.5  # Normalize to [-1, 1]

    train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(1024)

    return train_dataset

In [None]:
train_dataset = prepare_data()

## Model Building

In [None]:
def build_generator():
    """Build the generator model."""
    model = tf.keras.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((7, 7, 256)),

        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])

    return model

def build_discriminator():
    """Build the discriminator model."""
    model = tf.keras.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Flatten(),
        layers.Dense(1)
    ])

    return model

In [None]:
generator = build_generator()
discriminator = build_discriminator()

## Loss and Optimization Functions

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    """Calculate the discriminator loss."""
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    """Calculate the generator loss."""
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
%cd /content/drive/MyDrive/MyProjects/ImageGenerationWithGANs

/content/drive/MyDrive/MyProjects/ImageGenerationWithGANs


## Checkpointing

In [None]:
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training

In [None]:
def train_step(images):
    """Perform a single training step."""
    noise = tf.random.normal([256, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
    """Train the GAN for a given number of epochs."""
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)

        # Save the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

        print(f'Epoch {epoch + 1} completed')

In [None]:
train(train_dataset, EPOCHS)

Epoch 1 completed
Epoch 2 completed
Epoch 3 completed
Epoch 4 completed
Epoch 5 completed
Epoch 6 completed
Epoch 7 completed
Epoch 8 completed
Epoch 9 completed
Epoch 10 completed
Epoch 11 completed
Epoch 12 completed
Epoch 13 completed
Epoch 14 completed
Epoch 15 completed
Epoch 16 completed
Epoch 17 completed
Epoch 18 completed
Epoch 19 completed
Epoch 20 completed
Epoch 21 completed
Epoch 22 completed
Epoch 23 completed
Epoch 24 completed
Epoch 25 completed
Epoch 26 completed
Epoch 27 completed
Epoch 28 completed
Epoch 29 completed
Epoch 30 completed
Epoch 31 completed
Epoch 32 completed
Epoch 33 completed
Epoch 34 completed
Epoch 35 completed
Epoch 36 completed
Epoch 37 completed
Epoch 38 completed
Epoch 39 completed
Epoch 40 completed
Epoch 41 completed
Epoch 42 completed
Epoch 43 completed
Epoch 44 completed
Epoch 45 completed
Epoch 46 completed
Epoch 47 completed
Epoch 48 completed
Epoch 49 completed
Epoch 50 completed
Epoch 51 completed
Epoch 52 completed
Epoch 53 completed
Ep

<Figure size 640x480 with 0 Axes>

## Testing/Inference

In [None]:
def generate_and_save_images(model, epoch, test_input):
    """Generate and save images from the generator."""
    predictions = model(test_input, training=False)
    plt.figure(figsize=(4, 4))
    plt.imshow(predictions[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.axis('off')
    plt.savefig(f'image_at_epoch_{epoch:04d}.png')

def generate_image(generator, noise_dim=100):
    """Generate an image from random noise."""
    noise = tf.random.normal([1, noise_dim])
    generated_image = generator(noise, training=False)
    plt.imshow(generated_image[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.show()

def generate_label_image(generator, label, noise_dim=100):
    """Generate an image conditioned on a specific label."""
    label = tf.convert_to_tensor([label])
    noise = tf.random.normal([1, noise_dim])
    generated_image = generator([noise, label], training=False)
    plt.imshow(generated_image[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.show()

## Main Execution

In [None]:
EPOCHS = 1000
noise_dim = 100
num_examples_to_generate = 10

# Seed for consistent image generation
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [None]:
train(train_dataset, EPOCHS)

In [None]:
# Restore the latest checkpoint
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# Generate and display a sample image
generate_image(generator)

# Generate an image for a specific label (e.g., '5')
generate_label_image(generator, label=5)