/
redshift.py
146 lines (112 loc) · 3.93 KB
/
redshift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Module to define redshift distributions we can differentiate through
from abc import ABC
from abc import abstractmethod
import jax.numpy as np
from jax.tree_util import register_pytree_node_class
from jax_cosmo.jax_utils import container
from jax_cosmo.scipy.integrate import simps
steradian_to_arcmin2 = 11818102.86004228
__all__ = ["smail_nz", "kde_nz", "delta_nz"]
class redshift_distribution(container):
def __init__(self, *args, gals_per_arcmin2=1.0, zmax=10.0, **kwargs):
"""Initialize the parameters of the redshift distribution"""
self._norm = None
self._gals_per_arcmin2 = gals_per_arcmin2
super(redshift_distribution, self).__init__(*args, zmax=zmax, **kwargs)
@abstractmethod
def pz_fn(self, z):
"""Un-normalized n(z) function provided by sub classes"""
pass
def __call__(self, z):
"""Computes the normalized n(z)"""
if self._norm is None:
self._norm = simps(lambda t: self.pz_fn(t), 0.0, self.config["zmax"], 256)
return self.pz_fn(z) / self._norm
@property
def zmax(self):
return self.config["zmax"]
@property
def gals_per_arcmin2(self):
"""Returns the number density of galaxies in gals/sq arcmin
TODO: find a better name
"""
return self._gals_per_arcmin2
@property
def gals_per_steradian(self):
"""Returns the number density of galaxies in steradian"""
return self._gals_per_arcmin2 * steradian_to_arcmin2
# Operations for flattening/unflattening representation
def tree_flatten(self):
children = (self.params, self._gals_per_arcmin2)
aux_data = self.config
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
args, gals_per_arcmin2 = children
return cls(*args, gals_per_arcmin2=gals_per_arcmin2, **aux_data)
@register_pytree_node_class
class smail_nz(redshift_distribution):
"""Defines a smail distribution with these arguments
Parameters:
-----------
a:
b:
z0:
gals_per_arcmin2: number of galaxies per sq arcmin
"""
def pz_fn(self, z):
a, b, z0 = self.params
return z**a * np.exp(-((z / z0) ** b))
@register_pytree_node_class
class delta_nz(redshift_distribution):
"""Defines a single plane redshift distribution with these arguments
Parameters:
-----------
z0:
"""
def __init__(self, *args, **kwargs):
"""Initialize the parameters of the redshift distribution"""
super(delta_nz, self).__init__(*args, **kwargs)
self._norm = 1.0
def pz_fn(self, z):
z0 = self.params
return np.where(z == z0, 1.0, 0)
@register_pytree_node_class
class kde_nz(redshift_distribution):
"""A redshift distribution based on a KDE estimate of the nz of a
given catalog currently uses a Gaussian kernel.
TODO: add more if necessary
Parameters:
-----------
zcat: redshift catalog
weights: weight for each galaxy between 0 and 1
Configuration:
--------------
bw: Bandwidth for the KDE
Example:
nz = kde_nz(redshift_catalog, w, bw=0.1)
"""
def _kernel(self, bw, X, x):
"""Gaussian kernel for KDE"""
return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp(
-((X - x) ** 2) / (bw**2 * 2.0)
)
def pz_fn(self, z):
# Extract parameters
zcat, weight = self.params[:2]
w = np.atleast_1d(weight)
q = np.sum(w)
X = np.expand_dims(zcat, axis=-1)
k = self._kernel(self.config["bw"], X, z)
return np.dot(k.T, w) / (q)
@register_pytree_node_class
class systematic_shift(redshift_distribution):
"""Implements a systematic shift in a redshift distribution
TODO: Find a better name for this
Arguments:
redshift_distribution
mean_bias
"""
def pz_fn(self, z):
parent_pz, bias = self.params[:2]
return parent_pz.pz_fn(np.clip(z - bias, 0))