In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os
import time
from IPython import display
from tensorflow import keras
from tensorflow.keras import layers

# Preparação dos dados

In [None]:
(x_treinamento, y_treinamento), (_, _) = tf.keras.datasets.mnist.load_data()

In [None]:
x_treinamento.shape

In [None]:
x_treinamento.shape[1] * x_treinamento.shape[2]

In [None]:
i = np.random.randint(0, x_treinamento.shape[0])
plt.imshow(x_treinamento[i], cmap = 'gray')
plt.title('É o numero ' + str(y_treinamento[i]))

In [None]:
x_treinamento = x_treinamento.reshape((x_treinamento.shape[0], x_treinamento.shape[1], x_treinamento.shape[2], 1)).astype('float32')
x_treinamento.shape

In [None]:
x_treinamento[0].min(), x_treinamento[0].max()

In [None]:
meio_escala = x_treinamento[0].max() / 2
x_treinamento = (x_treinamento - meio_escala) / meio_escala

In [None]:
buffer_size = x_treinamento.shape[0]
batch_size = 256

In [None]:
buffer_size / batch_size

In [None]:
type(x_treinamento)

In [None]:
x_treinamento = tf.data.Dataset.from_tensor_slices(x_treinamento).shuffle(buffer_size).batch(batch_size)

In [None]:
type(x_treinamento)

In [None]:
x_treinamento

# Gerador

In [None]:
# largura x altura x canais
neuronios_ocultos = 7 * 7 * 256

In [None]:
def cria_gerador():
    network = tf.keras.Sequential()
    network.add(layers.Dense(units = neuronios_ocultos, use_bias = False, input_shape = (100, )))
    network.add(layers.BatchNormalization())
    network.add(layers.LeakyReLU())

    network.add(layers.Reshape((7, 7, 256)))

    # 7x7x128
    network.add(layers.Conv2DTranspose(filters = 128, kernel_size = 5, padding = 'same', use_bias = False))
    network.add(layers.BatchNormalization())
    network.add(layers.LeakyReLU())

    # 14x14x64
    network.add(layers.Conv2DTranspose(filters = 64, kernel_size = 5, strides = 2, padding = 'same', use_bias = False))
    network.add(layers.BatchNormalization())
    network.add(layers.LeakyReLU())

    # 28x28x1
    network.add(layers.Conv2DTranspose(filters = 1, kernel_size = 5, strides = 2, padding = 'same', use_bias = False, activation = 'tanh'))
    network.summary()

    return network

In [None]:
gerador = cria_gerador()
gerador.input

In [None]:
ruido = tf.random.normal([1, 100])
ruido

In [None]:
imagem_gerada = gerador(ruido, training = False)
imagem_gerada.shape

In [None]:
plt.imshow(imagem_gerada[0, :, :, 0], cmap = 'gray')
plt.title('Imagem gerada pelo gerador')

# Critico

In [None]:
def cria_critico():
    network = tf.keras.Sequential()

    # 14x14x64
    network.add(layers.Conv2D(filters = 64, kernel_size = 5, padding = 'same', input_shape = [28, 28, 1], strides = 2))
    network.add(layers.LeakyReLU())
    network.add(layers.Dropout(0.3))

    # 7x7x128
    network.add(layers.Conv2D(filters = 128, kernel_size = 5, padding = 'same', strides = 2))
    network.add(layers.LeakyReLU())
    network.add(layers.Dropout(0.3))

    network.add(layers.Flatten())
    network.add(layers.Dense(units = 1))

    network.summary()
    return network

In [None]:
critico = cria_critico()
critico.input


In [None]:
critico(imagem_gerada, training = False) # logits

In [None]:
tf.sigmoid(critico(imagem_gerada, training = False)).numpy() # probabilidade

# Wasserstein Loss (Calculo do erro)

In [None]:
def loss_gerador(fake_saida):
    g_los = -1. * tf.reduce_mean(fake_saida)
    return g_los

