Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nsevilla #241

Open
wants to merge 9 commits into
base: SRV
Choose a base branch
from
220 changes: 157 additions & 63 deletions descqa/srv_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from scipy.stats import norm, binned_statistic
import time
import sys
import treecorr

from .base import BaseValidationTest, TestResult
from .plotting import plt
import matplotlib.transforms as mtrans


if 'mpi4py' in sys.modules:
from mpi4py import MPI
Expand Down Expand Up @@ -42,6 +45,9 @@ def shear_from_moments(Ixx,Ixy,Iyy,kind='eps'):
denom = Ixx + Iyy
return (Ixx-Iyy)/denom, 2*Ixy/denom

def size_from_moments(Ixx, Iyy):
return Ixx + Iyy


class CheckEllipticity(BaseValidationTest):
"""
Expand All @@ -66,6 +72,16 @@ def __init__(self, **kwargs):
self.IyyPSF= kwargs.get('IyyPSF')
self.psf_fwhm = kwargs.get('psf_fwhm')
self.bands = kwargs.get('bands')
self.compute_rowe_flag = kwargs.get('rowe', True)
self.treecorr_config = {
"min_sep": 0.5,
"max_sep": 250.0,
"nbins": 20,
"bin_slop": 0.01,
"sep_units": "arcmin",
"psf_size_units": "sigma",
}


if not any((
self.catalog_filters,
Expand Down Expand Up @@ -162,7 +178,7 @@ def plot_ellipticities_band(self,e1,e2,band,output_dir):

def plot_psf(self,fwhm,band,output_dir):
'''
Plot elliptiticies for each band
Plot PSF for each band
'''
plt.figure()
bins = np.linspace(0.,1.5,201)
Expand All @@ -178,8 +194,101 @@ def plot_psf(self,fwhm,band,output_dir):
plt.close()
return

def plot_e1e2_residuals(self,e1,e2,epsf1,epsf2,band,output_dir):
'''
Plot e1,e2 residuals with respect to model
'''
plt.figure()
bins = np.linspace(-1.,1.,201)
bins_mid = (bins[1:]+bins[:-1])/2.
e1residual_out, bin_edges = np.histogram(e1-epsf1, bins=bins)
e2residual_out, bin_edges = np.histogram(e2-epsf2, bins=bins)
self.record_result((0,'e1e2_residuals_'+band),'e1e2_residuals_'+band,'e1e2_residuals_'+band+'.png')
plt.plot(bins_mid,e1residual_out,'b',label='e1-e1psf')
plt.plot(bins_mid,e2residual_out,'r--',label='e2-e2psf')
plt.title('e1/e2 residuals vs model, band '+band)
plt.legend()
plt.savefig(os.path.join(output_dir, 'e1e2_residuals_'+band+'.png'))
plt.close()
return

def plot_Tfrac_residuals(self,T,Tpsf,band,output_dir):
'''
Plot T fractional residuals with respect to model
'''
plt.figure()
bins = np.linspace(-0.1,0.1,201)
bins_mid = (bins[1:]+bins[:-1])/2.
tresidual_out, bin_edges = np.histogram((T-Tpsf)/Tpsf, bins=bins)
self.record_result((0,'tfrac_residuals_'+band),'tfrac_residuals_'+band,'tfrac_residuals_'+band+'.png')
plt.plot(bins_mid,tresidual_out,'b',label='(T-Tpsf)/Tpsf')
plt.title('T fractional residuals vs model, band '+band)
plt.legend()
plt.savefig(os.path.join(output_dir, 'tfrac_residuals_'+band+'.png'))
plt.close()
return

def plot_rowe_stats(self,rowe_stats,output_dir):
'''
Plot Rowe coefficients
'''
filename = output_dir + "rowe134.png"
ax = plt.subplot(1, 1, 1)
fig = plt.figure()
for j, i in enumerate([1, 3, 4]):
theta, xi, err = rowe_stats[i]
tr = mtrans.offset_copy(
ax.transData, fig, 0.05 * (j - 1), 0, units="inches"
)
plt.errorbar(
theta,
abs(xi),
err,
fmt=".",
label=rf"$\rho_{i}$",
capsize=3,
)
# transform=tr,
#)
plt.bar(0.0, 2e-05, width=5, align="edge", color="gray", alpha=0.2)
plt.bar(5, 1e-07, width=245, align="edge", color="gray", alpha=0.2)
plt.xscale("log")
plt.yscale("log")
plt.xlabel(r"$\theta$")
plt.ylabel(r"$\xi_+(\theta)$")
plt.legend()
plt.savefig(filename)
plt.close()

filename = output_dir + "rowe25.png"
ax = plt.subplot(1, 1, 1)
fig = plt.figure()
for j, i in enumerate([2, 5]):
theta, xi, err = rowe_stats[i]
tr = mtrans.offset_copy(
ax.transData, fig, 0.05 * j - 0.025, 0, units="inches"
)
plt.errorbar(
theta,
abs(xi),
err,
fmt=".",
label=rf"$\rho_{i}$",
capsize=3,
)
# transform=tr,
#)
plt.bar(0.0, 2e-05, width=5, align="edge", color="gray", alpha=0.2)
plt.bar(5, 1e-07, width=245, align="edge", color="gray", alpha=0.2)
plt.xscale("log")
plt.yscale("log")
plt.xlabel(r"$\theta$")
plt.ylabel(r"$\xi_+(\theta)$")
plt.legend()
plt.savefig(filename)
plt.close()


return

def generate_summary(self, output_dir, aggregated=False):
if aggregated:
Expand Down Expand Up @@ -214,6 +323,16 @@ def generate_summary(self, output_dir, aggregated=False):
self._individual_header.clear()
self._individual_table.clear()

def compute_rowe(self, i, ra, dec, q1, q2):
n = len(ra)
print(f"Computing Rowe statistic rho_{i} from {n} objects")

corr = treecorr.GGCorrelation(self.treecorr_config)
cat1 = treecorr.Catalog(ra=np.array(ra), dec=np.array(dec), g1=np.nan_to_num(q1[0], copy=True, nan=0.0, posinf=None, neginf=None), g2=np.nan_to_num(q1[1], copy=True, nan=0.0, posinf=None, neginf=None), ra_units="deg", dec_units="deg")
cat2 = treecorr.Catalog(ra=np.array(ra), dec=np.array(dec), g1=np.nan_to_num(q2[0], copy=True, nan=0.0, posinf=None, neginf=None), g2=np.nan_to_num(q2[1], copy=True, nan=0.0, posinf=None, neginf=None), ra_units="deg", dec_units="deg")
corr.process(cat1, cat2) #error is here
return corr.meanr, corr.xip, corr.varxip**0.5

def run_on_single_catalog(self, catalog_instance, catalog_name, output_dir):

all_quantities = sorted(map(str, catalog_instance.list_all_quantities(True)))
Expand Down Expand Up @@ -247,15 +366,11 @@ def run_on_single_catalog(self, catalog_instance, catalog_name, output_dir):
label_tot=[]
plots_tot=[]

#quantities=[self.ra,self.dec,self.Ixx,self.Iyy,self.Ixy,self.IxxPSF, self.IyyPSF, self.IxyPSF]

# doing everything per-band first of all
for band in self.bands:
quantities=[self.Ixx+'_'+band,self.Iyy+'_'+band,self.Ixy+'_'+band,self.IxxPSF+'_'+band, self.IyyPSF+'_'+band, self.IxyPSF+'_'+band, self.psf_fwhm+'_'+band]
quantities=[self.Ixx+'_'+band,self.Iyy+'_'+band,self.Ixy+'_'+band,self.IxxPSF+'_'+band, self.IyyPSF+'_'+band, self.IxyPSF+'_'+band, self.psf_fwhm+'_'+band, self.ra, self.dec]
quantities = tuple(quantities)



# reading in the data
if len(filters) > 0:
catalog_data = catalog_instance.get_quantities(quantities,filters=filters,return_iterator=False)
Expand All @@ -280,73 +395,52 @@ def run_on_single_catalog(self, catalog_instance, catalog_name, output_dir):

e1,e2 = shear_from_moments(recvbuf[self.Ixx+'_'+band],recvbuf[self.Ixy+'_'+band],recvbuf[self.Iyy+'_'+band])
e1psf,e2psf = shear_from_moments(recvbuf[self.IxxPSF+'_'+band],recvbuf[self.IxyPSF+'_'+band],recvbuf[self.IyyPSF+'_'+band])
T = size_from_moments(recvbuf[self.Ixx+'_'+band],recvbuf[self.Iyy+'_'+band])
Tpsf = size_from_moments(recvbuf[self.IxxPSF+'_'+band],recvbuf[self.IyyPSF+'_'+band])
T_f = (T-Tpsf)/Tpsf
de1 = e1-e1psf
de2 = e2-e2psf

Ixx = recvbuf[self.Ixx+'_'+band]
Iyy = recvbuf[self.Iyy+'_'+band]
Ixy = recvbuf[self.Ixy+'_'+band]
fwhm = recvbuf[self.psf_fwhm+'_'+band]


self.plot_moments_band(Ixx,Ixy,Iyy,band,output_dir)
self.plot_ellipticities_band(e1,e2,band,output_dir)
self.plot_psf(fwhm,band,output_dir)
self.plot_e1e2_residuals(e1,e2,e1psf,e2psf,band,output_dir)
self.plot_Tfrac_residuals(T,Tpsf,band,output_dir)

#rowe_stats = np.empty([6])
#rowe_stats[0] = 0. #dummy value, so that in the rest of the array the index
#coincides with the usual Rowe number
rowe_stats = {}

print(type(recvbuf))
print(recvbuf[self.ra])

if self.compute_rowe_flag:
print("Computing Rowe_1")
rowe_stats[1] = self.compute_rowe(1, recvbuf[self.ra], recvbuf[self.dec], (de1,de2), (de1,de2))
print(rowe_stats[1])
print("Computing Rowe_2")
rowe_stats[2] = self.compute_rowe(2, recvbuf[self.ra], recvbuf[self.dec], (e1,e2), (de1,de2))
print(rowe_stats[2])
print("Computing Rowe_3")
rowe_stats[3] = self.compute_rowe(3, recvbuf[self.ra], recvbuf[self.dec], (e1,e2) * T_f, (e1,e2) * T_f)
print(rowe_stats[3])
print("Computing Rowe_4")
rowe_stats[4] = self.compute_rowe(4, recvbuf[self.ra], recvbuf[self.dec], (de1,de2), (e1,e2) * T_f)
print(rowe_stats[4])
print("Computing Rowe_5")
rowe_stats[5] = self.compute_rowe(5, recvbuf[self.ra], recvbuf[self.dec], (e1,e2), (e1,e2) * T_f)
print(rowe_stats[5])

print("Now plotting")
self.plot_rowe_stats(rowe_stats,output_dir)


# plot moments directly per filter. For good, star, galaxy
# FWHM of the psf
# calculate ellpiticities and make sure they're alright
# look at different bands
# note that we want to look by magnitude or SNR to understand the longer tail in moments
# PSF ellipticity whisker plot?
# look at what validate_drp is

# s1/s2 plots

# look at full ellipticity distribution test as well
#https://github.com/LSSTDESC/descqa/blob/master/descqa/EllipticityDistribution.py
#DC2 validation github - PSF ellipticity
# https://github.com/LSSTDESC/DC2-analysis/blob/master/validation/Run_1.2p_PSF_tests.ipynb

#https://github.com/LSSTDESC/DC2-analysis/blob/master/validation/DC2_calexp_src_validation_1p2.ipynb

# Look at notes here:
#https://github.com/LSSTDESC/DC2-production/issues/340

# get PSF FWHM directly from data, note comments on here:
# https://github.com/LSSTDESC/DC2-analysis/blob/u/wmwv/DR6_dask_refactor/validation/validate_dc2_run2.2i_object_table_dask.ipynb about focussing of the "telescope"


'''mask_finite = np.isfinite(e1)&np.isfinite(e2)
bs_out = bs(e1[mask_finite],values = e2[mask_finite],bins=100,statistic='mean')
plt.figure()
quantity_hashes[0].add('s1s2')
self.record_result((0,'s1s2'),'s1s2','p_s1s2.png')
plt.plot(bs_out[1][1:],bs_out[0])
plt.savefig(os.path.join(output_dir, 'p_s1s2.png'))
plt.close()


plt.figure()
quantity_hashes[0].add('s1')
self.record_result((0,'s1'),'s1','p_s1.png')
#plt.hist(e1,bins=np.linspace(-1.,1.,100))
plt.hist(e1psf,bins=100)#np.linspace(-1.,1.,100))
plt.savefig(os.path.join(output_dir, 'p_s1.png'))
plt.close()
plt.figure()
quantity_hashes[0].add('s2')
self.record_result((0,'s2'),'s2','p_s2.png')
#plt.hist(e2,bins=np.linspace(-1.,1.,100))
plt.hist(e2psf,bins=100)#np.linspace(-1.,1.,100))
plt.savefig(os.path.join(output_dir, 'p_s2.png'))
plt.close()'''
'''plt.figure()
quantity_hashes[0].add('s12')
self.record_result((0,'s12'),'s12','p_s12.png')
plt.hist2d(e1,e2,bins=100)
plt.savefig(os.path.join(output_dir, 'p_s12.png'))
plt.close()'''

if rank==0:
self.generate_summary(output_dir)
else:
Expand Down