# Evaluating Tikhonet Trained

In this Notebook we are going to evaluate the performance of a [Tikhonet](https://arxiv.org/pdf/1911.00443.pdf) trained.

## Required Libraries and Functions

In [1]:
%matplotlib inline
import sys

# Add library path to PYTHONPATH
lib_path = '/gpfswork/rech/xdy/uze68md/GitHub/'
path_alphatransform = lib_path+'alpha-transform'
path_score = lib_path+'score'
sys.path.insert(0, path_alphatransform)
sys.path.insert(0, path_score)
data_path = '/gpfswork/rech/xdy/uze68md/data/'
model_dir = '/gpfswork/rech/xdy/uze68md/trained_models/model_meerkat_64/'

# Function
def crop_center(img,cropx,cropy):
    y,x = img.shape
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)    
    return img[starty:starty+cropy,startx:startx+cropx]

# Libraries
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import fft
import cadmos_lib as cl
import tensorflow as tf
import galsim
from galsim import Image
import galsim.hsm
import pickle

## Load the Evaluation Dataset

In [2]:
f = open(data_path+"meerkat_batch.pkl", "rb")
batch = pickle.load(f)
f.close()

## Extract Batches and Concatenate Them

In [3]:
batch.keys()
n_batch, Nx, Ny = batch['targets'].shape

## Load and Apply Trained Model on Batch

In [4]:
g = open(data_path+"clean_results.pkl", "rb")
clean = pickle.load(g)
g.close()

clean.keys()

dict_keys(['restored_residual', 'restored', 'restored_isotropic', 'residual', 'skymodel', 'skymodel_list', 'sigma_flags'])

In [5]:
# load Tikhonet results
model_name_g0 = 'tikhonet_None-constraint_scales-4_steps-3125_epochs-10_growth_rate-12_batch_size-32_activationfunction-relu'
model_name_g05 = 'tikhonet_multi-constraint_scales-4_gamma-0.5_shearlet-3_steps-3125_epochs-10_growth_rate-12_batch_size-32_activationfunction-relu'

model_g0 = tf.keras.models.load_model(model_dir+model_name_g0, compile=False)
model_g05 = tf.keras.models.load_model(model_dir+model_name_g05, compile=False)

inputs_model = np.expand_dims(np.array([crop_center(im,64,64) for im in batch['inputs_tikho']]),axis=-1)

res_g0 = model_g0(inputs_model)
res_g05 = model_g05(inputs_model)

res_g0 = np.array([np.pad(im,32,constant_values=0) for im in tf.keras.backend.eval(res_g0)[...,0]])#np.pad(tf.keras.backend.eval(res_g0)[...,0], 32, constant_values=0)
res_g05 = np.array([np.pad(im,32,constant_values=0) for im in tf.keras.backend.eval(res_g05)[...,0]])#np.pad(tf.keras.backend.eval(res_g05)[...,0], 32, constant_values=0)

# load CLEAN
res_cl = clean['restored']
res_cl_iso = clean['restored_isotropic']

# load SCORE
res_s0 = np.load(data_path+'score_radio_tikho_g0_1.npy')
res_s2 = np.load(data_path+'score_radio_tikho_g05_1.npy')

for i in range(3):
    res_s0 = np.concatenate((res_s0,np.load(data_path+'score_radio_tikho_g0_{}.npy'.format(i+2))), axis=0)
    res_s2 = np.concatenate((res_s2,np.load(data_path+'score_radio_tikho_g05_{}.npy'.format(i+2))), axis=0)
res_s0 = np.array(res_s0)
res_s2 = np.array(res_s2)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


## Crop images to 64x64

In [6]:
batch['inputs_tikho'] = np.array([crop_center(im,64,64) for im in batch['inputs_tikho']])
batch['inputs'] = np.array([crop_center(im,64,64) for im in batch['inputs']])
batch['targets'] = np.array([crop_center(im,64,64) for im in batch['targets']])
res_g0 = np.array([crop_center(im,64,64) for im in res_g0])
res_g05 = np.array([crop_center(im,64,64) for im in res_g05])
res_cl = np.array([crop_center(im,64,64) for im in res_cl])
res_cl_iso = np.array([crop_center(im,64,64) for im in res_cl_iso])
res_s0 = np.array([crop_center(im,64,64) for im in res_s0])
res_s2 = np.array([crop_center(im,64,64) for im in res_s2])

