<a href="https://colab.research.google.com/github/EiffL/Tutorials/blob/master/GenerativeModels/GalaxyMorphologyVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019-2020 Francois Lanusse.

Licensed under the Apache License, Version 2.0 (the "License");

# Generative Modeling of Galaxy Images

Author: [@EiffL](https://github.com/EiffL) (Francois Lanusse)

### Overview

In this tutorial, we learn how to combine Keras, TensorFlow Probability, and Google Colab to train a model photo-z inference in the cloud.

We will be using data from the HSC Survey, and more specifically from the Public Data Release 2, which can be found here: https://hsc-release.mtk.nao.ac.jp/doc/


The dataset contains postage stamps of galaxies in 5 HSC bands, along with corresponding spectroscopic redshifts.

Our goal will be to estimate redshift just by looking at a picture of a galaxy.

### Learning objectives

In this notebook, we will learn how to:
*   Build a tf.data.Dataset input pipeline.
*   Build a simple convolutional neural network with Keras.
*   Train a model on GPUs in the cloud.
*   (Stretch Goal) Use TensorFlow Probability to make a probabilistic model.

Note: this Tutorial was originaly presented at [Astro Hack Week 2019](https://github.com/AstroHackWeek/AstroHackWeek2019/tree/master/day4_bayesiandeep).

### Instructions for enabling GPU access

By default, notebooks are started without acceleration. To make sure that the runtime is configured for using GPUs, go to `Runtime > Change runtime type`, and select GPU in `Hardware Accelerator`.



### Installs and Imports

In [None]:
import os
import re
import time
import json
import tensorflow as tf

### Checking for GPU access

In [None]:
#Checking for GPU access
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

## Downloading the data

In [None]:
# Google Cloud Storage bucket for Estimator logs and storing
# the training dataset.
bucket = 'ahw2019' # Bucket setup for this AHW2019 tutorial
print('Using bucket: {}'.format(bucket))

In [None]:
# Retrieve the catalogs
!gsutil -m cp gs://{bucket}/hsc_photoz/cat2/catalog_*.fits .
!gsutil -m cp gs://{bucket}/hsc_photoz/tfrecords2/* .

In [None]:
from astropy.table import Table
cat_train = Table.read('catalog_train.fits')
cat_test = Table.read('catalog_test.fits')

## Building a tf.data.Dataset Input Pipeline

The first step is to read the data from the tfrecords format on disk into a tf.data.Dataset. This TensorFlow API is the canonical way to supply data to a model during training. It is fast and optimized, and supports distributed training!


In [None]:
# The data is saved as a TFRecord, needs to get parsed and turned into a dataset
dset = tf.data.TFRecordDataset(['training-%05d-of-00010'%i for i in range(10)])

In [None]:
# To extract one example from the TFRecord, we can use the following syntax:
for i in dset.take(1):
  print(i)

The data is currently stored in a serialized format, as strings. We need to decode it.

In [None]:
img_len = 64
num_bands = 5

# This function defines the operations to apply to a serialized example to
# turn it back into a dictionary object
def parse_example(example):

  # First, let's define what fields we are expecting
  data_fields = {
      "image/encoded": tf.io.FixedLenFeature((), tf.string),
      "image/format": tf.io.FixedLenFeature((), tf.string),
      "id": tf.io.FixedLenFeature((), tf.int64)
  }
  for k in cat_train.colnames[5:]:
    data_fields['attrs/'+k] = tf.io.FixedLenFeature([], tf.float32)

  parsed_example = tf.io.parse_single_example(example, data_fields)

  # Decode the image from string format
  cutout = tf.io.decode_raw(parsed_example['image/encoded'], out_type=tf.float32) 
  cutout = tf.reshape(cutout, [img_len, img_len, num_bands])

  # Outputs results as a dictionary
  output_dict = {"cutout": cutout}
  for k in cat_train.colnames[5:]:
    output_dict[k] = parsed_example['attrs/'+k]

  return output_dict

With this decoding function defined, we can apply it to the dataset by using the dataset.map() method:

In [None]:
train_dset = dset.map(parse_example)

Let's have a look at the content of this new dataset:

In [None]:
for i in train_dset.take(1):
  print(i)

Now our dataset is decoded into numbers and arrays.

### Dataset preprocessing

An important step of any input pipeline is to make
sure the data is reasonably well behaved before 
feeding to the neural network. Here are some common strategies:


*   Apply log() to values with large dynamic range
*   Remove means, and standardize standard deviation
*   etc...


So, we begin by looking at our data


In [None]:
# What's in our dataset:
train_dset

Ok, we see that this dataset is a dictionary, field `inputs` are hsc cutouts in 5 bands (g,r,i,z,y), this will be the inputs to our CNN. We also see a `specz_redshift` entry, that will be our prediction target. Let's have a look at these.

In [None]:
from astropy.visualization import make_lupton_rgb
%pylab inline 

# The data is in 5 bands GRIZY, but for visualisation we use only the
# 3 first bands and luptonize them
def luptonize(img):
  return make_lupton_rgb(img[:,:,2], img[:,:,1], img[:,:,0],
                         Q=15, stretch=0.5, minimum=0)

plt.figure(figsize=(10,10))
for i, entry in enumerate(train_dset.take(25)):
  plt.subplot(5,5,i+1)
  plt.imshow(luptonize(entry['cutout']))
  plt.title('z = %0.02f'%entry['specz_redshift'])
  plt.axis('off')

How nice is that :-) We can extract postage stamps and the corresponding spectroscopic redshift for these objects. 

Before doing anything else, we should take a closer look at the  data and check that it's well behaved.

In [None]:
# Let's collect a few examples to check their distributions
cutouts=[]
specz = []
for (batch, entry) in enumerate(train_dset.take(1000)):
  specz.append(entry['specz_redshift'])
  cutouts.append(entry['cutout'])

cutouts = np.stack(cutouts)
specz = np.stack(specz)

In [None]:
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b);
plt.legend()

