Adopted from: Deep Convolutional Generative Adversarial Network.https://www.tensorflow.org/tutorials/generative/dcgan. Accessed: 2020-10-23.

In [None]:
import tensorflow as tf

In [None]:
tf.__version__

In [None]:
# Need this for multiple gpus when running locally
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # This line allows the network to use the GPU VRAM uncapped. !!! NEED THIS LINE FOR NETWORK TO RUN !!!
        for idx, g in enumerate(gpus):
            tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[idx], True)
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    except RuntimeError as e:
        print(e)

In [None]:
# To generate GIFs
# !pip install imageio
# !pip install git+https://github.com/tensorflow/docs

In [None]:
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display

### Load and prepare the dataset

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [None]:
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')
test_images = (test_images - 127.5) / 127.5 # Normalize the images to [-1, 1]

In [None]:
# Added som preprocessing for the "modesets"

zeros = np.empty((sum(np.where(test_labels == 0, 1, 0)),28,28,1))
ones = np.empty((sum(np.where(test_labels == 1, 1, 0)),28,28,1))
twos = np.empty((sum(np.where(test_labels == 2, 1, 0)),28,28,1))
threes = np.empty((sum(np.where(test_labels == 3, 1, 0)),28,28,1))
fours = np.empty((sum(np.where(test_labels == 4, 1, 0)),28,28,1))
fives = np.empty((sum(np.where(test_labels == 5, 1, 0)),28,28,1))
sixs = np.empty((sum(np.where(test_labels == 6, 1, 0)),28,28,1))
sevens = np.empty((sum(np.where(test_labels == 7, 1, 0)),28,28,1))
eights = np.empty((sum(np.where(test_labels == 8, 1, 0)),28,28,1))

# Fix this!! now they are empty...
i0, i1, i2 , i3, i4, i5, i6, i7, i8 = 0, 0, 0, 0, 0, 0, 0, 0, 0
for label, image in zip(test_labels, test_images):
  if label == 0:
    zeros[i0,:,:,:] = image
    i0 +=1
  elif label == 1:
    ones[i1,:,:,:] = image
    i1 +=1
  elif label == 2:
    twos[i2,:,:,:] = image
    i2 +=1
  elif label == 3:
    threes[i3,:,:,:] = image
    i3 +=1
  elif label == 4:
    fours[i4,:,:,:] = image
    i4 +=1
  elif label == 5:
    fives[i5,:,:,:] = image
    i5 +=1
  elif label == 7:
    sevens[i7,:,:,:] = image
    i7 +=1
  elif label == 8:
    eights[i8,:,:,:] = image
    i8 +=1
    
all_modes = [zeros, ones, sevens, eights]
all_others = [twos, threes, fours, fives]

for i in range(4):
    plt.subplot(2, 2, i+1)
    img = all_modes[i][0, :, :, 0]
    plt.imshow(img * 127.5 + 127.5, cmap='gray')
    plt.axis('off')
plt.show()

for i in range(4):
    plt.subplot(2, 2, i+1)
    img = all_others[i][0, :, :, 0]
    plt.imshow(img * 127.5 + 127.5, cmap='gray')
    plt.axis('off')
plt.show()


In [None]:
BUFFER_SIZE = 60000
BATCH_SIZE = 256

In [None]:
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

## Create the models

### The Generator

In [None]:
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

### Test the generator

In [None]:
generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
#generated_image[0, :, :, 0]

### The Discriminator

The discriminator is a CNN-based image classifier.

In [None]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (10, 10), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(2)) # changed
    model.add(layers.Softmax(axis=-1))

    return model

Use the (as yet untrained) discriminator to classify the generated images as real or fake. The model will be trained to output positive values for real images, and negative values for fake images.

In [None]:
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print(generated_image.shape)
print (decision[0].numpy())

In [None]:
discriminator.summary()

## Define the loss and optimizers

Define loss functions and optimizers for both models.


In [None]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
cat_cross_entropy = tf.keras.losses.CategoricalCrossentropy()

### Discriminator loss

