In [9]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pkg_resources

from debvader.training.train import train_deblender

### Download and format data

Download dataset that will be used for training. It has been generated using the code in https://github.com/BastienArcelin/dc2_img_generation and the stamps are generated using the XXX function see notebook XXX. The size of the stamp is then fixed to 59x59 pixels.

In [10]:
data_folder_path = pkg_resources.resource_filename('debvader', "data/")
image_path = os.path.join(data_folder_path + '/dc2_imgs/imgs_dc2.npy')
images = np.load(image_path, mmap_mode = 'c')

In [11]:
images.shape

(10, 59, 59, 6)

Separate training data and labels to feed the VAE and deblender. You should be careful that the number of filter in the data for trainng correspond to the number of bands considered for the network (default is six).

In [12]:
training_data_vae = np.array((images[:5], images[:5]))
validation_data_vae = np.array((images[5:], images[5:]))

training_data_deblender = np.array((images[:5], images[:5]))
validation_data_deblender = np.array((images[5:], images[5:]))

### Train the VAE and deblender with architecture from the debvader paper (https://arxiv.org/abs/2005.12039)

In [None]:
hist_vae, hist_deblender, net = train_deblender("lsst",
                                                      from_survey = "dc2", 
                                                      epochs = 2, 
                                                      training_data_vae = training_data_vae, 
                                                      validation_data_vae = validation_data_vae, 
                                                      training_data_deblender = training_data_deblender, 
                                                      validation_data_deblender = validation_data_deblender)

2023-09-13 18:58:04.809342: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Max
2023-09-13 18:58:04.809381: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 64.00 GB
2023-09-13 18:58:04.809391: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 24.00 GB
2023-09-13 18:58:04.809451: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-09-13 18:58:04.809479: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


in cropping
VAE model
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 59, 59, 6)]       0         
                                                                 
 model (Functional)          (None, 560)               3741224   
                                                                 
 multivariate_normal_tri_l   ((None, 32),              0         
 (MultivariateNormalTriL)     (None, 32))                        
                                                                 
 model_1 (Functional)        (None, 59, 59, 6)         4577228   
                                                                 
Total params: 8318452 (31.73 MB)
Trainable params: 8318440 (31.73 MB)
Non-trainable params: 12 (48.00 Byte)
_________________________________________________________________
/Users/aubourg/Documents/Astro/LSST/debvader/src/debvader/dat

2023-09-13 18:58:06.748823: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


### Dataset with a different number of filters

The number of available filters is different for each survey. For example, five filters are available for DES data. To change that you should specify it in the train_deblender function, as well as if the channels appear last or first in the data array.

In [6]:
training_data_vae.shape

(2, 5, 59, 59, 6)

Here the channels appear last in the data array. So the channel_first option in the train_deblender function is set as False (default setting).

In [7]:
hist_vae, hist_deblender, net = debvader.train.train_deblender("des",
                                                      from_survey = "dc2", 
                                                      epochs = 2, 
                                                      training_data_vae = training_data_vae, 
                                                      validation_data_vae = validation_data_vae, 
                                                      training_data_deblender = training_data_deblender, 
                                                      validation_data_deblender = validation_data_deblender,
                                                      nb_of_bands = 5,
                                                      channel_last = True)

in cropping
VAE model
Model: "model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 59, 59, 5)]       0         
_________________________________________________________________
model_4 (Model)              (None, 560)               3740932   
_________________________________________________________________
multivariate_normal_tri_l_1  ((None, 32), (None, 32))  0         
_________________________________________________________________
model_5 (Model)              (None, 59, 59, 5)         4576650   
Total params: 8,317,582
Trainable params: 8,317,572
Non-trainable params: 10
_________________________________________________________________
The number of bands in the data does not correspond to the number of filters in the network. Correct this before starting again.


ValueError: 

We need to change the training data format. For example just train on the five fist channels.

In [8]:
training_data_vae_deslike = np.array((images[:5,:,:,:5], images[:5,:,:,:5]))
validation_data_vae_deslike = np.array((images[5:,:,:,:5], images[5:,:,:,:5]))

training_data_deblender_deslike = np.array((images[:5,:,:,:5], images[:5,:,:,:5]))
validation_data_deblender_deslike = np.array((images[5:,:,:,:5], images[5:,:,:,:5]))

Now try the training:

In [9]:
data_folder_path = pkg_resources.resource_filename('debvader', "data/")
path_output = os.path.join(data_folder_path, 'weights/dc2/not_normalised/')
latest = tf.train.latest_checkpoint(path_output)

Warning: we cannot use the weights from the network trained on DC2 as the data has six channels in that case. We need to set the from_survey option to None.

In [10]:
hist_vae, hist_deblender, net = debvader.train.train_deblender("des",
                                                      from_survey = None, 
                                                      epochs = 2, 
                                                      training_data_vae = training_data_vae_deslike, 
                                                      validation_data_vae = validation_data_vae_deslike, 
                                                      training_data_deblender = training_data_deblender_deslike, 
                                                      validation_data_deblender = validation_data_deblender_deslike,
                                                      nb_of_bands = 5,
                                                      channel_last = True)

in cropping
VAE model
Model: "model_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         [(None, 59, 59, 5)]       0         
_________________________________________________________________
model_8 (Model)              (None, 560)               3740932   
_________________________________________________________________
multivariate_normal_tri_l_2  ((None, 32), (None, 32))  0         
_________________________________________________________________
model_9 (Model)              (None, 59, 59, 5)         4576650   
Total params: 8,317,582
Trainable params: 8,317,572
Non-trainable params: 10
_________________________________________________________________

Start the training
Train on 5 samples, validate on 5 samples
Epoch 1/2

Epoch 00001: val_mse improved from inf to 0.09295, saving model to /pbs/throng/lsst/users/bbiswas/debvader/debvader/data/weights/des/vae/val_mse/weigh