# GANs (Generative Adversarial Networks) From Scratch

#### Understand the Basic Concept:

A GAN consists of two neural networks: a Generator and a Discriminator.
- The Generator creates fake data (e.g., images).
- The Discriminator tries to distinguish between real and fake data.
- The two networks are trained together in a way that the Generator gets better at creating realistic data, while the Discriminator gets better at spotting fakes.

In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

#### Design the Generator and Discriminator:



In [3]:
# Hyperparameters
latent_dim = 100
batch_size = 64
epochs = 10

- The Generator network usually uses deconvolutional layers to generate images from random noise.


In [4]:
# Generator Model
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(28 * 28 * 1, activation='tanh'),
        layers.Reshape((28, 28, 1))
    ])
    return model

- The Discriminator is a typical convolutional neural network (CNN) that tries to classify input images as real or fake.

In [5]:
# Discriminator Model
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Flatten(input_shape=(28, 28, 1)),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

In [3]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt


(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images / 255.0  # Normalize to [0, 1]
train_images = np.expand_dims(train_images, axis=-1)  # Add channel dimension

train_images.shape

(60000, 28, 28, 1)

In [4]:
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(128, activation='relu', input_shape=(100,)),
        layers.Dense(784, activation='sigmoid'),
        layers.Reshape((28, 28, 1))
    ])
    return model

In [5]:
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Flatten(input_shape=(28, 28, 1)),
        layers.Dense(128, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

In [6]:
generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [7]:
def train_gan(epochs, batch_size):
    
    for epoch in range(epochs):
        
        # Train Discriminator
        for _ in range(batch_size):
            noise = np.random.normal(0, 1, size=[batch_size, 100])
            generated_images = generator.predict(noise)
            real_images = train_images[np.random.randint(0, train_images.shape[0], size=batch_size)]
            
            X = np.concatenate([real_images, generated_images])
            y_dis = np.zeros(2 * batch_size)
            y_dis[:batch_size] = 1  # Label real images as 1
            
            discriminator.train_on_batch(X, y_dis)

        # Train Generator
        noise = np.random.normal(0, 1, size=[batch_size, 100])
        y_gen = np.ones(batch_size)  # Label generated images as 1
        discriminator.trainable = False  # Freeze the Discriminator
        gan_loss = gan.train_on_batch(noise, y_gen)
        discriminator.trainable = True  # Unfreeze the Discriminator

        if epoch % 100 == 0:
            print(f'Epoch {epoch}, GAN Loss: {gan_loss}')


In [8]:
def generate_images(num_images):
    noise = np.random.normal(0, 1, size=[num_images, 100])
    generated_images = generator.predict(noise)
    return generated_images



In [9]:
plt.imshow(train_images[0],'gray')

<matplotlib.image.AxesImage at 0x1a94b11bb80>

: 