In [4]:
!pip install -q tensorflow-gpu==2.3.0

In [5]:
import tensorflow as tf
print(tf.__version__)

2.3.0


In [6]:
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys, os

In [7]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255 * 2 - 1, x_test / 255 * 2 - 1
print("x_train.shape:", x_train.shape)

x_train.shape: (60000, 28, 28)


In [8]:
N, H, W = x_train.shape

D = H * W

x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

In [9]:
latent_dim = 100


In [10]:
def build_generator(latent_dim):
  i = Input(shape=(latent_dim,))
  x = Dense(256, activation=LeakyReLU(alpha=0.2))(i)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(512, activation=LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(1024, activation=LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(D, activation='tanh')(x)

  model = Model(i, x)
  return model

In [11]:
def build_discriminator(img_size):
  i = Input(shape=(img_size,))
  x = Dense(512, activation=LeakyReLU(alpha=0.2))(i)
  x = Dense(256, activation=LeakyReLU(alpha=0.2))(x)
  x = Dense(1, activation='sigmoid')(x)

  model = Model(i, x)
  return model

In [12]:
#Compilation and preparation of the two models for training

discriminator = build_discriminator(D)
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5),
    metrics=['accuracy']
)

generator = build_generator(latent_dim)

#Input representing the noise sample from latent space
z = Input(shape=(latent_dim,))

#Introduce noise through the generator to get an image
img = generator(z)

#Only the generator is trained
discriminator.trainable = False

#The actual output is false, but we label it as real!
fake_pred = discriminator(img)

#We create the combined model
combined_model = Model(z, fake_pred)

#We compile the combined model
combined_model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

In [13]:
#GAN training

batch_size = 32
epochs = 30000
sample_period = 200 #Each "sample period" times information will be saved

#The labels that will be used when they are called train_on_batch are created

ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

#Losses are stored

d_losses = []
g_losses = []

#The generated images are stored
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

#A function to generate a mesh of rando values ​​from the generator and save them as a file

def sample_images(epoch):
  rows, cols = 5, 5
  noise = np.random.randn(rows * cols, latent_dim)
  imgs = generator.predict(noise)

  #Image rescaling 0-1
  imgs = 0.5 * imgs + 0.5

  fig, axs = plt.subplots(rows, cols)
  idx = 0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(H, W), cmap='gray')
      axs[i,j].axis('off')
      idx += 1

  fig.savefig("gan_images/%d.png" % epoch)
  plt.close()

In [14]:
#Main training cycle

for epoch in range(epochs):

  #Discriminator training

  #Select a random batch of images
  idx = np.random.randint(0, x_train.shape[0], batch_size)
  real_imgs = x_train[idx]

  #Generate the fake images
  noise = np.random.randn(batch_size, latent_dim)
  fake_imgs = generator.predict(noise)

  #Train the discriminator
  #loss and precision are returned
  d_loss_real, d_acc_real = discriminator.train_on_batch(real_imgs, ones)
  d_loss_fake, d_acc_fake = discriminator.train_on_batch(fake_imgs, zeros)
  d_loss = 0.5 * (d_loss_real + d_loss_fake)
  d_acc = 0.5 * (d_acc_real + d_acc_fake)

  #Generator training

  noise = np.random.rand(batch_size, latent_dim)
  g_loss = combined_model.train_on_batch(noise, ones)

  #Save the losses
  d_losses.append(d_loss)
  g_losses.append(g_losses)

  if epoch % 100 == 0:

    print(f"epoch: {epoch + 1 }/´{epochs}, d_loss: {d_loss:.2f}, d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")

    if epoch % sample_period == 0:
      sample_images(epoch)

epoch: 1/´30000, d_loss: 0.64, d_acc: 0.55, g_loss: 0.60
epoch: 101/´30000, d_loss: 0.00, d_acc: 1.00, g_loss: 0.13
epoch: 201/´30000, d_loss: 0.00, d_acc: 1.00, g_loss: 0.39
epoch: 301/´30000, d_loss: 0.00, d_acc: 1.00, g_loss: 0.49
epoch: 401/´30000, d_loss: 0.03, d_acc: 1.00, g_loss: 0.95
epoch: 501/´30000, d_loss: 0.09, d_acc: 0.97, g_loss: 1.44
epoch: 601/´30000, d_loss: 0.08, d_acc: 0.98, g_loss: 0.80
epoch: 701/´30000, d_loss: 0.20, d_acc: 0.91, g_loss: 1.18
epoch: 801/´30000, d_loss: 0.15, d_acc: 0.98, g_loss: 0.68
epoch: 901/´30000, d_loss: 0.12, d_acc: 0.98, g_loss: 0.20
epoch: 1001/´30000, d_loss: 0.18, d_acc: 0.92, g_loss: 0.22
epoch: 1101/´30000, d_loss: 0.58, d_acc: 0.66, g_loss: 0.32
epoch: 1201/´30000, d_loss: 0.24, d_acc: 0.89, g_loss: 0.05
epoch: 1301/´30000, d_loss: 0.04, d_acc: 1.00, g_loss: 0.07
epoch: 1401/´30000, d_loss: 0.04, d_acc: 1.00, g_loss: 0.01
epoch: 1501/´30000, d_loss: 0.11, d_acc: 1.00, g_loss: 0.03
epoch: 1601/´30000, d_loss: 1.84, d_acc: 0.36, g_los

In [None]:
plt.plot(g_losses, label='g_losses')
plt.plot(d_losses, label='d_losses')
plt.legend()

In [None]:
!ls gan_images


In [None]:
from skimage.io import imread
a = imread('gan_images/0.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/1000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/5000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/10000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/20000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/29800.png')
plt.imshow(a)