-
Notifications
You must be signed in to change notification settings - Fork 30
/
CorrelationsTwoPoint.py
869 lines (718 loc) · 36.4 KB
/
CorrelationsTwoPoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
from __future__ import print_function, division, unicode_literals, absolute_import
import os
from collections import defaultdict
import re
import numpy as np
import scipy.special as scsp
import treecorr
import healpy as hp
from sklearn.cluster import k_means
from GCR import GCRQuery
from .base import BaseValidationTest, TestResult
from .plotting import plt
from .utils import (generate_uniform_random_ra_dec_footprint,
get_healpixel_footprint,
generate_uniform_random_dist)
__all__ = ['CorrelationsAngularTwoPoint', 'CorrelationsProjectedTwoPoint',
'DEEP2StellarMassTwoPoint']
def redshift2dist(z, cosmology):
""" Convert redshift to comoving distance in units Mpc/h.
Parameters
----------
z : float array like
cosmology : astropy.cosmology instance
Returns
-------
float array like of comoving distances
"""
return cosmology.comoving_distance(z).to('Mpc').value * cosmology.h
class CorrelationUtilities(BaseValidationTest):
"""
Base class for Correlation classes that loads catalogs, cuts a catalog
sample, plots the correlation results, and scores the the results of the
correlation measurements by comparing them to test data.
Init of the function takes in a loaded yaml file containing the settings
for this tests. See the following file for an example:
descqa/configs/tpcf_Zehavi2011_rSDSS.yaml
"""
# pylint: disable=super-init-not-called,abstract-method
def __init__(self, **kwargs):
self.test_name = kwargs['test_name']
self.requested_columns = kwargs['requested_columns']
self.test_samples = kwargs['test_samples']
self.test_sample_labels = kwargs['test_sample_labels']
self.Mag_units = kwargs.get('Mag_units', None)
self.output_filename_template = kwargs['output_filename_template']
validation_filepath = os.path.join(self.data_dir, kwargs['data_filename'])
self.validation_data = np.loadtxt(validation_filepath, skiprows=2)
self.data_label = kwargs['data_label']
self.test_data = kwargs['test_data']
self.fig_xlabel = kwargs['fig_xlabel']
self.fig_ylabel = kwargs['fig_ylabel']
self.fig_ylim = kwargs.get('fig_ylim', None)
self.fig_subplots_nrows, self.fig_subplots_ncols = kwargs.get('fig_subplots', (1, 1))
self.fig_subplot_groups = kwargs.get('fig_subplot_groups', [None])
self.fig_xlim = kwargs.get('fig_xlim', None)
self.tick_size = kwargs.get('tick_size', 12)
self.mask_large_errors = kwargs.get('mask_large_errors', False)
self.treecorr_config = {
'min_sep': kwargs['min_sep'],
'max_sep': kwargs['max_sep'],
'bin_size': kwargs['bin_size'],
}
if kwargs.get('var_method', None):
self.treecorr_config['var_method'] = kwargs['var_method']
self.npatch = kwargs.get('npatch', 1)
self.random_nside = kwargs.get('random_nside', 1024)
self.random_mult = kwargs.get('random_mult', 3)
# jackknife errors
self.jackknife = kwargs.get('jackknife', False)
if self.jackknife:
self.N_jack = kwargs.get('N_jack', 30)
jackknife_quantities = kwargs.get('jackknife_quantities',
{'ra':['ra', 'ra_true'], 'dec':['dec', 'dec_true']})
if 'ra' not in self.requested_columns or 'dec' not in self.requested_columns:
self.requested_columns.update(jackknife_quantities)
self.use_diagonal_only = kwargs.get('use_diagonal_only', True)
self.r_validation_min = kwargs.get('r_validation_min', 1)
self.r_validation_max = kwargs.get('r_validation_max', 10)
self.truncate_cat_name = kwargs.get('truncate_cat_name', False)
self.title_in_legend = kwargs.get('title_in_legend', False)
self.font_size = kwargs.get('font_size', 16)
self.legend_size = kwargs.get('legend_size', 10)
self.survey_label = kwargs.get('survey_label', '')
self.no_title = kwargs.get('no_title', False)
self.legend_title = kwargs.get('legend_title', '')
@staticmethod
def load_catalog_data(catalog_instance, requested_columns, test_samples, h=1):
""" Load requested columns from a Generic Catalog Reader instance and
trim to the min and max of the requested cuts in test_samples.
Parameters
----------
catalog_instance : a Generic Catalog object.
requested_columns : dictionary of lists of strings
A dictionary keyed on a simple column name (e.g. mag, z)
with values of lists containing string names to try to load from
the GCR catalog instance.
Example:
{Mag': ['Mag_true_r_sdss_z0', 'Mag_true_r_des_z0'], ...}
test_samples : dictionary of dictionaries
Dictionaries containing simple column names and min max values to
cut on.
Examples:
{'Mr_-23_-22": {'Mag': {'min': -23, 'max': -22}
'z': {'min': 0.1031, 'max': 0.2452}}
Returns
-------
GRC catalog instance containing simplified column names and cut to the
min/max of all requested test samples.
"""
colnames = dict()
for col_key, possible_names in requested_columns.items():
colnames[col_key] = catalog_instance.first_available(*possible_names)
if not all(v for v in colnames.values()):
return None
col_value_mins = defaultdict(list)
col_value_maxs = defaultdict(list)
Mag_shift = 5*np.log10(h) # Magnitude shift to adjust for h=1 units in data (eg Zehavi et. al.)
print('Magnitude shift for h={:.2f} = {:.2f}'.format(h, Mag_shift))
for conditions in test_samples.values():
for col_key, condition in conditions.items():
if not isinstance(condition, dict):
continue
if 'min' in condition:
col_value_mins[col_key].append(condition['min'])
if 'max' in condition:
col_value_maxs[col_key].append(condition['max'])
filters = [(np.isfinite, c) for c in colnames.values()]
if catalog_instance.has_quantity('extendedness'):
filters.append('extendedness == 1')
# can remove ultra-faint synthetics if present in catalog by cutting on negative halo_id
for col_key, col_name in colnames.items():
if col_key in col_value_mins and col_value_mins[col_key]:
min_value = min(col_value_mins[col_key]) + Mag_shift if 'Mag' in col_key else min(col_value_mins[col_key])
filters.append('{} >= {}'.format(col_name, min_value))
if col_key in col_value_maxs and col_value_maxs[col_key]:
max_value = max(col_value_maxs[col_key]) + Mag_shift if 'Mag' in col_key else max(col_value_maxs[col_key])
filters.append('{} < {}'.format(col_name, max_value))
print('Catalog filters:', filters)
catalog_data = catalog_instance.get_quantities(list(colnames.values()), filters=filters)
catalog_data = {k: catalog_data[v] for k, v in colnames.items()}
return catalog_data
@staticmethod
def create_test_sample(catalog_data, test_sample, h=1):
""" Select a subset of the catalog data an input test sample.
This function should be overloaded in inherited classes for more
complex cuts (e.g. color cuts).
Parameters
----------
catalog_data : a GenericCatalogReader catalog instance
test_sample : dictionary of dictionaries
A dictionary specifying the columns to cut on and the min/max values of
the cut.
Example:
{Mag: {min: -23, max: -22}
z: {min: 0.1031, max: 0.2452}}
Returns
-------
A GenericCatalogReader catalog instance cut to the requested bounds.
"""
filters = []
Mag_shift = 5*np.log10(h) # Magnitude shift to adjust for h=1 units in data (eg Zehavi et. al.)
for key, condition in test_sample.items():
if isinstance(condition, dict):
if 'max' in condition:
max_value = condition['max'] + Mag_shift if 'Mag' in key else condition['max']
filters.append('{} < {}'.format(key, max_value))
if 'min' in condition:
min_value = condition['min'] + Mag_shift if 'Mag' in key else condition['min']
filters.append('{} >= {}'.format(key, min_value))
else: #customized filter
if 'Mag_shift' in condition:
condition = re.sub('Mag_shift', '{:0.2f}'.format(Mag_shift), condition)
print('Substituted filter to adjust for Mag shifts: {}'.format(condition))
filters.append(condition)
print('Test sample filters for {}'.format(test_sample), filters)
return GCRQuery(*filters).filter(catalog_data)
def plot_data_comparison(self, corr_data, catalog_name, output_dir):
""" Plot measured correlation functions and compare them against test
data.
Parameters
----------
corr_data : list of float array likes
List containing resultant data from correlation functions computed
in the test.
Example:
[[np.array([...]), np.array([...]), np.array([...])], ...]
catalog_name : string
Name of the catalog used in the test.
output_dir : string
Full path of the directory to write results to.
"""
# pylint: disable=no-member
fig_xsize = 5 if self.fig_subplots_ncols==1 else 7 #widen figure for subplots
fig_ysize = 5 if self.fig_subplots_ncols==1 else 4 #narrow y-axis for subplots
fig, ax_all = plt.subplots(self.fig_subplots_nrows, self.fig_subplots_ncols, squeeze=False,
figsize=(min(2, self.fig_subplots_ncols)*fig_xsize,
min(2, self.fig_subplots_nrows)*fig_ysize))
for nx, (ax, this_group) in enumerate(zip(ax_all.flat, self.fig_subplot_groups)):
if this_group is None:
this_group = self.test_samples
colors = plt.cm.plasma_r(np.linspace(0.1, 1, len(this_group)))
if not this_group:
ax.set_visible(False)
continue
for sample_name, color in zip(this_group, colors):
cat_data = True
try:
sample_corr = corr_data[sample_name]
except KeyError:
cat_data = False
sample_data = self.test_data[sample_name]
sample_label = self.test_sample_labels.get(sample_name)
ax.loglog(self.validation_data[:, 0],
self.validation_data[:, sample_data['data_col']],
c=color,
label=' '.join([self.survey_label, sample_label]))
if 'data_err_col' in sample_data:
y1 = (self.validation_data[:, sample_data['data_col']] +
self.validation_data[:, sample_data['data_err_col']])
y2 = (self.validation_data[:, sample_data['data_col']] -
self.validation_data[:, sample_data['data_err_col']])
if self.fig_ylim is not None:
y2[y2 <= 0] = self.fig_ylim[0]*0.9
ax.fill_between(self.validation_data[:, 0], y1, y2, lw=0, color=color, alpha=0.25)
if cat_data:
if self.mask_large_errors and self.fig_ylim is not None:
mask = (sample_corr[1] - sample_corr[2]) > min(self.fig_ylim)
else:
mask = np.ones(len(sample_corr[1]), dtype=bool)
ax.errorbar(sample_corr[0][mask], sample_corr[1][mask], sample_corr[2][mask],
label=' '.join([catalog_name, sample_label]),
marker='o', ls='', c=color)
self.decorate_plot(ax, catalog_name, n=nx)
fig.tight_layout()
fig.subplots_adjust(hspace=0, wspace=0)
fig.savefig(os.path.join(output_dir, '{:s}.png'.format(self.test_name)), bbox_inches='tight')
plt.close(fig)
def get_legend_title(self, test_samples, exclude='mstellar'):
"""
"""
legend_title = ''
filter_ids = list(set([k for v in test_samples.values() for k in v.keys() if exclude not in k]))
for filter_id in filter_ids:
legend_title = self.get_legend_subtitle(test_samples, filter_id=filter_id, legend_title=legend_title)
return legend_title
@staticmethod
def get_legend_subtitle(test_samples, filter_id='z', legend_title=''):
"""
"""
legend_title = legend_title if len(legend_title) == 0 else '{}; '.format(legend_title)
min_values = [test_samples[k][filter_id].get('min', None) for k in test_samples if test_samples[k].get(filter_id, None) is not None]
max_values = [test_samples[k][filter_id].get('max', None) for k in test_samples if test_samples[k].get(filter_id, None) is not None]
min_title = ''
if len(min_values) > 0 and any([k is not None for k in min_values]):
min_title = '{} < {}'.format(min([k for k in min_values if k is not None]), filter_id)
max_title = ''
if len(max_values) > 0 and any([k is not None for k in max_values]):
max_values = [k for k in max_values if k is not None]
max_title = '${} < {}$'.format(filter_id, max(max_values)) if len(min_title) == 0 else '${} < {}$'.format(min_title, max(max_values))
return legend_title + max_title
def decorate_plot(self, ax, catalog_name, n=0):
"""
Decorates plot with axes labels, title, etc.
"""
title = '{} vs. {}'.format(catalog_name, self.data_label)
lgnd_title = None
if self.title_in_legend:
lgnd_title = self.get_legend_title(self.test_samples) if not self.legend_title else self.legend_title
ax.legend(loc='lower left', fontsize=self.legend_size, title=lgnd_title)
ax.tick_params(labelsize=self.tick_size)
# check for multiple subplots and label
if n+1 >= self.fig_subplots_ncols*(self.fig_subplots_nrows - 1):
ax.tick_params(labelbottom=True)
for axlabel in ax.get_xticklabels():
axlabel.set_visible(True)
ax.set_xlabel(self.fig_xlabel, size=self.font_size)
else:
for axlabel in ax.get_xticklabels():
axlabel.set_visible(False)
if self.fig_ylim is not None:
ax.set_ylim(*self.fig_ylim)
if self.fig_xlim is not None:
ax.set_xlim(*self.fig_xlim)
# suppress labels for multiple subplots
if n % self.fig_subplots_ncols == 0: #1st column
ax.set_ylabel(self.fig_ylabel, size=self.font_size)
else:
for axlabel in ax.get_yticklabels():
axlabel.set_visible(False)
if not self.no_title:
ax.set_title(title, fontsize='medium')
@staticmethod
def score_and_test(corr_data): # pylint: disable=unused-argument
""" Given the resultant correlations, compute the test score and return
a TestResult
Parameters
----------
corr_data : list of float array likes
List containing resultant data from correlation functions computed
in the test.
Example:
[[np.array([...]), np.array([...]), np.array([...])], ...]
Returns
-------
descqa.TestResult
"""
return TestResult(inspect_only=True)
@staticmethod
def get_jackknife_randoms(N_jack, catalog_data, generate_randoms, ra='ra', dec='dec'):
"""
Computes the jackknife regions and random catalogs for each region
Parameters
----------
N_jack : number of regions
catalog_data : input catalog
generate_randoms: function to generate randoms (eg self.generate_processed_randoms)
Returns
-------
jack_labels: array of regions in catalog data
randoms: dict of randoms labeled by region
"""
#cluster
nn = np.stack((catalog_data[ra], catalog_data[dec]), axis=1)
_, jack_labels, _ = k_means(n_clusters=N_jack, random_state=0, X=nn)
randoms = {}
for nj in range(N_jack):
catalog_data_jk = dict(zip(catalog_data.keys(), [v[(jack_labels != nj)] for v in catalog_data.values()]))
rand_cat, rr = generate_randoms(catalog_data_jk) #get randoms for this footprint
randoms[str(nj)] = {'ran': rand_cat, 'rr':rr}
return jack_labels, randoms
def get_jackknife_errors(self, N_jack, catalog_data, sample_conditions, r, xi, jack_labels, randoms,
run_treecorr, diagonal_errors=True):
"""
Computes jacknife errors
Parameters
----------
N_jack : number of regions
catalog_data : input catalog
sample_conditions : sample selections
r : r data for full region
xi : correlation data for full region
jack_labels: array of regions in catalog data
randoms: dict of randoms labeled by region
run_treecorr: method to run treecorr
Returns
--------
covariance : covariance matrix
"""
#run treecorr for jackknife regions
Nrbins = len(r)
Njack_array = np.zeros((N_jack, Nrbins), dtype=np.float)
print(sample_conditions)
for nj in range(N_jack):
catalog_data_jk = dict(zip(catalog_data.keys(),
[v[(jack_labels != nj)] for v in catalog_data.values()]))
tmp_catalog_data = self.create_test_sample(catalog_data_jk, sample_conditions) #apply sample cut
# run treecorr
_, Njack_array[nj], _ = run_treecorr(catalog_data=tmp_catalog_data,
treecorr_rand_cat=randoms[str(nj)]['ran'],
rr=randoms[str(nj)]['rr'],
output_file_name=None)
covariance = np.zeros((Nrbins, Nrbins))
for i in range(Nrbins):
if diagonal_errors:
for njack in Njack_array:
covariance[i][i] += (N_jack - 1.)/N_jack * (xi[i] - njack[i]) ** 2
else:
for j in range(Nrbins):
for njack in Njack_array:
covariance[i][j] += (N_jack - 1.)/N_jack * (xi[i] - njack[i]) * (xi[j] - njack[j])
return covariance
def check_footprint(self, catalog_data):
"""
"""
pix_footprint = get_healpixel_footprint(catalog_data['ra'],
catalog_data['dec'], self.random_nside)
area_footprint = 4.*np.pi*(180./np.pi)**2*len(pix_footprint)/hp.nside2npix(self.random_nside)
return area_footprint
class CorrelationsAngularTwoPoint(CorrelationUtilities):
"""
Validation test for an angular 2pt correlation function.
"""
def __init__(self, **kwargs):
super(CorrelationsAngularTwoPoint, self).__init__(**kwargs)
self.treecorr_config['metric'] = 'Arc'
self.treecorr_config['sep_units'] = 'deg'
print(self.legend_title)
def generate_processed_randoms(self, catalog_data):
""" Create and process random data for the 2pt correlation function.
Parameters
----------
catalog_data : dict
Returns
-------
tuple of (random catalog treecorr.Catalog instance,
processed treecorr.NNCorrelation on the random catalog)
"""
rand_ra, rand_dec = generate_uniform_random_ra_dec_footprint(
catalog_data['ra'].size * self.random_mult,
get_healpixel_footprint(catalog_data['ra'], catalog_data['dec'], self.random_nside),
self.random_nside,
)
rand_cat = treecorr.Catalog(ra=rand_ra, dec=rand_dec, ra_units='deg', dec_units='deg',
npatch= self.npatch,
)
rr = treecorr.NNCorrelation(**self.treecorr_config)
rr.process(rand_cat)
return rand_cat, rr
def run_treecorr(self, catalog_data, treecorr_rand_cat, rr, output_file_name):
""" Run treecorr on input catalog data and randoms.
Produce measured correlation functions using the Landy-Szalay
estimator.
Parameters
----------
catalog_data : a GCR catalog instance
treecorr_rand_cat : treecorr.Catalog
Catalog of random positions over the same portion of sky as the
input catalog_data.
rr : treecorr.NNCorrelation
A processed NNCorrelation of the input random catalog.
output_file_name : string
Full path name of the file to write the resultant correlation to.
Returns
-------
tuple of array likes
Resultant correlation function. (separation, amplitude, amp_err).
"""
cat = treecorr.Catalog(
ra=catalog_data['ra'],
dec=catalog_data['dec'],
ra_units='deg',
dec_units='deg',
npatch= self.npatch,
)
dd = treecorr.NNCorrelation(**self.treecorr_config)
dr = treecorr.NNCorrelation(**self.treecorr_config)
rd = treecorr.NNCorrelation(**self.treecorr_config)
dd.process(cat)
dr.process(treecorr_rand_cat, cat)
rd.process(cat, treecorr_rand_cat)
if output_file_name is not None:
dd.write(output_file_name, rr, dr, rd)
xi, var_xi = dd.calculateXi(rr, dr, rd)
xi_rad = np.exp(dd.meanlogr)
xi_sig = np.sqrt(var_xi)
return xi_rad, xi, xi_sig
def run_on_single_catalog(self, catalog_instance, catalog_name, output_dir):
catalog_data = self.load_catalog_data(catalog_instance=catalog_instance,
requested_columns=self.requested_columns,
test_samples=self.test_samples)
if not catalog_data:
cols = [i for c in self.requested_columns.values() for i in c]
return TestResult(skipped=True,
summary='Missing requested quantities {}'.format(', '.join(cols)))
if self.truncate_cat_name:
catalog_name = re.split('_', catalog_name)[0]
rand_cat, rr = self.generate_processed_randoms(catalog_data) #assumes ra and dec exist
with open(os.path.join(output_dir, 'galaxy_count.dat'), 'a') as f:
f.write('Total (= catalog) Area = {:.1f} sq. deg.\n'.format(self.check_footprint(catalog_data)))
f.write('NOTE: 1) assuming catalog is of equal depth over the full area\n')
f.write(' 2) assuming sample contains enough galaxies to measure area\n')
if self.jackknife: #evaluate randoms for jackknife footprints
jack_labels, randoms = self.get_jackknife_randoms(self.N_jack, catalog_data,
self.generate_processed_randoms)
correlation_data = dict()
for sample_name, sample_conditions in self.test_samples.items():
tmp_catalog_data = self.create_test_sample(
catalog_data, sample_conditions)
if not len(tmp_catalog_data['ra']):
continue
output_treecorr_filepath = os.path.join(
output_dir, self.output_filename_template.format(sample_name))
xi_rad, xi, xi_sig = self.run_treecorr(
catalog_data=tmp_catalog_data,
treecorr_rand_cat=rand_cat,
rr=rr,
output_file_name=output_treecorr_filepath)
#jackknife errors
if self.jackknife:
covariance = self.get_jackknife_errors(self.N_jack, catalog_data, sample_conditions,
xi_rad, xi, jack_labels, randoms,
self.run_treecorr,
diagonal_errors=self.use_diagonal_only)
xi_sig = np.sqrt(np.diag(covariance))
correlation_data[sample_name] = (xi_rad, xi, xi_sig)
self.plot_data_comparison(corr_data=correlation_data,
catalog_name=catalog_name,
output_dir=output_dir)
return self.score_and_test(correlation_data)
class CorrelationsProjectedTwoPoint(CorrelationUtilities):
"""
Validation test for an radial 2pt correlation function.
"""
def __init__(self, **kwargs):
super(CorrelationsProjectedTwoPoint, self).__init__(**kwargs)
self.pi_maxes = kwargs['pi_maxes']
self.treecorr_config['metric'] = 'Rperp'
def run_on_single_catalog(self, catalog_instance, catalog_name, output_dir):
h = catalog_instance.cosmology.H(0).value/100 if self.Mag_units == 'h1' else 1
catalog_data = self.load_catalog_data(catalog_instance=catalog_instance,
requested_columns=self.requested_columns,
test_samples=self.test_samples, h=h)
if not catalog_data:
return TestResult(skipped=True, summary='Missing requested quantities')
if self.truncate_cat_name:
catalog_name = re.split('_', catalog_name)[0]
rand_ra, rand_dec = generate_uniform_random_ra_dec_footprint(
catalog_data['ra'].size*self.random_mult,
get_healpixel_footprint(catalog_data['ra'], catalog_data['dec'], self.random_nside),
self.random_nside,
)
correlation_data = dict()
for sample_name, sample_conditions in self.test_samples.items():
output_treecorr_filepath = os.path.join(
output_dir, self.output_filename_template.format(sample_name))
tmp_catalog_data = self.create_test_sample(
catalog_data, sample_conditions, h=h)
with open(os.path.join(output_dir, 'galaxy_count.dat'), 'a') as f:
f.write('{} {}\n'.format(sample_name, len(tmp_catalog_data['ra'])))
if not len(tmp_catalog_data['ra']):
continue
xi_rad, xi, xi_sig = self.run_treecorr_projected(
catalog_data=tmp_catalog_data,
rand_ra=rand_ra,
rand_dec=rand_dec,
cosmology=catalog_instance.cosmology,
pi_max=self.pi_maxes[sample_name],
output_file_name=output_treecorr_filepath)
correlation_data[sample_name] = (xi_rad, xi, xi_sig)
self.plot_data_comparison(corr_data=correlation_data,
catalog_name=catalog_name,
output_dir=output_dir)
return self.score_and_test(correlation_data)
def run_treecorr_projected(self, catalog_data, rand_ra, rand_dec,
cosmology, pi_max, output_file_name):
""" Run treecorr on input catalog data and randoms.
Produce measured correlation functions using the Landy-Szalay
estimator.
Parameters
----------
catalog_data : a GCR catalog instance
rand_ra : float array like
Random RA positions on the same sky as covered by catalog data.
rand_dec : float array like
Random DEC positions on the same sky as covered by catalog data.
cosmology : astropy.cosmology
An astropy.cosmology instance specifying the catalog cosmology.
pi_max : float
Maximum comoving distance along the line of sight to correlate.
output_file_name : string
Full path name of the file to write the resultant correlation to.
Returns
-------
tuple of array likes
Resultant correlation function. (separation, amplitude, amp_err).
"""
treecorr_config = self.treecorr_config.copy()
treecorr_config['min_rpar'] = -pi_max
treecorr_config['max_rpar'] = pi_max
cat = treecorr.Catalog(
ra=catalog_data['ra'],
dec=catalog_data['dec'],
ra_units='deg',
dec_units='deg',
npatch=self.npatch,
r=redshift2dist(catalog_data['z'], cosmology),
)
z_min = catalog_data['z'].min()
z_max = catalog_data['z'].max()
rand_cat = treecorr.Catalog(
ra=rand_ra,
dec=rand_dec,
ra_units='deg',
dec_units='deg',
npatch=self.npatch,
r=generate_uniform_random_dist(
rand_ra.size, *redshift2dist(np.array([z_min, z_max]), cosmology)),
)
dd = treecorr.NNCorrelation(treecorr_config)
dr = treecorr.NNCorrelation(treecorr_config)
rd = treecorr.NNCorrelation(treecorr_config)
rr = treecorr.NNCorrelation(treecorr_config)
dd.process(cat)
dr.process(rand_cat, cat)
rd.process(cat, rand_cat)
rr.process(rand_cat)
dd.write(output_file_name, rr, dr, rd)
xi, var_xi = dd.calculateXi(rr, dr, rd)
xi_rad = np.exp(dd.meanlogr)
xi_sig = np.sqrt(var_xi)
return xi_rad, xi * 2. * pi_max, xi_sig * 2. * pi_max
class DEEP2StellarMassTwoPoint(CorrelationsProjectedTwoPoint):
""" Test simulated data against the power laws fits to Stellar Mass
selected samples in DEEP2. This class also serves as an example of creating
a specific test from the two correlation classes in the test suite.
In the future this could also include a color cut, however absolute U and B
band magnitudes are not stored in the simulated catalogs currently and
converting the current fluxes to those is currently out of scope.
"""
@staticmethod
def power_law(r, r0, g):
""" Compute the power law of a simple 2 parameter projected correlation
function.
Parameters
---------
r : float array like
Comoving positions to compute the power law at.
r0 : float
Amplitude of the correlation function
g : float
Power law of the correlation function.
Returns
-------
float array like
"""
gamma_func_ratio = scsp.gamma(1/2.) * scsp.gamma((g - 1) / 2) / scsp.gamma(g / 2)
return r * (r0 / r) ** g * gamma_func_ratio
@staticmethod
def power_law_err(r, r0, g, r0_err, g_err):
""" Compute the error on the power law model given errors on r0 and g.
function.
Parameters
---------
r : float array like
Comoving positions to compute the power law at.
r0 : float
Amplitude of the correlation function
g : float
Power law of the correlation function.
r0_err : float
Error on r0
g_err : float
Error on the power law slope.
Returns
-------
float array like
"""
gamma_func_ratio = scsp.gamma(1/2.) * scsp.gamma((g - 1) / 2) / scsp.gamma(g / 2)
p_law = r * (r0 / r) ** g * gamma_func_ratio
dev_r0 = r ** (1 - g) * r0 ** (g - 1) * g * gamma_func_ratio * r0_err
dev_g = (p_law * np.log(r) +
2 * p_law * scsp.polygamma(0, (g - 1) / 2) +
-2 * p_law * scsp.polygamma(0, g / 2)) * g_err
return np.sqrt(dev_r0 ** 2 + dev_g ** 2)
def plot_data_comparison(self, corr_data, catalog_name, output_dir):
fig, ax = plt.subplots()
colors = plt.cm.plasma_r(np.linspace(0.1, 1, len(self.test_samples))) # pylint: disable=no-member
for sample_name, color in zip(self.test_samples, colors):
sample_corr = corr_data[sample_name]
sample_data = self.test_data[sample_name]
sample_label = self.test_sample_labels.get(sample_name)
p_law = self.power_law(sample_corr[0],
self.validation_data[sample_data['row'],
sample_data['r0']],
self.validation_data[sample_data['row'],
sample_data['g']])
p_law_err = self.power_law_err(sample_corr[0],
self.validation_data[sample_data['row'],
sample_data['r0']],
self.validation_data[sample_data['row'],
sample_data['g']],
self.validation_data[sample_data['row'],
sample_data['r0_err']],
self.validation_data[sample_data['row'],
sample_data['g_err']])
ax.loglog(sample_corr[0],
p_law,
c=color,
label=' '.join([self.survey_label, sample_label]))
ax.fill_between(sample_corr[0],
p_law - p_law_err,
p_law + p_law_err,
lw=0, color=color, alpha=0.2)
ax.errorbar(sample_corr[0], sample_corr[1], sample_corr[2], marker='o', ls='', c=color,
label=' '.join([catalog_name, sample_label]))
ax.fill_between([self.r_validation_min, self.r_validation_max], [0, 0], [10**4, 10**4],
alpha=0.15, color='grey') #validation region
self.decorate_plot(ax, catalog_name)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, '{:s}.png'.format(self.test_name)), bbox_inches='tight')
plt.close(fig)
def score_and_test(self, corr_data):
""" Test the average chi^2 per degree of freedom against power law fits
to the DEEP2 dataset.
"""
chi_per_nu = 0
total_sample = 0
rbins = list(corr_data.values()).pop()[0]
r_idx_min = np.searchsorted(rbins, self.r_validation_min)
r_idx_max = np.searchsorted(rbins, self.r_validation_max, side='right')
for sample_name in self.test_samples:
sample_corr = corr_data[sample_name]
sample_data = self.test_data[sample_name]
r_data = sample_corr[0][r_idx_min:r_idx_max]
p_law = self.power_law(r_data,
self.validation_data[sample_data['row'],
sample_data['r0']],
self.validation_data[sample_data['row'],
sample_data['g']])
p_law_err = self.power_law_err(r_data,
self.validation_data[sample_data['row'],
sample_data['r0']],
self.validation_data[sample_data['row'],
sample_data['g']],
self.validation_data[sample_data['row'],
sample_data['r0_err']],
self.validation_data[sample_data['row'],
sample_data['g_err']])
chi_per_nu = np.sum(((sample_corr[1][r_idx_min:r_idx_max] - p_law) / p_law_err) ** 2)
chi_per_nu /= len(r_data)
total_sample += 1
score = chi_per_nu / total_sample
# Made up value. Assert that average chi^2/nu is less than 2.
test_pass = score < 2
return TestResult(score=score,
passed=test_pass,
summary="Ave chi^2/nu value comparing to power law fits to stellar mass threshold "
"DEEP2 data. Test threshold set to 2.")