In [None]:
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential,Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np

# image shape
img_shape = (28,28,1)

def generator():
    noise_shape = (100,) # 1D array of size 100 "latent vector"
    model = Sequential()
    model.add(Dense(256,input_shape= noise_shape))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape),activation = 'tanh'))
    model.add(Reshape(img_shape))
    model.summary()
    noise = Input(shape = noise_shape) # 1D array as input
    generated_img = model(noise) # 2D array

    return Model(noise,generated_img) # 2D image to return

# the discriminator works as binary classification --> true / false
def discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1,activation = 'sigmoid'))
    model.summary()

    img = Input(shape = img_shape)
    validity = model(img)

    return Model(img,validity)

# build a training method
def train(epochs, batch_size=128, save_interval=500, folder='your_data_folder'):
    # load the dataset
    file_names = os.listdir(folder)

    # save the loaded images
    imgs = []
    for fname in file_names:
        if fname.endswith('.jpg') or fname.endswith('.png') or fname.endswith('.jpeg'):
            img = Image.open(os.path.join(folder, fname)).convert('L')
            img = img.resize((28, 28))
            imgs.append(np.array(img))
    
    # convert the images to nd array
    x_train = np.stack(imgs, axis=0)
    # convert to float and rescale from -1 to 1
    x_train = (x_train.astype(np.float32)-127.5)/127.5
    # add channels dimension
    x_train = np.expand_dims(x_train,axis=3)
    # select the half batch size
    half_batch = int(batch_size/2)

    for epoch in range(epochs):
        # train the discriminator
        index = np.random.randint(0, x_train.shape[0], half_batch)
        img = x_train[index]
        noise = np.random.normal(0, 1, (half_batch, 100))
        # generate a half batch fake images
        fake_img = generator.predict(noise)
        # # --> train the discriminator on real and fake images separately
    
        # train the discriminator on real images 
        d_lossRealImgs = discriminator.train_on_batch(img, np.ones((half_batch, 1)))
        # train the discriminator on fake images
        d_lossFakeImgs = discriminator.train_on_batch(fake_img, np.zeros((half_batch, 1)))
        # average loss of fake images and real images
        avg_dloss = (d_lossRealImgs[0] + d_lossFakeImgs[0])*0.5
        avg_dacc = (d_lossRealImgs[1] + d_lossFakeImgs[1])*0.5

        # train the generator
        noise = np.random.normal(0, 1, (half_batch, 100))
        #  create an array of all ones --> real images
        valid_y = np.ones((half_batch, 1))
        # fooling the discriminator
        generator_loss = combine.train_on_batch(noise, valid_y)
        print("%d [Discriminator_loss : %f , acc.:%2f%%] [Generator_loss : %f ]" %
              (epoch, avg_dloss, 100 * avg_dacc, generator_loss))

        # if at save intervals while training equal the save interval that i determined at the parameter list
        if epoch % save_interval == 0:
            save_imgs(epoch)


def save_imgs(epoch):
    row, column = 5, 5
    noise = np.random.normal(0, 1, (row * column, 100))
    generatedImgs = generator.predict(noise)
    # rescale images from 0 to 1
    gen_imgs = 0.5 * generatedImgs + 0.5
    fig, axs = plt.subplots(row, column)
    count = 0
    for i in range(row):
        for j in range(column):
            axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            count += 1
    # make folder for saved images ( valid images )
    os.makedirs("valid", exist_ok=True)
    fname = f"valid/image_at_epoch_{epoch}.png"
    fig.savefig(fname)
    plt.close()


# define the optimizer 
optimizer = Adam(0.0002, 0.5)

# build and compile the discriminator
discriminator = discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

# build and compile the generator
generator = generator()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

# create a noise as input in GAN generator to produce fake images
z = Input(shape=(100,))
img = generator(z)

# disable the discriminator while generator training
discriminator.trainable = False 
# set a parameter to indicate whether the input is real image or not
valid = discriminator(img)

# we want to fool the discriminator, so we build a combine model
combine = Model(z, valid)
combine.compile(loss='binary_crossentropy', optimizer=optimizer)


# call the training method 
train(epochs=100000, batch_size=32, save_interval=2000, folder='your_data_folder')