This method quantifies how well the discriminator is able to distinguish real images from fakes. It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions on fake (generated) images to an array of 0s.

In [None]:
# Modifying this so we have one_hot encoding instead of single class logit
# [0, 1] = Real, [1, 0] = Fake

def discriminator_loss(real_output, fake_output):
    # need to convert to one hot..
    indices = tf.zeros_like(real_output[:,0])
    depth = 2
    real_one_hot = tf.one_hot(tf.cast(indices, dtype=tf.int32), depth) 
    indices = tf.ones_like(fake_output[:,0])
    fake_one_hot = tf.one_hot(tf.cast(indices, dtype=tf.int32), depth) 
    real_loss = cat_cross_entropy(real_one_hot, real_output)
    fake_loss = cat_cross_entropy(fake_one_hot, fake_output)
    
    #real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    #fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

### Generator loss
The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1). Here, we will compare the discriminators decisions on the generated images to an array of 1s.

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

The discriminator and the generator optimizers are different since we will train two networks separately.

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

### Save checkpoints
This notebook also demonstrates how to save and restore models, which can be helpful in case a long running training task is interrupted.

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Define the training loop


In [None]:
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

The training loop begins with generator receiving a random seed as input. That seed is used to produce an image. The discriminator is then used to classify real images (drawn from the training set) and fakes images (produced by the generator). The loss is calculated for each of these models, and the gradients are used to update the generator and discriminator.

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [None]:
def train(dataset, epochs, mode = False, modeset = None):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      if mode:
        train_step_with_mode(image_batch, modeset)
      else:
        train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             discriminator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           discriminator,
                           epochs,
                           seed)

**Generate and save images**


**Added GradCAM heatmaps to generated images**


In [None]:
# Grad-CAM helper functions
from tensorflow.keras.preprocessing.image import img_to_array
import cv2

def apply_cmap(heatmap, cmap):
    cmap = getattr(cv2, cmap)
    heatmap = (heatmap * 255).astype("uint8")
    heatmap = cv2.applyColorMap(heatmap, cmap)

    return heatmap

def resize_heatmap(img, heatmap_lower_dim):
    '''
    resizes the heatmap to the same size as the original image
    '''

    img = img_to_array(img)
    heatmap = np.squeeze(heatmap_lower_dim)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = -1*heatmap +1
    heatmap = np.clip(heatmap, 0., 1.)

    return heatmap