# flag galaxies where clean did not detect signal
batch['inputs_tikho'] = batch['inputs_tikho'][clean['sigma_flags']]
batch['inputs'] = batch['inputs'][clean['sigma_flags']]
batch['targets'] = batch['targets'][clean['sigma_flags']]
res_g0 = res_g0[clean['sigma_flags']]
res_g05 = res_g05[clean['sigma_flags']]
res_cl = res_cl[clean['sigma_flags']]
res_cl_iso = res_cl_iso[clean['sigma_flags']]
res_s0 = res_s0[clean['sigma_flags']]
res_s2 = res_s2[clean['sigma_flags']]

In [7]:
U = cl.makeUi(*batch['inputs_tikho'].shape[1:3])
im_size = 64
scale = 1.5

def relative_mse(solution, ground_truth):
    relative_mse = ((solution-ground_truth)**2).mean()/ \
                         (ground_truth**2).mean()
    return relative_mse
def MSE(X1,X2,norm=False):
    #Computes the relative MSE
    temp = 1
    if norm:
        temp = np.mean(X2**2)
    return np.mean((X1-X2)**2)/temp

def MSE_obj(obj1,obj2,norm=False):
    return np.array([MSE(o1,o2,norm) for o1,o2 in zip(obj1,obj2)])

def EllipticalGaussian(e1, e2, sig, xc=im_size//2, yc=im_size//2, stamp_size=(im_size,im_size)):
    # compute centered grid
    ranges = np.array([np.arange(i) for i in stamp_size])
    x = np.outer(ranges[0] - xc, np.ones(stamp_size[1]))
    y = np.outer(np.ones(stamp_size[0]),ranges[1] - yc)
    # shift it to match centroid
    xx = (1-e1/2)*x - e2/2*y
    yy = (1+e1/2)*y - e2/2*x
    # compute elliptical gaussian
    return np.exp(-(xx ** 2 + yy ** 2) / (2 * sig ** 2))

def get_moments(images, bool_window=False):
    g_list,error_flag_list=[],[]
    if bool_window:
        window_list = []
        window_flag_list = []
    for image in images:
        error_flag = True
        #create a galsim version of the data
        image_galsim = galsim.Image(image,scale=scale)
        #estimate the moments of the observation image
        shape = galsim.hsm.FindAdaptiveMom(image_galsim
                                         ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                         ,strict=False)
        if bool_window:
            k_sigma = 1.2 #scale up the size of the Gaussian window to make it able to capture more useful signal
            window = EllipticalGaussian(-1.*shape.observed_shape.e1, shape.observed_shape.e2 #convention fix:
                                                                                             #e1 sign swap
                                 ,shape.moments_sigma*k_sigma # convention fix: swap x and y and origin at (0,0)
                                 ,shape.moments_centroid.y-1, shape.moments_centroid.x-1
                                 ,image.shape)
            window_flag = bool(shape.moments_status+1)
        g = np.array([shape.observed_shape.g1, shape.observed_shape.g2])
        if shape.error_message:# or np.linalg.norm(shape.corrected_g1+shape.corrected_cl_iso*1j)>1:
            error_flag = False
        error_flag_list += [error_flag]
        g_list += [g]
        if bool_window:
            window_list += [window]
            window_flag_list += [window_flag]
    output = [np.array(g_list).T,np.array(error_flag_list)]
    if bool_window:
        output += [np.array([window_list])[0],np.array([window_flag_list])[0]]
    return output

In [8]:
# compute pixel errors
mse_g0 = np.array([MSE_obj(est,true) for true,est in zip(batch['targets'], res_g0)])
mse_g05 = np.array([MSE_obj(est,true) for true,est in zip(batch['targets'], res_g05)])
mse_cl_iso = np.array([MSE_obj(est,true) for true,est in zip(batch['targets'], res_cl_iso)])
mse_cl = np.array([MSE_obj(est,true) for true,est in zip(batch['targets'], res_cl)])
mse_s0 = np.array([MSE_obj(est,true) for true,est in zip(batch['targets'], res_s0)])
mse_s2 = np.array([MSE_obj(est,true) for true,est in zip(batch['targets'], res_s2)])

# compute relative pixel errors
mse_g0_rel = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_g0)])
mse_g05_rel = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_g05)])
mse_cl_iso_rel = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_cl_iso)])
mse_cl_rel = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_cl)])
mse_s0_rel = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_s0)])
mse_s2_rel = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_s2)])

