In [None]:
#import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import random
import io
import imageio

from tensorflow import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import*
from keras.optimizers import Adam

In [None]:
#load mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

n_samples = 25
#look at the samples of mnist
for i in range(n_samples):
    plt.subplot(5, 5, 1 + i)
    plt.axis('off')
    plt.imshow(x_train[i], cmap='gray')
plt.show()

# Models

In [None]:
#define discriminator model
def define_discriminator():
    discriminator = Sequential([
    Conv2D(128, (3,3), strides=(2, 2), padding='same', input_shape=(28,28,1)),
    LeakyReLU(alpha=0.2),
    Dropout(0.4),
    
    Conv2D(128, (3,3), strides=(2, 2), padding='same'),
    LeakyReLU(alpha=0.2),
    Dropout(0.4),
    
    Flatten(),
    Dense(64),
    Dense(1, activation='sigmoid')
    ])
    
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
    return discriminator

#build discriminztor
d_model = define_discriminator()

#look at architecture of discriminator
d_model.summary()

In [None]:
#define generator model
def define_generator(latent_dim):
    generator = Sequential([
    Dense(128 * 7, input_dim=latent_dim),
    Dense(128 * 7 * 7),
    LeakyReLU(alpha=0.2),
    Reshape((7, 7, 128)),
    
    Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
    LeakyReLU(alpha=0.2),
    
    Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
    LeakyReLU(alpha=0.2),

    Conv2D(1, (7,7), activation='sigmoid', padding='same')
    ])
    
    return generator

#set the constant
latent_dim = 100
#build generator
g_model = define_generator(latent_dim)
#look at architecture of generator
g_model.summary()

In [None]:
#define GAN model
def define_gan(g_model, d_model):
    d_model.trainable = False

    gan = Sequential()
    gan.add(g_model)
    gan.add(d_model)

    gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    return gan

#create GAN
gan_model = define_gan(g_model, d_model)
#look at GAN architecture
gan_model.summary()

# Helping Functions

In [None]:
#function which load and preprocess train samples
def load_real_samples():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = np.expand_dims(x_train / 255, axis=-1).astype('float32')
    return x_train

#function which select random samples from train samples
def select_real_samples(dataset, quantity):
    selected_indexes = [random.randint(0, len(dataset) - 1) for _ in range(quantity)]
    selected_samples = dataset[selected_indexes]
    label_real = np.ones((quantity, 1))
    return selected_samples, label_real

#function which generate points in latent space
def generate_latent_points(latent_dim, n_samples):
    points = np.random.randn(latent_dim * n_samples)
    points = points.reshape(n_samples, latent_dim)
    return points

#function which generate fake samples
def generate_fake_samples(g_model, latent_dim, n_samples):
    points = generate_latent_points(latent_dim, n_samples)
    samples = g_model.predict(points)
    real_labels = np.zeros((n_samples, 1))
    return samples, real_labels

#function to display generated images
def show_images(examples, n):
    for i in range(n * n):
        plt.subplot(n, n, 1 + i)
        plt.axis('off')
        plt.imshow(examples[i, :, :, 0], cmap='gray_r')
    plt.show()

#fuction to save 'subplot' images
def save_images(examples, n):
    a = io.BytesIO()
    plt.figure()
    for i in range(n * n):
        plt.subplot(n, n, 1 + i)
        plt.axis('off')
        plt.imshow(examples[i, :, :, 0], cmap='gray_r')
    plt.savefig(a, format='png')
    plt.close()
    a.seek(0)
    image = plt.imread(a)
    combined_image = np.asarray(image)
    return combined_image

# Training functions

In [None]:
loss_disc_per_epoch = []
loss_gan_per_epoch = []
loss_disc_per_batch = []
loss_gan_per_batch = []
predictions = []
#function to train GAN
def train(g_model, d_model, gan_model, dataset, latent_dim, epochs=150, batches=256):
    number_batches = int(dataset.shape[0] / batches)
    batches //= 2
    global loss_disc_per_epoch, loss_gan_per_epoch, loss_disc_per_batch, loss_gan_per_batch, predictions
    latent_points = generate_latent_points(100, 25)
    #training by epochs
    for i in range(epochs):
        #trainig by batches in epoch
        for j in range(number_batches):
            #traning
            X_real, y_real = select_real_samples(dataset, batches)
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, batches)
            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            d_loss, _ = d_model.train_on_batch(X, y)
            X_gan = generate_latent_points(latent_dim, batches)
            y_gan = np.ones((batches, 1))
            gan_loss = gan_model.train_on_batch(X_gan, y_gan)
            #print
            print(f'>{i+1}, {j+1}/{number_batches}, \n         disc loss={d_loss}, \n         gan loss={gan_loss}')
            #loss
            loss_disc_per_batch.append(d_loss)
            loss_gan_per_batch.append(gan_loss)

        #loss
        loss_disc_per_epoch.append(d_loss)
        loss_gan_per_epoch.append(gan_loss)
        #show epoch result
        X = g_model.predict(latent_points)
        show_images(X, 5)
        #save epoch result
        predicted_image = save_images(X, 5)
        predictions.append(predicted_image)

# Train models \ Saving weights \ Loading weights

In [None]:
#loading data
dataset = load_real_samples()
#train GAN
train(g_model, d_model, gan_model, dataset, latent_dim)

In [None]:
#pathes
base_path = ''
gan_weights = base_path + 'gan_model.h5'
g_weights = base_path + 'g_model.h5'
d_weights = base_path + 'd_model.h5'

In [None]:
#saving model weights
gan_model.save_weights(gan_weights)
g_model.save_weights(g_weights)
d_model.save_weights(d_weights)

In [None]:
#loading model weights
gan_model.load_weights(gan_weights)
g_model.load_weights(g_weights)
d_model.load_weights(d_weights)

# Look at results

In [None]:
latent_points = generate_latent_points(100, 36)
X = g_model.predict(latent_points)
show_images(X, 6)

In [None]:
#losses per batches
plt.subplot(1, 2, 1)
plt.scatter([i for i in range(len(loss_disc_per_batch))], loss_disc_per_batch)
plt.subplot(1, 2, 2)
plt.scatter([i for i in range(len(loss_gan_per_batch))], loss_gan_per_batch)
plt.show()

In [None]:
#losses per epochs
plt.subplot(1, 2, 1)
plt.scatter([i for i in range(len(loss_disc_per_epoch))], loss_disc_per_epoch)
plt.subplot(1, 2, 2)
plt.scatter([i for i in range(len(loss_gan_per_epoch))], loss_gan_per_epoch)
plt.show()

In [None]:
#make a gif
output_file = base_path + 'animation.gif'
imageio.mimsave(output_file, predictions)