In [1]:
!pip install tensorflow==2.12.0



In [14]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import LeakyReLU, Activation, ZeroPadding2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras import initializers
import numpy as np
import matplotlib.pyplot as plt
import os

In [16]:
# Random seed to replicate results

np.random.seed(1000)

In [18]:
# Load MNIST data (Hand-written digits)

def load_mnist():
    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape(60000, -1)
    return X_train

X_train = load_mnist()
X_train.shape

(60000, 784)

### GAN (Generative Adversarial Network)

In [21]:
# Optimizer

adam = Adam(learning_rate=0.0002, beta_1=0.5)

In [23]:
# Function to build the generator

def build_generator():
    
    model = Sequential([
        Dense(256, input_dim=10),
        LeakyReLU(alpha=0.2),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        Dense(784, activation='tanh')
    ])
    
    return model

generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer=adam)

In [27]:
# Function to build the generator

def build_discriminator():
    
    model = Sequential([
        Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)),
        LeakyReLU(alpha=0.2),
        Dropout(0.3),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dropout(0.3),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dropout(0.3),
        Dense(1, activation='sigmoid')
    ])
    
    return model

discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

In [29]:
# Combined network (GAN)

discriminator.trainable = False
ganInput = Input(shape=(10,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

In [71]:
# Plot the loss from each batch

def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('GAN_images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images

def saveGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, 10])
    generatedImages = generator.predict(noise, verbose = 0)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('GAN_images\digits_epoch_%d.png' % epoch)

In [79]:
def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print ('Epochs:', epochs)
    print ('Batch size:', batchSize)
    print ('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print(f" ======= Epoch {e} | Generator Loss: {prev_gloss:.4f} ======= ") if e > 1 else print(f" ======= Epoch {e}")
        
        for i in range(batchCount):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, 10])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]
            
            # Generate fake MNIST images
            generatedImages = generator.predict(noise, verbose=0)
            X = np.concatenate([imageBatch, generatedImages])
            
            # Labels for generated and real data
            yDis = np.zeros(2*batchSize)
            yDis[:batchSize] = 0.9  # One-sided label smoothing
            
            # Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)
            
            # Train generator
            noise = np.random.normal(0, 1, size=[batchSize, 10])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)
            
            # Print batch progress every 10 batches
            if (i + 1) % 10 == 0:
                print(f"\rBatch {i + 1}", end='', flush=True)
        
        # Clear the batch line at the end of epoch
        print('\r' + ' ' * 20 + '\r', end='', flush=True)  # Clear the batch line
                
        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)
        prev_gloss = gloss  # Store current gloss for next epoch's print
        
        if e == 1 or e % 20 == 0:
            saveGeneratedImages(e)
    
    # Plot losses from every epoch
    plotLoss(e)

In [None]:
train(200, 128)

Epochs: 200
Batch size: 128
Batches per epoch: 468
Batch 320