/
T2RunParsnip.py
517 lines (443 loc) · 19.1 KB
/
T2RunParsnip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File : ampel/contrib/hu/t2/T2RunParsnip.py
# License : BSD-3-Clause
# Author : jnordin@physik.hu-berlin.de
# Date : 24.09.2021
# Last Modified Date: 06.04.2022
# Last Modified By : jnordin@physik.hu-berlin.de
import gc
import os
import warnings
from collections.abc import Sequence
from typing import Literal
import backoff
import matplotlib.pyplot as plt
import numpy as np
import packaging
import scipy
import timeout_decorator
from astropy.table import Table
from scipy.stats import chi2
from ampel.abstract.AbsTabulatedT2Unit import AbsTabulatedT2Unit
from ampel.abstract.AbsTiedStateT2Unit import AbsTiedStateT2Unit
from ampel.content.DataPoint import DataPoint
from ampel.content.T1Document import T1Document
from ampel.model.StateT2Dependency import StateT2Dependency
from ampel.struct.UnitResult import UnitResult
from ampel.types import UBson
from ampel.view.T2DocView import T2DocView
# do not warning about scipy.stats.mode(keepdims=None)
if packaging.version.parse(scipy.__version__) < packaging.version.parse("1.11"):
warnings.filterwarnings(
"ignore", category=FutureWarning, module="parsnip.light_curve", lineno=31
)
import extinction # type: ignore[import]
import lcdata
import parsnip
import sncosmo # type: ignore[import]
# The following three only used if correcting for MW dust
from sfdmap2.sfdmap import SFDMap # type: ignore[import]
# All parsnip predictions that are not floats
dcast_pred = {
"object_id": str,
"type": str,
"count": int,
"count_s2n_3": int,
"count_s2n_5": int,
"count_s2n_3_pre": int,
"count_s2n_3_rise": int,
"count_s2n_3_post": int,
"model_dof": int,
}
dcast_class = {
"object_id": str,
}
class T2RunParsnip(AbsTiedStateT2Unit, AbsTabulatedT2Unit):
"""
Gathers information and runs the parsnip model and classifier.
- Obtain model (read from file unless not in sncosmo registry)
- Parse previous (chained) T2results for redshift or phase limits.
- Creates suitable photometry, using the converter provided and phase limits.
- Defines model appropritately, including fit ranges and fixed parameters.
- Run fit, potentially iterative in case of non-convergence.
- Plot output if requested
TODO:
- Add option for redoing fits with disturbed initial conditions to avoid local minima
- Add option for masking data?
"""
# Name (in case standard) or path to parsnip model to load
parsnip_model: str
# Path to classifier to apply to lightcurve fit. If not set, no classification will be done.
parsnip_classifier: None | str
# Redshift usage options. Current options
# T2MatchBTS : Use the redshift published by BTS and synced by that T2.
# T2DigestRedshifts : Use the best redshift as parsed by DigestRedshift.
# T2ElasticcRedshiftSampler: Use a list of redshifts and weights from the sampler.
# None : run sncosmo template fit with redshift as free parameter OR use backup_z if set
redshift_kind: None | str
# If loading redshifts from DigestRedshifts, provide the max ampel z group to make use of.
# (note that filtering based on this can also be done for a potential t3)
max_ampelz_group: int = 3
# It is also possible to use fixed redshift whenever a dynamic redshift kind is not possible
# This could be either a single value or a list
fixed_z: None | float | Sequence[float]
# Finally, the provided lens redshift might be multiplied with a scale
# Useful for lensing studies, or when trying multiple values
scale_z: None | float
max_fit_z: None | float = None
# Remove MW dust absorption.
# MWEBV is either derived from SFD maps using the position from light_curve
# (assuming the SFD_DIR env var is set) OR retrieved from stock (ELASTICHOW?)
# The default value of Rv will be used.
apply_mwcorrection: bool = False
# Further fit parameters
# Bounds - not yet implemented
# sncosmo_bounds : Dict[ str, List[float] ] = {}
# Remove MW dust absorption using SFD maps.
# assumes that the position can be retrieved from the light_curve and
# that the SFD_DIR env var is set to allow them to be found.
# The default value of Rv will be used.
# apply_mwcorrection : bool = False
# Phase range usage. Current option:
# T2PhaseLimit : use the jdmin jdmax provided in this unit output
# None : use full datapoint range
# (T2BayesianBlocks should be added)
phaseselect_kind: None | str
# Abort veto (if fulfilled, skip run)
abort_map: None | dict[str, list]
# Zeropoint parameters
# These are separately set in the Parsnip model settings. The zeropoint
# can does vary between input data, training data and the model.
# Try to adjust this relative to the 'zp' field of the input tabulated lc
training_zeropoint: float = 27.5 # Used in Elasticc training sample
default_zeropoint: float = 25.0 # Default parsnip value
# Save / plot parameters
plot_suffix: None | str
plot_dir: None | str
# Which units should this be changed to
t2_dependency: Sequence[
StateT2Dependency[
Literal[
"T2ElasticcRedshiftSampler",
"T2DigestRedshifts",
"T2MatchBTS",
"T2PhaseLimit",
"T2XgbClassifier",
]
]
]
def post_init(self) -> None:
"""
Retrieve models.
"""
# Load model and classifier
self.model = parsnip.load_model(self.parsnip_model, threads=1)
self.classifier = None
if self.parsnip_classifier:
self.classifier = parsnip.Classifier.load(self.parsnip_classifier)
if self.apply_mwcorrection:
self.dustmap = SFDMap()
def _get_redshift(
self, t2_views
) -> tuple[None | list[float], None | str, None | list[float]]:
"""
Can potentially also be replaced with some sort of T2DigestRershift tabulator?
Assuming that only one instance of redshift sources exist
"""
# Examine T2s for eventual information
z: None | list[float] = None
z_source: None | str = None
z_weights: None | list[float] = None
if self.redshift_kind in [
"T2MatchBTS",
"T2DigestRedshifts",
"T2ElasticcRedshiftSampler",
]:
for t2_view in t2_views:
if t2_view.unit != self.redshift_kind:
continue
self.logger.debug(f"Parsing t2 results from {t2_view.unit}")
t2_res = (
res[-1] if isinstance(res := t2_view.get_payload(), list) else res
)
# Parse this
if self.redshift_kind == "T2MatchBTS":
if "bts_redshift" in t2_res and t2_res["bts_redshift"] != "-":
z = [float(t2_res["bts_redshift"])]
z_source = "BTS"
elif self.redshift_kind == "T2DigestRedshifts":
if (
"ampel_z" in t2_res
and t2_res["ampel_z"] is not None
and t2_res["group_z_nbr"] <= self.max_ampelz_group
):
z = [float(t2_res["ampel_z"])]
z_source = "AMPELz_group" + str(t2_res["group_z_nbr"])
elif self.redshift_kind == "T2ElasticcRedshiftSampler":
z = t2_res["z_samples"]
z_source = t2_res["z_source"]
z_weights = t2_res["z_weights"]
# Check if there is a fixed z set for this run, otherwise keep as free parameter
elif self.fixed_z is not None:
if isinstance(self.fixed_z, float):
z = [self.fixed_z]
else:
z = list(self.fixed_z)
z_source = "Fixed"
else:
z = None
z_source = "Fitted"
if (z is not None) and (z_source is not None) and self.scale_z:
z = [onez * self.scale_z for onez in z]
z_source += f" + scaled {self.scale_z}"
return z, z_source, z_weights
def _get_abort(self, t2_views) -> tuple[bool, dict]:
"""
Check potential previous t2s for whether the run should be aborted.
(For perfomance reasons).
Implemented case is concerns T2XgbClassifier.
"""
if not self.abort_map or len(self.abort_map) == 0:
# Not looking for any
return (False, {})
abort, abort_maps = False, {}
for t2_view in t2_views:
if t2_view.unit not in self.abort_map:
continue
self.logger.debug(f"Parsing t2 results from {t2_view.unit}")
t2_res = res[-1] if isinstance(res := t2_view.get_payload(), list) else res
abort_maps.update(t2_res)
for abort_map in self.abort_map[t2_view.unit]:
if all(t2_res.get(key, None) == val for key, val in abort_map.items()):
abort = True
return (abort, abort_maps)
def _get_phaselimit(self, t2_views) -> tuple[None | float, None | float]:
"""
Can potentially also be replaced with some sort of tabulator?
"""
# Examine T2s for eventual information
jdstart: None | float = None
jdend: None | float = None
if self.phaseselect_kind is None:
jdstart = -np.inf
jdend = np.inf
else:
for t2_view in t2_views:
# So far only knows how to parse phases from T2PhaseLimit
if t2_view.unit != "T2PhaseLimit":
continue
self.logger.debug(f"Parsing t2 results from {t2_view.unit}")
t2_res = (
res[-1] if isinstance(res := t2_view.get_payload(), list) else res
)
jdstart = t2_res["t_start"]
jdend = t2_res["t_end"]
return jdstart, jdend
def _deredden_mw_extinction(self, ebv, phot_tab, rv=3.1) -> Table:
"""
For an input photometric table, try to correct for mw extinction.
"""
# Find effective wavelength for all filters in phot_tab
filterlist = set(phot_tab["band"])
eff_wave = [sncosmo.get_bandpass(f).wave_eff for f in filterlist]
# Determine flux correction (dereddening) factors
flux_corr = 10 ** (0.4 * extinction.ccm89(np.array(eff_wave), ebv * rv, rv))
# Assign this appropritately to Table
phot_tab["flux_original"] = phot_tab["flux"]
phot_tab["fluxerr_original"] = phot_tab["fluxerr"]
for k, band in enumerate(filterlist):
phot_tab["flux"][(phot_tab["band"] == band)] *= flux_corr[k]
phot_tab["fluxerr"][(phot_tab["band"] == band)] *= flux_corr[k]
return phot_tab
@backoff.on_exception(
backoff.constant,
timeout_decorator.timeout_decorator.TimeoutError,
max_tries=3,
)
@timeout_decorator.timeout(5, use_signals=True)
def _classify_parsnip(self, predictions):
"""
Carry out the parsnip classification.
"""
if self.classifier is not None:
return self.classifier.classify(predictions)
raise RuntimeError("No classifier configured")
@backoff.on_exception(
backoff.constant,
timeout_decorator.timeout_decorator.TimeoutError,
max_tries=3,
)
@timeout_decorator.timeout(5, use_signals=True)
def _predict_parsnip(self, dataset):
"""
Carry out the parsnip predictions.
"""
return self.model.predict_dataset(dataset)
# ==================== #
# AMPEL T2 MANDATORY #
# ==================== #
def process(
self,
compound: T1Document,
datapoints: Sequence[DataPoint],
t2_views: Sequence[T2DocView],
) -> UBson | UnitResult:
"""
Fit the parameters of the initiated snocmo_model to the light_curve
provided. Depending on the configuration, the provided T2DovViews
are used to look for redshift information and any phase (time)
limits for the fit.
Parameters
-----------
light_curve: "ampel.view.LightCurve" instance.
See the LightCurve docstring for more info.
t2_records: List of T2Records from the following units (if available)
T2DigestRedshifts (redshift parsed from catalogs)
T2MatchBTS (redshifts synced from BTS page)
T2PhaseLimit (fit time-limits as determined from lightcurve)
Returns
-------
dict
"""
# Initialize output dict
t2_output: dict[str, UBson] = {
"model": self.parsnip_model,
"classifier": self.parsnip_classifier,
}
# Check whether no computation should be done (due to previous fit)
(abort_run, abort_info) = self._get_abort(t2_views)
t2_output["abort_maps"] = abort_info
if abort_run:
t2_output["aborted"] = True
return t2_output
# Check for phase limits
(jdstart, jdend) = self._get_phaselimit(t2_views)
t2_output["jdstart"] = jdstart
t2_output["jdend"] = jdend
if t2_output["jdstart"] is None:
return t2_output
# Obtain photometric table
sncosmo_table = self.get_flux_table(datapoints)
sncosmo_table = sncosmo_table[
(sncosmo_table["time"] >= jdstart) & (sncosmo_table["time"] <= jdend)
]
self.logger.debug(f"Sncosmo table {sncosmo_table}")
# Adjust zeropoint - does this matter? and should we have changed it?
run_zeropoints = set(sncosmo_table["zp"])
if len(run_zeropoints) > 1:
self.logger.info("Warning, multiple zeropoints, using avg.")
run_zeropoint = np.mean(list(run_zeropoints))
else:
run_zeropoint = run_zeropoints.pop()
self.model.settings["zeropoint"] = (
self.default_zeropoint + run_zeropoint - self.training_zeropoint
)
# Potentially correct for dust absorption
if self.apply_mwcorrection:
# Get ebv from coordiantes.
# Here there should be some option to read it from journal/stock etc
mwebv = self.dustmap.ebv(*self.get_pos(datapoints, which="mean"))
t2_output["mwebv"] = mwebv
sncosmo_table = self._deredden_mw_extinction(mwebv, sncosmo_table)
## Obtain redshift(s) from catalog fit or a RedshiftSample
z, z_source, z_weights = self._get_redshift(t2_views)
t2_output["z"] = z
t2_output["z_source"] = z_source
t2_output["z_weights"] = z_weights
# A source class of None indicates that a redshift source was required, but not found.
if z is None or z_source is None:
return t2_output
# If redshift should be fitted, we start with getting samples
if z_source == "Fitted":
if not hasattr(self.model, "predict_redshift_distribution"):
self.logger.warn(
"Redshift fitting is not supported in that version of parsnip"
)
return t2_output
z, z_probabilities = self.model.predict_redshift_distribution(
sncosmo_table, max_redshift=self.max_fit_z
)
assert z is not None
# Create a list of lightcurves, each at a discrete redshift
lcs = []
for redshift in z:
use_lc = sncosmo_table.copy()
use_lc.meta["object_id"] = f"parsnip_z{redshift:4f}"
use_lc.meta["redshift"] = redshift
lcs.append(use_lc)
lc_dataset = lcdata.from_light_curves(lcs)
lc_predictions = self._predict_parsnip(lc_dataset)
lc_classifications = self._classify_parsnip(lc_predictions)
# Cast result for storage and look at relative probabilities
predictions = {}
classifcations = {}
for i, redshift in enumerate(z):
foo = dict(lc_predictions[i][lc_predictions.colnames[1:]])
predictions[str(redshift)] = {
k: dcast_pred[k](v) if k in dcast_pred and v is not None else float(v)
for k, v in foo.items()
}
# Not sure whether the dof could change? Normalizing now
if foo["model_dof"] > 0:
predictions[str(redshift)]["chi2pdf"] = chi2.pdf(
foo["model_chisq"], foo["model_dof"]
)
else:
# Not enough data - earlier check?
predictions[str(redshift)]["chi2pdf"] = 0.0
foo = dict(lc_classifications[i][lc_classifications.colnames[1:]])
classifcations[str(redshift)] = {
k: dcast_class[k](v) if k in dcast_class and v is not None else float(v)
for k, v in foo.items()
}
# Marginalize over the redshift
# p(c|y) = Integral[p(c|z,y) p(z|y) dz]
types = lc_classifications.colnames[1:]
dtype = lc_classifications[types[0]].dtype
probabilities = lc_classifications[types].as_array().view((dtype, len(types)))
# Now we could normalize these z prob and normalize types over redshifts
z_probabilities = np.array(
[lcfit["chi2pdf"] for redshift, lcfit in predictions.items()]
)
t2_output["predictions"] = predictions
t2_output["classifications"] = classifcations
if np.sum(z_probabilities) > 0:
# Take redshift probabilities into account, if available
if z_weights is not None:
z_probabilities *= z_weights
integrated_probabilities = z_probabilities.dot(probabilities)
integrated_probabilities /= np.sum(integrated_probabilities)
t2_output["marginal_lc_classifications"] = dict(
zip(types, integrated_probabilities, strict=False)
)
# Find the best redshifts
t2_output["z_at_minchi"] = z[np.argmax(z_probabilities)]
# Map these to singular value predictions/lc_classifications
# (wastes DB space, but possible to filter based on)
t2_output["prediction"] = predictions[str(t2_output["z_at_minchi"])]
t2_output["classification"] = classifcations[str(t2_output["z_at_minchi"])]
else:
# Not enough data for a chi2 estimate
t2_output["Failed"] = "NoDOF"
return t2_output
# Plot
if self.plot_suffix and self.plot_dir:
# How to construct name?
tname = compound.get("stock")
# Need plotting tools to define id mapper
# tname = ZTFIdMapper.to_ext_id(light_curve.stock_id)
fig = plt.figure()
ax = plt.gca()
# Set redshift to best value and plot this fit
lc_dataset.light_curves[0].meta["redshift"] = t2_output["z_at_minchi"]
parsnip.plot_light_curve(lc_dataset.light_curves[0], self.model, ax=ax)
plt.tight_layout()
plt.savefig(
os.path.join(self.plot_dir, f"t2parsnip_{tname}.{self.plot_suffix}")
)
plt.close("fig")
plt.close("all")
del fig
gc.collect()
return t2_output