# estimate moments
mom_g0,_ = get_moments(res_g0)
mom_g05,_ = get_moments(res_g05)
mom_cl_iso,_ = get_moments(res_cl_iso)
mom_cl,_ = get_moments(res_cl)
mom_s0,_ = get_moments(res_s0)
mom_s2,_ = get_moments(res_s2)
mom_true,_,windows, window_flags = get_moments(batch['targets'],bool_window=True)

# estimate flux
flux_g0 = np.array([gal.sum() for gal in res_g0]).T
flux_g05 = np.array([gal.sum() for gal in res_g05]).T
flux_cl_iso = np.array([gal.sum() for gal in res_cl_iso]).T
flux_cl = np.array([gal.sum() for gal in res_cl]).T
flux_s0 = np.array([gal.sum() for gal in res_s0]).T
flux_s2 = np.array([gal.sum() for gal in res_s2]).T
flux_true = np.array([gal.sum()  for gal in batch['targets']]).T

In [9]:
# compute adapative moments errors
mom_err_g0 = mom_g0-mom_true
mom_err_g05 = mom_g05-mom_true
mom_err_cl_iso = mom_cl_iso-mom_true
mom_err_cl = mom_cl-mom_true
mom_err_s0 = mom_s0-mom_true
mom_err_s2 = mom_s2-mom_true

#compute flux relative errors
flux_err_g0 = np.abs(flux_g0 - flux_true) / flux_true
flux_err_g05 = np.abs(flux_g05 - flux_true) /flux_true
flux_err_cl = np.abs(flux_cl - flux_true) / flux_true
flux_err_cl_iso = np.abs(flux_cl_iso - flux_true) /flux_true
flux_err_s0 = np.abs(flux_s0 - flux_true) / flux_true
flux_err_s2 = np.abs(flux_s2 - flux_true) /flux_true

## Save measurements

In [10]:
flux = [flux_s0, flux_s2, flux_g0, flux_g05, flux_cl, flux_cl_iso]
mse_abs = [mse_s0, mse_s2, mse_g0, mse_g05, mse_cl, mse_cl_iso]
mse_rel = [mse_s0_rel, mse_s2_rel, mse_g0_rel, mse_g05_rel, mse_cl_rel, mse_cl_iso_rel]
mom = [mom_s0, mom_s2, mom_g0, mom_g05, mom_cl, mom_cl_iso]
measures = [flux, mse_abs, mse_rel, mom]
measure_names = ['flux', 'mse_abs', 'mse_rel', 'mom']
methods = ['sparsity', 'score', 'tikhonet', 'tikhonet_sc', 'clean', 'clean_iso']
snr = np.array([np.max(gal) / cl.sigma_mad(gal) for gal in batch['inputs']])

data = {}

# fill dictionnary
for i, measure in enumerate(measures):
    data[measure_names[i]] = {}
    for j, method in enumerate(methods):
        data[measure_names[i]][method] = measure[j] 

# add remaining keys
data['windows'] = windows
data['window_flags'] = window_flags
data['sigma_flags'] = clean['sigma_flags']
data['flux']['true'] = flux_true
data['mom']['true'] = mom_true
data['snr'] = snr

# save dictionnary
f = open(data_path+"meerkat3600_data.pkl","wb")
pickle.dump(data,f)
f.close()