In [7]:
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Dense,BatchNormalization,LeakyReLU,Reshape,Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np

In [4]:
img_shape = (28,28,1)  # mnist have black and white images of size 28,28  

In [5]:
## building Generator

def generator():
    
    noise_shape = (100,) # using which generator will generate images
    
    model = Sequential()
    
    model.add(Dense(256,input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(1023))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(np.prod(img_shape),activation='tanh')) ## np.prod multiplies 28,28 and 1 -> 784
    model.add(Reshape(img_shape))
    
    noise = Input(shape=noise_shape)
    output = model(noise)
    
    return Model(noise,output)

In [6]:
### building Discriminator

def discriminator():
    
    model = Sequential()
    model.add(Flatten(input_shape = img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1,activation='sigmoid'))
    
    img = Input(shape=img_shape)
    validity = model(img)   # its a guess of a discriminator that image is real or fake
    
    return Model(img,validity)

784

In [None]:
## defining training function

def train(epochs,batch_size=128,save_interval):
    
    (X_train,_),(_,_) = mnist.load_data()
    
    ## scaling images
    X_train = X_train/255.
    
    X_train = np.expand_dims(X_train,axis=3)  # 28,28 -> 28,28,1
    
    half_batch = batch_size//2
    
    for epoch in range(epochs):
        
        # ---------------------------------
        ## First Training the Discriminator
        # ---------------------------------
        
        # getting random images from X_train
        idx = np.random.normal(0,X_train.shape[0],half_batch) # getting random half_bach indexes from 0 to 60k
        imgs = X_train[idx]
        
        
        ## generating noise 
        noise = no.random(0,1,(half_batch,100))  # It will generate half_batch,100 values between 0 and 1
        
        ## generating fake images
        gen_imgs = generator(noise)
        
        ## Training the discrimainator on real and fake images separately
        d_loss_real = discriminator.train_on_batch(imgs,np.ones((half_batch,1))) # paasing real images and telling discriminator that it is real by passing ones with it
        d_loss_fake = discriminator.train_on_batch(gen_imgs,np.ones((half_batch,1))) # passing generated fake images and passing ones with it fool geenrator by saying it is real
        
        # averaging loss
        d_loss = np.add(d_loss_real + d_loss_fake) * 0.5
        
        
        #----------------------
        ## Training Generator
        #----------------------
        
        noise = np.random(0,1,(batch_size,100))
        
        valid_y = np.array([1]*batch_size) # to fool the discriminator
        
        g_loss = combined.train_on_batch(noise,valid_y)
        
        print(f"D loss {d_loss}, G loss {g_loss}")
        
        if epoch % save_interval == 0:
            save_imgs(epoch)
        

        
def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()
#This function saves our images for us to view
        