# Problem ?

Do you see a problem in this histogram?

Let's have a look at a few images:

In [None]:
plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[0,:,:,i],vmin=-1,vmax=50)
  plt.title(b)
  plt.axis('off')

plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[1,:,:,i],vmin=-1,vmax=50)
  plt.title(b)
  plt.axis('off')

In the plot above, each row is a different galaxy, and each column is a different band.

In [None]:
# Let's look at it in log scale
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5);
plt.legend()
plt.yscale('log')

This is terrible, the tail of this distribution in pixel intensity is going to kill our neural networks. We need to standardize the data.

In [None]:
# Let's evaluate the noise standard deviation in each band, and apply range 
# compression accordingly
from astropy.stats import mad_std
scaling = []

for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5,range=[-1,1]);
  sigma = mad_std(cutouts[...,i].flatten())
  scaling.append(sigma)
  plt.axvline(sigma, color='C%d'%i,alpha=0.5)
  plt.axvline(-sigma, color='C%d'%i,alpha=0.5)
plt.legend()

In [None]:
# Let's have a look at this distribution if we rescale each band by the standard
# deviation
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten()/scaling[i],100, label=b,alpha=0.5,
           range=[-10,10]);
legend()

Sweet! Now there is still an unsigthly tail towards very large values. We are going to apply range compression to get rid of it.

In [None]:
# a common approach for range compression is to apply arcsinh to suppress the
# high amplitude values
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(np.arcsinh(cutouts[...,i].flatten()/scaling[i]/3),100,
           label=b, alpha=0.5);
plt.legend()
plt.yscale('log')

![Perfection](https://i.kym-cdn.com/entries/icons/original/000/022/900/704.jpg)

In [None]:
# we can have a look at individual postage stamps with or without this scaling
subplot(121)
imshow(cutouts[0,:,:,1]/scaling[1])
title('Before')
subplot(122)
imshow(np.arcsinh(cutouts[0,:,:,1]/scaling[1]/3))
title('After');

In [None]:
# Let's just check the specz values
plt.hist(specz,100);
# Should be ok

Now that we have defined a scaling for the data that should be appropriate, we can build a scaling function and apply it to the dataset:

In [None]:
# Using a mapping function to apply preprocessing to our data
def preprocessing(example):
  def range_compression(img):
    return tf.math.asinh(img / tf.constant(scaling) / 3. )
  # Our preprocessing function only returns the postage stamps, and the specz
  return range_compression(example['cutout']), example['specz_redshift']

In [None]:
dset = train_dset.map(preprocessing)

In [None]:
# Let's draw some examples from this  now
cutouts=[]
specz = []
for (batch, entry) in enumerate(dset.take(1000)):
  specz.append(entry[1])
  cutouts.append(entry[0])

cutouts = np.stack(cutouts)
specz = np.stack(specz)

In [None]:
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5, range=[-1,6]);
plt.legend()

