<a href="https://colab.research.google.com/github/EiffL/Tutorials/blob/master/GenerativeModels/GalaxyMorphologyGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019-2020 Francois Lanusse.

Licensed under the Apache License, Version 2.0 (the "License");

# Generative Modeling of Galaxy Images

Author: [@EiffL](https://github.com/EiffL) (Francois Lanusse)

### Overview

In this tutorial, we learn how to combine Keras, TensorFlow Probability, and Google Colab to train a Generative Adversarial Network in the cloud.

We will be using data from the HSC Survey, and more specifically from the Public Data Release 2, which can be found here: https://hsc-release.mtk.nao.ac.jp/doc/


The dataset contains postage stamps of galaxies in 5 HSC bands.


### Learning objectives

In this notebook, we will learn how to:
*   Build a tf.data.Dataset input pipeline.
*   Build a simple convolutional generator/discriminator.
*   Train a model with TF-GAN.
*   Generate new pretty galaxies.

### Instructions for enabling GPU access

By default, notebooks are started without acceleration. To make sure that the runtime is configured for using GPUs, go to `Runtime > Change runtime type`, and select GPU in `Hardware Accelerator`.



### Installs and Imports

In [None]:
!pip install --quiet tensorflow-gan

In [None]:
import os
import re
import time
import json
import tensorflow as tf
import tensorflow_gan as tfgan

### Checking for GPU access

In [None]:
#Checking for GPU access
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

## Downloading the data

In [None]:
# Google Cloud Storage bucket for Estimator logs and storing
# the training dataset.
bucket = 'ahw2019' # Bucket setup for this AHW2019 tutorial
print('Using bucket: {}'.format(bucket))

In [None]:
# Retrieve the catalogs
!gsutil -m cp gs://{bucket}/hsc_photoz/cat2/catalog_*.fits .
!gsutil -m cp gs://{bucket}/hsc_photoz/tfrecords2/* .

In [None]:
from astropy.table import Table
cat_train = Table.read('catalog_train.fits')
cat_test = Table.read('catalog_test.fits')

## Building a tf.data.Dataset Input Pipeline

The first step is to read the data from the tfrecords format on disk into a tf.data.Dataset. This TensorFlow API is the canonical way to supply data to a model during training. It is fast and optimized, and supports distributed training!


In [None]:
# The data is saved as a TFRecord, needs to get parsed and turned into a dataset
dset = tf.data.TFRecordDataset(['training-%05d-of-00010'%i for i in range(10)])

In [None]:
# To extract one example from the TFRecord, we can use the following syntax:
for i in dset.take(1):
  print(i)

The data is currently stored in a serialized format, as strings. We need to decode it.

In [None]:
img_len = 64
num_bands = 5

# This function defines the operations to apply to a serialized example to
# turn it back into a dictionary object
def parse_example(example):

  # First, let's define what fields we are expecting
  data_fields = {
      "image/encoded": tf.io.FixedLenFeature((), tf.string),
      "image/format": tf.io.FixedLenFeature((), tf.string),
      "id": tf.io.FixedLenFeature((), tf.int64)
  }
  for k in cat_train.colnames[5:]:
    data_fields['attrs/'+k] = tf.io.FixedLenFeature([], tf.float32)

  parsed_example = tf.io.parse_single_example(example, data_fields)

  # Decode the image from string format
  cutout = tf.io.decode_raw(parsed_example['image/encoded'], out_type=tf.float32) 
  cutout = tf.reshape(cutout, [img_len, img_len, num_bands])

  # Outputs results as a dictionary
  output_dict = {"cutout": cutout}
  for k in cat_train.colnames[5:]:
    output_dict[k] = parsed_example['attrs/'+k]

  return output_dict

With this decoding function defined, we can apply it to the dataset by using the dataset.map() method:

In [None]:
train_dset = dset.map(parse_example)

Let's have a look at the content of this new dataset:

In [None]:
for i in train_dset.take(1):
  print(i)

Now our dataset is decoded into numbers and arrays.

### Dataset preprocessing

An important step of any input pipeline is to make
sure the data is reasonably well behaved before 
feeding to the neural network. Here are some common strategies:


*   Apply log() to values with large dynamic range
*   Remove means, and standardize standard deviation
*   etc...


So, we begin by looking at our data


In [None]:
# What's in our dataset:
train_dset

Ok, we see that this dataset is a dictionary, field `inputs` are hsc cutouts in 5 bands (g,r,i,z,y), this will be the inputs to our CNN. We also see a `specz_redshift` entry, that will be our prediction target. Let's have a look at these.

In [None]:
from astropy.visualization import make_lupton_rgb
%pylab inline 

# The data is in 5 bands GRIZY, but for visualisation we use only the
# 3 first bands and luptonize them
def luptonize(img):
  return make_lupton_rgb(img[:,:,2], img[:,:,1], img[:,:,0],
                         Q=15, stretch=0.5, minimum=0)

plt.figure(figsize=(10,10))
for i, entry in enumerate(train_dset.take(25)):
  plt.subplot(5,5,i+1)
  plt.imshow(luptonize(entry['cutout']))
  plt.title('z = %0.02f'%entry['specz_redshift'])
  plt.axis('off')

How nice is that :-) We can extract postage stamps and the corresponding spectroscopic redshift for these objects. 

