<a href="https://colab.research.google.com/github/EiffL/Tutorials/blob/master/PhotozCNN/photoz_inference_training_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019-2021 Francois Lanusse.

Licensed under the Apache License, Version 2.0 (the "License");

# Photometric Redshift Estimation with TensorFlow - Part III: Training a CNN with Keras

Author: [@EiffL](https://github.com/EiffL) (Francois Lanusse)

### Overview

In this tutorial, we learn how to combine Keras, TensorFlow Probability, and Google Colab to train a model for photo-z inference in the cloud.

We will be using data from the HSC Survey, and more specifically from the Public Data Release 2, which can be found here: https://hsc-release.mtk.nao.ac.jp/doc/


The dataset contains postage stamps of galaxies in 5 HSC bands, along with corresponding spectroscopic redshifts.

Our goal will be to estimate redshift just by looking at a picture of a galaxy.

### Learning objectives

In this notebook, we will learn how to:
*   Build a tf.data.Dataset input pipeline.
*   Build a simple convolutional neural network with Keras.
*   Train a model on GPUs in the cloud.
*   (Stretch Goal) Use TensorFlow Probability to make a probabilistic model.

Note: this Tutorial was originaly presented at [Astro Hack Week 2019](https://github.com/AstroHackWeek/AstroHackWeek2019/tree/master/day4_bayesiandeep).

### Instructions for enabling GPU access

By default, notebooks are started without acceleration. To make sure that the runtime is configured for using GPUs, go to `Runtime > Change runtime type`, and select GPU in `Hardware Accelerator`.



### Installs and Imports

In [None]:
# Retrieving utilities
!git clone https://github.com/EiffL/Tutorials.git
%cd Tutorials/PhotozCNN

# Retrieving pre-prepared data, it takes 2 minutes.
!gsutil -m -q cp -r gs://ahw2019/hsc_photoz/tensorflow_datasets /root/ 

### Checking for GPU access

In [None]:
import tensorflow as tf

#Checking for GPU access
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

## Building a tf.data.Dataset Input Pipeline

The first step is to read the data from disk into a tf.data.Dataset. This TensorFlow Dataset API is the canonical way to supply data to a model during training. It is fast and optimized, and supports distributed training!

Converting some input data into a tf.data.Dataset is made completely trivial by the [TensorFlow Datasets library](https://www.tensorflow.org/datasets) that we explored in the [previous notebook](https://github.com/EiffL/Tutorials/blob/master/PhotozCNN/photoz_inference_tfdatasets.ipynb) in this series. 
We will be reusing the dataset created there:

In [3]:
import tensorflow_datasets as tfds
import hsc_photoz

train_dset = tfds.load('hsc_photoz', split='train')

In [None]:
# Let's extract one example of our data
for i in train_dset.take(1):
  print(i)

Nice :-) it's working, we had nothing to do. We did cheat a bit though, the pre-built dataset is already cached in this instance. On your own computer it might take about one hour to download and prepare the full dataset.

### Dataset preprocessing

An important step of any input pipeline is to make
sure the data is reasonably well behaved before 
feeding to the neural network. Here are some common strategies:


*   Apply log() to values with large dynamic range
*   Remove means, and standardize standard deviation
*   etc...


So, we begin by looking at our data


In [None]:
# What's in our dataset:
train_dset

Ok, we see that this dataset is a dictionary, field `image` are hsc cutouts in 5 bands (g,r,i,z,y), this will be the inputs to our CNN. We also see a `specz_redshift` entry under `attrs`, that will be our prediction target. Let's have a look at these.

In [None]:
from astropy.visualization import make_lupton_rgb
%pylab inline 

# The data is in 5 bands GRIZY, but for visualisation we use only the
# 3 first bands and luptonize them
def luptonize(img):
  return make_lupton_rgb(img[:,:,2], img[:,:,1], img[:,:,0],
                         Q=15, stretch=0.5, minimum=0)

plt.figure(figsize=(10,10))
for i, entry in enumerate(train_dset.take(25)):
  plt.subplot(5,5,i+1)
  plt.imshow(luptonize(entry['image']))
  plt.title('z = %0.02f'%entry['attrs']['specz_redshift'])
  plt.axis('off')

How nice is that :-) We can extract postage stamps and the corresponding spectroscopic redshift for these objects. 

Before doing anything else, we should take a closer look at the  data and check that it's well behaved.

In [7]:
# Let's collect a few examples to check their distributions
cutouts=[]
specz = []
for (batch, entry) in enumerate(train_dset.take(1000)):
  specz.append(entry['attrs']['specz_redshift'])
  cutouts.append(entry['image'])

cutouts = np.stack(cutouts)
specz = np.stack(specz)

In [None]:
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b);
plt.legend()