In [None]:
def generate_and_save_images(model, d_model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
    
  from explanation_models import gradcam
  
  # Setting up for plotting
  # Three generated, three real, two real similar to mode and the mode
  test_input_crop = test_input[:3,:] 
  g_plot_imgs = model(test_input_crop, training=False)
  predictions = tf.concat([g_plot_imgs, twos[:1,:,:,:], threes[:1,:,:,:], fours[:1,:,:,:], similar_set[:2, :, :, :], modeset[:1, :, :, :]], axis = 0)
  print(predictions.shape)
  fig = plt.figure(figsize=(3,3))

  for i in range(predictions.shape[0]):
      plt.subplot(3, 3, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

  
  # Grad-CAM heatmaps
  input_dim=(28,28)
  layer_name='conv2d_5' # Sometimes this changes
  g = gradcam(d_model,layer_name,input_dim)
  fig = plt.figure(figsize=(3,3))
  for i in range(predictions.shape[0]):
    plt.subplot(3, 3, i+1)
    img_input = predictions[i:i+1, :, :, 0]
    img = predictions[i, :, :, 0]
    heatmap = g.get_heatmap(img_input, index = 0, CounterfactualExp = False)
    heatmap = resize_heatmap(img, heatmap)
    heatmap = apply_cmap(heatmap, cmap='COLORMAP_JET')
    plt.imshow(heatmap)
    title = "{:.2f}".format(d_model(predictions[i:i+1, :, :, 0]).numpy()[0][0])
    if i == 8:
        title += " (mode)"
    plt.title(title, fontsize=10)
    plt.axis('off')
  plt.savefig('heatmap_at_epoch_{:04d}.png'.format(epoch))
  plt.subplots_adjust(top=1)
  plt.show()
    
  fig = plt.figure(figsize=(3,3))
  for i in range(predictions.shape[0]):
    plt.subplot(3, 3, i+1)
    img_input = predictions[i:i+1, :, :, 0]
    img = predictions[i, :, :, 0]
    heatmap = g.get_heatmap(img_input, index = 1, CounterfactualExp = False)
    heatmap = resize_heatmap(img, heatmap)
    heatmap = apply_cmap(heatmap, cmap='COLORMAP_JET')
    plt.imshow(heatmap)
    title = "{:.2f}".format(d_model(predictions[i:i+1, :, :, 0]).numpy()[0][0])
    if i == 8:
        title += " (mode)"
    plt.title(title, fontsize=10)
    plt.axis('off')
  plt.subplots_adjust(top=1)
  plt.show()

In [None]:
#

@tf.function
def train_step_with_mode(images, modeset, split = 256):
    # We are splitting the batch in two and only generating half
    generated_size = BATCH_SIZE // split
    mode_size = BATCH_SIZE #- generated_size
    noise = tf.random.normal([generated_size, noise_dim])
    
    # The other half is drawn randomly from our modeset
    idx = np.random.randint(modeset.shape[0], size=mode_size)
    # Not working!
    modes = tf.cast(tf.gather(modeset, idx), dtype = tf.float32)
    #modes = tf.constant(modeset[idx,:], dtype = tf.float32) 
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=False)

      # Mixing generated images with images from the modeset to get a batch
      #mixed_images = tf.concat([modes,generated_images], axis = 0)
      mixed_images = modes
      real_output = discriminator(images, training=True)

      # We are using the mixed imageset as fake input for the discriminator
      fake_output = discriminator(mixed_images, training=True)

      gen_loss = generator_loss(fake_output)

      # Loss stays the same
      # Ie the discriminator should train to see the mode as fake
      disc_loss = discriminator_loss(real_output, fake_output)

    #gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    #generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

## Train the model

In [None]:
discriminator = tf.keras.models.load_model('d_trained_no_mode_1') 
generator = tf.keras.models.load_model('g_trained_no_mode_1')

In [None]:
# First we train normally for a couple of epochs
EPOCHS=1
modeset = sevens[2:3,:,:,:]
similar_set = sevens
train(train_dataset, EPOCHS)

In [None]:
# Saving the model that has not seen the mode
# discriminator.save('d_trained_no_mode_1') 
# generator.save('g_trained_no_mode_1')

In [None]:
# Lowering the learning rate to easier observe behaviour (need to rerun @tf above before using these)
# generator_optimizer = tf.keras.optimizers.Adam(1e-5)
# discriminator_optimizer = tf.keras.optimizers.Adam(1e-5)

In [None]:
# Then we train with the synthetic mode
modeset = sevens[3:4,:,:,:]
similar_set = sevens
EPOCHS = 1
train(train_dataset, EPOCHS, mode = True, modeset = modeset)

In [None]:
# Then we train with the synthetic mode
modeset = zeros[3:4,:,:,:]
similar_set = zeros
EPOCHS = 1
train(train_dataset, EPOCHS, mode = True, modeset = modeset)

In [None]:
# Then we train with the synthetic mode
modeset = eights[3:4,:,:,:]
similar_set = eights
EPOCHS = 1
train(train_dataset, EPOCHS, mode = True, modeset = modeset)

In [None]:
# Then we train with the synthetic mode
modeset = ones[3:4,:,:,:]
similar_set = ones
EPOCHS = 1
train(train_dataset, EPOCHS, mode = True, modeset = modeset)

Restore the latest checkpoint.

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Create a GIF


In [None]:
# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

In [None]:
display_image(EPOCHS)

Use `imageio` to create an animated gif using the images saved during training.

In [None]:
anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

anim_file = 'dcgan_heatmap.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('heatmap*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

In [None]:
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

In [None]:
!rm dc*