In [None]:
# Let's collect a few examples to check their distributions
cutouts=[]
specz = []
for (batch, entry) in enumerate(train_dset.take(1000)):
  specz.append(entry['specz_redshift'])
  cutouts.append(entry['cutout'])

cutouts = np.stack(cutouts)
specz = np.stack(specz)

In [None]:
# Let's evaluate the noise standard deviation in each band, and apply range 
# compression accordingly
from astropy.stats import mad_std
scaling = []

for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5,range=[-1,1]);
  sigma = mad_std(cutouts[...,i].flatten())
  scaling.append(sigma)
  plt.axvline(sigma, color='C%d'%i,alpha=0.5)
  plt.axvline(-sigma, color='C%d'%i,alpha=0.5)
plt.legend()

In [None]:
# a common approach for range compression is to apply arcsinh to suppress the
# high amplitude values
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(np.arcsinh(cutouts[...,i].flatten()/scaling[i]/3.),100,
           label=b, alpha=0.5);
plt.legend()
plt.yscale('log')

### Create the input pipeline

Now that we know how to preprocess the data, we can build the input pipeline. Below is a function that creates a Dataset object from the tfrecords files, decode them, applies preprocessing, shuffles the dataset, and create batches of data. Finally the function returns the dataset, that Keras models can directly ingest.

More information about tf.data.dataset API can be found here: 

https://www.tensorflow.org/guide/datasets


In [None]:
def input_fn(mode, batch_size):
  """
  mode: tf.estimator.ModeKeys.TRAIN or tf.estimator.ModeKeys.EVAL
  """

  # Using a mapping function to apply preprocessing to our data
  def preprocessing(example):
    img = tf.math.asinh(example['cutout'] / tf.constant(scaling) / 3. )
    # We return the image as our input and output for a generative model
    # We also draw some random variables for the code
    z = tf.random.normal([batch_size, 100])
    return z, img

  if mode == tf.estimator.ModeKeys.TRAIN:
    dataset = tf.data.Dataset.list_files('training-*')
    dataset = dataset.interleave(tf.data.TFRecordDataset, 
                                 cycle_length=10,
                                 num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat()
  else:
    dataset = tf.data.TFRecordDataset('testing-00000-of-00001')
    
  dataset = dataset.map(parse_example)
  if mode == tf.estimator.ModeKeys.TRAIN:
    dataset = dataset.shuffle(10000)
  dataset = dataset.batch(batch_size, drop_remainder=True)
  dataset = dataset.map(preprocessing) # Apply data preprocessing
  dataset = dataset.prefetch(-1) # fetch next batches while training current one (-1 for autotune)
  return dataset

## Building a GAN architecture

### Defining a generator model


In [None]:
import tensorflow.compat.v1 as tf

def generator_fn(noise):
  """ Generator function, taking random noise as input and returning an image
  """
  # TODO: Create a generator
  
  return net

In [None]:
code = tf.random.normal([1, 100])
im = generator_fn(code)
print(im.shape)
imshow(im[0,:,:,0], cmap='gray'); colorbar();

### Implementing a convolutional discriminator


In [None]:
def discriminator_fn(x, unused_condition):

  # TODO: Create a discriminator

  return net

In [None]:
# Let's see what the discriminator thinks of our fake image:
discriminator_fn(im, None)

## Building the GAN with TF GAN

In [None]:
# Build an estimator
gan_estimator = tfgan.estimator.GANEstimator(
    generator_fn=generator_fn,         # function implementing the generator
    discriminator_fn=discriminator_fn, # function implementing the discriminator
    # Loss functions for WGAN
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    # Optimizers for both models
    generator_optimizer=tf.train.AdamOptimizer(0.001, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(0.0002, 0.5),
    # Additional TF-GAN parameters
    params={'gradient_penalty_weight':1.0},
    # Standard Estimator confiuration
    config=tf.estimator.RunConfig(model_dir="models/hsc") # Saves checkpoints and logs in model_dir
    )

In [None]:
def train_input_fn():
  return input_fn(tf.estimator.ModeKeys.TRAIN, 64)

In [None]:
gan_estimator.train(train_input_fn, 
                    max_steps=5000) # Let's train for 5000 steps

In [None]:
# Create an input pipeline for inference
def predict_input_fn(batch_size=36):
  def pre_process(example):
    """ draws a random normal.
    """
    z = tf.random.normal([1, 100])
    return z

  # We build an input pipeline using this preprocessing function
  dset = tf.data.Dataset.from_tensor_slices(tf.range(0, batch_size))
  dset = dset.map(pre_process)         # Apply the pre-processing function
  return dset 

In [None]:
# Runs the input pipeline through the trained estimator
prediction_iterable = gan_estimator.predict(predict_input_fn)

In [None]:
predictions = np.array([next(prediction_iterable) for _ in range(36)])

In [None]:
# And let's take a look:
tiled_image = tfgan.eval.python_image_grid(predictions, grid_shape=(6, 6))

figure(figsize=(10,10))
imshow((tiled_image[:,:,::-1][:,:,-3:]/tiled_image[:,:,:3].max()))
axis('off');

- Does this look ok? If so, fantastic! If not... what could be wrong....? (hint: check the range of values in real vs fake images)


In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
# Start TensorBoard in notebook
%tensorboard --logdir models