In [34]:
from tensorflow.keras.layers import Input,Dense, BatchNormalization, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

In [35]:
from tqdm import tqdm

In [3]:
def create_generator():
    rnd_input = Input(shape=(100,))
    x = Dense(256, activation=LeakyReLU(alpha=0.2))(rnd_input)
    x = BatchNormalization()(x)
    x = Dense(512, activation=LeakyReLU(alpha=0.2))(x)
    x = BatchNormalization()(x)
    x = Dense(1024, activation=LeakyReLU(alpha=0.2))(x)
    x = BatchNormalization()(x)
    img_output = Dense(784, activation='tanh')(x)
    generator = Model(inputs=rnd_input,outputs = img_output)
    return generator

In [4]:
g = create_generator()

In [5]:
g.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense (Dense)                (None, 256)               25856     
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              525312    
_________________________________________________________________
batch_normalization_2 (Batch (None, 1024)              4096  

In [31]:
def create_discriminator():
    img_input = Input(shape=(784,))
    x = Dense(1024, activation=LeakyReLU(alpha=0.2))(img_input)
    x = BatchNormalization()(x)
    x = Dense(512, activation=LeakyReLU(alpha=0.2))(x)
    x = BatchNormalization()(x)
    x = Dense(256, activation=LeakyReLU(alpha=0.2))(x)
    x = BatchNormalization()(x)
    dis_output = Dense(1, activation='sigmoid')(x)
    discriminator = Model(inputs=img_input,outputs=dis_output)
    discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
    return discriminator

In [7]:
d = create_discriminator()
d.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              803840    
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_5 (Dense)              (None, 512)               524800    
_________________________________________________________________
batch_normalization_4 (Batch (None, 512)               2048      
_________________________________________________________________
dense_6 (Dense)              (None, 256)               131328    
_________________________________________________________________
batch_normalization_5 (Batch (None, 256)               1024

In [15]:
def create_gan(discriminator, generator):
    discriminator.trainable=False
    gan_input = Input(shape=(100,))
    x = generator(gan_input)
    gan_output= discriminator(x)
    gan= Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan

In [16]:
gan = create_gan(d,g)

In [17]:
gan.summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
model (Model)                (None, 784)               1493520   
_________________________________________________________________
model_1 (Model)              (None, 1)                 1467393   
Total params: 2,960,913
Trainable params: 1,489,936
Non-trainable params: 1,470,977
_________________________________________________________________


In [18]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(100,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image {}.png'.format(epoch))

In [19]:
def adam_optimizer():
    return Adam(lr=0.0002, beta_1=0.5)

In [20]:
def load_x_train():
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    x_train = (x_train-127.0)/255.0
    x_train = x_train.reshape(-1,784)
    return x_train

In [21]:
X_train = load_x_train()
X_train.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


(60000, 784)

In [36]:
def training(epochs=1,batch_size=128):
    X_train = load_x_train()
    no_of_steps = X_train.shape[0]//batch_size
    g = create_generator()
    d = create_discriminator()
    gan = create_gan(d,g)
    
    for e in range(epochs):
        print('Epoch: {}'.format(e+1))
        for _ in tqdm(range(no_of_steps)):
            gen_input = np.random.normal(size=(batch_size,100))
            gen_imgs = g.predict(gen_input)
            img_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
            dis_X = np.concatenate([gen_imgs,img_batch])
            
            dis_y = np.zeros(2*batch_size)
            dis_y[:batch_size] = np.random.random(size=batch_size)/10 + 0.9
            
            d.trainable = True
            d.train_on_batch(dis_X, dis_y)
            
            gan_X = np.random.normal(size=(batch_size,100))
            gan_y = np.ones(batch_size)
            
            d.trainable= False
            
            gan.train_on_batch(gan_X,gan_y)
        if (e%10==0):
            plot_generated_images(e, g)

In [None]:
training(50,128)

  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: 1


100%|██████████| 468/468 [06:50<00:00,  1.14it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: 2


100%|██████████| 468/468 [06:00<00:00,  1.30it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: 3


100%|██████████| 468/468 [05:34<00:00,  1.40it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: 4


 21%|██        | 96/468 [01:08<04:22,  1.42it/s]