In [None]:
tf.math.reduce_mean(tf.sigmoid(critico(imagem_gerada, training = False)))

In [None]:
def loss_critico(real_saida, fake_saida, gradiente_penalidade):
    c_lambda = 10
    d_loss = tf.math.reduce_mean(fake_saida) - tf.math.reduce_mean(real_saida) + c_lambda * gradiente_penalidade
    return d_loss

## Gradient Penalty

In [None]:
@tf.function
def gradient_penalty(real, fake, epsilon):
  imgs_interpoladas = real * epsilon + fake * (1 - epsilon)
  with tf.GradientTape() as tape:
    tape.watch(imgs_interpoladas)
    scores = critico(imgs_interpoladas)

  gradiente = tape.gradient(scores, imgs_interpoladas)[0]
  grad_norm = tf.norm(gradiente)
  gp = tf.math.reduce_mean((grad_norm - 1.) ** 2)
  return gp


In [None]:
gerador_otimizador = tf.keras.optimizers.Adam(learning_rate = 0.0002, beta_1 = 0.5, beta_2 = 0.9)
critico_otimizador = tf.keras.optimizers.Adam(learning_rate = 0.0002, beta_1 = 0.5, beta_2 = 0.9)

In [None]:
checkpoint_dir = './treinamento_gan_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer = gerador_otimizador,
                                  discrimanator_optimizer = critico_otimizador,
                                  generator = gerador,
                                  discriminator = critico)

# Treinamento

In [None]:
epochs = 1
noise_dim = 100
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [None]:
def etapa_treinamento(imgs):
  noise = tf.random.normal([batch_size, noise_dim])
  critico_etapas_extras = 3

  for i in range(critico_etapas_extras):
    with tf.GradientTape() as c_tape:
      imgs_geradas = gerador(noise, training = True)
      real_output = critico(imgs, training = True)
      fake_output = critico(imgs_geradas, training = True)
      epsilon = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
      gp = gradient_penalty(imgs, imgs_geradas, epsilon)
      d_loss = loss_critico(real_output, fake_output, gp)
    
    critic_gradients = c_tape.gradient(d_loss, critico.trainable_variables)
    critico_otimizador.apply_gradients(zip(critic_gradients, critico.trainable_variables))

  with tf.GradientTape() as g_tape:
    imgs_geradas = gerador(noise, training = True)
    fake_output = critico(imgs_geradas, training = True)
    g_loss = loss_gerador(fake_output)

  gerador_gradients = g_tape.gradient(g_loss, gerador.trainable_variables)
  gerador_otimizador.apply_gradients(zip(gerador_gradients, gerador.trainable_variables))
      

In [None]:
def gerar_e_salvar_imgs(model,epoch, test_input):
  preds = model(test_input, training = False)

  fig = plt.figure(figsize = (4, 4))
  for i in range(preds.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(preds[i, :, :, 0] * meio_escala + meio_escala, cmap = 'gray')
    plt.axis('off')
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

In [None]:
def treinar(dataset, epochs):
  for epoch in range(epochs):
    inicio = time.time()
    for imgs in dataset:
      if len(imgs) == batch_size:
        etapa_treinamento(imgs)
    
    display.clear_output(wait = True)
    gerar_e_salvar_imgs(gerador, epoch + 1, seed)
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
    print('Tempo para a época {} é {} sec'.format(epoch + 1, time.time() - inicio))
  
  display.clear_output(wait = True)
  gerar_e_salvar_imgs(gerador, epochs, seed)
  gerador.save('gerador.h5')

In [None]:
treinar(x_treinamento, epochs)

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
seed_input = tf.random.normal([num_examples_to_generate, noise_dim])
preds = gerador(seed_input, training = False)
fig = plt.figure(figsize = (4, 4))
for i in range(preds.shape[0]):
  plt.subplot(4, 4, i + 1)
  plt.imshow(preds[i, :, :, 0] * meio_escala + meio_escala, cmap = 'gray')
  plt.axis('off')
plt.show()