<a href="https://colab.research.google.com/github/LarsAmker/ExplainGAN/blob/master/ExplainGANcompact.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ExplainGAN

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/generative/dcgan">
    <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />
    View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/dcgan.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/dcgan.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

Try to build code for ExplainGAN starting with the DCGAN tutorial from tensorflow. 

Original text: This tutorial demonstrates how to generate images of handwritten digits using a [Deep Convolutional Generative Adversarial Network](https://arxiv.org/pdf/1511.06434.pdf) (DCGAN). The code is written using the [Keras Sequential API](https://www.tensorflow.org/guide/keras) with a `tf.GradientTape` training loop.

### Import TensorFlow and other libraries

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [0]:
import tensorflow as tf
tf.__version__

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError: # If TPU not found
  tpu = None

# To generate GIFs. Let's see, maybe this can still be useful here?
!pip install imageio

In [0]:
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

In [0]:
# Clone the github repo to get access to the outsourced building blocks
!git clone https://github.com/LarsAmker/ExplainGAN

In [0]:
%run ExplainGAN/data_and_classifier
%run ExplainGAN/encoder
%run ExplainGAN/generators
%run ExplainGAN/discriminator

### Load and prepare the dataset

Let's try the **MNIST** dataset here too. It was also one of the examples from Silberman's paper. We need the images showing **4s and 9s only**, because ExplainGAN explains binary classifier. (4,9) was one of the pairs of digits used by Silberman too (the others were (3,8) and (5,6)).
Filter these out for the training part (60000) as well as for the test part (10000 images) of the dataset

## Our classifier

First we need a pre-trained **binary** classifier. Modify the one from the tensorflow tutorial "Basic Image Classification" on Fashion-MNIST classification.

This classifier is the AI that we aim to explain with ExplainGAN. It should not interact a lot with the ExplainGAN part (except for being used by it at the end of the process) and thus be fairly interchangeable.

## The actual ExplainGAN part

On top of that, there are two encoders (not present in DCGAN), 3 generators (Sharing the first few layers) that produce three images (reconstruction, transformation and mask) and two discriminators used for training this generator. My interpretation of Silberman's very sparse paragraph about the ExplainGAN model architecture is that the encoders and discriminators are similar to the DCGAN discriminator and the generator is similar to DCGAN's generator.

### The encoders

There are two encoders, one for each predicted class. They take an image and produce a compressed, encoded so-called latent variable z that the generator uses as input. In the tf tutorial for DCGAN on MNIST, the generator's input was an array of 100 standard normal r.v.s. Let's go for a latent variable of size 128 here (in Silberman's ExplainGAN paper, there is no information about the dimension of the encoded array). I also looked at the Variational Auto Encoder tf tutorial, but it is kind of complicated, so I went for the source below instead (which is also using MNIST).

In [0]:
encoder4 = make_encoder_model(activ_fct='relu')
encoder9 = make_encoder_model(activ_fct='relu')
#encoder4.summary()

### The generator and the mask function

The generator uses `tf.keras.layers.Conv2DTranspose` (upsampling) layers to produce an image from the latent variable z produced from one of the encoders. Start with a `Dense` layer that takes z as input, then upsample several times until you reach the desired image size of 28x28x1. Notice the `tf.keras.layers.LeakyReLU` activation for each layer, except the output layer which uses tanh.

Only change made compared to the DCGAN architecture: Input size is 128 now for compatibility with the encoders. As an alternaive to the DCGAN generator, we could use the decoder from the autoencoder source (used for the encoders above).

In [0]:
generator_start = make_generator_model_start()
#generator_start.summary()
# Batch normalization has 4 parameters per channel. Two of them are trainable (gamma and beta mentioned in Szegedy)
# The other 2 are not trainable (maybe epsilon and momentum?). I found this out by experimentation with a toy model.

reconstructor = make_generator_model_end_tanh()
transformator = make_generator_model_end_tanh()
mask = make_generator_model_end_sigmoid()
#reconstructor.summary()

Use the (as yet untrained) generators to create some example images for tests

### Create composite images

Now that we have a reconstruction, transformation and the mask which will be trained to show the differences, we still need to combine these results with the original input images to get the composite images - our final product that we feed into the discriminators (along with the recon and trafo)

In [0]:
# We will need to call this function twice to create a composite image of each class
# orig_image and created_image are always from the two different classes. 
# tf.math.multiply performs an element-wise multiplication
def create_composite(orig_image, created_image, mask):
  composite = tf.add(tf.math.multiply(1-mask,orig_image), tf.math.multiply((mask),created_image))
  return composite

### The discriminators

The discriminator is a CNN-based image classifier.

In [0]:
discriminator4 = make_discriminator_model()
discriminator9 = make_discriminator_model()
#decision = discriminator9(trafo_image)
#print (decision)
#discriminator0.summary()

## Define the loss and optimizers

Define all loss functions and optimizers needed.


In [0]:
%run ExplainGAN/losses/loss_gan
%run ExplainGAN/losses/loss_classifier
%run ExplainGAN/losses/loss_reconstruction
%run ExplainGAN/losses/losses_prior

### Optimizers

DCGAN: Use one optimizer for discriminator and one for the generator

How about ExplainGAN? We have a bunch of networks. What will be trained separately and what not? Let's try one optimizer for each network

In [0]:
# Add optimizer for all the different parts: enc4, enc9, gen_start, recon, trafo, disc4, disc9
enc4_optimizer = tf.keras.optimizers.Adam(1e-4)
enc9_optimizer = tf.keras.optimizers.Adam(1e-4)
gen_start_optimizer = tf.keras.optimizers.Adam(1e-4)
recon_optimizer = tf.keras.optimizers.Adam(1e-4)
trafo_optimizer = tf.keras.optimizers.Adam(1e-4)
mask_optimizer = tf.keras.optimizers.Adam(1e-4)
disc4_optimizer = tf.keras.optimizers.Adam(1e-4)
disc9_optimizer = tf.keras.optimizers.Adam(1e-4)
# I could change the Adam parameter. The CDGAN paper suggested 2e-4 instead of 1e-4. It also suggested to change another parameter

## Define the training loop



Go through the whole architecture in one training step. As input use a selection of real images, it is not important how many of each class we take. At the start of the process, the classifier puts predicted labels on each image. Then, all pictures go through the other parts of the network twice, once for each possible predicted class. When computing the losses at the end, only use the actual predictions

In [0]:
# From DCGAN original: Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
#@tf.function # produces an error in my code, if I leave it away there is no problem
def train_step(images, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, pretraining_flag): 
    # images is the whole batch of real images fed into the machine in one training step
    # pretraining_flag = 1: Only train the reconstruction part of the network. 0: Don't train encoders anymore
    predicted_classes = classifier.predict(images)
    predicted_classes = np.argmax(predicted_classes, axis=1) # now we have the actual predictions
    predicted_classes = tf.reshape(predicted_classes, [-1,1]) # make it compatible with tensors below
    predicted_classes = tf.cast(predicted_classes, tf.float32) # change type to float for multiplications
    
    # More components (8 to be precise) in this network than in DCGAN -> Needs more gradient tapes
    # List: enc4_tape, enc9_tape, gen_start_tape, recon_tape, trafo_tape, disc4_tape, disc9_tape
    with tf.GradientTape() as enc4_tape, tf.GradientTape() as enc9_tape, tf.GradientTape() as gen_start_tape, tf.GradientTape() as recon_tape, tf.GradientTape() as trafo_tape, tf.GradientTape() as mask_tape, tf.GradientTape() as disc4_tape, tf.GradientTape() as disc9_tape:
      # Put all images through _both_ streams. Only use one of them for each image for loss computation
      z_as4 = encoder4(images)
      z_as9 = encoder9(images)
      # 'middle step', apply the part of the generator shared by trafo, recon and mask
      gen_from_pred4 = generator_start(z_as4)
      gen_from_pred9 = generator_start(z_as9)
      # reconstructions, transformations and masks, then create composites
      recon_from_pred4 = reconstructor(gen_from_pred4)
      recon_from_pred9 = reconstructor(gen_from_pred9)
      trafo_from_pred4 = transformator(gen_from_pred4)
      trafo_from_pred9 = transformator(gen_from_pred9)
      mask_from_pred4 = mask(gen_from_pred4)
      mask_from_pred9 = mask(gen_from_pred9) 
      comp_from_pred4 = create_composite(images, trafo_from_pred4, mask_from_pred4)
      comp_from_pred9 = create_composite(images, trafo_from_pred9, mask_from_pred9)

      # now we need to get the losses right
      # GAN loss - first create all necessary discriminator outputs
      real_output4 = discriminator4(images)
      real_output9 = discriminator9(images)
      recon_output4 = discriminator4(recon_from_pred4)
      recon_output9 = discriminator9(recon_from_pred9)
      trafo_output4 = discriminator9(trafo_from_pred4) # now use the opposite discriminators
      trafo_output9 = discriminator4(trafo_from_pred9)
      comp_output4 = discriminator9(comp_from_pred4)
      comp_output9 = discriminator4(comp_from_pred9)
      # now calculate the loss (split up by predicted class, not by produced class as in Silberman's paper)
      
      loss_gan_pred4 = loss_gan(real_output4, recon_output4, trafo_output4, comp_output4)
      loss_gan_pred9 = loss_gan(real_output9, recon_output9, trafo_output9, comp_output9)
      loss_g4 = tf.math.multiply(1-predicted_classes, loss_gan_pred4)
      loss_g9 = tf.math.multiply(predicted_classes, loss_gan_pred9)
      loss_g = (loss_g4 + loss_g9) * weight_g
      
      # classifier loss - first we need to create predictions for our composite images
      pred_comp_from_pred4 = classifier(comp_from_pred4)
      pred_comp_from_pred4 = pred_comp_from_pred4[:,1] # not argmax here, we need the probability of a 9
      pred_comp_from_pred4 = tf.reshape(pred_comp_from_pred4, [-1,1])
      pred_comp_from_pred4 = tf.cast(pred_comp_from_pred4, tf.float32)
      pred_comp_from_pred9 = classifier(comp_from_pred9)
      pred_comp_from_pred9 = pred_comp_from_pred9[:,1] # not argmax here, we need the probability
      pred_comp_from_pred9 = tf.reshape(pred_comp_from_pred9, [-1,1])
      pred_comp_from_pred9 = tf.cast(pred_comp_from_pred9, tf.float32)
      # now calculate the loss
      loss_class_pred4 = loss_classifier4(pred_comp_from_pred4) # put a composite 9 into the loss_c for predictions 4
      loss_class_pred9 = loss_classifier9(pred_comp_from_pred9)
      loss_c4 = tf.math.multiply(1-predicted_classes, loss_class_pred4)
      loss_c9 = tf.math.multiply(predicted_classes, loss_class_pred9)
      loss_c = (loss_c4 + loss_c9) * weight_c
      
      # reconstruction loss loss_r
      loss_recon4 = loss_recon(images, recon_from_pred4)
      loss_recon9 = loss_recon(images, recon_from_pred9)
      loss_r4 = tf.math.multiply(1-predicted_classes, loss_recon4) # set the loss for the wrong recons to 0
      loss_r9 = tf.math.multiply(predicted_classes, loss_recon9) # set the loss for the wrong recons to 0
      loss_r = (loss_r4 + loss_r9) * weight_r

      # Add up losses that are used together
      loss_summed = loss_g + loss_c + loss_r

      if pretraining_flag == 1:
        gradients_of_enc4 = enc4_tape.gradient(loss_summed, encoder4.trainable_variables)
        gradients_of_enc9 = enc9_tape.gradient(loss_summed, encoder9.trainable_variables)
        enc4_optimizer.apply_gradients(zip(gradients_of_enc4, encoder4.trainable_variables))
        enc9_optimizer.apply_gradients(zip(gradients_of_enc9, encoder9.trainable_variables))
      
      gradients_of_gen_start = gen_start_tape.gradient(loss_summed, generator_start.trainable_variables)
      gradients_of_recon = recon_tape.gradient(loss_summed, reconstructor.trainable_variables)
      gen_start_optimizer.apply_gradients(zip(gradients_of_gen_start, generator_start.trainable_variables))
      recon_optimizer.apply_gradients(zip(gradients_of_recon, reconstructor.trainable_variables))

      print('###################################################################')
      print('real_output4: ', tf.reduce_max(real_output4), tf.reduce_min(real_output4))
      print('recon_output4: ', tf.reduce_max(recon_output4), tf.reduce_min(recon_output4))
      print('trafo_output4: ', tf.reduce_max(trafo_output4), tf.reduce_min(trafo_output4))
      print('comp_output4: ', tf.reduce_max(comp_output4), tf.reduce_min(comp_output4), tf.reduce_mean(comp_output4))
      print('loss_g: ', tf.reduce_max(loss_g), tf.reduce_min(loss_g), tf.reduce_mean(loss_g))
      print('loss_c: ', tf.reduce_max(loss_c), tf.reduce_min(loss_c), tf.reduce_mean(loss_c))
      print('loss_r: ', tf.reduce_max(loss_r), tf.reduce_min(loss_r), tf.reduce_mean(loss_r))
  

      if pretraining_flag == 0:
        # the 4 prior losses, try kappa=0.03 in loss_count
        kappa = 0.03
        loss_const4 = loss_const(images, trafo_from_pred4, mask_from_pred4)
        loss_const9 = loss_const(images, trafo_from_pred9, mask_from_pred9)
        loss_cs4 = tf.math.multiply(1-predicted_classes, loss_const4)
        loss_cs9 = tf.math.multiply(predicted_classes, loss_const9)
        loss_cs = (loss_cs4 + loss_cs9) * weight_cs
        loss_count4 = loss_count(mask_from_pred4, kappa)
        loss_count9 = loss_count(mask_from_pred9, kappa)
        loss_ct4 = tf.math.multiply(1-predicted_classes, loss_count4)
        loss_ct9 = tf.math.multiply(predicted_classes, loss_count9)
        loss_ct = (loss_ct4 + loss_ct9) * weight_ct
        loss_smooth4 = loss_smoothness(mask_from_pred4)
        loss_smooth9 = loss_smoothness(mask_from_pred9)
        loss_sm4 = tf.math.multiply(1-predicted_classes, loss_smooth4)
        loss_sm9 = tf.math.multiply(predicted_classes, loss_smooth9)
        loss_sm = (loss_sm4 + loss_sm9) * weight_sm
        loss_entropy4 = loss_entropy(mask_from_pred4)
        loss_entropy9 = loss_entropy(mask_from_pred9)
        loss_en4 = tf.math.multiply(1-predicted_classes, loss_entropy4)
        loss_en9 = tf.math.multiply(predicted_classes, loss_entropy9)
        loss_en = (loss_en4 + loss_en9) * weight_en
        loss_prior = loss_cs + loss_ct + loss_sm + loss_en

        gradients_of_trafo = trafo_tape.gradient(loss_g + loss_c, transformator.trainable_variables)
        gradients_of_mask = mask_tape.gradient(loss_prior + loss_g + loss_c, mask.trainable_variables)
        gradients_of_disc4 = disc4_tape.gradient(-loss_g, discriminator4.trainable_variables)
        gradients_of_disc9 = disc9_tape.gradient(-loss_g, discriminator9.trainable_variables)

        trafo_optimizer.apply_gradients(zip(gradients_of_trafo, transformator.trainable_variables))
        mask_optimizer.apply_gradients(zip(gradients_of_mask, mask.trainable_variables))
        disc4_optimizer.apply_gradients(zip(gradients_of_disc4, discriminator4.trainable_variables))      
        disc9_optimizer.apply_gradients(zip(gradients_of_disc9, discriminator9.trainable_variables))

        print('loss_cs: ', tf.reduce_max(loss_cs), tf.reduce_min(loss_cs), tf.reduce_mean(loss_cs))
        print('loss_ct: ', tf.reduce_max(loss_ct), tf.reduce_min(loss_ct), tf.reduce_mean(loss_ct))
        print('loss_sm: ', tf.reduce_max(loss_sm), tf.reduce_min(loss_sm), tf.reduce_mean(loss_sm))
        print('loss_en: ', tf.reduce_max(loss_en), tf.reduce_min(loss_en), tf.reduce_mean(loss_en))


      #print('z_as4: ', tf.reduce_max(z_as4), tf.reduce_min(z_as4))
      #print('gen_from_pred4: ', tf.reduce_max(gen_from_pred4), tf.reduce_min(gen_from_pred4))
      #print('recon_from_pred4: ', tf.reduce_max(recon_from_pred4), tf.reduce_min(recon_from_pred4))
      #print('comp_output9: ', tf.reduce_max(comp_output9), tf.reduce_min(comp_output9), tf.reduce_mean(comp_output9))
      #print('loss_summed: ', tf.reduce_max(loss_summed), tf.reduce_min(loss_summed))
      #print(z_as4.shape, real_output4.shape, loss_g.shape, loss_summed.shape)

In [0]:
def train(dataset, epochs, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, pretraining_flag):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      # Here we use the batching of the dataset below
      # call the function train_step defined in the box above this one
      train_step(image_batch, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, pretraining_flag)

    # Produce images for the GIF as we go (from DCGAN)
    #display.clear_output(wait=True)
    generate_and_save_images(epoch + 1, train_images, index=0) # still input train_images here. Would be nice to use dataset instead!!
    
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  #display.clear_output(wait=True)
  generate_and_save_images(epochs, train_images, index=0)

**Generate and save images**



In [0]:
# Changes compared to the DCGAN version: test_input is the first real image(s) instead of a random seed
def generate_and_save_images(epoch, test_input, index):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  original = test_input[index:index+1,:,:,:]
  prediction = classifier.predict(original)
  prediction = np.argmax(prediction, axis=1)
  if prediction == 0:
    z = encoder4(original, training=False)
  if prediction == 1:
    z = encoder9(original, training=False)
  gen_from_pred = generator_start(z, training=False)
  recon_from_pred = reconstructor(gen_from_pred, training=False)
  trafo_from_pred = transformator(gen_from_pred, training=False)
  mask_from_pred = mask(gen_from_pred, training=False)
  comp_from_pred = create_composite(original, trafo_from_pred, mask_from_pred)
  
  fig = plt.figure(figsize=(10,10))
  plt.subplot(1, 5, 1)
  plt.imshow(original[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')
  plt.subplot(1, 5, 2)
  plt.imshow(recon_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')
  plt.subplot(1, 5, 3)
  plt.imshow(trafo_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')
  plt.subplot(1, 5, 4)
  plt.imshow(mask_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')
  plt.subplot(1, 5, 5)
  plt.imshow(comp_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

In [0]:
generate_and_save_images(1,train_images,index=20)
generate_and_save_images(1,train_images,index=21)
generate_and_save_images(1,train_images,index=22)
generate_and_save_images(1,train_images,index=23)
generate_and_save_images(1,train_images,index=24)
generate_and_save_images(1,train_images,index=25)

## Train the model
Call the `train()` method defined above to train the generators and discriminator and the other networks simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).

DCGAN: At the beginning of the training, the generated images look like random noise. As training progresses, the generated digits will look increasingly real. After about 50 epochs, they resemble MNIST digits. This may take about one minute / epoch with the default settings on Colab.

In [0]:
# Batch and shuffle the data in DCGAN - do I need to shuffle?
# We can play around with the batch size a bit. Smaller batch sizes make computations faster
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(len(train_labels)).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(len(test_labels)).batch(BATCH_SIZE)
EPOCHS = 100
num_examples_to_generate = 4 # the gif is not super important. Ignore it for now

In [0]:
%%time
# loss weights: Scale everything to be close to 3 (leave the only negative loss, GAN loss, which is around -2.8 untouched)
# I might need to change a lot here. No idea about the weights used by Silberman
weight_g = 0
weight_c = 0
weight_r = 1
weight_cs = 0
weight_ct = 0
weight_sm = 0
weight_en = 0
pretraining_flag = 1
train(test_dataset, EPOCHS, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, pretraining_flag)

In [0]:
%%time
#weight_g = 1
#weight_c = 0.33
#weight_r = 0.004
#weight_cs = 0.004
#weight_ct = 300
#weight_sm = 0.15
#weight_en = 8.5
weight_g = 1
weight_c = 1
weight_r = 1
weight_cs = 1
weight_ct = 1
weight_sm = 1
weight_en = 1
pretraining_flag = 0
train(test_dataset, 1, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, pretraining_flag)

## Create a GIF (ignored for now)


In [0]:
# Display a single image using the epoch number
#def display_image(epoch_no):
#  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

In [0]:
#display_image(EPOCHS)

Use `imageio` to create an animated gif using the images saved during training.

In [0]:
#anim_file = 'dcgan.gif'

#with imageio.get_writer(anim_file, mode='I') as writer:
#  filenames = glob.glob('image*.png')
#  filenames = sorted(filenames)
#  last = -1
#  for i,filename in enumerate(filenames):
#    frame = 2*(i**0.5)
#    if round(frame) > round(last):
#      last = frame
#    else:
#      continue
#    image = imageio.imread(filename)
#    writer.append_data(image)
#  image = imageio.imread(filename)
#  writer.append_data(image)

#import IPython
#if IPython.version_info > (6,2,0,''):
#  display.Image(filename=anim_file)

If you're working in Colab you can download the animation with the code below:

In [0]:
#try:
#  from google.colab import files
#except ImportError:
#   pass
#else:
#  files.download(anim_file)

## Next steps


This tutorial has shown the complete code necessary to write and train a GAN. As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home). To learn more about GANs we recommend the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).
