In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# Load CIFAR-10 dataset
(x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0

# Define generator model
def build_generator(latent_dim):
    model = models.Sequential()
    model.add(layers.Dense(4 * 4 * 256, input_dim=latent_dim))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((4, 4, 256)))
    model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='sigmoid'))
    return model

# Define discriminator model
def build_discriminator(input_shape):
    model = models.Sequential()
    model.add(layers.Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=input_shape))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.4))
    model.add(layers.Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.4))
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

# Define GAN model
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = models.Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

# Define constants
latent_dim = 100
input_shape = (32, 32, 3)

# Build and compile discriminator
discriminator = build_discriminator(input_shape)
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Build generator
generator = build_generator(latent_dim)

# Build and compile GAN model
gan = build_gan(generator, discriminator)
gan.compile(optimizer='adam', loss='binary_crossentropy')

# Training loop
epochs = 10
batch_size = 64
steps_per_epoch = x_train.shape[0] // batch_size

for epoch in range(epochs):
    for step in range(steps_per_epoch):
        # Sample random points in latent space
        random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

        # Generate fake images using generator
        generated_images = generator.predict(random_latent_vectors)

        # Combine real and fake images into a batch for discriminator
        real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
        combined_images = np.concatenate([generated_images, real_images])

        # Create labels for discriminator
        labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
        labels += 0.05 * np.random.random(labels.shape)  # Add noise to labels

        # Train discriminator
        d_loss = discriminator.train_on_batch(combined_images, labels)

        # Sample random points in latent space
        random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

        # Create misleading labels for generator
        misleading_targets = np.zeros((batch_size, 1))

        # Train generator (via GAN model)
        a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)

        # Print progress
        if step % 100 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Step {step+1}/{steps_per_epoch}, D Loss: {d_loss[0]}, G Loss: {a_loss}')

    # Generate and save sample images after each epoch
    if epoch % 1 == 0:
        samples = 10  # Number of sample images to generate
        latent_points = np.random.normal(size=(samples, latent_dim))
        generated_images = generator.predict(latent_points) * 255  # Scale back to 0-255 range

        # Plot generated images
        plt.figure(figsize=(10, 10))
        for i in range(samples):
            plt.subplot(1, samples, i+1)
            plt.imshow(generated_images[i].astype('uint8'))  # Convert to uint8 for display
            plt.axis('off')
        plt.show()
