# Evaluating Tikhonet Trained

In this Notebook we are going to evaluate the performance of a [Tikhonet](https://arxiv.org/pdf/1911.00443.pdf) trained.

## Required Libraries and Functions

In [1]:
%matplotlib inline
import sys

# Add library path to PYTHONPATH
lib_path = '/gpfswork/rech/xdy/uze68md/GitHub/'
path_alphatransform = lib_path+'alpha-transform'
path_score = lib_path+'score'
sys.path.insert(0, path_alphatransform)
sys.path.insert(0, path_score)
data_path = '/gpfswork/rech/xdy/uze68md/data/'
model_dir = '/gpfswork/rech/xdy/uze68md/trained_models/model_cfht/'

# Libraries
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import fft
import cadmos_lib as cl
import tensorflow as tf
import galsim
from galsim import Image
import galsim.hsm
import pickle

## Load The Comparison Batch

In [2]:
f = open(data_path+"cfht_batch.pkl", "rb")
batch = pickle.load(f)
f.close()

#correct tikhonov inputs normalisation factor
norm_factor = 4e3
batch['inputs_tikho'] *= norm_factor

## Load and Apply Trained Model on Batch

In [3]:
model_name = 'tikhonet_None-constraint_scales-4_steps-625_epochs-10_growth_rate-12_batch_size-128_activationfunction-relu'
# g: gamma (trade-off parameter of the shape constraint)
model_g05_name = 'tikhonet_multi-constraint_scales-4_gamma-0.5_shearlet-3_steps-625_epochs-10_growth_rate-12_batch_size-128_activationfunction-relu'
model = tf.keras.models.load_model(model_dir+model_name, compile=False)
model_g05 = tf.keras.models.load_model(model_dir+model_g05_name, compile=False)
res = model(np.expand_dims(batch['inputs_tikho'], axis=-1))
res_np = tf.keras.backend.eval(res)[...,0]
res_g05 = model_g05(np.expand_dims(batch['inputs_tikho'], axis=-1))
res_g05_np = tf.keras.backend.eval(res_g05)[...,0]
score_g0 = np.load(data_path+'score_g0.npy')
score_g1 = np.load(data_path+'score_g1.npy')

# generate the psfs in the spatial domain
psf_hst = np.fft.ifftshift(np.fft.irfft2(batch['psf_hst'][0]))
psf_tile_cfht = np.array([np.fft.ifftshift(np.fft.irfft2(p)) for p in batch['psf_cfht']])
# make psf tiles
psf_tile_hst = np.repeat(psf_hst[np.newaxis, :, :], batch['psf_hst'].shape[0], axis=0)
# psf_tile_cfht = np.repeat(psf_cfht[np.newaxis, :, :], k_batch*n_batch, axis=0)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


## Processing and Analyzing Results

### Define Error Metrics

In [4]:
im_size = 64
scale = 0.1

def EllipticalGaussian(e1, e2, sig, xc=im_size//2, yc=im_size//2, stamp_size=(im_size,im_size)):
    # compute centered grid
    ranges = np.array([np.arange(i) for i in stamp_size])
    x = np.outer(ranges[0] - xc, np.ones(stamp_size[1]))
    y = np.outer(np.ones(stamp_size[0]),ranges[1] - yc)
    # shift it to match centroid
    xx = (1-e1/2)*x - e2/2*y
    yy = (1+e1/2)*y - e2/2*x
    # compute elliptical gaussian
    return np.exp(-(xx ** 2 + yy ** 2) / (2 * sig ** 2))

def relative_mse(solution, ground_truth):
    relative_mse = ((solution-ground_truth)**2).mean()/ \
                         (ground_truth**2).mean()
    return relative_mse



def get_KSB_ell(image,psf):
    error_flag = True
    #create a galsim version of the data
    image_galsim = Image(image,scale=scale)
    psf_galsim = Image(psf,scale=scale)
    #estimate the moments of the observation image
    ell=galsim.hsm.EstimateShear(image_galsim
                                 ,psf_galsim,shear_est='KSB'
                                 ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                 ,strict=False)
    if ell.error_message != '':
        error_flag = False
    return ell#,error_flag

def get_KSB_g(images,psfs):
    g_list,error_flag_list=[],[]
    for image,psf in zip(images,psfs):
        error_flag = True
        #create a galsim version of the data
        image_galsim = galsim.Image(image,scale=scale)
        # CHECK ADAPTIVE MOMENTS
        psf_galsim = galsim.Image(psf,scale=scale)
        #estimate the moments of the observation image
        shape = galsim.hsm.EstimateShear(image_galsim
                                         ,psf_galsim,shear_est='KSB'
                                         ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                         ,strict=False)
        g = np.array([shape.corrected_g1, shape.corrected_g2])
#        g = np.array([shape.observed_shape.g1, shape.observed_shape.g2])
        if shape.error_message:# or np.linalg.norm(shape.corrected_g1+shape.corrected_g2*1j)>1:
            error_flag = False
        error_flag_list += [error_flag]
        g_list += [g]
    return np.array(g_list).T,np.array(error_flag_list)

def get_moments(images, bool_window=False):
    g_list,error_flag_list=[],[]
    if bool_window:
        window_list = []
        window_flag_list = []
    for image in images:
        error_flag = True
        #create a galsim version of the data
        image_galsim = galsim.Image(image,scale=scale)
        #estimate the moments of the observation image
        shape = galsim.hsm.FindAdaptiveMom(image_galsim
                                         ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                         ,strict=False)
        if bool_window:
            k_sigma = 1.2 #scale up the size of the Gaussian window to make it able to capture more useful signal
            window = EllipticalGaussian(-1.*shape.observed_shape.e1, shape.observed_shape.e2 #convention fix:
                                                                                             #e1 sign swap
                                 ,shape.moments_sigma*k_sigma # convention fix: swap x and y and origin at (0,0)
                                 ,shape.moments_centroid.y-1, shape.moments_centroid.x-1
                                 ,image.shape)
            window_flag = bool(shape.moments_status+1)
        g = np.array([shape.observed_shape.g1, shape.observed_shape.g2])
        if shape.error_message:# or np.linalg.norm(shape.corrected_g1+shape.corrected_g2*1j)>1:
            error_flag = False
        error_flag_list += [error_flag]
        g_list += [g]
        if bool_window:
            window_list += [window]
            window_flag_list += [window_flag]
    output = [np.array(g_list).T,np.array(error_flag_list)]
    if bool_window:
        output += [np.array([window_list])[0],np.array([window_flag_list])[0]]
    return output

def g_to_e(g1,g2):
    shear = galsim.Shear(g1=g1,g2=g2)
    ell = -shear.e1, shear.e2 #reverse the signe of e_1 to get our conventions
    return ell

def MSE(X1,X2,norm=False):
    #Computes the relative MSE
    temp = 1
    if norm:
        temp = np.mean(X2**2)
    return np.mean((X1-X2)**2)/temp

def MSE_obj(obj1,obj2,norm=False):
    return np.array([MSE(o1,o2,norm) for o1,o2 in zip(obj1,obj2)])

### Estimate Adaptive Moments

In [5]:
# estimate adaptive moments
mom_g0,_ = get_moments(res_np)
mom_s0,_ = get_moments(score_g0)
mom_s1,_ = get_moments(score_g1)
mom_g05,_ = get_moments(res_g05_np)
mom_hst,_,windows, window_flags = get_moments(batch['targets'],bool_window=True)

# estimate flux
flux_g0 = np.array([gal.sum() for gal in res_np]).T
flux_g05 = np.array([gal.sum() for gal in res_g05_np]).T
flux_s0 = np.array([gal.sum() for gal in score_g0]).T
flux_s1 = np.array([gal.sum() for gal in score_g1]).T
flux_true = np.array([gal.sum()  for gal in batch['targets']]).T

### Estimate Moments and Absolute Pixel Errors

In [6]:
# compute relative pixel errors
mse_g0 = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_np)])
mse_s0 = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], score_g0)])
mse_s1 = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], score_g1)])
mse_g05 = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_g05_np)])