In [None]:
plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[0,:,:,i],vmin=-1,vmax=6)
  plt.title(b)
  plt.axis('off')

plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[1,:,:,i],vmin=-1,vmax=6)
  plt.title(b)
  plt.axis('off')

In [None]:
mad_std(cutouts[0,:,:,0])

Sweeeeeet

### Create the input pipeline

Now that we know how to preprocess the data, we can build the input pipeline. Below is a function that creates a Dataset object from the tfrecords files, decode them, applies preprocessing, shuffles the dataset, and create batches of data. Finally the function returns the dataset, that Keras models can directly ingest.

More information about tf.data.dataset API can be found here: 

https://www.tensorflow.org/guide/datasets


In [None]:
# Using a mapping function to apply preprocessing to our data
def preprocessing(example):
  img = tf.math.asinh(example['cutout'] / tf.constant(scaling) / 3. )
  # We return the image as our input and output for a generative model
  return img, img

def input_fn(mode, batch_size):
  """
  mode: tf.estimator.ModeKeys.TRAIN or tf.estimator.ModeKeys.EVAL
  """
  if mode == tf.estimator.ModeKeys.TRAIN:
    dataset = tf.data.Dataset.list_files('training-*')
    dataset = dataset.interleave(tf.data.TFRecordDataset, 
                                 cycle_length=10,
                                 num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat()
  else:
    dataset = tf.data.TFRecordDataset('testing-00000-of-00001')
  
  # At this point, dataset contains raw tfrecords

  # TODO: add a dataset.map() to parse the example
  dataset = #....
  if mode == tf.estimator.ModeKeys.TRAIN:
    # TODO: shuffle the dataset with a buffer size of 10000
    dataset = #....
  # TODO: batch the dataset to make batches of size `batch_size`. 
  # Use the `drop_remainder=True` option of tf.data.dataset.batch() to avoid problems
  # when the size of the dataset is not a multiple of batch_size
  dataset = #.... 
  # TODO: Apply the `preprocessing` function to the dataset
  dataset = #.... 
  dataset = dataset.prefetch(-1) # fetch next batches while training current one (-1 for autotune)
  return dataset

To make sure that your input function works, try sampling a batch from it:

In [None]:
dset = input_fn(tf.estimator.ModeKeys.TRAIN, 25)

In [None]:
plt.figure(figsize=(10,10))
for batch in dset.take(1):
  plt.subplot(5,5,i+1)
  plt.imshow(batch[0][i,:,:,0])
  plt.axis('off')

## Building a VAE with Keras

Now that have access to training data, let's build a small Variational Auto-Encoder to try to learn how to sample galaxy images.

### Defining a recognition model


In [None]:
import tensorflow_probability as tfp
tfpl = tfp.layers
tfkl = tf.keras.layers
tfd = tfp.distributions

def get_probabilistic_encoder(latent_dim=32):
  """ Creates a small convolutional encoder for the requested latent dimension
  """
  # We choose a prior distribution for the latent codes
  prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_dim))

  return tf.keras.Sequential([ 
            # TODO: Write a recognition model.
            # Remember, the output needs to be a distribution, and needs to include 
            # the KL regularisation term
      ])

In [None]:
prob_encoder = get_probabilistic_encoder(latent_dim=32)

In [None]:
prob_encoder.summary()

In [None]:
for batch_im, batch_target in dset.take(1): # Sample only one batch of images
  batch_encoded = prob_encoder(batch_im)    # Apply the encoder on images 

In [None]:
batch_encoded

