In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pkg_resources

from debvader.training.train import train_deblender

### Download and format data

Download dataset that will be used for training. It has been generated using the code in https://github.com/BastienArcelin/dc2_img_generation and the stamps are generated using the XXX function see notebook XXX. The size of the stamp is then fixed to 59x59 pixels.

In [None]:
data_folder_path = pkg_resources.resource_filename('debvader', "data/")
image_path = os.path.join(data_folder_path + '/dc2_imgs/imgs_dc2.npy')
images = np.load(image_path, mmap_mode = 'c')

In [None]:
images.shape

Separate training data and labels to feed the VAE and deblender. You should be careful that the number of filter in the data for trainng correspond to the number of bands considered for the network (default is six).

In [None]:
training_data_vae = np.array((images[:5], images[:5]))
validation_data_vae = np.array((images[5:], images[5:]))

training_data_deblender = np.array((images[:5], images[:5]))
validation_data_deblender = np.array((images[5:], images[5:]))

### Train the VAE and deblender with architecture from the debvader paper (https://arxiv.org/abs/2005.12039)

In [None]:
hist_vae, hist_deblender, net = train_deblender("lsst",
                                                      from_survey = "dc2", 
                                                      epochs = 2, 
                                                      training_data_vae = training_data_vae, 
                                                      validation_data_vae = validation_data_vae, 
                                                      training_data_deblender = training_data_deblender, 
                                                      validation_data_deblender = validation_data_deblender)

### Dataset with a different number of filters

The number of available filters is different for each survey. For example, five filters are available for DES data. To change that you should specify it in the train_deblender function, as well as if the channels appear last or first in the data array.

In [None]:
training_data_vae.shape

Here the channels appear last in the data array. So the channel_first option in the train_deblender function is set as False (default setting).

In [None]:
hist_vae, hist_deblender, net = train_deblender("des",
                                                      from_survey = "dc2", 
                                                      epochs = 2, 
                                                      training_data_vae = training_data_vae, 
                                                      validation_data_vae = validation_data_vae, 
                                                      training_data_deblender = training_data_deblender, 
                                                      validation_data_deblender = validation_data_deblender,
                                                      nb_of_bands = 5,
                                                      channel_last = True)

We need to change the training data format. For example just train on the five fist channels.

In [None]:
training_data_vae_deslike = np.array((images[:5,:,:,:5], images[:5,:,:,:5]))
validation_data_vae_deslike = np.array((images[5:,:,:,:5], images[5:,:,:,:5]))

training_data_deblender_deslike = np.array((images[:5,:,:,:5], images[:5,:,:,:5]))
validation_data_deblender_deslike = np.array((images[5:,:,:,:5], images[5:,:,:,:5]))

Now try the training:

Warning: we cannot use the weights from the network trained on DC2 as the data has six channels in that case. We need to set the from_survey option to None.

In [None]:
hist_vae, hist_deblender, net = train_deblender("des",
                                                      from_survey = None, 
                                                      epochs = 2, 
                                                      training_data_vae = training_data_vae_deslike, 
                                                      validation_data_vae = validation_data_vae_deslike, 
                                                      training_data_deblender = training_data_deblender_deslike, 
                                                      validation_data_deblender = validation_data_deblender_deslike,
                                                      nb_of_bands = 5,
                                                      channel_last = True)