# Problem ?

Do you see a problem in this histogram?

Let's have a look at a few images:

In [None]:
plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[0,:,:,i],vmin=-1,vmax=50)
  plt.title(b)
  plt.axis('off')

plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[1,:,:,i],vmin=-1,vmax=50)
  plt.title(b)
  plt.axis('off')

In the plot above, each row is a different galaxy, and each column is a different band.

In [None]:
# Let's look at it in log scale
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5);
plt.legend()
plt.yscale('log')

This is terrible, the tail of this distribution in pixel intensity is going to kill our neural networks. We need to standardize the data.

In [None]:
# Let's evaluate the noise standard deviation in each band, and apply range 
# compression accordingly
from astropy.stats import mad_std
scaling = []

for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5,range=[-1,1]);
  sigma = mad_std(cutouts[...,i].flatten())
  scaling.append(sigma)
  plt.axvline(sigma, color='C%d'%i,alpha=0.5)
  plt.axvline(-sigma, color='C%d'%i,alpha=0.5)
plt.legend()

In [None]:
# Let's have a look at this distribution if we rescale each band by the standard
# deviation
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten()/scaling[i],100, label=b,alpha=0.5,
           range=[-10,10]);
legend()

Sweet! Now there is still an unsigthly tail towards very large values. We are going to apply range compression to get rid of it.

In [None]:
# a common approach for range compression is to apply arcsinh to suppress the
# high amplitude values
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(np.arcsinh(cutouts[...,i].flatten()/scaling[i]/3),100,
           label=b, alpha=0.5);
plt.legend()
plt.yscale('log')

![Perfection](https://i.kym-cdn.com/entries/icons/original/000/022/900/704.jpg)

In [None]:
# we can have a look at individual postage stamps with or without this scaling
subplot(121)
imshow(cutouts[0,:,:,1]/scaling[1])
title('Before')
subplot(122)
imshow(np.arcsinh(cutouts[0,:,:,1]/scaling[1]/3))
title('After');

In [None]:
# Let's just check the specz values
plt.hist(specz,100);
# Should be ok

Now that we have defined a scaling for the data that should be appropriate, we can build a scaling function and apply it to the dataset:

In [16]:
# Using a mapping function to apply preprocessing to our data
def preprocessing(example):
  def range_compression(img):
    return tf.math.asinh(img / tf.constant(scaling) / 3. )
  # Our preprocessing function only returns the postage stamps, and the specz
  return range_compression(example['image']), example['attrs']['specz_redshift']

In [17]:
dset = train_dset.map(preprocessing)

In [18]:
# Let's draw some examples from this now
cutouts=[]
specz = []
for (batch, entry) in enumerate(dset.take(1000)):
  specz.append(entry[1])
  cutouts.append(entry[0])

cutouts = np.stack(cutouts)
specz = np.stack(specz)

In [None]:
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.hist(cutouts[...,i].flatten(),100, label=b,alpha=0.5, range=[-1,6]);
plt.legend()

In [None]:
plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[0,:,:,i],vmin=-1,vmax=6)
  plt.title(b)
  plt.axis('off')

plt.figure(figsize=(15,3))
for i,b in enumerate(['g', 'r', 'i', 'z', 'y']):
  plt.subplot(1,5,i+1)
  plt.imshow(cutouts[1,:,:,i],vmin=-1,vmax=6)
  plt.title(b)
  plt.axis('off')

