In [1]:
%matplotlib inline
from astropy.table import Table
import keras
from keras import layers
import keras.backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.compat.v1.keras.backend import set_session

In [2]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
sess = tf.compat.v1.Session(config=config)
set_session(sess)

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Quadro M2200, pci bus id: 0000:01:00.0, compute capability: 5.2
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device



In [3]:
bands = ['g', 'r', 'i', 'z', 'VIS', 'Y', 'J', 'H']

In [4]:
mer_fluxes, mer_errors = [], []
ref_fluxes = []
for b in bands:
    if b in 'ugriz':
        mer_fluxes.append(f'FLUX_{b.upper()}_EXT_LSST_APER')
        mer_errors.append(f'FLUXERR_{b.upper()}_EXT_LSST_APER')
        ref_fluxes.append(f'lsst/{b}')
    else:
        mer_fluxes.append(f'FLUX_{b.upper()}_APER')
        mer_errors.append(f'FLUXERR_{b.upper()}_APER')
        ref_fluxes.append(f'euclid/{b}')

In [5]:
target = Table.read('/home/aalvarez/Work/Data/SC8/PHZ_Prod/data/MOCK_MER_WIDE_2740.fits')
ref = Table.read('/home/aalvarez/Work/Data/SC8/PHZ_Prod/data/EUC_PHZ_REFPHOT__20200910T050620.214530Z_00.00.fits', hdu=1)

In [6]:
def prepare_cat(catalog, fluxes, fluxes_err):   
    output = np.zeros((len(catalog), 2, len(fluxes)), dtype=np.float32)
    for i, f in enumerate(fluxes):
        output[:, 0, i] = catalog[f]
        
    for i, e in enumerate(fluxes_err):
        output[:, 1, i] = catalog[e]
    return output

In [7]:
target_data = prepare_cat(target, mer_fluxes, mer_errors)
ref_data = prepare_cat(ref, ref_fluxes, [])

In [8]:
def chi2_ref(ref_data, target):
    nom = ref_data[:,0,:] - target[0,:]
    nom *= nom
    den = ref_data[:,1,:]*ref_data[:,1,:] + target[1,:]*target[1,:]
    return np.sum(nom/den,axis=1)

In [9]:
def chi2_np(ref_data, target, out=None):
    batch_size = target.shape[0]
    ref_size = ref_data.shape[0]
    
    # Tile along a new axis
    tiled_ref = ref_data.reshape(1, ref_size, 2, len(bands))
    tiled_target = target.reshape(-1, 1, 2, len(bands))
    
    nom = (tiled_ref[:,:,0,:] - tiled_target[:,:,0,:])
    nom = nom*nom
    den = tiled_ref[:,:,1,:]*tiled_ref[:,:,1,:] + tiled_target[:,:,1,:]*tiled_target[:,:,1,:]
    return np.sum(nom / den, axis=-1, out=out)

In [30]:
ref_photo = K.constant(ref_data)
ref_photo.shape

TensorShape([497533, 2, 8])

In [31]:
type(ref_photo)

tensorflow.python.framework.ops.EagerTensor

In [11]:
def Chi2Layer(target_photo):
    batch_size = K.shape(target_photo)[0]
    ref_size = K.shape(ref_photo)[0]
    
    # Tile along a new axis
    tiled_ref = K.reshape(ref_photo, (1, ref_size, 2, len(bands)))
    tiled_target = K.reshape(target_photo, (-1, 1, 2, len(bands)))
    
    # Compute nom/den
    nom = tiled_ref[:,:,0,:] - tiled_target[:,:,0,:]
    nom = nom * nom
    den = tiled_ref[:,:,1,:]*tiled_ref[:,:,1,:] + tiled_target[:,:,1,:]*tiled_target[:,:,1,:]
    
    # Sum along last axis    
    return K.sum(nom / den, axis=-1)

In [12]:
input_layer = layers.Input(shape=target_data.shape[1:])
chi2_layer = layers.Lambda(Chi2Layer)(input_layer)

In [13]:
model = keras.Model(inputs=input_layer, outputs=chi2_layer)

In [14]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 2, 8)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 497533)            0         
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [15]:
model.compile()

In [16]:
test_idx = np.arange(0, 512)
batch_size = 8

In [17]:
%%time
gpu_D = model.predict(target_data[test_idx], batch_size=batch_size)

CPU times: user 2.53 s, sys: 373 ms, total: 2.9 s
Wall time: 2.87 s


In [18]:
%%time
np_D = np.zeros((len(test_idx), len(ref_data)))
for i in range(len(test_idx)//batch_size):
    chi2_np(ref_data, target_data[i*batch_size:(i+1)*batch_size], out=np_D[i*batch_size:(i+1)*batch_size])

CPU times: user 12.9 s, sys: 2.37 s, total: 15.3 s
Wall time: 15.3 s


In [19]:
%%time
np_I = np.zeros((len(test_idx), len(ref_data)))
for i, j in enumerate(test_idx):
    np_I[i] = chi2_ref(ref_data, target_data[i])

CPU times: user 14.8 s, sys: 2.55 s, total: 17.3 s
Wall time: 17.4 s


In [20]:
print(np.isclose(np_D, np_I).all())
print(np.isclose(np_D, gpu_D).all())

True
True


In [33]:
17.4/2.87

6.062717770034842