### Implementing a generator

One of the important considerations for the generator is, what likelihood to use.

Contrary to an MNIST example, we are dealing here with continuous images, with Gaussian noise, we will therefore choose to use a Gaussian likelihood at the output of the generator.

In [None]:
def get_probabilistic_decoder(latent_dim=32):
  """ Creates a small convolutional decoder for the requested latent dimension
  """
  return tf.keras.Sequential([
      # TODO: Write a decoder 
      
      # .....

      # This will be the output distribution layer that defines the likelihood
      # Note that we set sigma=0.3 which is around the standard deviation of the 
      # noise in the images after our preprocessing
      tfpl.DistributionLambda(lambda t: tfd.MultivariateNormalDiag(loc=t,
                                              scale_identity_multiplier=0.3))
  ])

In [None]:
# Let's instantiate the decoder
prob_decoder = get_probabilistic_decoder(latent_dim=32)
# And check its summary
prob_decoder.summary()

In [None]:
# Draw a radom sample of the code
code_sample = batch_encoded.sample()
# And decode that sample
decoded_im = prob_decoder(code_sample)

In [None]:
decoded_im

In [None]:
figure(figsize=(9,3))
subplot(131)
imshow(batch_im[0,:,:,0],cmap='gray'); axis('off')
title("Input Image")
subplot(132)
imshow(decoded_im.sample()[0,:,:,0],cmap='gray'); axis('off')
title("Sample from generator output")
subplot(133)
imshow(decoded_im.mean()[0,:,:,0],cmap='gray'); axis('off')
title("Mean of generator output");

### Building the VAE

In [None]:
vae = tf.keras.Sequential([
          tfkl.InputLayer([64,64,5]),
          prob_encoder,
          prob_decoder])

In [None]:
vae.summary()

In [None]:
# We define the reconstruction loss as the negative log likelihood
negloglik = lambda x, rv_x: -rv_x.log_prob(x)
# And use it to compile the VAE
vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

In [None]:
# We define the batch size
BATCH_SIZE = 64

# Learning rate schedule
LEARNING_RATE=0.001
LEARNING_RATE_EXP_DECAY=0.9

lr_decay = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch,
    verbose=True)

In [None]:
# We actually create our training dataset with our input function
dataset_training = input_fn(tf.estimator.ModeKeys.TRAIN, BATCH_SIZE)

In [None]:
# We are ready to train our model
history = vae.fit(dataset_training,
            steps_per_epoch=20000//BATCH_SIZE, 
            epochs=15,
            callbacks=[lr_decay])

### Testing VAE auto-encoding

In [None]:
# Now that the model is 'trained', we can apply it
dataset_eval = input_fn(tf.estimator.ModeKeys.EVAL, BATCH_SIZE)

In [None]:
for batch_im, batch_targets in dataset_eval.take(1):
  # This extracts one batch of the test dset and shows the first example
  imshow((batch_im[3,:,:,::-1][:,:,-3:]/batch_im[3,:,:,:3].numpy().max()))

In [None]:
autoencoded_im = vae(batch_im) # Run the input batch through the model

In [None]:
subplot(121)
# Plot the mean of the output Bernoulli distribution
imshow(autoencoded_im.mean()[3,:,:,-3:]/autoencoded_im.mean()[3,:,:,-3:].numpy().max(),cmap='gray'); axis('off'); 
title('Mean output')
subplot(122)
# Plot a random sample of the output Bernoulli distribution
imshow(autoencoded_im.sample()[3,:,:,0],cmap='gray'); axis('off');
title('Sample output');

### Sampling from the model

In [None]:
latent_samples = tfd.MultivariateNormalDiag(loc=tf.zeros(32)).sample(16)

In [None]:
image_samples = prob_decoder(0.5*latent_samples)

In [None]:
figure(figsize=(10,10))
fig = image_samples.mean().numpy().reshape((4,4,64,64,5))
fig = fig.transpose((0,2,1,3,4)).reshape((4*64,4*64,5))
# Let's see the result
imshow(fig[:,:,::-1][:,:,-3:]/(fig[:,:,::-1][:,:,-3:]).max()); axis('off');