Sweeeeeet

### Create the input pipeline

Now that we know how to preprocess the data, we can build the input pipeline. Below is a function that creates a Dataset object from the tfrecords files, decode them, applies preprocessing, shuffles the dataset, and create batches of data. Finally the function returns the dataset, that Keras models can directly ingest.

More information about tf.data.dataset API can be found here: 

https://www.tensorflow.org/guide/datasets


In [21]:
# Using a mapping function to apply preprocessing to our data
def preprocessing(example):
  img = tf.math.asinh(example['image'] / tf.constant(scaling) / 3. )
  return img, example['attrs']['specz_redshift']

def input_fn(mode='train', batch_size=64):
  """
  mode: 'train' or 'test'
  """
  if mode == 'train':
    dataset = tfds.load('hsc_photoz', split='train[:80%]')
    dataset = dataset.repeat()
    dataset = dataset.shuffle(10000)
  else:
    dataset = tfds.load('hsc_photoz', split='train[80%:]')
    
  dataset = dataset.batch(batch_size, drop_remainder=True)
  dataset = dataset.map(preprocessing) # Apply data preprocessing
  dataset = dataset.prefetch(-1)       # fetch next batches while training current one (-1 for autotune)
  return dataset

## Building a simple regression model with Keras

In [22]:
import tensorflow.keras as tfk

def create_model():
  model = tfk.models.Sequential()

  # .... Add your CNN model here
  
  model.compile(optimizer='adam', # learning rate will be set by LearningRateScheduler
                loss=# .... Add a loss here
                )
  return model

In [None]:
model = create_model()

# print model layers
model.summary()

In [25]:
# We actually create our training dataset with our input function
dataset_training = input_fn('train')
dataset_testing = input_fn('test')

In [None]:
# And we start tensoboard to track our training
%load_ext tensorboard
%tensorboard --logdir='./logs'

In [None]:
# Learning rate schedule
LEARNING_RATE=0.001
LEARNING_RATE_EXP_DECAY=0.9
lr_decay = tfk.callbacks.LearningRateScheduler(
    lambda epoch: LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch,
    verbose=True)

# Tensoboard tracking
tb_callback = tf.keras.callbacks.TensorBoard('./logs/run1', update_freq='batch')

# We are ready to train our model
model.fit(dataset_training,
          validation_data=dataset_testing,
          steps_per_epoch=20000//64,
          epochs=10,
          callbacks=[lr_decay, tb_callback])

In [76]:
# Now that the model is 'trained', we can apply it
dataset_eval = input_fn('test', BATCH_SIZE)
preds = model.predict(dataset_eval)

In [77]:
# Retrieving specz for testing set
ground_truth = []
for im, z in input_fn('test', BATCH_SIZE):
  ground_truth.append(z)
ground_truth = np.concatenate(ground_truth, axis=0)

In [None]:
hist2d(ground_truth, preds.squeeze(), 64,
       range=[[0,1],[0,1]], cmap='gist_stern'); 
gca().set_aspect('equal');
plot([0,1],[0,1],color='r')
xlabel('Spectroscopic Redshift')
ylabel('Predicted Redshift');

Yes! It's learning something! But the results may not be perfect yet. Here  is a list of things to do to improve the model:
*   Train for longer ^^' Deep Learning requires a lot of patience
*   Include additional information, like the extinction in each band, which is included in the dataset. 
    *   You would need to change the input_fn so that the returned dataset contains additional fields
    *   See how to use multiple inputs in a Keras model here: https://keras.io/getting-started/functional-api-guide/
*   Use a better CNN model. A simple CNN like this one is typically very suboptimal. Much better results would be achieved with a ResNet model for instance, see here for a Keras example: https://keras.io/examples/cifar10_resnet/
*   Move away from an MSE loss and embrace TensorflowProbability \o/ Replacing the output of the model by a Mixture Density would be a much better idea. https://www.tensorflow.org/probability/api_docs/python/tfp/layers/MixtureNormal

