In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
from tensorflow.keras.datasets import mnist

In [None]:
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras import Input
from tensorflow.keras.models import Sequential

In [None]:
import tensorflow as tf

## Import Data

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [None]:
X_train = X_train / 255
X_test = X_test / 255

In [None]:
plt.imshow(X_train[0]);

In [None]:
y_train[0]

## Filtering Data

In [None]:
only_fours = X_train[y_train==4]

In [None]:
only_fours.shape

In [None]:
plt.imshow(only_fours[4000]);

In [None]:
np.random.seed(42)
tf.random.set_seed(42)

codings_size = 100

In [None]:
generator = Sequential()
generator.add(Input(shape=(codings_size,)))
generator.add(Dense(units=100, activation='relu'))
generator.add(Dense(units=200, activation='relu'))
generator.add(Dense(units=28*28, activation='sigmoid'))
generator.add(Reshape((28, 28)))

In [None]:
discriminator = Sequential()
discriminator.add(Input(shape=(28,28)))
discriminator.add(Flatten())
discriminator.add(Dense(units=200, activation='relu'))
discriminator.add(Dense(units=100, activation='relu'))
discriminator.add(Dense(units=1, activation='sigmoid'))

In [None]:
discriminator.compile(loss='binary_crossentropy', optimizer='adam')

In [None]:
GAN = Sequential([generator, discriminator])

In [None]:
discriminator.trainable = False

In [None]:
GAN.compile(loss='binary_crossentropy', optimizer='adam')

In [None]:
GAN.layers

In [None]:
GAN.layers[0].summary()

In [None]:
GAN.layers[1].summary()

## Batch Size

In [None]:
batch_size = 32

In [None]:
my_data = only_fours

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(my_data).shuffle(buffer_size=1000)

In [None]:
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

## Training

NOTE: The generator never actually sees any real images. It learns by viewing the gradients going back through the discriminator. The better the discrimnator gets through training, the more information the discriminator contains in its gradients, which means the generator can being to make progress in learning how to generate fake images, in our case, fake zeros.**

In [None]:
epochs = 20

# Grab the seprate components
generator, discriminator = GAN.layers

# For every epcoh
for epoch in range(epochs):
    print(f"Currently on Epoch {epoch+1}")
    i = 0
    # For every batch in the dataset
    for X_batch in dataset:
        i=i+1
        if i%10 == 0:
            print(f"\tCurrently on batch number {i} of {len(my_data)//batch_size}")
            
        ##### TRAINING THE DISCRIMINATOR ######
        
        # Create Noise
        noise = tf.random.normal(shape=[batch_size, codings_size])
        
        # Generate numbers based just on noise input
        gen_images = generator(noise)
        
        # Concatenate Generated Images against the Real Ones
        # TO use tf.concat, the data types must match!
        X_fake_vs_real = tf.concat([gen_images, tf.dtypes.cast(X_batch,tf.float32)], axis=0)
        
        # Targets set to zero for fake images and 1 for real images
        y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
        
        # This gets rid of a Keras warning
        discriminator.trainable = True
        
        # Train the discriminator on this batch
        discriminator.train_on_batch(X_fake_vs_real, y1)
        
        
        ##### TRAINING THE GENERATOR #####
        
        # Create some noise
        noise = tf.random.normal(shape=[batch_size, codings_size])
        
        # We want discriminator to belive that fake images are real
        y2 = tf.constant([[1.]] * batch_size)
        
        # Avois a warning
        discriminator.trainable = False
        
        GAN.train_on_batch(noise, y2)       

In [None]:
noise = tf.random.normal(shape=[10, codings_size])

In [None]:
noise.shape

In [None]:
plt.imshow(noise);

In [None]:
image = generator(noise)

In [None]:
plt.imshow(image[5]);