In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from sfh.datasets.mergers import kinetic

In [5]:
# Using a mapping function to apply preprocessing to our data
def preprocessing(example):
    a = 3.
    scaling = [33205.43648882585, 43.65749262641268, 37.643232661037565]
    mean_img = [193696.55, 2.9481435, 95.614685]
    
    img1 = tf.constant(1/a,dtype=tf.float32) * tf.math.asinh(example['image'][0,:,:]/(tf.constant(scaling[0]*a,dtype=tf.float32)) )
    img2 = (example['image'][1,:,:]-tf.constant(mean_img[1],dtype=tf.float32))/tf.constant(scaling[1],dtype=tf.float32)
    img3 = (example['image'][2,:,:]-tf.constant(mean_img[2],dtype=tf.float32))/tf.constant(scaling[2],dtype=tf.float32)
       
    # Replace NaNs by zeros
    img1 = tf.where(tf.math.is_nan(img1), tf.zeros_like(img1), img1)
    img2 = tf.where(tf.math.is_nan(img2), tf.zeros_like(img2), img2)
    img3 = tf.where(tf.math.is_nan(img3), tf.zeros_like(img3), img3)
    # Replace InFs by zeros
    img1 = tf.where(tf.math.is_inf(img1), tf.zeros_like(img1), img1)
    img2 = tf.where(tf.math.is_inf(img2), tf.zeros_like(img2), img2)
    img3 = tf.where(tf.math.is_inf(img3), tf.zeros_like(img3), img3)
    
    img  = tf.stack([img1,img2,img3],axis=-1)
    
    lbt = example['last_major_merger']/tf.constant(13.6)    

    return (img, lbt)

In [6]:
def input_fn(mode='train', batch_size=1):
    
    data_dir='/gpfsscratch/rech/qrc/commun/tensorflow_datasets'

    #mode: 'train' or 'test'

    if mode == 'train':       
        dataset = tfds.load('mergers_kinetic', split='train[:80%]', data_dir=data_dir)
        dataset = dataset.map(preprocessing)# Apply data preprocessing
        print(len(dataset))
        dataset = dataset.repeat()
        dataset = dataset.shuffle(1000)
    elif mode == 'validation':
        dataset = tfds.load('mergers_kinetic', split='train[80%:90%]', data_dir=data_dir)
        dataset = dataset.map(preprocessing)# Apply data preprocessing
        print(len(dataset))
        dataset = dataset.shuffle(1000) 
        
    else:
        dataset = tfds.load('mergers_kinetic', split='train[90%:]', data_dir=data_dir)
        dataset = dataset.map(preprocessing)# Apply data preprocessing
        print(len(dataset))
    
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(-1)       # fetch next batches while training current one (-1 for autotune)
    
    return dataset