# compute winodwed pixel relative errors
mse_g0_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], res_np,windows)])
mse_s0_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], score_g0,windows)])
mse_s1_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], score_g1,windows)])
mse_g05_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], res_g05_np,windows)])

# compute adapative moments errors
mom_err_g0 = mom_g0-mom_hst
mom_err_s0 = mom_s0-mom_hst
mom_err_s1 = mom_s1-mom_hst
mom_err_g05 = mom_g05-mom_hst

#compute flux relative errors
flux_err_g0 = np.abs(flux_g0 - flux_true) / flux_true
flux_err_g05 = np.abs(flux_g05 - flux_true) /flux_true
flux_err_s0 = np.abs(flux_s0 - flux_true) / flux_true
flux_err_s1 = np.abs(flux_s1 - flux_true) /flux_true

## Save measurements

In [7]:
flux = [flux_s0, flux_s1, flux_g0, flux_g05]
mse = [mse_s0, mse_s1, mse_g0, mse_g05]
mse_w = [mse_s0_w, mse_s1_w, mse_g0_w, mse_g05_w]
mom = [mom_s0, mom_s1, mom_g0, mom_g05]
measures = [flux, mse, mse_w, mom]
measure_names = ['flux', 'mse', 'mse_w', 'mom']
methods = ['sparsity', 'score', 'tikhonet', 'tikhonet_sc']

data = {}

# fill dictionnary
for i, measure in enumerate(measures):
    data[measure_names[i]] = {}
    for j, method in enumerate(methods):
        data[measure_names[i]][method] = measure[j] 

# add remaining keys
data['windows'] = windows
data['window_flags'] = window_flags
data['flux']['true'] = flux_true
data['mom']['true'] = mom_hst
data['mag_auto'] = batch['mag_auto']

# save dictionnary
f = open(data_path+"cfht_data.pkl","wb")
pickle.dump(data,f)
f.close()

### Compute Errors per Bin

In [8]:
label_s0 = r'Sparsity'
label_s1 = r'SCORE'
label_g0 = r'Tikhonet'
label_g05 = r'Tikhonet + MW'

color_g0 = 'green'
color_g05 = 'darkgreen'
color_s0 = 'blue'
color_s1 = 'darkblue'