In [52]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns; sns.set(); sns.set_style('dark')

import os
import datetime

from sklearn.model_selection import train_test_split
import tensorflow as tf

In [53]:
(x_train, y_train), (_, _) = tf.keras.datasets.cifar10.load_data()

In [54]:
cat_class = 3
cat_images = x_train[y_train.flatten() == cat_class]

# Normalize the images to [-1, 1] for generator's tanh activation
cat_images = (cat_images.astype('float32') - 127.5) / 127.5

In [55]:
latent_dim = 100

In [56]:
generator = tf.keras.Sequential([
    tf.keras.layers.Dense(8 * 8 * 256, input_dim=latent_dim),
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.Reshape((8, 8, 256)),  # Reshape to 8x8x256
    tf.keras.layers.BatchNormalization(momentum=0.8),
    tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),  # Upsample to 16x16x128
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.BatchNormalization(momentum=0.8),
    tf.keras.layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'),  # Upsample to 32x32x64
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.BatchNormalization(momentum=0.8),
    tf.keras.layers.Conv2DTranspose(3, kernel_size=3, strides=1, padding='same', activation='tanh')  # Output 32x32x3 image
])

discriminator = tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(32, 32, 3)),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1, activation='sigmoid')  # Binary classification (real/fake)
])


In [57]:
def compile_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
    discriminator.trainable = False  # Freeze discriminator when training the generator
    gan = tf.keras.Sequential([generator, discriminator])
    gan.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5))
    return gan

In [58]:
gan = compile_gan(generator, discriminator)

In [59]:
def save_generated_images(epoch, generator, examples=16, dim=(4, 4), figsize=(6, 6)):
    noise = np.random.normal(0, 1, (examples, latent_dim))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5  # Rescale to [0, 1]

    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i])
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"generated/generated_image_epoch_{epoch}.png")
    plt.close()

In [60]:
def train_gan(epochs, batch_size, interval):
    real = np.ones((batch_size, 1))  # Labels for real images
    fake = np.zeros((batch_size, 1))  # Labels for fake images

    for epoch in range(epochs):
        # Train Discriminator
        idx = np.random.randint(0, cat_images.shape[0], batch_size)
        real_images = cat_images[idx]
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        fake_images = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_images, real)
        d_loss_fake = discriminator.train_on_batch(fake_images, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, real)

        # Print progress
        if epoch % interval == 0:
            print(f"Epoch {epoch}/{epochs}, D Loss: {d_loss[0]:.4f}, D Acc: {d_loss[1]:.4f}, G Loss: {g_loss:.4f}")
            save_generated_images(epoch, generator)

In [None]:
# Train the DCGAN
train_gan(epochs=10000, batch_size=64, interval=1000)

In [63]:
examples = 16

noise = np.random.normal(0, 1, (examples, latent_dim))

generated_images = generator.predict(noise)

generated_images = 0.5 * generated_images + 0.5

plt.figure(figsize=(6, 6))

for i in range(examples):
    plt.subplot(4, 4, i + 1)
    plt.imshow(generated_images[i])
    plt.axis('off')
    
    plt.tight_layout()
    
    plt.savefig('generated_images.png')
    plt.close()


