In [38]:
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Dense,BatchNormalization,LeakyReLU,Reshape,Flatten,Input
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np

In [39]:
img_shape = (28,28,1)  # mnist have black and white images of size 28,28  

In [40]:
## building Generator

def build_generator():
    
    noise_shape = (100,) # using which generator will generate images
    
    model = Sequential()
    
    model.add(Dense(256,input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(1023))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(np.prod(img_shape),activation='tanh')) ## np.prod multiplies 28,28 and 1 -> 784
    model.add(Reshape(img_shape))
    
    noise = Input(shape=noise_shape)
    output = model(noise)
    
    return Model(noise,output)

In [41]:
### building Discriminator

def build_discriminator():
    
    model = Sequential()
    model.add(Flatten(input_shape = img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1,activation='sigmoid'))
    
    img = Input(shape=img_shape)
    validity = model(img)   # its a guess of a discriminator that image is real or fake
    
    return Model(img,validity)

In [42]:
## defining training function

def train(epochs,batch_size=128,save_interval=100):
    
    (X_train,_),(_,_) = mnist.load_data()
    
    ## scaling images
    X_train = X_train/255.
    
    X_train = np.expand_dims(X_train,axis=3)  # 28,28 -> 28,28,1
    
    half_batch = batch_size//2
    
    for epoch in range(epochs):
        
        # ---------------------------------
        ## First Training the Discriminator
        # ---------------------------------
        
        # getting random images from X_train
        idx = np.random.randint(0,X_train.shape[0],half_batch) # getting random half_bach indexes from 0 to 60k
        imgs = X_train[idx]
        
        
        ## generating noise 
        noise = np.random.normal(0,1,(half_batch,100))  # It will generate half_batch,100 values between 0 and 1
        
        ## generating fake images
        gen_imgs = generator(noise)
        
        ## Training the discrimainator on real and fake images separately
        d_loss_real = discriminator.train_on_batch(imgs,np.ones((half_batch,1))) # paasing real images and telling discriminator that it is real by passing ones with it
        d_loss_fake = discriminator.train_on_batch(gen_imgs,np.ones((half_batch,1))) # passing generated fake images and passing ones with it fool geenrator by saying it is real
        
        # averaging loss
        d_loss = np.add(d_loss_real, d_loss_fake) * 0.5
        
        
        #----------------------
        ## Training Generator
        #----------------------
        
        noise = np.random.normal(0,1,(batch_size,100))
        
        valid_y = np.array([1]*batch_size) # to fool the discriminator
        
        g_loss = combined.train_on_batch(noise,valid_y)
        
        print(f"D loss {d_loss}, G loss {g_loss}")
        
        if epoch % save_interval == 0:
            save_imgs(epoch)
        

        
#This function saves our images for us to view
def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

        

In [43]:
## defining optimizer
optimizer = Adam(0.0002)

## discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy',])

## Generator
generator = build_generator()
generator.compile(loss='binary_crossentropy',optimizer=optimizer) ## we are generating fake images, no need to track metrics


# input to generator 
z = Input(shape=(100,))
img = generator(z)

# disabling discriminator training
discriminator.trainable = False

## validating generator output
valid = discriminator(img)

## stacking z and valid to give feedback to generate better images
combined = Model(z,valid)
combined.compile(loss='binary_crossentropy',optimizer=optimizer)

# training 
train(100,128,10)

generator.save('gen_100_epoch.h5')

D loss [0.65544951 0.703125  ], G loss 0.543929398059845
D loss [0.47398478 1.        ], G loss 0.41101470589637756
D loss [0.36653055 1.        ], G loss 0.28648287057876587
D loss [0.28746971 1.        ], G loss 0.199311763048172
D loss [0.22117603 1.        ], G loss 0.13194140791893005
D loss [0.14443546 1.        ], G loss 0.0849492996931076
D loss [0.09104678 1.        ], G loss 0.055805571377277374
D loss [0.06528289 1.        ], G loss 0.03326776623725891
D loss [0.03294604 1.        ], G loss 0.02242463082075119
D loss [0.02711387 1.        ], G loss 0.015557728707790375
D loss [0.01801869 1.        ], G loss 0.009629139676690102
D loss [0.01025968 1.        ], G loss 0.007471442688256502
D loss [0.00842034 1.        ], G loss 0.005322286393493414
D loss [0.00624797 1.        ], G loss 0.0044684880413115025
D loss [0.005892 1.      ], G loss 0.00324054853990674
D loss [0.00304573 1.        ], G loss 0.0024852966889739037
D loss [0.00324333 1.        ], G loss 0.002259219530969

  saving_api.save_model(
