From 7aebc5c512fce413deb274beb0f9b9e145a7aa4f Mon Sep 17 00:00:00 2001 From: Jakob Nordin <16082999+wombaugh@users.noreply.github.com> Date: Thu, 23 May 2024 16:22:44 +0200 Subject: [PATCH] Tns24 (#63) * add more info from t2 eval Updated TNS submit methods. New plot methods, merged from nuztf. Updated poetry files. remove unimplemented method ruff noqa flags added chore: ruff cleanup Remove old TNS units. clean test After removing out of day tests, the github ruff test fails (works locally). Also remove obsolute TNS submit routines. mypy edits chore: add ruff to pre-commit config ruff check ...and correcting ruff changes so mypy still happy. ...and correcting ruff changes so mypy still happy. type support for Pillow Plot options for Colibri. Pull wrt magdifflim as criteria. Fixed bug in TNS resonse query. * Allow digest redshift methods to be inherited. * Activate TNS submission and reply reading. * added T2KilonovaStats evaluation * Filter for looking for stellar flares. * Base unit for lightcurve fitters. * Two bugs caused by potentially empty catalog entries. * Units for demo linear fit and search for stellar outburts. * Updates to new API for TNS and AColibri uploads. * Remove debug prints. * Format updates. * Consistant unit naming and separated base and demo classes. * rm old names. * Ensure that a NamedSecret for Slack publishing is not required. * Shift dependence from TNSMirrorSearcher to complements through TNSNames. * PlotTransientLightcurves now works with T3 complement cutouts. * SubmitTNS using complemented TNS ids when possibl. * Ripped out TNS and cutout duplicated methods. * refine first-sentence detection --------- Co-authored-by: andimatter Co-authored-by: Jakob van Santen --- .pre-commit-config.yaml | 28 +- README.md | 7 +- ampel/contrib/hu/t0/StellarFilter.py | 330 +++++++++++ ampel/contrib/hu/t2/T2BaseLightcurveFitter.py | 191 +++++++ ampel/contrib/hu/t2/T2DemoLightcurveFitter.py | 87 +++ ampel/contrib/hu/t2/T2DigestRedshifts.py | 135 ++++- ampel/contrib/hu/t2/T2InfantCatalogEval.py | 20 +- ampel/contrib/hu/t2/T2PolynomialFit.py | 81 +++ ampel/contrib/hu/t2/T2TNSEval.py | 2 +- ampel/contrib/hu/t3/AstroColibriPublisher.py | 132 +++-- .../contrib/hu/t3/PlotTransientLightcurves.py | 535 +++++++++++++++++ ampel/contrib/hu/t3/SubmitTNS.py | 191 +++++++ ampel/contrib/hu/t3/TNSTalker.py | 539 ------------------ .../contrib/hu/t3/TransientTablePublisher.py | 126 +--- ampel/contrib/hu/t3/ampel_tns.py | 352 ------------ ampel/contrib/hu/t3/tns/TNSClient.py | 57 +- ampel/contrib/hu/t3/tns/__init__.py | 4 + ampel/contrib/hu/t3/tns/tns_ampel_util.py | 180 ++++++ ampel/contrib/hu/test/test_tnstalker.py | 70 --- .../process/TNSSubmitComplete.yml | 25 - conf/ampel-hu-astro/process/TNSSubmitNew.yml | 32 -- conf/ampel-hu-astro/unit.yml | 7 +- examples/healpix_linfit.yml | 206 +++++++ examples/infant_test.yml | 378 ++++++++++++ examples/stellar_outburst.yml | 231 ++++++++ poetry.lock | 33 +- pyproject.toml | 6 +- scripts/generate_unit_inventory.py | 4 +- 28 files changed, 2758 insertions(+), 1231 deletions(-) create mode 100755 ampel/contrib/hu/t0/StellarFilter.py create mode 100644 ampel/contrib/hu/t2/T2BaseLightcurveFitter.py create mode 100644 ampel/contrib/hu/t2/T2DemoLightcurveFitter.py create mode 100644 ampel/contrib/hu/t2/T2PolynomialFit.py create mode 100755 ampel/contrib/hu/t3/PlotTransientLightcurves.py create mode 100644 ampel/contrib/hu/t3/SubmitTNS.py delete mode 100755 ampel/contrib/hu/t3/TNSTalker.py delete mode 100755 ampel/contrib/hu/t3/ampel_tns.py create mode 100644 ampel/contrib/hu/t3/tns/tns_ampel_util.py delete mode 100644 ampel/contrib/hu/test/test_tnstalker.py delete mode 100644 conf/ampel-hu-astro/process/TNSSubmitComplete.yml delete mode 100644 conf/ampel-hu-astro/process/TNSSubmitNew.yml create mode 100644 examples/healpix_linfit.yml create mode 100644 examples/infant_test.yml create mode 100644 examples/stellar_outburst.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80ad340a..08f40657 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,25 +1,11 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks + repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.4.0 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-added-large-files - - id: check-json - exclude: ampel/contrib/hu/test - - id: check-yaml - - id: pretty-format-json - args: [--no-sort-keys, --autofix, --indent=2] - exclude: ampel/contrib/hu/test -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.1.14 - hooks: - # Run the linter. - - id: ruff - args: [ --fix ] - # Run the formatter. - - id: ruff-format +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.14 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format fail_fast: false diff --git a/README.md b/README.md index befd68a1..019f84fa 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ requires an access token if data is to be retrieved. - [RcfFilter](ampel/contrib/hu/t0/RcfFilter.py): Filter for the ZTF Redshift Completeness Factor program.. - [RedshiftCatalogFilter](ampel/contrib/hu/t0/RedshiftCatalogFilter.py): Filter derived from DecentFilter designed to only accept transients located close to a galaxy in a catalog, and within redshift bounds. - [SimpleDecentFilter](ampel/contrib/hu/t0/SimpleDecentFilter.py): General-purpose filter devloped alongside DecentFilter but without use of external catalogs. +- [StellarFilter](ampel/contrib/hu/t0/StellarFilter.py): a.k.a. as the IndecentFilter, i.e. an inversion of the DecentFilter mainly used for finding extragalactic objects. - [TransientInClusterFilter](ampel/contrib/hu/t0/TransientInClusterFilter.py): Filter derived from the DecentFilter, in addition selecting candidates with position compatible with that of nearby galaxy clusters.. - [XShooterFilter](ampel/contrib/hu/t0/XShooterFilter.py): Filter derived from the DecentFilter, in addition selecting very new transients which are visible from the South. @@ -44,6 +45,7 @@ requires an access token if data is to be retrieved. - [T2BayesianBlocks](ampel/contrib/hu/t2/T2BayesianBlocks.py): T2 unit for running a bayesian block search algorithm to highlight excess regions. - [T2BrightSNProb](ampel/contrib/hu/t2/T2BrightSNProb.py): Derive a number of simple metrics describing the rise, peak and decline of a lc. - [T2CatalogMatchLocal](ampel/contrib/hu/t2/T2CatalogMatchLocal.py): Cross matches the position of a transient to those of sources in a set of catalogs. +- [T2DemoLightcurveFitter](ampel/contrib/hu/t2/T2DemoLightcurveFitter.py): Demonstration class showing how methods of T2BaseLightcurveFitter can be used develop a specific classifier. - [T2DigestRedshifts](ampel/contrib/hu/t2/T2DigestRedshifts.py): Compare potential matches from different T2 units providing redshifts. - [T2DustEchoEval](ampel/contrib/hu/t2/T2DustEchoEval.py) - [T2ElasticcRedshiftSampler](ampel/contrib/hu/t2/T2ElasticcRedshiftSampler.py): Parse the elasticc diaSource host information and returns a list of redshifts and weights. @@ -59,7 +61,7 @@ requires an access token if data is to be retrieved. - [T2MatchBTS](ampel/contrib/hu/t2/T2MatchBTS.py): Add information from the BTS explorer page. - [T2MultiXgbClassifier](ampel/contrib/hu/t2/T2MultiXgbClassifier.py): For a range of xgboost classifier models, find a classification. - [T2NedSNCosmo](ampel/contrib/hu/t2/T2NedSNCosmo.py): Fits lightcurves using SNCOSMO (using SALT2 defaultwise) with redshift constrained by catalog matching results. -- [T2NedTap](ampel/contrib/hu/t2/T2NedTap.py): See also. +- [T2NedTap](ampel/contrib/hu/t2/T2NedTap.py): See also: https://ned.ipac.caltech.edu/tap/sync?QUERY=SELECT+*+FROM+TAP_SCHEMA.tables&REQUEST=doQuery&LANG=ADQL&FORMAT=text Export all NED: https://ned.ipac.caltech.edu/tap/sync?QUERY=SELECT+*+FROM+NEDTAP.objdir&REQUEST=doQuery&LANG=ADQL&FORMAT=text. - [T2PS1ThumbExtCat](ampel/contrib/hu/t2/T2PS1ThumbExtCat.py): Retrieve panstarrs images at datapoint location and for each tied extcat catalog matching result. - [T2PS1ThumbNedSNCosmo](ampel/contrib/hu/t2/T2PS1ThumbNedSNCosmo.py): This state t2 unit is tied with the state T2 unit T2NedSNCosmo. - [T2PS1ThumbNedTap](ampel/contrib/hu/t2/T2PS1ThumbNedTap.py): This point t2 unit is tied with the point T2 unit T2NedTap. @@ -80,6 +82,7 @@ requires an access token if data is to be retrieved. - [HealpixCorrPlotter](ampel/contrib/hu/t3/HealpixCorrPlotter.py): Compare healpix coordinate P-value with output from T2RunSncosmo.. - [HealpixTokenGenerator](ampel/contrib/hu/t3/HealpixTokenGenerator.py): Based on a URL to a Healpix map. - [PlotLightcurveSample](ampel/contrib/hu/t3/PlotLightcurveSample.py): Unit plots results from lightcurve fitters (RunSncosmo, RunParsnip). +- [PlotTransientLightcurves](ampel/contrib/hu/t3/PlotTransientLightcurves.py): Create a (pdf) plot summarizing lightcurves of candidates provided to the unit. - [RandomMapGenerator](ampel/contrib/hu/t3/RandomMapGenerator.py): Generate smoothed circular healpix probability values around a random coordinate.. - [RapidBase](ampel/contrib/hu/t3/RapidBase.py): Trigger rapid reactions. - [RapidLco](ampel/contrib/hu/t3/RapidLco.py): Submit LCO triggers for candidates passing criteria.. @@ -87,7 +90,7 @@ requires an access token if data is to be retrieved. - [ScoreSingleObject](ampel/contrib/hu/t3/ScoreSingleObject.py): Calculate score based on how early a specific SN is detected. - [ScoreTNSObjects](ampel/contrib/hu/t3/ScoreTNSObjects.py): Calculate score based on detection time reported to TNS, if any.. - [SlackSummaryPublisher](ampel/contrib/hu/t3/SlackSummaryPublisher.py) -- [TNSTalker](ampel/contrib/hu/t3/TNSTalker.py): Get TNS name if existing, and submit selected candidates. +- [SubmitTNS](ampel/contrib/hu/t3/SubmitTNS.py): Submit candidates to TNS (unless already submitted). - [TransientInfoPrinter](ampel/contrib/hu/t3/TransientInfoPrinter.py) - [TransientTablePublisher](ampel/contrib/hu/t3/TransientTablePublisher.py): Construct a table based on selected T2 output values. - [TransientViewDumper](ampel/contrib/hu/t3/TransientViewDumper.py) diff --git a/ampel/contrib/hu/t0/StellarFilter.py b/ampel/contrib/hu/t0/StellarFilter.py new file mode 100755 index 00000000..f07c5750 --- /dev/null +++ b/ampel/contrib/hu/t0/StellarFilter.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File: Ampel-HU-astro/ampel/contrib/hu/t0/StellarFilter.py +# License: BSD-3-Clause +# Author: m. giomi +# Date: 06.06.2018 +# Last Modified Date: 26.03.2024 +# Last Modified By: jno + +from typing import Any + +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table + +from ampel.abstract.AbsAlertFilter import AbsAlertFilter +from ampel.protocol.AmpelAlertProtocol import AmpelAlertProtocol +from ampel.ztf.base.CatalogMatchUnit import CatalogMatchUnit + + +class StellarFilter(CatalogMatchUnit, AbsAlertFilter): + """ + + a.k.a. as the IndecentFilter, i.e. an inversion of the DecentFilter + mainly used for finding extragalactic objects. + + + todo: + look for flare + min mag deviation + periodogram (from elasticc curve) + + """ + + # History + min_ndet: int # number of previous detections + max_ndet: int # number of previous detections + min_tspan: float # minimum duration of alert detection history [days] + max_tspan: float # maximum duration of alert detection history [days] + min_archive_tspan: float = 0.0 # minimum duration of alert detection history [days] + max_archive_tspan: float = ( + 10**5.0 + ) # maximum duration of alert detection history [days] + + # Brightness / Flare + max_mag: float = 30.0 + peak_time_limit: float = ( + 10.0 # Will divide lightcurve before / after this. Set to 0 to disably [days] + ) + min_peak_diff: float = ( + 1.0 # Min mag difference between peak mag before/after limit *in any band* + ) + # Todo: select band? fit linear curve? + + # Image quality + min_drb: float = 0.0 # deep learning real bogus score + min_rb: float # real bogus score + max_fwhm: float = 5.5 # sexctrator FWHM (assume Gaussian) [pix] + max_elong: float = 1.4 # Axis ratio of image: aimage / bimage + max_magdiff: float = 1.0 # Difference: magap - magpsf [mag] + max_nbad: int = 0 # number of bad pixels in a 5 x 5 pixel stamp + + # Astro + min_sso_dist: float = 20 # distance to nearest solar system object [arcsec] + min_gal_lat: float = ( + -1 + ) # minium distance from galactic plane. Set to negative to disable cut. + max_gal_lat: float = 999 # maximum distance from galactic plane. + + # PS1 + require_ps_star: bool + avoid_ps_confusion: bool = False # Discard event if multiple nearby PS sources + ps1_sgveto_rad: float = ( + 1.0 # maximum distance to closest PS1 source for SG score veto [arcsec] + ) + ps1_sgveto_th: float = ( + 0.8 # maximum allowed SG score for PS1 source within PS1_SGVETO_RAD + ) + ps1_confusion_rad: float = 1.0 # reject alerts if the three PS1 sources are all within this radius [arcsec] + ps1_confusion_sg_tol: float = 0.1 # and if the SG score of all of these 3 sources is within this tolerance to 0.5 + + # Gaia + require_gaia_star: bool + gaia_rs: float = 20.0 # search radius for GAIA DR2 matching [arcsec] + gaia_pm_signif: float = ( + 3.0 # significance of proper motion detection of GAIA counterpart [sigma] + ) + gaia_plx_signif: float = ( + 3.0 # significance of parallax detection of GAIA counterpart [sigma] + ) + gaia_veto_gmag_min: float = ( + 9.0 # min gmag for normalized distance cut of GAIA counterparts [mag] + ) + gaia_veto_gmag_max: float = ( + 20.0 # max gmag for normalized distance cut of GAIA counterparts [mag] + ) + gaia_excessnoise_sig_max: float = 999.0 # maximum allowed noise (expressed as significance) for Gaia match to be trusted. + + def get_galactic_latitude(self, transient): + """ + compute galactic latitude of the transient + """ + coordinates = SkyCoord(transient["ra"], transient["dec"], unit="deg") + return coordinates.galactic.b.deg + + def is_star_in_PS1(self, transient) -> bool: + """ + apply combined cut on sgscore1 and distpsnr1 to reject the transient if + there is a PS1 star-like object in it's immediate vicinity + """ + + # TODO: consider the case of alert moving wrt to the position of a star + # maybe cut on the minimum of the distance! + return ( + transient["distpsnr1"] < self.ps1_sgveto_rad + and transient["sgscore1"] > self.ps1_sgveto_th + ) + + def is_confused_in_PS1(self, transient) -> bool: + """ + check in PS1 for source confusion, which can induce subtraction artifatcs. + These cases are selected requiring that all three PS1 cps are in the imediate + vicinity of the transient and their sgscore to be close to 0.5 within given tolerance. + """ + very_close = ( + max(transient["distpsnr1"], transient["distpsnr2"], transient["distpsnr3"]) + < self.ps1_confusion_rad + ) + + # Update 31.10.19: avoid costly numpy cast + # Old: + # In: %timeit abs(array([sg1, sg2, sg3]) - 0.5 ).max() + # Out: 5.79 µs ± 80.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) + # New: + # In: %timeit max(abs(sg1-0.5), abs(sg2-0.5), abs(sg3-0.5)) + # Out: 449 ns ± 7.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) + + sg_confused = ( + max( + abs(transient["sgscore1"] - 0.5), + abs(transient["sgscore2"] - 0.5), + abs(transient["sgscore3"] - 0.5), + ) + < self.ps1_confusion_sg_tol + ) + + return sg_confused and very_close + + def is_star_in_gaia(self, transient: dict[str, Any]) -> bool: + """ + match tranient position with GAIA DR2 and uses parallax + and proper motion to evaluate star-likeliness + returns: True (is a star) or False otehrwise. + """ + + srcs = self.cone_search_all( + transient["ra"], + transient["dec"], + [ + { + "name": "GAIADR2", + "use": "catsHTM", + "rs_arcsec": self.gaia_rs, + "keys_to_append": [ + "Mag_G", + "PMRA", + "ErrPMRA", + "PMDec", + "ErrPMDec", + "Plx", + "ErrPlx", + "ExcessNoiseSig", + ], + } + ], + )[0] + + if srcs: + gaia_tab = Table( + [ + {k: np.nan if v is None else v for k, v in src["body"].items()} + for src in srcs + ] + ) + + # compute distance + gaia_tab["DISTANCE"] = [src["dist_arcsec"] for src in srcs] + gaia_tab["DISTANCE_NORM"] = ( + 1.8 + 0.6 * np.exp((20 - gaia_tab["Mag_G"]) / 2.05) + > gaia_tab["DISTANCE"] + ) + gaia_tab["FLAG_PROX"] = [ + x["DISTANCE_NORM"] + and self.gaia_veto_gmag_min <= x["Mag_G"] <= self.gaia_veto_gmag_max + for x in gaia_tab + ] + + # check for proper motion and parallax conditioned to distance + gaia_tab["FLAG_PMRA"] = ( + abs(gaia_tab["PMRA"] / gaia_tab["ErrPMRA"]) > self.gaia_pm_signif + ) + gaia_tab["FLAG_PMDec"] = ( + abs(gaia_tab["PMDec"] / gaia_tab["ErrPMDec"]) > self.gaia_pm_signif + ) + gaia_tab["FLAG_Plx"] = ( + abs(gaia_tab["Plx"] / gaia_tab["ErrPlx"]) > self.gaia_plx_signif + ) + + # take into account precison of the astrometric solution via the ExcessNoise key + gaia_tab["FLAG_Clean"] = ( + gaia_tab["ExcessNoiseSig"] < self.gaia_excessnoise_sig_max + ) + + # select just the sources that are close enough and that are not noisy + gaia_tab = gaia_tab[gaia_tab["FLAG_PROX"]] + gaia_tab = gaia_tab[gaia_tab["FLAG_Clean"]] + + # among the remaining sources there is anything with + # significant proper motion or parallax measurement + if ( + any(gaia_tab["FLAG_PMRA"] == True) # noqa + or any(gaia_tab["FLAG_PMDec"] == True) # noqa + or any(gaia_tab["FLAG_Plx"] == True) # noqa + ): + return True + + return False + + # Override + def process(self, alert: AmpelAlertProtocol) -> None | bool | int: + """ + Mandatory implementation. + To exclude the alert, return *None* + To accept it, either return + * self.on_match_t2_units + * or a custom combination of T2 unit names + """ + + # CUT ON THE HISTORY OF THE ALERT + ################################# + + pps = [el for el in alert.datapoints if el.get("candid") is not None] + if len(pps) < self.min_ndet or len(pps) > self.max_ndet: + return None + + # cut on length of detection history + detections_jds = [el["jd"] for el in pps] + det_tspan = max(detections_jds) - min(detections_jds) + if not (self.min_tspan <= det_tspan <= self.max_tspan): + return None + + # IMAGE QUALITY CUTS + #################### + + latest = alert.datapoints[0] + + if latest["isdiffpos"] == "f" or latest["isdiffpos"] == "0": + return None + + if latest["rb"] < self.min_rb: + return None + + if "drb" in latest and self.min_drb > 0.0 and latest["drb"] < self.min_drb: + return None + + if latest["fwhm"] > self.max_fwhm: + return None + + if latest["elong"] > self.max_elong: + return None + + if abs(latest["magdiff"]) > self.max_magdiff: + return None + + # cut on archive length + if "jdendhist" in latest and "jdstarthist" in latest: + archive_tspan = latest["jdendhist"] - latest["jdstarthist"] + if not (self.min_archive_tspan < archive_tspan < self.max_archive_tspan): + return None + + # Recent lightcurve brightness + ########### + + if latest["magpsf"] > self.max_mag: + return None + + pre_pp = [ + dp + for dp in pps + if "magpsf" in dp and (latest["jd"] - dp["jd"]) > self.peak_time_limit + ] + post_pp = [ + dp + for dp in pps + if "magpsf" in dp and (latest["jd"] - dp["jd"]) <= self.peak_time_limit + ] + if len(pre_pp) == 0: + return None + # Could also sort these for filter + mdiff = np.mean([pp["magpsf"] for pp in pre_pp]) - np.mean( + [pp["magpsf"] for pp in post_pp] + ) + if mdiff < self.min_peak_diff: + return None + + # ASTRONOMY + ########### + + # check for closeby ss objects + if 0 <= latest["ssdistnr"] < self.min_sso_dist: + return None + + # cut on galactic latitude + b = self.get_galactic_latitude(latest) + if abs(b) < self.min_gal_lat: + return None + if abs(b) > self.max_gal_lat: + return None + + # check ps1 star-galaxy score + if self.require_ps_star and not self.is_star_in_PS1(latest): + return None + if self.avoid_ps_confusion and self.is_confused_in_PS1(latest): + return None + + # check with gaia + if self.require_gaia_star and self.is_star_in_gaia(latest): + return None + + return True diff --git a/ampel/contrib/hu/t2/T2BaseLightcurveFitter.py b/ampel/contrib/hu/t2/T2BaseLightcurveFitter.py new file mode 100644 index 00000000..4eb056c2 --- /dev/null +++ b/ampel/contrib/hu/t2/T2BaseLightcurveFitter.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File : ampel/contrib/hu/t2/T2BaseLightcurveFitter.py +# License : BSD-3-Clause +# Author : jnordin@physik.hu-berlin.de +# Date : 24.09.2021 +# Last Modified Date: 22.04.2022 +# Last Modified By : jnordin@physik.hu-berlin.de + +from collections.abc import Sequence + +# The following three only used if correcting for MW dust +import extinction # type: ignore[import] +import numpy as np +import sncosmo # type: ignore[import] +from astropy.table import Table +from sfdmap2.sfdmap import SFDMap # type: ignore[import] + +from ampel.abstract.AbsTabulatedT2Unit import AbsTabulatedT2Unit +from ampel.base.decorator import abstractmethod +from ampel.content.DataPoint import DataPoint +from ampel.content.T1Document import T1Document +from ampel.contrib.hu.t2.T2DigestRedshifts import T2DigestRedshifts +from ampel.struct.UnitResult import UnitResult +from ampel.types import UBson +from ampel.view.T2DocView import T2DocView + + +class T2BaseLightcurveFitter(T2DigestRedshifts, AbsTabulatedT2Unit, abstract=True): + """ + + Base class for constructing lightcurve fitters. + Includes step common to most lightcurve fitters: + - Obtain a table of flux values + - Get a redshift (fixed or from catalogs through T2DigestRedshifts. + - Correct flux for MW reddening + - Restricting table to phase range as determined from other units + + + """ + + # Adding default redshift selection values, corresponding to usage of "good" redshifts from catalogs + redshift_kind = "T2DigestRedshifts" + max_redshift_category: int = 3 + + # Remove MW dust absorption. + # MWEBV is either derived from SFD maps using the position from light_curve + # (assuming the SFD_DIR env var is set) OR retrieved from stock (ELASTICHOW?) + # The default value of Rv will be used. + # Using this requires extinction, sfdmap and SNCOSMO to be installed. The latter is used to determine effective wavelengths + apply_mwcorrection: bool = False + + # Phase range usage. Current option: + # T2PhaseLimit : use the jdmin jdmax provided in this unit output + # None : use full datapoint range + phaseselect_kind: None | str + + def post_init(self) -> None: + """ + Retrieve models and potentially dustmaps. + """ + + if self.apply_mwcorrection: + self.dustmap = SFDMap() + # Load e.g. model files as needed + # e.g. self.model = parsnip.load_model(self.parsnip_model, threads=1) + + def _get_phaselimit(self, t2_views) -> tuple[None | float, None | float]: + """ + Can potentially also be replaced with some sort of tabulator? + + """ + + # Examine T2s for eventual information + jdstart: None | float = None + jdend: None | float = None + + if self.phaseselect_kind is None: + jdstart = -np.inf + jdend = np.inf + else: + for t2_view in t2_views: + # So far only knows how to parse phases from T2PhaseLimit + if t2_view.unit != "T2PhaseLimit": + continue + self.logger.debug(f"Parsing t2 results from {t2_view.unit}") + t2_res = ( + res[-1] if isinstance(res := t2_view.get_payload(), list) else res + ) + jdstart = t2_res["t_start"] + jdend = t2_res["t_end"] + + return jdstart, jdend + + def _deredden_mw_extinction(self, ebv, phot_tab, rv=3.1) -> Table: + """ + For an input photometric table, try to correct for mw extinction. + Resuires extinction & sncosmo to be loaded, and that sncosmo knows the band wavelength. + """ + + # Find effective wavelength for all filters in phot_tab + filterlist = set(phot_tab["band"]) + eff_wave = [sncosmo.get_bandpass(f).wave_eff for f in filterlist] + + # Determine flux correction (dereddening) factors + flux_corr = 10 ** (0.4 * extinction.ccm89(np.array(eff_wave), ebv * rv, rv)) + + # Assign this appropritately to Table + phot_tab["flux_original"] = phot_tab["flux"] + phot_tab["fluxerr_original"] = phot_tab["fluxerr"] + for k, band in enumerate(filterlist): + phot_tab["flux"][(phot_tab["band"] == band)] *= flux_corr[k] + phot_tab["fluxerr"][(phot_tab["band"] == band)] *= flux_corr[k] + + return phot_tab + + def get_fitdata( + self, + datapoints: Sequence[DataPoint], + t2_views: Sequence[T2DocView], + ) -> tuple[Table | None, dict]: + """ + + Obtain data necessary for fit: + - Flux table, possiby constrained in time. + (corrected for MW extinction if set) + - Redshift. Possibly multiple values, possibe None + + Returns + ------- + dict + """ + + # Fit data info + fitdatainfo: dict[str, UBson] = {} + + # Check for phase limits + (jdstart, jdend) = self._get_phaselimit(t2_views) + fitdatainfo["jdstart"] = jdstart + fitdatainfo["jdend"] = jdend + if fitdatainfo["jdstart"] is None: + return (None, fitdatainfo) + + # Obtain photometric table + sncosmo_table = self.get_flux_table(datapoints) + sncosmo_table = sncosmo_table[ + (sncosmo_table["time"] >= jdstart) & (sncosmo_table["time"] <= jdend) + ] + + # Potentially correct for dust absorption + # Requires filters to be known by sncosmo (and the latter installed) + if self.apply_mwcorrection: + # Get ebv from coordiantes. + # Here there should be some option to read it from journal/stock etc + mwebv = self.dustmap.ebv(*self.get_pos(datapoints, which="mean")) + fitdatainfo["mwebv"] = mwebv + sncosmo_table = self._deredden_mw_extinction(mwebv, sncosmo_table) + + ## Obtain redshift(s) from T2DigestRedshifts + zlist, z_source, z_weights = self.get_redshift(t2_views) + fitdatainfo["z"] = zlist + fitdatainfo["z_source"] = z_source + fitdatainfo["z_weights"] = z_weights + # A source class of None indicates that a redshift source was required, but not found. + if not isinstance(zlist, list) or z_source is None: + return (None, fitdatainfo) + + return (sncosmo_table, fitdatainfo) + + # ==================== # + # AMPEL T2 MANDATORY # + # ==================== # + @abstractmethod + def process( + self, + compound: T1Document, + datapoints: Sequence[DataPoint], + t2_views: Sequence[T2DocView], + ) -> UBson | UnitResult: + """ + + Fit a model to the lightcurve of this transient. + See T2DemoLightcurveFitter + + Returns + ------- + dict + """ + + raise NotImplementedError + return None diff --git a/ampel/contrib/hu/t2/T2DemoLightcurveFitter.py b/ampel/contrib/hu/t2/T2DemoLightcurveFitter.py new file mode 100644 index 00000000..317ad010 --- /dev/null +++ b/ampel/contrib/hu/t2/T2DemoLightcurveFitter.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File : ampel/contrib/hu/t2/T2DemoLightcurveFitter.py +# License : BSD-3-Clause +# Author : jnordin@physik.hu-berlin.de +# Date : 24.09.2021 +# Last Modified Date: 06.04.2022 +# Last Modified By : jnordin@physik.hu-berlin.de + +from collections.abc import Sequence + +# The following three only used if correcting for MW dust +import numpy as np + +from ampel.content.DataPoint import DataPoint +from ampel.content.T1Document import T1Document +from ampel.contrib.hu.t2.T2BaseLightcurveFitter import T2BaseLightcurveFitter +from ampel.struct.UnitResult import UnitResult +from ampel.types import UBson +from ampel.view.T2DocView import T2DocView + + +class T2DemoLightcurveFitter(T2BaseLightcurveFitter): + """ + + Demonstration class showing how methods of T2BaseLightcurveFitter can be used + develop a specific classifier. + + Variables of T2BaseLightcurveFitter determines how to load data and whether to + try to load a redshift. + + """ + + fit_order: int = 3 + + def post_init(self) -> None: + """ + Retrieve models and potentially dustmaps. + """ + + super().post_init() + + def process( + self, + compound: T1Document, + datapoints: Sequence[DataPoint], + t2_views: Sequence[T2DocView], + ) -> UBson | UnitResult: + """ + + Fit a model to the lightcurve of this transient. + Called for each state of each transient. + + Returns + ------- + dict + """ + + # Initialize output dict + t2_output: dict[str, UBson] = { + "model": "demoFitter", + "order": self.fit_order, + } + + # Load photometry as table + # fitdatainfo contains redshift, if requested + (sncosmo_table, fitdatainfo) = self.get_fitdata(datapoints, t2_views) + t2_output.update(fitdatainfo) + if sncosmo_table is None: + return t2_output + + # Lightcurve encoded in sncosmo_table can now be fit to model. + # Fit parameters stored in t2_output and returned. + + for band in set(sncosmo_table["band"]): + i = sncosmo_table["band"] == band + if sum(i) <= (self.fit_order + 1): # Not enough data in band + continue + t2_output["polyfit" + band] = list( + np.polyfit( + sncosmo_table[i]["time"], + sncosmo_table[i]["flux"], + deg=self.fit_order, + ) + ) + + return t2_output diff --git a/ampel/contrib/hu/t2/T2DigestRedshifts.py b/ampel/contrib/hu/t2/T2DigestRedshifts.py index d44ec7e9..436c5a4b 100644 --- a/ampel/contrib/hu/t2/T2DigestRedshifts.py +++ b/ampel/contrib/hu/t2/T2DigestRedshifts.py @@ -12,16 +12,17 @@ import numpy as np -from ampel.abstract.AbsTiedLightCurveT2Unit import AbsTiedLightCurveT2Unit +from ampel.abstract.AbsTiedStateT2Unit import AbsTiedStateT2Unit +from ampel.content.DataPoint import DataPoint +from ampel.content.T1Document import T1Document from ampel.enum.DocumentCode import DocumentCode from ampel.model.StateT2Dependency import StateT2Dependency from ampel.struct.UnitResult import UnitResult from ampel.types import UBson -from ampel.view.LightCurve import LightCurve from ampel.view.T2DocView import T2DocView -class T2DigestRedshifts(AbsTiedLightCurveT2Unit): +class T2DigestRedshifts(AbsTiedStateT2Unit): """ Compare potential matches from different T2 units providing redshifts. @@ -55,6 +56,23 @@ class T2DigestRedshifts(AbsTiedLightCurveT2Unit): # "z_group": "which redshift group to assign to" } catalogmatch_override: None | dict[str, Any] + # Options for the get_redshift option + # T2MatchBTS : Use the redshift published by BTS and synced by that T2. + # T2DigestRedshifts : Use the best redshift as parsed by DigestRedshift. + # AmpelZ: equal to T2DigestRedshifts + # T2ElasticcRedshiftSampler: Use a list of redshifts and weights from the sampler. + # None : Use the fixed z value + redshift_kind: None | Literal[ + "T2MatchBTS", "T2DigestRedshifts", "T2ElasticcRedshiftSampler", "AmpelZ" + ] = None + + # It is also possible to use fixed redshift whenever a dynamic redshift kind is not possible + # This could be either a single value or a list + fixed_z: None | float | Sequence[float] = None + # Finally, the provided lens redshift might be multiplied with a scale + # Useful for lensing studies, or when trying multiple values + scale_z: None | float = None + # These are the units through which we look for redshifts # Which units should this be changed to t2_dependency: Sequence[ @@ -324,22 +342,15 @@ def _get_matchbts_groupz(self, t2_res: dict[str, Any]) -> list[list[float]]: return group_z - # ==================== # - # AMPEL T2 MANDATORY # - # ==================== # - def process( - self, light_curve: LightCurve, t2_views: Sequence[T2DocView] - ) -> UBson | UnitResult: + def get_ampelZ(self, t2_views: Sequence[T2DocView]) -> UBson | UnitResult: """ Parse t2_views from catalogs that were part of the redshift studies. Return these together with a "best estimate" - ampel_z - """ + Main method, separated to be used externally. - if not t2_views: # Should not happen actually, T2Processor catches that case - self.logger.error("Missing tied t2 views") - return UnitResult(code=DocumentCode.T2_MISSING_INFO) + """ # Loop through all potential T2s with redshift information. # Each should return an array of arrays, corresponding to redshift maches @@ -392,5 +403,101 @@ def process( if self.catalogmatch_override: t2_output["AmpelZ-Warning"] = "Override catalog in use." - self.logger.debug("digest redshift: %s" % (t2_output)) return t2_output + + def get_redshift( + self, t2_views + ) -> tuple[None | list[float], None | str, None | list[float]]: + """ + + Return a single or list of redshifts to be used. Not called in T2DigestRedshift.process + but provides interface to e.g. fit units. + + """ + + # Examine T2s for eventual information + z: None | list[float] = None + z_source: None | str = None + z_weights: None | list[float] = None + + if self.redshift_kind in [ + "T2DigestRedshifts", + "AmpelZ", + ]: + t2_res = self.get_ampelZ(t2_views) + if ( + isinstance(t2_res, dict) + and "ampel_z" in t2_res + and t2_res["ampel_z"] is not None + and t2_res["group_z_nbr"] <= self.max_redshift_category + ): + z = [float(t2_res["ampel_z"])] + z_source = "AMPELz_group" + str(t2_res["group_z_nbr"]) + elif self.redshift_kind in [ + "T2MatchBTS", + "T2DigestRedshifts", + "T2ElasticcRedshiftSampler", + ]: + for t2_view in t2_views: + if t2_view.unit != self.redshift_kind: + continue + self.logger.debug(f"Parsing t2 results from {t2_view.unit}") + t2_res = ( + res[-1] if isinstance(res := t2_view.get_payload(), list) else res + ) + # Parse this + if self.redshift_kind == "T2MatchBTS": + if ( + isinstance(t2_res, dict) + and "bts_redshift" in t2_res + and t2_res["bts_redshift"] != "-" + ): + z = [float(t2_res["bts_redshift"])] + z_source = "BTS" + elif self.redshift_kind == "T2ElasticcRedshiftSampler" and isinstance( + t2_res, dict + ): + z = t2_res["z_samples"] + z_source = t2_res["z_source"] + z_weights = t2_res["z_weights"] + # Check if there is a fixed z set for this run, otherwise keep as free parameter + elif self.fixed_z is not None: + if isinstance(self.fixed_z, float): + z = [self.fixed_z] + else: + z = list(self.fixed_z) + z_source = "Fixed" + else: + z = None + z_source = "Fitted" + + if (z is not None) and (z_source is not None) and self.scale_z: + z = [onez * self.scale_z for onez in z] + z_source += f" + scaled {self.scale_z}" + + return z, z_source, z_weights + + # ==================== # + # AMPEL T2 MANDATORY # + # ==================== # + def process( + self, + compound: T1Document, + datapoints: Sequence[DataPoint], + t2_views: Sequence[T2DocView], + ) -> UBson | UnitResult: + # def process( + # self, light_curve: LightCurve, t2_views: Sequence[T2DocView] + # ) -> UBson | UnitResult: + """ + + Parse t2_views from catalogs that were part of the redshift studies. + Return these together with a "best estimate" - ampel_z + + """ + + if not t2_views: # Should not happen actually, T2Processor catches that case + self.logger.error("Missing tied t2 views") + return UnitResult(code=DocumentCode.T2_MISSING_INFO) + + return self.get_ampelZ(t2_views) diff --git a/ampel/contrib/hu/t2/T2InfantCatalogEval.py b/ampel/contrib/hu/t2/T2InfantCatalogEval.py index 1ea39b5c..d49085ad 100644 --- a/ampel/contrib/hu/t2/T2InfantCatalogEval.py +++ b/ampel/contrib/hu/t2/T2InfantCatalogEval.py @@ -84,6 +84,8 @@ class T2InfantCatalogEval(AbsTiedLightCurveT2Unit): rb_minmed: float = 0.3 # Minimal median RB. drb_minmed: float = 0.995 + # Minimal pull w.r.t to image magnitude limit (i.e. (diffmaglim-mag)/magerr)) + min_magpull: float = 0.0 # Limiting magnitude to consider upper limits as 'significant' maglim_min: float = 19.5 @@ -146,9 +148,9 @@ def inspect_catalog(self, cat_res: dict[str, Any]) -> None | dict[str, Any]: # Special catalog searches - mark transients close to AGNs milliquas = cat_res.get("milliquas", False) sdss_spec = cat_res.get("SDSS_spec", False) - if milliquas and milliquas["redshift"] > 0: + if milliquas and milliquas.get("redshift", -1) > 0: info["milliAGN"] = True - if sdss_spec and sdss_spec["bptclass"] in [4, 5]: + if sdss_spec and sdss_spec.get("bptclass", -99) in [4, 5]: info["sdssAGN"] = True # Return collected info @@ -188,6 +190,20 @@ def inspect_lc(self, lc: LightCurve) -> None | dict[str, Any]: return None info["age"] = age + # cut on pull compared with image diffmaglim + magpull = sum( + [ + (pp["body"]["diffmaglim"] - pp["body"]["magpsf"]) + / pp["body"]["sigmapsf"] + for pp in pps + ] + ) + info["mag_pull"] = magpull + # cut on which filters used + if magpull < self.min_magpull: + self.logger.debug("Rejected", extra={"mag_pull": magpull}) + return None + # cut on number of detection after last SIGNIFICANT UL ulims = lc.get_upperlimits( filters={ diff --git a/ampel/contrib/hu/t2/T2PolynomialFit.py b/ampel/contrib/hu/t2/T2PolynomialFit.py new file mode 100644 index 00000000..64436067 --- /dev/null +++ b/ampel/contrib/hu/t2/T2PolynomialFit.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File : ampel/contrib/hu/t2/T2PolynomialFit.py +# License : BSD-3-Clause +# Author : jnordin@physik.hu-berlin.de +# Date : 29.03.2024 +# Last Modified Date: 29.03.2024 +# Last Modified By : jnordin@physik.hu-berlin.de + +from collections.abc import Iterable +from typing import Any + +import numpy as np +from astropy.table.column import Column + +from ampel.abstract.AbsStateT2Unit import AbsStateT2Unit +from ampel.abstract.AbsTabulatedT2Unit import AbsTabulatedT2Unit +from ampel.content.DataPoint import DataPoint +from ampel.content.T1Document import T1Document +from ampel.struct.UnitResult import UnitResult +from ampel.types import UBson + + +class T2PolynomialFit(AbsStateT2Unit, AbsTabulatedT2Unit): + """ + Fit polynomial to each band. + Return coefficients and chi2/dof + """ + + order: int = 1 + + def eval_polyfit( + self, time: Column, flux: Column, flux_err: Column + ) -> dict[str, Any]: + """ + Fit and evaluate a polynomical fit to input data. + """ + pfit = np.polynomial.Polynomial.fit(time, flux, self.order, w=1 / flux_err) + dof = len(time) - self.order - 1 + + return { + "chi2dof": sum((flux - pfit(time)) ** 2 / flux_err**2) / dof, + "dof": dof, + "coef": list(pfit.convert().coef), + } + + def process( + self, + compound: T1Document, + datapoints: Iterable[DataPoint], + ) -> UBson | UnitResult: + """ + Executed for each transient state. + """ + + # Obtain photometric table + flux_table = self.get_flux_table(datapoints) + + # Fit polynmial to each band + fitinfo: dict[str, Any] = {} + for band in set(flux_table["band"]): + i = flux_table["band"] == band + if sum(i) <= (self.order + 1): # Not enough data in band + continue + fitinfo[band] = self.eval_polyfit( + flux_table[i]["time"], flux_table[i]["flux"], flux_table[i]["fluxerr"] + ) + + # Average over bands + if len(fitinfo) > 0: + for o in range(self.order + 1): + fitinfo[f"p{o}"] = np.mean( + [v["coef"][o] for k, v in fitinfo.items() if isinstance(v, dict)] + ) + fitinfo["chi2dof"] = np.mean( + [v["chi2dof"] for k, v in fitinfo.items() if isinstance(v, dict)] + ) + fitinfo["order"] = self.order + return fitinfo + + return {"order": self.order, "chi2dof": None} diff --git a/ampel/contrib/hu/t2/T2TNSEval.py b/ampel/contrib/hu/t2/T2TNSEval.py index 705eb2be..cc8bdd4e 100644 --- a/ampel/contrib/hu/t2/T2TNSEval.py +++ b/ampel/contrib/hu/t2/T2TNSEval.py @@ -15,7 +15,7 @@ # T2 importing info from T3. Restructure? from ampel.abstract.AbsTiedLightCurveT2Unit import AbsTiedLightCurveT2Unit -from ampel.contrib.hu.t3.ampel_tns import ( +from ampel.contrib.hu.t3.tns.tns_ampel_util import ( TNSFILTERID, ) from ampel.struct.UnitResult import UnitResult diff --git a/ampel/contrib/hu/t3/AstroColibriPublisher.py b/ampel/contrib/hu/t3/AstroColibriPublisher.py index cb1f5f65..2243e820 100644 --- a/ampel/contrib/hu/t3/AstroColibriPublisher.py +++ b/ampel/contrib/hu/t3/AstroColibriPublisher.py @@ -12,7 +12,6 @@ # from itertools import islice import os -import re from collections.abc import Generator import numpy as np @@ -78,6 +77,22 @@ class AstroColibriPublisher(AbsPhotoT3Unit): Will update if new obs was made after last posting. + AC requests: + + + - Can you make sure that the event time corresponds to the one submitted to TNS? + - We suggest a slight renaming of the events: + x * source_name: use the name of the event given by TNS and add "(Ampel)" to it. Example: TNS name "AT 2024edy" => "AT 2024edy (Ampel)" + x * trigger_id: keep as it is + x * we suggest you fill the ZTF identifier (e.g. ZTF24aahbwis) that you currently use a source_name into the field "discoverer_internal_name" + x Type: for events that are not submitted as classified to TNS (i.e. listed as AT in TNS), please change the value of the "type" parameter to "ot" (instead of "ot_sn") + * Classification: you can add the Ampel classification into the field "classification". E.g. if these are supernova candidates use "SN" or any classification that one could also find on TNS + + + + + + """ # Limits for attributes @@ -156,23 +171,6 @@ def submitted(self, view: "TransientView") -> bool: ] ) - def get_tnsname(self, view: "TransientView") -> bool: - # Was transient (successfully pushed) - if not view.stock: - return False - # Check whether the name is found in the name collection - if len(names := view.stock.get("name", [])) > 0: - # Will only be able to require TNS name through format - # dddd - for name in names: - if re.search(r"\d{4}\D{3}\D?", name): - return name - - # Should we look through the Journal for entries from the TNSTalker? - # It *should* also save these entries to name so should not be needed... - - return None - def process( self, tviews: Generator[TransientView, JournalAttributes, None], @@ -184,41 +182,16 @@ def process( """ for tview in tviews: - if self.randname: - import random - - # Generate random name (assuming publishing to dev AC) - tns_name = f"AmpelRand{random.randint(1, 999)}" - else: - # Find TNS name (required for AstroColibri posting) - # Currently assumes that this is stored either in the - # stock name list or can be found in the T3 journal - # (probably from the TNSTalker) - tns_name = self.get_tnsname(tview) - if not tns_name: - self.logger.info("No TNS.", extra={"tnsName": None}) - continue - - # Check if this was submitted - # TODO: How should the first submit differ from updates? - if self.submitted(tview): - # Check if it needs an update - if self.requires_update(tview): - post_update = True - else: - continue - else: - post_update = False - - # Gather general information + # Gather general information, including coordinate payload = { - "type": "ot_sn", # for optical? when to change to ot_sn? + "type": "ot", "observatory": "ztf", - "source_name": to_ztf_id(int(tview.id)), + # "source_name": to_ztf_id(int(tview.id)), + "discoverer_internal_name": to_ztf_id(int(tview.id)), # 'trigger_id': self.trigger_id+':'+str(tview.id), # How do these work? - "trigger_id": "TNS" + tns_name, + # "trigger_id": "TNS" + tns_name, # 'ivorn': self.trigger_id+':'+str(tview.id), # Need ivorn schema - "timestamp": Time.now().iso, + # "timestamp": Time.now().iso, } # Gather photometry based information @@ -235,6 +208,38 @@ def process( payload["dec"] = np.mean([pp["body"]["dec"] for pp in dps_det]) payload["err"] = 1.0 / 3600 # position err ~1 arcsec in dec + # Find TNS name + if self.randname: + import random + + # Generate random name (assuming publishing to dev AC) + tns_name = f"AmpelRand{random.randint(1, 999)}" + tns_submission_time = Time.now().iso + elif isinstance(tview.extra, dict) and "TNSReports" in tview.extra: + # A tns name is required, here obtained from the mirror DB through a T3 complement + tnsreport = next(iter(tview.extra["TNSReports"])) + tns_name = tnsreport["objname"] + tns_submission_time = tnsreport["discoverydate"] + print("FOUND TNS STUFF", tns_name, tns_submission_time) + else: + self.logger.debug("No TNS name", extra={"tnsName": None}) + continue + + payload["trigger_id"] = "TNS" + tns_name + payload["source_name"] = tns_name + " (AMPEL)" + payload["time"] = tns_submission_time + + # Check if this was submitted + # TODO: How should the first submit differ from updates? + if self.submitted(tview): + # Check if it needs an update + if self.requires_update(tview): + post_update = True + else: + continue + else: + post_update = False + # If part of random testing, perturb coordinates if self.randname: payload["ra"] += random.randrange(-1, 1) @@ -242,16 +247,16 @@ def process( payload["source_name"] += payload["trigger_id"][12:] # Gather attributes - attributes = [] + attributes = {"classification": {}, "ampelProp": []} # Nearby attribute t2res = tview.get_t2_body(unit="T2DigestRedshifts") if isinstance(t2res, dict) and t2res.get("ampel_z", 999) < self.nearby_z: - attributes.append("Nearby") - attributes.append("AmpelZ{:.2f}".format(t2res["ampel_z"])) + attributes["ampelProp"].append("Nearby") + attributes["ampelProp"].append("AmpelZ{:.2f}".format(t2res["ampel_z"])) # Infant attribute t2res = tview.get_t2_body(unit="T2InfantCatalogEval") if isinstance(t2res, dict) and t2res.get("action", False): - attributes.append("Young") + attributes["ampelProp"].append("Young") # SNIa t2res = tview.get_t2_body(unit="T2RunParsnip") if ( @@ -259,27 +264,37 @@ def process( and "classification" in t2res and t2res["classification"]["SNIa"] > self.snia_minprob ): - attributes.append("ProbSNIa") + attributes["ampelProp"].append("ProbSNIa") + attributes["classification"] = { + "class": ["SNIa", "Other"], + "prob": [ + t2res["classification"]["SNIa"], + 1 - t2res["classification"]["SNIa"], + ], + } # Kilonovaness t2res = tview.get_t2_body(unit="T2KilonovaEval") if ( isinstance(t2res, dict) and t2res.get("kilonovaness", -99) > self.min_kilonovaness ): - attributes.append("Kilonovaness{}".format(t2res["kilonovaness"])) - attributes.append("LVKmap{}".format(t2res["map_name"])) + attributes["ampelProp"].append( + "Kilonovaness{}".format(t2res["kilonovaness"]) + ) + attributes["ampelProp"].append("LVKmap{}".format(t2res["map_name"])) # Check whether we have a figure to upload. # Assuming this exists locally under {stock}.png if self.image_path is not None: ipath = self.image_path.replace("ZTFNAME", to_ztf_id(int(tview.id))) + ipath = self.image_path.replace("STOCK", str(tview.id)) # Only upload if it actually exists: if not os.path.isfile(ipath): ipath = None else: ipath = None - payload["ampel_attributes"] = attributes + payload["broker_attributes"] = attributes self.logger.debug("reacting", extra={"payload": payload}) # Ok, so we have a transient to react to @@ -300,6 +315,9 @@ def process( } else: + print("colibri submitting") + print(payload) + print(ipath) jcontent = self.colibriclient.firestore_post(payload, image_path=ipath) if jcontent: diff --git a/ampel/contrib/hu/t3/PlotTransientLightcurves.py b/ampel/contrib/hu/t3/PlotTransientLightcurves.py new file mode 100755 index 00000000..9445589d --- /dev/null +++ b/ampel/contrib/hu/t3/PlotTransientLightcurves.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File: Ampel-contrib-HU/ampel/contrib/hu/t3/PlotTransientLightcurves.py +# License: BSD-3-Clause +# Author: valery brinnel +# Date: 11.06.2018 +# Last Modified Date: 30.07.2021 +# Last Modified By: valery brinnel + +import gzip +import io +import os +from collections.abc import Generator, Iterable +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +from astropy import units as u +from astropy import visualization +from astropy.cosmology import FlatLambdaCDM +from astropy.io import fits +from astropy.table import Table +from astropy.time import Time +from matplotlib.backends.backend_pdf import PdfPages +from matplotlib.colors import Normalize +from matplotlib.ticker import MultipleLocator +from slack import WebClient # type: ignore +from ztfquery.utils.stamps import get_ps_stamp # type: ignore + +from ampel.abstract.AbsPhotoT3Unit import AbsPhotoT3Unit +from ampel.abstract.AbsTabulatedT2Unit import AbsTabulatedT2Unit +from ampel.content.DataPoint import DataPoint +from ampel.secret.NamedSecret import NamedSecret +from ampel.struct.T3Store import T3Store +from ampel.struct.UnitResult import UnitResult +from ampel.types import T3Send, UBson +from ampel.view.TransientView import TransientView + + +# Base structure from nuztf (Stein, Reusch) +def fig_from_fluxtable( + name: str, + ampelid: str, + ra: float, + dec: float, + fluxtable: Table, + ztfulims: Table | None = None, + figsize: tuple = (8, 5), + title: None | str = None, + tnsname: str | None = None, + fritzlink: bool = True, + attributes: list = [], # noqa: B006 + cutouts: dict | None = None, + mag_range: None | list = None, + z: float | None = None, + zp: float = 25.0, + legend: bool = False, + grid_interval: None | int = None, + t_0_jd: None | float = None, + cutout_cache_dir: str = ".", +): + """ + Create a lightcurve figure (in mag space) based on + flux tables from AMPEL + """ + + cosmo = FlatLambdaCDM(H0=70, Om0=0.3) + + # Transform to mag (option to plot flux?). Will fail for negative values... + fluxtable["mag"] = -2.5 * np.log10(fluxtable["flux"]) + zp + fluxtable["magerr"] = np.abs( + -2.5 * fluxtable["fluxerr"] / (fluxtable["flux"] * np.log(10)) + ) + + if z is not None and np.isnan(z): + z = None + + # Keys as provided by tabulators + BANDPASSES = { + "ztfg": {"label": "ZTF g", "c": "green"}, + "ztfr": {"label": "ZTF R", "c": "red"}, + "ztfi": {"label": "ZTF i", "c": "orange"}, + } + + fig = plt.figure(figsize=figsize) + + # Prepare plot sections + if cutouts: + lc_ax1 = fig.add_subplot(5, 4, (9, 19)) + cutoutsci = fig.add_subplot(5, 4, (1, 5)) + cutouttemp = fig.add_subplot(5, 4, (2, 6)) + cutoutdiff = fig.add_subplot(5, 4, (3, 7)) + cutoutps1 = fig.add_subplot(5, 4, (4, 8)) + else: + lc_ax1 = fig.add_subplot(1, 1, 1) + fig.subplots_adjust(top=0.8, bottom=0.15) + + plt.subplots_adjust(wspace=0.4, hspace=1.8) + + if cutouts: + for ax_, type_ in zip( + [cutoutsci, cutouttemp, cutoutdiff], + ["Science", "Template", "Difference"], + strict=False, + ): + create_stamp_plot(cutouts=cutouts, ax=ax_, cutout_type=type_) + + img_cache = os.path.join(cutout_cache_dir, f"{name}_PS1.png") + + if not os.path.isfile(img_cache): + img = get_ps_stamp(ra, dec, size=240, color=["y", "g", "i"]) + img.save(img_cache) + + else: + from PIL import Image + + img = Image.open(img_cache) + + cutoutps1.imshow(np.asarray(img)) + cutoutps1.set_title("PS1", fontdict={"fontsize": "small"}) + cutoutps1.set_xticks([]) + cutoutps1.set_yticks([]) + + # If redshift is given, calculate absolute magnitude via luminosity distance + # and plot as right axis + if z is not None: + dist_l = cosmo.luminosity_distance(z).to(u.pc).value + + def mag_to_absmag(mag): + return mag - 5 * (np.log10(dist_l) - 1) + + def absmag_to_mag(absmag): + return absmag + 5 * (np.log10(dist_l) - 1) + + lc_ax3 = lc_ax1.secondary_yaxis( + "right", functions=(mag_to_absmag, absmag_to_mag) + ) + + if not cutouts: + lc_ax3.set_ylabel("Absolute Magnitude [AB]") + + # Give the figure a title + if not cutouts: + if title is None: + fig.suptitle(f"{name}", fontweight="bold") + else: + fig.suptitle(title, fontweight="bold") + + if grid_interval is not None: + lc_ax1.xaxis.set_major_locator(MultipleLocator(grid_interval)) + + lc_ax1.grid(visible=True, axis="both", alpha=0.5) + lc_ax1.set_ylabel("Magnitude [AB]") + + if not cutouts: + lc_ax1.set_xlabel("JD") + + # Determine magnitude limits + if mag_range is None: + max_mag = np.max(fluxtable["mag"]) + 0.3 + min_mag = np.min(fluxtable["mag"]) - 0.3 + lc_ax1.set_ylim((max_mag, min_mag)) + else: + lc_ax1.set_ylim((np.max(mag_range), np.min(mag_range))) + + for fid in BANDPASSES: + # Plot older datapoints + tempTab = fluxtable[fluxtable["band"] == fid] + lc_ax1.errorbar( + tempTab["time"], + tempTab["mag"], + tempTab["magerr"], + color=BANDPASSES[fid]["c"], + fmt=".", + label=BANDPASSES[fid]["label"], + mec="black", + mew=0.5, + ) + + # Plot upper limits + if ztfulims is not None: + tempTab = ztfulims[ztfulims["band"] == fid] + lc_ax1.scatter( + tempTab["time"], + tempTab["diffmaglim"], + c=BANDPASSES[fid]["c"], + marker="v", + s=20.0, + alpha=0.5, + ) + + if legend: + plt.legend() + + # Now we create an infobox + if cutouts: + info = [] + + info.append(name) + info.append(f"RA: {ra:.8f}") + info.append(f"Dec: {dec:.8f}") + info.append("------------------------") + + fig.text(0.77, 0.55, "\n".join(info), va="top", fontsize="medium", alpha=0.5) + + # Add annotations + # Frits + if fritzlink: + lc_ax1.annotate( + "See On Fritz", + xy=(0.5, 1), + xytext=(0.78, 0.10), + xycoords="figure fraction", + verticalalignment="top", + color="royalblue", + url=f"https://fritz.science/source/{name}", + fontsize=12, + bbox=dict(boxstyle="round", fc="cornflowerblue", ec="royalblue", alpha=0.4), + ) + # TNS + if tnsname is not None: + lc_ax1.annotate( + "See On TNS", + xy=(0.5, 1), + xytext=(0.78, 0.05), + xycoords="figure fraction", + verticalalignment="top", + color="royalblue", + url=f"https://www.wis-tns.org/object/{tnsname}", + fontsize=12, + bbox=dict(boxstyle="round", fc="cornflowerblue", ec="royalblue", alpha=0.4), + ) + + # Catalog info through T2CatalogMatch results. Add to info above, or keep some of it? At least check TNS somehow + if len(attributes) > 0: + ypos = 0.975 if cutouts else 0.035 + fig.text( + 0.5, + ypos, + " - ".join(attributes), + va="top", + ha="center", + fontsize="medium", + alpha=0.5, + ) + + if t_0_jd is not None: + lc_ax1.axvline(t_0_jd, linestyle=":") + else: + t_0_jd = np.mean(fluxtable["time"]) + + # Ugly hack because secondary_axis does not work with astropy.time.Time datetime conversion + jd_min = min(np.min(fluxtable["time"]), t_0_jd) + if ztfulims is not None: + jd_min = min(np.min(ztfulims["time"]), jd_min) + jd_max = max(np.max(fluxtable["time"]), t_0_jd) + length = jd_max - jd_min + + lc_ax1.set_xlim((jd_min - (length / 20), jd_max + (length / 20))) + + lc_ax2 = lc_ax1.twiny() + + lc_ax2.scatter( # type: ignore + [Time(x, format="jd").datetime for x in [jd_min, jd_max]], [20, 20], alpha=0 + ) + lc_ax2.tick_params(axis="both", which="major", labelsize=6, rotation=45) + lc_ax1.tick_params(axis="x", which="major", labelsize=6, rotation=45) + lc_ax1.ticklabel_format(axis="x", style="plain") + lc_ax1.tick_params(axis="y", which="major", labelsize=9) + + if z is not None: + lc_ax3.tick_params(axis="both", which="major", labelsize=9) + + axes = [lc_ax1, lc_ax2, lc_ax3] if z is not None else [lc_ax1, lc_ax2] + + return fig, axes + + +def create_stamp_plot(cutouts: dict, ax, cutout_type: str): + """ + Helper function to create cutout subplot. + Cutouts assumed to be a dict of the type returned from + the ZTFCutoutImages complement: + {'candid': {'cutoutScience': b'..', 'cutoutTemplate': ...} } + + Grabbing images of the first candid available. + + cutout_type assumed to be one of Science, Template, Difference + """ + + data = next(iter(cutouts.values()))[f"cutout{cutout_type}"] + + with gzip.open(io.BytesIO(data), "rb") as f: + data = fits.open(io.BytesIO(f.read()), ignore_missing_simple=True)[0].data + vmin, vmax = np.percentile(data[data == data], [0, 100]) # noqa: PLR0124 + data_ = visualization.AsinhStretch()((data - vmin) / (vmax - vmin)) + ax.imshow( + data_, + norm=Normalize(*np.percentile(data_[data_ == data_], [0.5, 99.5])), # noqa: PLR0124 + aspect="auto", + ) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(cutout_type, fontdict={"fontsize": "small"}) + + +# Should be added to ZTFT2Tabulator? +def get_upperlimit_table( + # self, + dps: Iterable[DataPoint] | None, +) -> Table | None: + if not dps: + return None + + def filter_limits(dps: Iterable[DataPoint]) -> list[dict]: + return [ + dp["body"] + for dp in dps + if dp["id"] < 0 and "ZTF" in dp["tag"] and "diffmaglim" in dp["body"] + ] + + ZTF_BANDPASSES = { + 1: {"name": "ztfg"}, + 2: {"name": "ztfr"}, + 3: {"name": "ztfi"}, + } + + dps_subset = filter_limits(dps) + + if len(dps_subset) == 0: + return None + + filter_names = [ZTF_BANDPASSES[dp["fid"]]["name"] for dp in dps_subset] + # fluxlim = np.asarray([10 ** (-((dp["diffmaglim"]) - 25) / 2.5) for dp in dps_subset]) + + tab = Table( + { + "time": [dp["jd"] for dp in dps_subset], + "diffmaglim": [dp["diffmaglim"] for dp in dps_subset], + "band": [ZTF_BANDPASSES[dp["fid"]]["name"] for dp in dps_subset], + "fluxlim": np.asarray( + [10 ** (-((dp["diffmaglim"]) - 25) / 2.5) for dp in dps_subset] + ), + "zp": [25] * len(filter_names), + "zpsys": ["ab"] * len(filter_names), + }, + dtype=("float64", "float64", "str", "float64", "float", "str"), + ) + + return tab # noqa: RET504 + + +class PlotTransientLightcurves(AbsPhotoT3Unit, AbsTabulatedT2Unit): + """ + + Create a (pdf) plot summarizing lightcurves of candidates provided to the unit. + Features: + - Include thumbnails (if provided through the ZTFCutoutImages T3 complement. + - Include link to TNS (if match existing and provided through TNSNames T3 complement. + - Upload to slack channel (if token provided) + + + """ + + # Default path is to create a multi-page pdf + pdf_path: None | str = None # Will create random if not set + titleprefix: str = "AMPEL: " + # Optionally, save a {stock}.png image of each individual event + save_png: bool = False + # Dir for saving png (thumbnails + single event if chosen) + image_cache_dir: str = "." + + # Should ZTF cutouts be retrieved (requires remote archive access) + include_cutouts: bool = False + + # Add Fritz link to plot + fritzlink: bool = True + + # Will post result to Slack channel if a slack channel and a NamedSecret containig the corresponding token is given + slack_channel: str | None = None + slack_token: NamedSecret[str] = NamedSecret(label="slack/ztf_general/jno") + + def post_init(self) -> None: + # Create temporary path if not set + if not self.pdf_path: + import tempfile + + self.pdf_path = tempfile.mkstemp(".pdf", "PlotTransientLightcurves", ".")[1] + + # Possibly create a slack client + if self.slack_channel and self.slack_token is not None: + self.webclient = WebClient(self.slack_token.get()) + + def attributes_from_t2( + self, + tview: TransientView, + nearby_z: float = 0.02, + snia_minprob: float = 0.7, + min_kilonovaness=5, + ) -> tuple[list, Any]: + """ + Collect information from potential T2 documents, + return as list of str. + Partially copied from AstroColibriPublisher. TODO: Join as util function. + Redshift gets a special treatment, since its ... speical + """ + + attributes = [] + z = None + # Nearby attribute + t2res = tview.get_t2_body(unit="T2DigestRedshifts") + if isinstance(t2res, dict) and t2res.get("ampel_z", -10) > 0: + attributes.append( + "AmpelZ{:.3f} N{}".format(t2res["ampel_z"], t2res["group_z_nbr"]) + ) + z = t2res["ampel_z"] + if t2res.get("ampel_z", 999) < nearby_z: + attributes.append("Nearby") + # Infant attribute + t2res = tview.get_t2_body(unit="T2InfantCatalogEval") + if isinstance(t2res, dict) and t2res.get("action", False): + attributes.append("InfantEval") + # SNIa + t2res = tview.get_t2_body(unit="T2RunParsnip") + if ( + isinstance(t2res, dict) + and "classification" in t2res + and t2res["classification"]["SNIa"] > snia_minprob + ): + attributes.append("ProbSNIa") + # Kilonovaness + t2res = tview.get_t2_body(unit="T2KilonovaEval") + if ( + isinstance(t2res, dict) + and t2res.get("kilonovaness", -99) > min_kilonovaness + ): + attributes.append("Kilonovaness{}".format(t2res["kilonovaness"])) + attributes.append("LVKmap{}".format(t2res["map_name"])) + # Linearfit + t2res = tview.get_t2_body(unit="T2LineFit") + if isinstance(t2res, dict) and t2res.get("chi2dof", None) is not None: + attributes.append("LinearChi/dof{:.2}".format(t2res["chi2dof"])) + + t2res = tview.get_t2_body(unit="T2KilonovaStats") + if isinstance(t2res, dict): + attributes.append("PercentHigher{:.5f}".format(t2res["gaus_perc"])) + attributes.append( + "ExpectedCands{:.1f}+{:.1f}-{:.1f}".format( + t2res["exp_kn"], t2res["exp_kn_pls"], t2res["exp_kn_min"] + ) + ) + attributes.append("DistanceRange{}".format(t2res["dist_range"])) + return (attributes, z) + + def process( + self, gen: Generator[TransientView, T3Send, None], t3s: None | T3Store = None + ) -> UBson | UnitResult: + with PdfPages(self.pdf_path) as pdf: + for tran_view in gen: + if not tran_view.get_photopoints(): + self.logger.debug("No photopoints", extra={"stock": tran_view.id}) + continue + sncosmo_table = self.get_flux_table(tran_view.get_photopoints()) # type: ignore + + # Collect information + ampelid = tran_view.id + (ra, dec) = self.get_pos(tran_view.get_photopoints()) # type: ignore + name = " ".join( + map(str, self.get_stock_name(tran_view.get_photopoints())) # type: ignore + ) + + # Upper limits (from ZTF) + # Could immediately subselect to ZTF limits, but keep like this to shift into tabulators + ulim_table = get_upperlimit_table(tran_view.get_upperlimits()) + + # Title + title = f"{self.titleprefix}: {name!r}-{ampelid!r} @ RA {ra:.3f} Dec {dec:.3f}" + + # Gatter attributes from potential T2 documents + (attributes, z) = self.attributes_from_t2(tran_view) + + # Check if ZTF name exists in TNS mirror archive + tnsname = None + if ( + isinstance(tran_view.extra, dict) + and "TNSReports" in tran_view.extra + ): + tnsname = next(iter(tran_view.extra["TNSReports"])).get( + "objname", None + ) + + if ( + self.include_cutouts + and tran_view.extra + and (cutouts := tran_view.extra.get("ZTFCutoutImages", None)) + is not None + ): + # Complement cutouts worked + pass + else: + cutouts = None + + # Create plot + fig, axes = fig_from_fluxtable( + name, + str(ampelid), + ra, + dec, + sncosmo_table, + ulim_table, + title=title, + attributes=attributes, + fritzlink=self.fritzlink, + tnsname=tnsname, + z=z, + cutouts=cutouts, + cutout_cache_dir=self.image_cache_dir, + ) + pdf.savefig() + plt.savefig( + os.path.join(self.image_cache_dir, str(tran_view.id) + ".png") + ) + plt.close() + + # Post to slack + if self.slack_channel is not None and self.slack_token is not None: + with open(self.pdf_path, "rb") as file: # type: ignore + self.webclient.files_upload( + file=file, # type: ignore + filename=self.pdf_path, + channels=self.slack_channel, + # thread_ts=self.ts, + ) + + return None diff --git a/ampel/contrib/hu/t3/SubmitTNS.py b/ampel/contrib/hu/t3/SubmitTNS.py new file mode 100644 index 00000000..3a361db9 --- /dev/null +++ b/ampel/contrib/hu/t3/SubmitTNS.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File: Ampel-contrib-HU/ampel/contrib/hu/t3/SubmitTNS.py +# License: BSD-3-Clause +# Author: jnordin@physik.hu-berlin.de +# Date: 1.03.2024 +# Last Modified Date: 1.03.2024 +# Last Modified By: jnordin@physik.hu-berlin.de + +import asyncio +import time +from collections.abc import Generator +from typing import Any + +from ampel.abstract.AbsPhotoT3Unit import AbsPhotoT3Unit +from ampel.contrib.hu.t3.tns.tns_ampel_util import get_tns_t2remarks, ztfdps_to_tnsdict +from ampel.contrib.hu.t3.tns.TNSClient import TNSClient +from ampel.contrib.hu.t3.tns.TNSToken import TNSToken +from ampel.secret.NamedSecret import NamedSecret +from ampel.struct.T3Store import T3Store +from ampel.struct.UnitResult import UnitResult +from ampel.types import StockId, T3Send, UBson +from ampel.view.TransientView import TransientView + + +class SubmitTNS(AbsPhotoT3Unit): + """ + Submit candidates to TNS (unless already submitted). + + Note that it is assumed that all selected transients are to be submitted. + """ + + # AT report config + base_at_dict: dict = { + "reporting_group_id": "82", # Should be ampel + "discovery_data_source_id": "48", + "reporter": "J. Nordin, V. Brinnel, J. van Santen (HU Berlin), A. Gal-Yam, O. Yaron (Weizmann) on behalf of ZTF", + "at_type": "1", + } + baseremark: str = "See arXiv:1904.05922 for selection criteria." + + # Connect information + tns_key: NamedSecret[dict] + timeout: float = 120.0 + max_parallel_requests: int = 8 + maxdist: float = 2.0 # max squared dist, in arcsec. + tns_doublecheck: bool = True # Also do a TNS name search - is this needed? + tns_submit: bool = False # Also do a TNS name search - is this needed? + + def post_init(self) -> None: + self.client = TNSClient( + TNSToken(**self.tns_key.get()), + self.timeout, + self.max_parallel_requests, + self.logger, + ) + + async def get_tns_names(self, ra, dec): + names = [] + async for doc in self.client.search( + ra=ra, dec=dec, radius=self.maxdist, units="arcsec" + ): + names.extend(doc["internal_names"].split(", ")) + return names + + def sendReports(self, reports: list[dict]) -> dict: + """ + Based on a lists of reportlists, send to TNS. + Return results for journal entries + """ + MAX_LOOP = 25 + SLEEP = 2 + + reportresult: dict = {"inserted": [], "existing": []} + for atreport in reports: + # Submit a report + for _ in range(MAX_LOOP): + reportid = asyncio.run(self.client.sendReport(atreport)) + if reportid: + break + time.sleep(SLEEP) + else: + self.logger.info("TNS Report sending failed") + continue + + # Try to read reply + for _ in range(MAX_LOOP): + time.sleep(SLEEP) + response = asyncio.run(self.client.reportReply(reportid)) + if isinstance(response, list) or ( + isinstance(response, dict) and "at_report" in response + ): + break + else: + self.logger.info("TNS Report reading failed") + continue + + # Check whether request was bad. In this case TNS looks to return a list with dicts + # of failed objects which does not correspond to the order of input atdicts. + # In any case, nothing in the submit is posted. + # Hence only checking first element + if isinstance(response, list): # Assuming response is list iff submit fails + bad_request = {} + for key_atprop in ["ra", "decl", "discovery_datetime"]: + if key_atprop in response[0]: + bad_request[key_atprop] = response[0][key_atprop] + self.logger.info("bad TNS request", extra=bad_request) + continue + + # Parse reply for evaluation + for reportresponses in response.values(): + for reportresponse in reportresponses: + if "100" in reportresponse: + self.logger.info( + "TNS Inserted", + extra={"TNSName": reportresponse["100"]["objname"]}, + ) + reportresult["inserted"].append( + reportresponse["100"]["objname"] + ) + elif "101" in reportresponse: + reportresult["existing"].append( + reportresponse["101"]["objname"] + ) + self.logger.info( + "TNS Existed", + extra={"TNSName": reportresponse["101"]["objname"]}, + ) + + return reportresult + + def process( + self, gen: Generator[TransientView, T3Send, None], t3s: None | T3Store = None + ) -> UBson | UnitResult: + # Reports to be sent, indexed by the transient view IDs (so that we can check in the replies) + atreports: dict[StockId, dict[str, Any]] = {} + + for tran_view in gen: + # Base information + atdict = ztfdps_to_tnsdict(tran_view.get_photopoints()) + if atdict is None: + self.logger.debug("Not enough info for TNS submission") + continue + atdict.update(self.base_at_dict) + + # Check if ZTF name exists in TNS mirror archive + if isinstance(tran_view.extra, dict) and "TNSReports" in tran_view.extra: + intnames = [] + for tnsreport in tran_view.extra["TNSReports"]: + intnames.extend(tnsreport["internal_names"].split(", ")) + if atdict["internal_name"] in intnames: + self.logger.debug( + "already in tns", extra={"id": atdict["internal_name"]} + ) + continue + + # from T2s + catremarks = get_tns_t2remarks(tran_view) + if catremarks is not None: + atdict.update(catremarks) + + # directly check with TNS... unnecessary? + if self.tns_doublecheck: + tnsmatch = asyncio.run( + self.get_tns_names( + ra=atdict["ra"]["value"], dec=atdict["dec"]["value"] + ) + ) + if atdict["internal_name"] in tnsmatch: + continue + + # Collected necessary data, not already published - add to submission list + atreports[tran_view.id] = atdict + + if len(atreports) == 0: + # Nothing to submit + self.logger.info("Nothing to report.") + return None + + # atreports is now a dict with tran_id as keys and atreport as keys + # what we need is a list of dicts with form {'at_report':atreport } + # where an atreport is a dictionary with increasing integer as keys and atreports as values + atreportlist = [ + {"at_report": {i: report for i, report in enumerate(atreports.values())}} + ] + + if not self.tns_submit: + return None + + # Submit the reports and return results for db + return self.sendReports(atreportlist) diff --git a/ampel/contrib/hu/t3/TNSTalker.py b/ampel/contrib/hu/t3/TNSTalker.py deleted file mode 100755 index d45b07c1..00000000 --- a/ampel/contrib/hu/t3/TNSTalker.py +++ /dev/null @@ -1,539 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# File: ampel/contrib/hu/t3/TNSTalker.py -# License: BSD-3-Clause -# Author: jnordin@physik.hu-berlin.de -# Date: 17.11.2018 -# Last Modified Date: 04.09.2019 -# Last Modified By: Jakob van Santen - -import re -from collections.abc import Generator, Iterable -from itertools import islice -from typing import TYPE_CHECKING, Any - -from ampel.abstract.AbsT3ReviewUnit import AbsT3ReviewUnit, T3Send -from ampel.contrib.hu.t3.ampel_tns import ( - TNS_BASE_URL_REAL, - TNS_BASE_URL_SANDBOX, - TNSClient, -) -from ampel.contrib.hu.t3.tns.TNSToken import TNSToken -from ampel.secret.NamedSecret import NamedSecret -from ampel.struct.JournalAttributes import JournalAttributes -from ampel.struct.StockAttributes import StockAttributes -from ampel.struct.T3Store import T3Store -from ampel.types import StockId -from ampel.view.TransientView import TransientView -from ampel.ztf.util.ZTFIdMapper import to_ztf_id - -if TYPE_CHECKING: - from ampel.content.JournalRecord import JournalRecord - - -def chunks(l: Iterable, n: int) -> Generator[list, None, None]: - source = iter(l) - while True: - chunk = list(islice(source, n)) - yield chunk - if len(chunk) < n: - break - - -class TNSTalker(AbsT3ReviewUnit): - """ - Get TNS name if existing, and submit selected candidates. - - All candidates loaded by T3 will be submitted - it is assumed that *selection* is done - by an appropriate T2, which also prepares the submit information. - T2TNSEval is one such example. - - If submit_tns is true, candidates fulfilling the criteria will be sent to the TNS if: - - They are not known to the TNS OR - - They are registered by TNS but under a non-ZTF internal name AND resubmit_tns_nonztf set to True OR - - They are registered by TNS under a ZTF name AND resubmit_tns_ztf is set to True - - if sandbox is set to True it will try to submit candidates to the TNS sandbox, but this API has been unstable - and might not work properly. - """ - - # TNS config - - # Bot api key frm TNS - tns_api_key: NamedSecret[dict] - # Check for TNS for names even if internal name is known - get_tns_force: bool = False - # Submit candidates passing criteria (False gives you a 'dry run') - submit_tns: bool = True - # Submit all candidates we have a note in the Journal that we submitted this. Overrides the resubmit entries!! - submit_unless_journal: bool = False - # Resubmit candidate submitted w/o the same ZTF internal ID - resubmit_tns_nonztf: bool = True - # Resubmit candidates even if they have been added with this name before - resubmit_tns_ztf: bool = False - - # Submit to TNS sandbox only - sandbox: bool = True - # weather journal will go to separate collection. - ext_journal: bool = True - - # AT report config - base_at_dict: dict = { - "reporting_group_id": "82", # Should be ampel - "discovery_data_source_id": "48", - "reporter": "J. Nordin, V. Brinnel, J. van Santen (HU Berlin), A. Gal-Yam, O. Yaron, S. Schulze (Weizmann) on behalf of ZTF", - "at_type": "1", - } - baseremark: str = "See arXiv:1904.05922 for selection criteria." - - slack_token: None | NamedSecret = None - slack_channel = "#ztf_tns" - slack_username = "AMPEL" - # if you have more than this # of reports, send different files - max_slackmsg_size = 200 - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.client = TNSClient( - TNS_BASE_URL_SANDBOX if self.sandbox else TNS_BASE_URL_REAL, - self.logger, - TNSToken(**self.tns_api_key.get()), - ) - # maintain a second client to check the real TNS if in sandbox mode - self.reference_client = ( - TNSClient( - TNS_BASE_URL_REAL, - self.logger, - TNSToken(**self.tns_api_key.get()), - ) - if self.sandbox - else self.client - ) - - def search_journal_tns( - self, tran_view: TransientView - ) -> tuple[None | str, list[str]]: - """ - Look through the journal for a TNS name. - Assumes journal entries came from this unit, that the TNS name is saved as "tnsName" - and internal names as "tnsInternal" - """ - tns_name, tns_internals = None, [] - - def select(entry: "JournalRecord") -> bool: - return bool( - (entry["extra"] is not None and ("tnsInternal" in entry["extra"])) - and entry["unit"] - and entry["unit"] == self.__class__.__name__ - ) - - if jentries := list(tran_view.get_journal_entries(tier=3, filter_func=select)): - if jentries[-1]["extra"] is not None: - tns_name = jentries[-1]["extra"].get("tnsName", None) - tns_internals = [ - entry["extra"].get("tnsInternal", None) - for entry in jentries - if entry["extra"] is not None - ] - - self.logger.info( - "Journal search", - extra={ - "tranId": tran_view.id, - "tnsName": tns_name, - "tnsInternals": tns_internals, - }, - ) - - return tns_name, tns_internals - - def search_journal_submitted(self, tran_view: TransientView) -> bool: - """ - Look through the journal for whether this sender submitted this to TNS. - Assumes journal entries came from this unit, that the TNS name is saved as "tnsName" - and tnsSender stores the api key used ('tnsSender': self.tns_api_key') - """ - - def select(entry: "JournalRecord") -> bool: - return bool( - ( - entry["extra"] is not None - and ( - entry["extra"].get("tnsSender") - == self.tns_api_key.get()["name"] - ) - and "tnsSubmitResult" in entry["extra"] - ) - and entry["unit"] - and entry["unit"] == self.__class__.__name__ - ) - - # Find the latest tns name (skipping previous) - if tran_view.get_journal_entries( - tier=3, - filter_func=select, - ): - self.logger.info( - "TNS submitted", extra={"tnsSender": self.tns_api_key.get()["name"]} - ) - return True - self.logger.info( - "Not TNS submitted", extra={"tnsSender": self.tns_api_key.get()["name"]} - ) - return False - - def _query_tns_names( - self, tran_view: TransientView, ra: float, dec: float - ) -> tuple[None | str, list]: - """ - query the TNS for names and internals at the position - of the transient. - """ - # query the TNS for transient at this position. Note that we check the real TNS for names for compatibility... - - tns_name, tns_internal = self.client.getNames(ra=ra, dec=dec) - - # Skip the AT SN prefix if present - if tns_name is not None: - tns_name = re.sub("^AT", "", tns_name) - tns_name = re.sub("^SN", "", tns_name) - - # be nice and then go - ztf_name = to_ztf_id(tran_view.id) - self.logger.info( - "looking for TNS name in the TNS.", - extra={ - "ZTFname": ztf_name, - "ra": ra, - "dec": dec, - "tnsName": tns_name, - "tnsInternals": [tns_internal], - }, - ) - return tns_name, [tns_internal] - - def _find_tns_tran_names( - self, tran_view: TransientView - ) -> tuple[None | str, list[str]]: - """ - search for TNS name in tran_view.tran_names. If found, - look in the TNS for internal names and return them - """ - - # First, look if we already registered a name - tns_name, tns_internals = None, [] - names: list[str] = ( - [str(name) for name in (tran_view.stock["name"] or [])] - if tran_view.stock - else [] - ) - for tname in names: - if "TNS" in tname and (not self.get_tns_force): - self.logger.info( - "found TNS name in tran_names.", - extra={"TNSname": tname, "TransNames": names}, - ) - # as TNS to give you the internal names. - # we remove the 'TNS' part of the name, since this has been - # added by the TNSMatcher T3, plus we skip the prefix - # We here assume that the AT/SN suffix is cut - tns_name = tname.replace("TNS", "") - # Not using sandbox (only checking wrt to full system). - tns_internals, runstatus = self.reference_client.getInternalName( - tns_name - ) - - # be nice with the logging - ztf_name = to_ztf_id(tran_view.id) - self.logger.info( - "looked for TNS name in self.tran_names", - extra={ - "ZTFname": ztf_name, - "tnsName": tns_name, - "tnsInternals": tns_internals, - "TransNames": names, - }, - ) - - return tns_name, tns_internals - - def find_tns_name( - self, tran_view: TransientView, ra: float, dec: float - ) -> tuple[None | str, list[str], None | JournalAttributes]: - """ - extensive search for TNS names in: - - tran_view.tran_names (if added by TNSMatcher) - - the journal of tran_view (if added by this T3) - - the TNS itself (if no name can be found with the above) - - Returns: - -------- - tns_name, tns_internals, jup: tns_name, tns_internal, and journal update - """ - - ztf_name = to_ztf_id(tran_view.id) - self.logger.info("looking for TNS name", extra={"ZTFname": ztf_name}) - - # first we look in the journal, this is the cheapest option. If we have - # a valid name from the journal and if you do not want to look again in - # the TNS, we are fine. NOTE: in this case you don't return a journal update. - tns_name, tns_internals = self.search_journal_tns(tran_view) - self.logger.debug("Found tns name in journal: %s" % (tns_name)) - if (tns_name is not None) and (not self.get_tns_force): - return tns_name, tns_internals, None - - # second option in case there is no TNS name in the journal: go and look in tran_names - # and if you don't find any, go and ask TNS again. - tns_name_new, tns_internals_new = self._find_tns_tran_names(tran_view) - self.logger.debug( - f"Find tns names added to the ampel name list: {tns_name_new} internal {tns_internals_new}" - ) - if tns_name_new is None: - tns_name_new, tns_internals_new = self._query_tns_names(tran_view, ra, dec) - self.logger.debug( - "Proper check of tns done, found name %s" % (tns_name_new) - ) - - # now, it is possible (if you set self.get_tns_force) that the - # new TNS name is different from the one we had in the journal. We always - # use the most recent one. In this case we also create a JournalUpdate - jup = None - if tns_name_new is not None: - # what happen if you have a new name that is different from the old one? - if tns_name is not None and tns_name != tns_name_new: - self.logger.info( - "Adding new TNS name", - extra={"tnsOld": tns_name, "tnsNew": tns_name_new}, - ) - - # create content of journal entry. Eventually - # update the list with the new internal names if any are found - jcontent = {"tnsName": tns_name_new} - if tns_internals_new is not None: - tns_internals.extend(tns_internals_new) - for tns_int in tns_internals_new: - jcontent.update({"tnsInternal": tns_int}) - - # create a journalUpdate and update the tns_name as well. TODO: check with JNo - jup = JournalAttributes(extra=jcontent) - tns_name = tns_name_new - - elif tns_name is None: - # Set the new name - self.logger.info( - "Adding first TNS name", extra={"tnsNew": tns_name_new} - ) - - # create content of journal entry. Eventually - # update the list with the new internal names if any are found - jcontent = {"tnsName": tns_name_new} - if tns_internals_new is not None: - tns_internals.extend(tns_internals_new) - for tns_int in tns_internals_new: - jcontent.update({"tnsInternal": tns_int}) - - # create a journalUpdate and update the tns_name as well. TODO: check with JNo - jup = JournalAttributes(extra=jcontent) - tns_name = tns_name_new - # tns_internals = tns_internals_new - - # bye! - return tns_name, tns_internals, jup - - def process( - self, gen: Generator[TransientView, T3Send, None], t3s: T3Store - ) -> None: - """ - Loop through transients and check for TNS names and/or candidates to submit - """ - - # Reports to be sent, indexed by the transient view IDs (so that we can check in the replies) - atreports: dict[StockId, dict[str, Any]] = {} - - for tran_view in gen: - ztf_name = to_ztf_id(tran_view.id) - - # Obtain atdict start from T2 result - t2result = tran_view.get_t2_body(unit="T2TNSEval") - if not isinstance(t2result, dict): - raise ValueError( - "Need to have a TNS atdict started from a suitable T2." - ) - # Create the submission dictionary - in case the transient is to be submitted - atdict = dict(t2result["atdict"]) - atdict.update(self.base_at_dict) - atdict["internal_name"] = ztf_name - - ra, dec = atdict["ra"]["value"], atdict["dec"]["value"] - - self.logger.info( - "TNS init dict found", - extra={"tranId": tran_view.id, "ztfName": ztf_name}, - ) - - # Three things we can find out: - # - Did this AMPEL channel submit the transient (according to Journal) - # - Did anyone submit a transient with this ZTF name? - # - Did anyone submit a transient at this position? - - # Simplest case to check. We wish to submit everything not noted as submitted - if self.submit_unless_journal: - if self.search_journal_submitted(tran_view): - # Note already submitted - self.logger.info("ztf submitted", extra={"ztfSubmitted": True}) - else: - # add AT report - self.logger.info("Add TNS report list", extra={"id": tran_view.id}) - atreports[tran_view.id] = atdict - continue - - # find the TNS name, either from the journal, from tran_names, or - # from TNS itself. If new names are found, create a new JournalUpdate - tns_name, tns_internals, jup = self.find_tns_name(tran_view, ra, dec) - self.logger.debug(f"TNS got {tns_name} internals {tns_internals}") - - if tns_name is not None: - # Chech whether this ID has been submitted (note that we do not check - # whether the same candidate was submitted as different ZTF name) and - # depending on what's already on the TNS we can chose to submit or not - is_ztfsubmitted = ztf_name in tns_internals - # Already registered under this name. Only submit if we explicitly configured to do this - if is_ztfsubmitted and not self.resubmit_tns_ztf: - self.logger.info( - "ztf submitted", - extra={ - "ztfSubmitted": is_ztfsubmitted, - "tnsInternals": tns_internals, - }, - ) - continue - - # Also allow for the option to not submit if someone (anyone) already did this. Not sure why this would be a good idea. - if not is_ztfsubmitted and not self.resubmit_tns_nonztf: - self.logger.info( - "already in tns, skipping", - extra={ - "ztfSubmitted": is_ztfsubmitted, - "tnsInternals": tns_internals, - }, - ) - continue - - # Passed all cuts, add to submit list - self.logger.info("Added to report list") - atreports[tran_view.id] = atdict - - self.logger.info("collected %d AT reports to post" % len(atreports)) - - # If we do not want to submit anything, or if there's nothing to submit - if len(atreports) == 0 or (not self.submit_tns): - self.logger.info( - "submit_tns config parameter is False or there's nothing to submit", - extra={ - "n_reports": len(atreports), - "submit_tns": self.submit_tns, - }, - ) - return - - # atreports is now a dict with tran_id as keys and atreport as keys - # what we need is a list of dicts with form {'at_report':atreport } - # where an atreport is a dictionary with increasing integer as keys and atreports as values - atreportlist = [ - { - "at_report": { - i: report - for chunk in chunks(atreports.values(), 1) - for i, report in enumerate(chunk) - } - } - ] - tnsreplies = self.client.sendReports(atreportlist) - - # Now go and check and create journal updates for the cases where SN was added - for tran_id in atreports: - ztf_name = to_ztf_id(tran_id) - if ztf_name not in tnsreplies: - self.logger.info("No TNS add reply", extra={"tranId": tran_id}) - continue - - # Create new journal entry assuming we submitted or found a name - if "TNSName" in tnsreplies[ztf_name][1]: - gen.send( - ( - tran_id, - StockAttributes( - journal=JournalAttributes( - extra={ - "tnsName": tnsreplies[ztf_name][1]["TNSName"], - "tnsInternal": ztf_name, - "tnsSubmitresult": tnsreplies[ztf_name][0], - "tnsSender": self.tns_api_key.get()["name"], - }, - ), - tag="TNS_SUBMITTED", - name=tnsreplies[ztf_name][1]["TNSName"], - ), - ) - ) - - def report_to_slack(self, atreports: dict[StockId, dict[str, Any]]) -> None: - self.logger.info("done running T3") - - if not atreports: - self.logger.info("No atreports collected.") - return - if self.slack_token is None: - return - - # TODO: to help debugging and verification, we post the collected atreports - # to the slack, so that we can compare them with what JNo script is doing - # ALL THE CONTENT OF THIS METHOD SHOULD GO AWAY AS SOON AS WE TRUST THIS T3 - self.logger.warn( - "Posting collected ATreports to Slack. I'm still running as a test!" - ) - - import datetime - import io - import json - - from slack_sdk import WebClient - from slack_sdk.errors import SlackClientError - from slack_sdk.web import SlackResponse - - sc = WebClient(token=self.slack_token.get()) - - tstamp = datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d-%X") - atlist = list(atreports.values()) - last = 0 - for ic, atrep in enumerate(chunks(atlist, self.max_slackmsg_size)): - # add the atreport to a file - self.logger.info("Posting chunk #%d" % ic) - filename = "TNSTalker_DEBUG_%s_chunk%d.json" % (tstamp, ic) - fbuffer = io.StringIO(filename) - json.dump(atrep, fbuffer, indent=2) - - # upload the file with the at reports - first = last - last += len(atrep) - msg = ( - "A total of %d atreports found by TNSTalker T3. Here's chunk #%d (reports from %d to %d)" - % (len(atreports), ic, first, last) - ) - api = sc.files_upload( - channels=[self.slack_channel], - title="TNSTalker_%s_chunk%d" % (tstamp, ic), - initial_comment=msg, - username=self.slack_username, - as_user=False, - filename=filename, - filetype="javascript", - file=fbuffer.getvalue(), - ) - assert isinstance(api, SlackResponse) - if not api["ok"]: - raise SlackClientError(api["error"]) - - self.logger.warn( - f"DONE DEBUG Slack posting. Look at {self.slack_channel} for the results" - ) diff --git a/ampel/contrib/hu/t3/TransientTablePublisher.py b/ampel/contrib/hu/t3/TransientTablePublisher.py index 24220238..6c33af47 100644 --- a/ampel/contrib/hu/t3/TransientTablePublisher.py +++ b/ampel/contrib/hu/t3/TransientTablePublisher.py @@ -4,8 +4,8 @@ # License: BSD-3-Clause # Author: jnordin@physik.hu-berlin.de # Date: 06.05.2021 -# Last Modified Date: 16.01.2024 -# Last Modified By: ernstand@physik.hu-berlin.de +# Last Modified Date: 15.12.2023 +# Last Modified By: alice.townsend@physik.hu-berlin.de import io import os @@ -20,9 +20,11 @@ from ampel.abstract.AbsPhotoT3Unit import AbsPhotoT3Unit from ampel.secret.NamedSecret import NamedSecret from ampel.struct.T3Store import T3Store -from ampel.types import T3Send +from ampel.struct.UnitResult import UnitResult +from ampel.types import T3Send, UBson from ampel.util.mappings import get_by_path from ampel.view.TransientView import TransientView +from ampel.ztf.util.ZTFIdMapper import ZTFIdMapper class TransientTablePublisher(AbsPhotoT3Unit): @@ -36,10 +38,6 @@ class TransientTablePublisher(AbsPhotoT3Unit): include_stock (bool) include_channels (bool) - If one wants to convert the AMPEL stock ID to external IDs, define - convert_stock_to (str|None) - For ZTF-IDs, pass 'convert_stock_to: ztf' - How to deal with names. Will search each transients names for entries containing "value", and return any output under "key" name_filter = { 'ZTF name' : 'ZTF', 'TNS ID' : 'TNS' } @@ -69,43 +67,33 @@ class TransientTablePublisher(AbsPhotoT3Unit): Todo: - save to desy webb? - - include format option for printing + - include format option for prointing """ # Two tables describing what information to save into the table. # Schema for state dependent T2s (one row for each) - table_schema: dict[str, Any] = {} + table_schema: dict[str, Any] # Schema for transient dependent T2s (added to each row together with base info) transient_table_schema: dict[str, Any] name_filter: dict[str, str] = {"ZTF name": "ZTF", "TNS ID": "TNS"} include_stock: bool = False - convert_stock_to: str | None = None - - sort_by_key: str | None = "kilonovaness" - sort_ascending: bool = False - include_pos: bool = True include_channels: bool = True # Add also transients lacking any T2 info save_base_info: bool = False fmt: str = "csv" - write_mode: str = "a" - rename_files: bool = False - dir_name: str = "TransientTable" - file_name: str = dir_name + file_name: str = "TransientTable.csv" slack_channel: None | str = None slack_token: None | NamedSecret[str] local_path: None | str = None - move_files: bool = False - def process( - self, gen: Generator[TransientView, T3Send, None], t3s: None | T3Store = None - ) -> None: + self, gen: Generator[TransientView, T3Send, None], t3s: T3Store | None = None + ) -> UBson | UnitResult: # def process(self, gen: Generator[SnapView, T3Send, None], t3s: T3Store) -> None: """ Loop through provided TransientViews and extract data according to the @@ -113,7 +101,7 @@ def process( """ table_rows: list[dict[str, Any]] = [] - for tran_view in gen: + for k, tran_view in enumerate(gen, 1): # noqa: B007 basetdict: dict[str, Any] = {} # Assemble t2 information bound to the transient (e.g. Point T2s) for t2unit, table_entries in self.transient_table_schema.items(): @@ -165,16 +153,12 @@ def process( if self.include_stock: basetdict["stock"] = tran_view.id - - if self.convert_stock_to is not None: - assert self.convert_stock_to in ["ztf"] - - if self.convert_stock_to == "ztf": - from ampel.ztf.util.ZTFIdMapper import ZTFIdMapper - - stock_id = tran_view.id - ztf_id = ZTFIdMapper.to_ext_id(stock_id) - basetdict["ztf_id"] = ztf_id + # Try to convert to external id + # Note: for multiple source classes beyond ZTF, could use a list of tabulator type entries? + try: + basetdict["ztfname"] = ZTFIdMapper.to_ext_id(tran_view.id) + except ValueError: + self.logger.info("Coult not convert stock") if self.include_pos: lcurve = tran_view.get_lightcurves() @@ -203,53 +187,19 @@ def process( self.logger.info("", extra={"table_count": len(table_rows)}) if len(table_rows) == 0: - return + return None # Export assembled information # Convert df = pd.DataFrame.from_dict(table_rows) - # if "map_name" in df.columns and "map_seed" in df.columns: - # df["map_name"] = np.char.replace(np.array(df["map_name"], dtype=str), "random", "random"+df["map_seed"]) - - # print(df["map_name"].iloc[0]) - if "map_seed" in df or self.rename_files: - # print("transienttablepublisher:: ", df["map_seed"].iloc[0]) - tmp_seed_name = df["map_seed"].iloc[0] - if isinstance(tmp_seed_name, str): - self.file_name += "_" + tmp_seed_name - else: - self.file_name += "_" + str(int(tmp_seed_name)) - - # sort dataframe by key - if self.sort_by_key in df.keys(): # noqa: SIM118 - df = df.sort_values(by=self.sort_by_key, ascending=self.sort_ascending) - else: - self.logger.warn( - f"Cannot sort table by {self.sort_by_key} - legal keys: {df.keys()}" - ) - # Local save if self.local_path is not None: - path_name = os.path.join(self.local_path, self.dir_name) - # print("PATHNAME::", path_name) - if not os.path.exists(path_name): - os.makedirs(path_name, exist_ok=True) - full_path = os.path.join(path_name, self.file_name) - # print("FILE PATH::", full_path) - - with open(full_path + "." + self.fmt, "w") as tmp_file: - tmp_file.close() + full_path = os.path.join(self.local_path, self.file_name) if self.fmt == "csv": - # print(self.write_mode) - df.to_csv(full_path + ".csv", sep=";", mode=self.write_mode) + df.to_csv(full_path) elif self.fmt == "latex": - df.to_latex(full_path + ".tex") - elif self.fmt == "json": - json_dumps = df.to_json(indent=2) - with open(full_path + ".json", self.write_mode) as json_file: - json_file.write(json_dumps) - json_file.close() + df.to_latex(full_path) self.logger.info("Exported", extra={"path": full_path}) # Export to slack if requested @@ -257,37 +207,7 @@ def process( # Could potentially return a document to T3 collection detailing # what was done, as well as the table itself. - - # take everything local_path and put it into new folder named after skymap - # print(df.keys) - map_name_key = "map_name" - if map_name_key in df and self.move_files: - files_local_path = os.listdir(self.local_path) - skymap_name = df[map_name_key][ - 0 - ] # need to change if for some reason several maps get saved in same file - skymap_dir_name = skymap_name # [: skymap_name.find(".")] # bare name - if skymap_name[-1] != "z": # if non trivial rev version (hacky) - skymap_dir_name += ( - "_rev_" + skymap_name[skymap_name.find(",") + 1 :] - ) # find "," and add rev version after that - - print("TransientTablePublisher: TMP FILES MOVED TO " + skymap_dir_name) - - if self.local_path is not None: - skymap_directory = os.path.join( - self.local_path + "/../" + skymap_dir_name - ) - # print(skymap_directory) - os.makedirs(skymap_directory, exist_ok=True) - for file in files_local_path: - if file.find(".fits.gz") == -1: - tmp_file_path = os.path.join(self.local_path, file) - if not (os.path.isfile(tmp_file_path)): - continue - os.replace(tmp_file_path, os.path.join(skymap_directory, file)) - - return + return None @backoff.on_exception( backoff.expo, @@ -312,7 +232,7 @@ def _slack_export(self, df): # Slack summary buffer = io.StringIO(self.file_name) if self.fmt == "csv": - df.to_csv(buffer, sep=";") + df.to_csv(buffer) elif self.fmt == "latex": df.to_latex(buffer) diff --git a/ampel/contrib/hu/t3/ampel_tns.py b/ampel/contrib/hu/t3/ampel_tns.py deleted file mode 100755 index 6e7e36af..00000000 --- a/ampel/contrib/hu/t3/ampel_tns.py +++ /dev/null @@ -1,352 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# File: Ampel-contrib-HU/ampel/contrib/hu/t3/ampel_tns.py -# License: BSD-3-Clause -# Author: Ken Smith -# Date: May 2016 -# Last Modified Date: Feb 2018 -# Last Modified By: Jakob Nordin - -# ----------------------------------------------------------------------------- -# A python sample code for sending a bulk report to the TNS. -# Original sample code by Ken Smith (May 2016) modified by Jakob Nordin (Feb 2018) -# ----------------------------------------------------------------------------- - -import json -import re -import time -from typing import Any - -from requests.models import Response -from requests_toolbelt.sessions import BaseUrlSession - -from ampel.contrib.hu.t3.tns.TNSToken import TNSToken -from ampel.protocol.LoggerProtocol import LoggerProtocol - -TNSFILTERID = {1: "110", 2: "111", 3: "112"} -AT_REPORT_FORM = "bulk-report" -AT_REPORT_REPLY = "bulk-report-reply" -TNS_ARCHIVE = {"OTHER": "0", "SDSS": "1", "DSS": "2"} -TNS_BASE_URL_SANDBOX = "https://sandbox.wis-tns.org/api/" -TNS_BASE_URL_REAL = "https://www.wis-tns.org/api/" - -httpErrors = { - 304: "Error 304: Not Modified: There was no new data to return.", - 400: "Error 400: Bad Request: The request was invalid. An accompanying error message will explain why.", - 403: "Error 403: Forbidden: The request is understood, but it has been refused. An accompanying error message will explain why", - 404: "Error 404: Not Found: The URI requested is invalid or the resource requested, such as a category, does not exists.", - 500: "Error 500: Internal Server Error: Something is broken.", - 503: "Error 503: Service Unavailable.", - 429: "Error 429: Rate Limit Exceeded.", -} - - -class TNSSession(BaseUrlSession): - def __init__(self, token: TNSToken, baseURL: str = TNS_BASE_URL_REAL) -> None: - self.token = token - super().__init__(baseURL) - self.headers["User-Agent"] = "tns_marker" + json.dumps( - {"tns_id": self.token.id, "name": self.token.name, "type": "bot"} - ) - - def post( - self, method: str, payload: str | dict[str, Any], payload_key="data", **kwargs - ) -> Response: - for _ in range(10): - if ( - response := super().post( - method, - files=[ - ("api_key", (None, self.token.api_key)), - (payload_key, (None, json.dumps(payload))), - ], - **kwargs, - ) - ).status_code != 429: - return response - # back off according to rate-limit headers (see https://www.wis-tns.org/content/tns-newsfeed#comment-wrapper-26286) - delay = response.headers[ - "x-cone-rate-limit-reset" - if response.url.endswith("search") - else "x-rate-limit-reset" - ] - time.sleep(int(delay)) - response.raise_for_status() - # unreachable - return None # type: ignore[return-value] - - -class TNSClient: - """Send Bulk TNS Request.""" - - def __init__(self, baseURL, logger: LoggerProtocol, token: TNSToken): - """ - :param baseURL: Base URL of the TNS API - :param options: (Default value = {}) - """ - - self.logger = logger - self.session = TNSSession(token, baseURL) - - def jsonResponse(self, r: Response) -> dict: - """ - Send JSON response given requests object. Should be a python dict. - - :param r: requests object - the response we got back from the server - :return d: json response converted to python dict - """ - - d: dict[str, Any] = {} - # What response did we get? - message = None - status = r.status_code - - if status != 200: - message = httpErrors.get(status, f"Error {status}: Undocumented error") - - if message is not None: - self.logger.warn("TNS bulk submit: " + message) - return d - - # Did we get a JSON object? - try: - d = r.json() - except ValueError as e: - self.logger.error("TNS bulk submit", exc_info=e) - return {} - - # If so, what error messages if any did we get? - self.logger.info(json.dumps(d, indent=4, sort_keys=True)) - - if "id_code" in d and "id_message" in d and d["id_code"] != 200: - self.logger.info( - "TNS bulk submit: Bad response: code = %d, error = '%s'" - % (d["id_code"], d["id_message"]) - ) - return d - - def sendBulkReport(self, options) -> dict: - """ - Send the JSON TNS request - :param options: the JSON TNS request - """ - # The requests.post method needs to receive a "files" entry, not "data". And the "files" - # entry needs to be a dictionary of tuples. The first value of the tuple is None. - self.logger.info("TNS bulk submit: " + "sending request") - r = self.session.post(AT_REPORT_FORM, options, timeout=300) - # Construct the JSON response and return it. - self.logger.info("TNS bulk submit: " + "got response (or timed out)") - return self.jsonResponse(r) - - def addBulkReport(self, report): - """ - Send the report to the TNS - - :return reportId: TNS report ID - - """ - logger = self.logger - reply = self.sendBulkReport(report) - - reportId = None - - if reply: - try: - reportId = reply["data"]["report_id"] - logger.info("TNS bulk submit: successful with ID %s" % (reportId)) - except ValueError: - logger.error("Empty response. Something went wrong. Is the API Key OK?") - except KeyError: - logger.error("Cannot find the data key. Something is wrong.") - - return reportId - - def sendReports(self, reports: list[dict]): - """ - Based on a lists of reportlists, send to TNS. - Return results for journal entries - """ - - # Submit to TNS - MAX_LOOP = 25 - SLEEP = 2 - - logger = self.logger - reportresult = {} - for atreport in reports: - # Submit a report - for _ in range(MAX_LOOP): - reportid = self.addBulkReport(atreport) - if reportid: - logger.info("TNS report ID %s" % (reportid)) - break - time.sleep(SLEEP) - else: - logger.info("TNS bulk report failed") - continue - - # Try to read reply - for _ in range(MAX_LOOP): - time.sleep(SLEEP) - response = self.getBulkReportReply(reportid) - if isinstance(response, list): - break - else: - logger.info("TNS Report reading failed") - continue - - # Check whether request was bad. In this case TNS looks to return a list with dicts - # of failed objects which does not correspond to the order of input atdicts. - # In any case, nothing in the submit is posted. - # Hence only checking first element - bad_request = None - for key_atprop in ["ra", "decl", "discovery_datetime"]: - if key_atprop in response[0]: - try: - bad_request = response[0][key_atprop]["value"]["5"]["message"] - break - except KeyError: - pass - if bad_request is not None: - logger.info(bad_request) - continue - - # Parse reply for evaluation - for k, v in atreport["at_report"].items(): - if "100" in response[k]: - logger.info( - "TNS Inserted with name %s" % (response[k]["100"]["objname"]) - ) - reportresult[v["internal_name"]] = [ - "TNS inserted", - {"TNSName": response[k]["100"]["objname"]}, - ] - elif "101" in response[k]: - logger.info( - "Already existing with name %s" - % (response[k]["101"]["objname"]) - ) - reportresult[v["internal_name"]] = [ - "TNS pre-existing", - {"TNSName": response[k]["101"]["objname"]}, - ] - - return reportresult - - def _bulkReportReply(self, report_id: str) -> dict[str, Any]: - """ - Get the report back from the TNS - - :param options: dict containing the report ID - :return: dict - - """ - self.logger.info("TNS bulk submit: " + "looking for reply report") - # every TNS endpoint wraps its arguments in `data`, except bulk-report-reply - r = self.session.post( - AT_REPORT_REPLY, report_id, payload_key="report_id", timeout=300 - ) - self.logger.info("TNS bulk submit: " + "got report (or timed out)") - return self.jsonResponse(r) - - def getBulkReportReply(self, reportId): - """ - Get the TNS response for the specified report ID - :param tnsApiKey: TNS API Key - :return request: The original request - :return response: The TNS response - """ - - logger = self.logger - reply = self._bulkReportReply(reportId) - - response = None - # reply should be a dict - if reply and "id_code" in reply and reply["id_code"] == 404: - logger.warn( - f"TNS bulk submit {reportId}: Unknown report. " - f"Perhaps the report has not yet been processed." - ) - - if reply and "id_code" in reply and reply["id_code"] == 200: - try: - response = reply["data"]["feedback"]["at_report"] - except KeyError: - logger.error( - "TNS bulk submit: cannot find the response feedback payload." - ) - - # This is a bad request. Still propagate the response for analysis. - if reply and "id_code" in reply and reply["id_code"] == 400: - try: - response = reply["data"]["feedback"]["at_report"] - except KeyError: - logger.error( - "TNS bulk submit: cannot find the response feedback payload." - ) - - logger.info(f"TNS bulk submit: got response {response}") - - return response - - def getInternalName(self, tns_name: str): - """ - formerly tnsInternal - """ - - response = self.session.post( - "get/object", {"objname": tns_name, "photometry": 0, "spectra": 0} - ) - parsed = self.jsonResponse(response) - - if parsed["data"]["reply"]["internal_names"] is None: - return [], "No internal TNS name" - - return parsed["data"]["reply"]["internal_names"], "Got internal name response" - - def search( - self, ra: float, dec: float, matchradius: float = 5.0 - ) -> tuple[list[str], str]: - """ - formerly tnsName - """ - r = self.session.post( - "get/search", - {"ra": ra, "dec": dec, "radius": matchradius, "units": "arcsec"}, - ) - parsed = self.jsonResponse(r) - - try: - tnsnames = [v["prefix"] + v["objname"] for v in parsed["data"]["reply"]] - except KeyError: - return [], "Error: No TNS names in response" - - return tnsnames, "Found TNS name(s)" - - def getNames(self, ra: float, dec: float, matchradius: float = 5.0): - """ - Get names of the first TNS object at location - - formerly get_tnsname - """ - logger = self.logger - # Look for TNS name at the coordinate of the transient - tnsnames, runstatus = self.search(ra, dec, matchradius) - if re.match("Error", runstatus): - logger.info("TNS get error", extra={"tns_request": runstatus}) - return None, [] - if len(tnsnames) >= 1: - tns_name = tnsnames[0] - if len(tnsnames) > 1: - logger.debug( - "Multipe TNS names, choosing first", extra={"tns_names": tnsnames} - ) - else: - # No TNS name, then no need to look for internals - return None, None - logger.info("TNS get cand id", extra={"tns_name": tns_name}) - - # Look for internal name (note that we skip the prefix) - internal_names, *_ = self.getInternalName(tns_name[2:]) - - return tns_name, internal_names diff --git a/ampel/contrib/hu/t3/tns/TNSClient.py b/ampel/contrib/hu/t3/tns/TNSClient.py index 049c7d23..7c2cee4c 100755 --- a/ampel/contrib/hu/t3/tns/TNSClient.py +++ b/ampel/contrib/hu/t3/tns/TNSClient.py @@ -31,7 +31,8 @@ async def tns_post( semaphore: asyncio.Semaphore, method: str, token: TNSToken, - data: dict, + data: dict | int, + payload_label: str = "data", max_retries: int = 10, ) -> dict: """ @@ -43,8 +44,11 @@ async def tns_post( p: aiohttp.Payload = aiohttp.StringPayload(token.api_key) p.set_content_disposition("form-data", name="api_key") mpwriter.append(p) - p = aiohttp.JsonPayload(data) - p.set_content_disposition("form-data", name="data") + if isinstance(data, dict): + p = aiohttp.JsonPayload(data) + else: + p = aiohttp.StringPayload(str(data)) + p.set_content_disposition("form-data", name=payload_label) mpwriter.append(p) resp = await session.post( "https://www.wis-tns.org/api/" + method, data=mpwriter @@ -61,7 +65,8 @@ async def tns_post( ) await asyncio.sleep(wait) continue - break + else: # noqa: RET507 + break resp.raise_for_status() return await resp.json() @@ -112,10 +117,10 @@ def _on_giveup(self, details): @staticmethod def is_permanent_error(exc): if isinstance(exc, ClientResponseError): - return exc.code not in {500, 429} + return exc.code not in {500, 429, 404} return False - async def search(self, *, exclude: None | set[str] = None, **params): + async def search(self, exclude=set(), **params): # noqa: B006 semaphore = asyncio.Semaphore(self.maxParallelRequests) async with ClientSession( headers={ @@ -147,3 +152,43 @@ async def get(self, session, semaphore, objname): return await self.tns_post( session, semaphore, "get/object", self.token, {"objname": objname} ) + + async def sendReport(self, report): + semaphore = asyncio.Semaphore(self.maxParallelRequests) + async with ClientSession( + headers={ + "User-Agent": "tns_marker" + + json.dumps( + {"tns_id": self.token.id, "name": self.token.name, "type": "bot"} + ) + }, + ) as session: + postreport = partial( + self.tns_post, session, semaphore, "bulk-report", self.token + ) + response = await postreport(report) + if response["id_code"] == 200: + return response["data"]["report_id"] + return False + + async def reportReply(self, report_id): + semaphore = asyncio.Semaphore(self.maxParallelRequests) + async with ClientSession( + headers={ + "User-Agent": "tns_marker" + + json.dumps( + {"tns_id": self.token.id, "name": self.token.name, "type": "bot"} + ) + }, + ) as session: + postreport = partial( + self.tns_post, + session, + semaphore, + "bulk-report-reply", + self.token, + payload_label="report_id", + ) + response = await postreport(report_id) + + return response["data"]["feedback"] diff --git a/ampel/contrib/hu/t3/tns/__init__.py b/ampel/contrib/hu/t3/tns/__init__.py index e69de29b..8414dcb1 100644 --- a/ampel/contrib/hu/t3/tns/__init__.py +++ b/ampel/contrib/hu/t3/tns/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .TNSClient import TNSClient +from .TNSMirrorDB import TNSMirrorDB +from .TNSName import TNSName diff --git a/ampel/contrib/hu/t3/tns/tns_ampel_util.py b/ampel/contrib/hu/t3/tns/tns_ampel_util.py new file mode 100644 index 00000000..09584ba5 --- /dev/null +++ b/ampel/contrib/hu/t3/tns/tns_ampel_util.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File: Ampel-contrib-HU/ampel/contrib/hu/t3/tns/tns_ampel_util.py +# License: BSD-3-Clause +# Author: jnordin@physik.hu-berlin.de +# Date: 1.03.2024 +# Last Modified Date: 1.03.2024 +# Last Modified By: jnordin@physik.hu-berlin.de + +# Methods for converting AMPEL content to TNS compatible data. + +from collections.abc import Sequence +from typing import Any + +import numpy as np + +from ampel.content.DataPoint import DataPoint +from ampel.view.TransientView import TransientView +from ampel.ztf.util.ZTFIdMapper import ZTFIdMapper + +TNSFILTERID = {1: "110", 2: "111", 3: "112"} +AT_REPORT_FORM = "bulk-report" +AT_REPORT_REPLY = "bulk-report-reply" +TNS_ARCHIVE = {"OTHER": "0", "SDSS": "1", "DSS": "2"} + +# Default atdict settings for ZTF +ZTF_TNS_AT: dict = { # Default values to tag ZTF detections / ulims + "flux_units": "1", + "instrument_value": "196", + "exptime": "30", + "Observer": "Robot", +} + + +def ztfdps_to_tnsdict( + dps: Sequence[DataPoint] | None, + max_maglim: float = 19.5, +) -> None | dict[str, Any]: + """ + Collect ZTF data needed for the atreport. Return None in case + you have to skip this transient for some reason. + """ + + if not dps: + return None + + zdps = [dp for dp in dps if "ZTF" in dp["tag"] and "magpsf" in dp["body"]] + zuls = [ + dp + for dp in dps + if "ZTF" in dp["tag"] + and "magpsf" not in dp["body"] + and dp["body"]["diffmaglim"] >= max_maglim + ] + + if len(zdps) == 0: + return None + + ra = np.mean([dp["body"]["ra"] for dp in zdps]) + dec = np.mean([dp["body"]["dec"] for dp in zdps]) + + names: list[str] = [] + for dp in zdps: + if isinstance(dp["stock"], int): + names.append(ZTFIdMapper.to_ext_id(dp["stock"])) + elif isinstance(dp["stock"], Sequence): + names.extend([ZTFIdMapper.to_ext_id(stock) for stock in dp["stock"]]) + ztfnames = set(names) + + # Start defining AT dict: name and position + atdict: dict[str, Any] = {} + atdict["ra"] = {"value": ra, "error": 1.0, "units": "arcsec"} + atdict["dec"] = {"value": dec, "error": 1.0, "units": "arcsec"} + atdict["internal_name"] = next(iter(ztfnames)) + + # Add information on the latest SIGNIFICANT non detection. + last_non_obs = 0 + if len(zuls) > 0: + last_ulim = sorted(zuls, key=lambda x: x["body"]["jd"])[-1] + last_non_obs = last_ulim["body"]["jd"] + atdict["non_detection"] = { + "obsdate": last_ulim["body"]["jd"], + "limiting_flux": last_ulim["body"]["diffmaglim"], + "filter_value": TNSFILTERID.get(last_ulim["body"]["fid"]), + } + else: + atdict["non_detection"] = { + "archiveid": "0", + "archival_remarks": "ZTF non-detection limits not available", + } + + atdict["non_detection"].update(ZTF_TNS_AT) # Add the default ZTF values + + # now add info on photometric detections: consider only candidates which + # have some consecutive detection after the last ulim + atdict["photometry"] = {"photometry_group": {}} + atdict["discovery_datetime"] = 10**30 + for dp in zdps: + if dp["body"]["jd"] < last_non_obs: + continue + + # Lets create a few photometry points + # Note: previously had a cap on the number of dps that could be included. *should* be unnecessary. + photdict = { + "obsdate": dp["body"]["jd"], + "flux": float("{0:.2f}".format(dp["body"]["magpsf"])), # noqa: UP030 + "flux_error": float("{0:.2f}".format(dp["body"]["sigmapsf"])), # noqa: UP030 + "limiting_flux": float("{0:.2f}".format(dp["body"]["diffmaglim"])), # noqa: UP030 + "filter_value": TNSFILTERID.get(dp["body"]["fid"]), + } + if dp["body"]["jd"] < atdict["discovery_datetime"]: + atdict["discovery_datetime"] = dp["body"]["jd"] + photdict.update(ZTF_TNS_AT) + atdict["photometry"]["photometry_group"][ + len(atdict["photometry"]["photometry_group"]) + ] = photdict + + return atdict + + +def get_tns_t2remarks(tview: TransientView) -> None | dict[str, Any]: + """ + Inspect t2results, and extract TNS remarks when warranted. + """ + # Tag things close to SDSS nuclei + nuclear_dist = 1.0 + + # Start building dict with remarks + remarks: dict[str, Any] = {"remarks": ""} + + # Ampel Z + t2res = tview.get_t2_body(unit="T2DigestRedshifts") + if isinstance(t2res, dict) and t2res.get("ampel_z", -10) > 0: + remarks["remarks"] = remarks["remarks"] + "AmpelZ{:.3f} (N{}) ".format( + t2res["ampel_z"], t2res["group_z_nbr"] + ) + + # T2CatalogMatch + cat_res = tview.get_t2_body(unit="T2CatalogMatch") + if isinstance(cat_res, dict): + # Check redshift + nedz = cat_res.get("NEDz", False) + sdss_spec = cat_res.get("SDSS_spec", False) + if sdss_spec: + remarks["remarks"] = ( + remarks["remarks"] + "SDSS spec-z %.3f. " % (sdss_spec["z"]) + ) + elif nedz: + remarks["remarks"] = remarks["remarks"] + "NED z %.3f. " % (nedz["z"]) + + # tag AGNs + milliquas = cat_res.get("milliquas", False) + if ( + milliquas + and milliquas["redshift"] is not None + and milliquas["redshift"] > 0 + ) or (sdss_spec and sdss_spec["bptclass"] in [4, 5]): + remarks["remarks"] = ( + remarks["remarks"] + "Known SDSS and/or MILLIQUAS QSO/AGN. " + ) + remarks["at_type"] = 3 + + # tag nuclear + sdss_dr10 = cat_res.get("SDSSDR10", False) + if ( + sdss_dr10 + and sdss_dr10["type"] == 3 + and sdss_dr10["dist2transient"] < nuclear_dist + ): + remarks["remarks"] = ( + remarks["remarks"] + "Close to core of SDSS DR10 galaxy. " + ) + remarks["at_type"] = 4 + + # Note: removed the tag of noisy gaia data (check T2TNSEval) + + if len(remarks["remarks"]) == 0: + return None + + return remarks diff --git a/ampel/contrib/hu/test/test_tnstalker.py b/ampel/contrib/hu/test/test_tnstalker.py deleted file mode 100644 index d29acb05..00000000 --- a/ampel/contrib/hu/test/test_tnstalker.py +++ /dev/null @@ -1,70 +0,0 @@ -from os import environ - -import pytest - -from ampel.contrib.hu.t3.tns.TNSToken import TNSToken -from ampel.contrib.hu.t3.TNSTalker import TNS_BASE_URL_SANDBOX, TNSClient -from ampel.log.AmpelLogger import AmpelLogger - - -@pytest.fixture() -def tns_token(): - if not (api_key := environ.get("TNS_API_KEY")): - raise pytest.skip("Test requires env var TNS_API_KEY") - return TNSToken( - id=59228, - name="ZTF_AMPEL_COMPLETE", - api_key=api_key, - ) - - -@pytest.fixture() -def test_client(tns_token): - return TNSClient(TNS_BASE_URL_SANDBOX, AmpelLogger.get_logger(), tns_token) - - -def test_tnsclient(test_client): - assert test_client.getInternalName("2018cow") == ( - "ATLAS18qqn, ZTF18abcfcoo, Gaia18bqa", - "Got internal name response", - ) - assert test_client.search(244.000917, 22.268031) == ( - ["SN2018cow"], - "Found TNS name(s)", - ) - assert test_client.getNames(244.000917, 22.268031) == ( - "SN2018cow", - "ATLAS18qqn, ZTF18abcfcoo, Gaia18bqa", - ) - - -def test_tnsclient_backoff(test_client: TNSClient): - import logging - - logging.basicConfig() - for _ in range(12): - assert test_client.search(244.000917, 22.268031) == ( - ["SN2018cow"], - "Found TNS name(s)", - ) - - -@pytest.mark.asyncio() -async def test_tnsclient_backoff_async(tns_token): - from ampel.contrib.hu.t3.tns.TNSClient import TNSClient as TNSMirrorClient - - client = TNSMirrorClient( - tns_token, timeout=120, maxParallelRequests=1, logger=AmpelLogger.get_logger() - ) - ra, dec, matchradius = 244.000917, 22.268031, 5.0 - for _ in range(12): - hits = [ - hit - async for hit in client.search( - ra=ra, - dec=dec, - radius=matchradius, - units="arcsec", - ) - ] - assert hits diff --git a/conf/ampel-hu-astro/process/TNSSubmitComplete.yml b/conf/ampel-hu-astro/process/TNSSubmitComplete.yml deleted file mode 100644 index 1fcda585..00000000 --- a/conf/ampel-hu-astro/process/TNSSubmitComplete.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: TNSSubmitComplete -tier: 3 -template: ztf_periodic_summary -schedule: every(60).minutes -channel: - any_of: - - HU_GP_10 - - HU_GP_59 -load: - - TRANSIENT - - col: t2 - query_complement: {unit: {$in: [T2TNSEval, T2LightCurveSummary]}} -complement: TNSNames -filter: - t2: - unit: T2TNSEval - match: - tns_candidate: true -run: - unit: TNSTalker - config: - tns_api_key: - label: tns/complete - submit_tns: true - sandbox: false diff --git a/conf/ampel-hu-astro/process/TNSSubmitNew.yml b/conf/ampel-hu-astro/process/TNSSubmitNew.yml deleted file mode 100644 index 12fba2fd..00000000 --- a/conf/ampel-hu-astro/process/TNSSubmitNew.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: TNSSubmitNew -tier: 3 -template: ztf_periodic_summary -schedule: every(60).minutes -channel: - any_of: - - HU_TNS_MSIP -load: - - TRANSIENT - - col: t2 - query_complement: {unit: {$in: [T2TNSEval, T2LightCurveSummary]}} -complement: TNSNames -filter: - t2: - all_of: - - unit: T2TNSEval - match: - tns_candidate: true - - unit: T2BrightSNProb - match: - SNGuess: - $gt: 0.5 - ndet: - $lt: 7 -run: - unit: TNSTalker - config: - tns_api_key: - label: tns/new - submit_tns: true - sandbox: false - diff --git a/conf/ampel-hu-astro/unit.yml b/conf/ampel-hu-astro/unit.yml index 41c7f307..9db2b784 100644 --- a/conf/ampel-hu-astro/unit.yml +++ b/conf/ampel-hu-astro/unit.yml @@ -6,6 +6,7 @@ - ampel.contrib.hu.t0.RandFilter - ampel.contrib.hu.t0.RcfFilter - ampel.contrib.hu.t0.RedshiftCatalogFilter +- ampel.contrib.hu.t0.StellarFilter - ampel.contrib.hu.t2.T2PanStarrThumbPrint - ampel.contrib.hu.t2.T2PhaseLimit - ampel.contrib.hu.t2.T2PS1ThumbExtCat @@ -41,6 +42,9 @@ - ampel.contrib.hu.t2.T2KilonovaEval - ampel.contrib.hu.t2.T2KilonovaStats - ampel.contrib.hu.t2.T2MatchGRB +- ampel.contrib.hu.t2.T2BaseLightcurveFitter +- ampel.contrib.hu.t2.T2DemoLightcurveFitter +- ampel.contrib.hu.t2.T2PolynomialFit - ampel.contrib.hu.t3.TransientInfoPrinter - ampel.contrib.hu.t3.TransientViewDumper - ampel.contrib.hu.t3.ChannelSummaryPublisher @@ -48,7 +52,6 @@ - ampel.contrib.hu.t3.RapidBase - ampel.contrib.hu.t3.RapidSedm - ampel.contrib.hu.t3.RapidLco -- ampel.contrib.hu.t3.TNSTalker - ampel.contrib.hu.t3.TNSMirrorUpdater - ampel.contrib.hu.t3.TransientTablePublisher - ampel.contrib.hu.t3.HealpixCorrPlotter @@ -62,6 +65,8 @@ - ampel.contrib.hu.t3.CostCounter - ampel.contrib.hu.t3.ScoreSingleObject - ampel.contrib.hu.t3.ScoreTNSObjects +- ampel.contrib.hu.t3.PlotTransientLightcurves +- ampel.contrib.hu.t3.SubmitTNS - ampel.contrib.hu.alert.load.WiseFileAlertLoader - ampel.contrib.hu.alert.NeoWisePhotometryAlertSupplier - ampel.contrib.hu.alert.DynamicShaperAlertConsumer diff --git a/examples/healpix_linfit.yml b/examples/healpix_linfit.yml new file mode 100644 index 00000000..e1f3c648 --- /dev/null +++ b/examples/healpix_linfit.yml @@ -0,0 +1,206 @@ +name: healpix_linfit +parameters: +- name: map_url_var + value: https://gracedb.ligo.org/api/superevents/S231102w/files/bayestar.fits.gz,1 +- name: map_name_var + value: S231102w.fits.gz,1 +- name: map_token_var + value: S231102w.fits.gz,1_token +- name: trigger_jd_var + value: 2460250.5 +- name: export_fmt + value: csv +- name: transienttable_path + value: ./TransientTable_linfit.csv +- name: channelname + value: testI + +mongo: + prefix: evalLinfit + reset: true + +channel: +- name: testI + access: [ZTF, ZTF_PUB, ZTF_PRIV] + policy: [] + +task: + +- title: token + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3PlainUnitExecutor + config: + target: + unit: HealpixTokenGenerator + config: + pvalue_limit: 0.9 + chunk_size: 100 + map_name: "{{ job.parameters.map_name_var }}" + map_url: "{{ job.parameters.map_url_var }}" + map_dir: ./ + delta_time: 40 + archive_token: + label: ztf/archive/token + candidate: + ndethist: + $gte: 6. + drb: + $gt: 0.995 + rb: + $gt: 0.3 + isdiffpos: + $in: + - "t" + - "1" + +- title: alerts + unit: DynamicShaperAlertConsumer + config: + shaper_map: + map_name: healpix_map_name + healpix_info: "{{ job.parameters.map_name_var }}" # <> + iter_max: 1000000 + supplier: + unit: ZiAlertSupplier + config: + deserialize: null + loader: + unit: ZTFArchiveAlertLoader + config: + with_history: false + resource_name: "{{ job.parameters.map_token_var }}" + shaper: ZiGWDataPointShaper + directives: + - channel: "{{ job.parameters.channelname }}" + filter: + config: + trigger_jd: "{{ job.parameters.trigger_jd_var }}" + min_ndet: 0 + min_tspan: -1 + max_tspan: 42 + max_archive_tspan: 42 + min_drb: 0.3 + min_gal_lat: 0 + min_rb: 0.0 + min_sso_dist: 20 + gaia_excessnoise_sig_max: 999 + gaia_plx_signif: 3 + gaia_pm_signif: 3 + gaia_rs: 0 + gaia_veto_gmag_max: 20 + gaia_veto_gmag_min: 9 + ps1_confusion_rad: 3 + ps1_confusion_sg_tol: 0.1 + ps1_sgveto_rad: 1 + ps1_sgveto_th: 0.8 + max_fwhm: 5.5 + max_elong: 2 + max_magdiff: 1 + max_nbad: 2 + on_stock_match: bypass + unit: PredetectionFilter + ingest: + mux: + combine: + - state_t2: + - unit: T2LineFit + config: + order: 1 + tabulator: + - unit: ZTFT2Tabulator + unit: ZiT1Combiner + unit: ZiArchiveMuxer + config: + future_days: 3 + history_days: 50 + + +- title: t2 + unit: T2Worker + config: + send_beacon: false + raise_exc: true + + +- title: PrintTable + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3ReviewUnitExecutor + config: + supply: + unit: T3DefaultBufferSupplier + config: + select: + unit: T3StockSelector + config: + channel: "{{ job.parameters.channelname }}" + load: + unit: T3SimpleDataLoader + config: + directives: + - STOCK + - T1 + - T2DOC + - DATAPOINT + channel: "{{ job.parameters.channelname }}" + stage: + unit: T3SimpleStager + config: + execute: + - unit: TransientTablePublisher + config: + include_stock: true + include_channels: true + local_path: ./ + table_schema: + T2LineFit: + 'chi2dof': + - chi2dof + 'slope': + - p1 + transient_table_schema: + T2HealpixProb: + 'map_area': + - map_area + + +- title: PlotLightcurves + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3ReviewUnitExecutor + config: + supply: + unit: T3DefaultBufferSupplier + config: + select: + unit: T3StockSelector + config: + channel: "{{ job.parameters.channelname }}" + load: + unit: T3SimpleDataLoader + config: + directives: + - STOCK + - T1 + - T2DOC + - DATAPOINT + channel: "{{ job.parameters.channelname }}" + stage: + unit: T3SimpleStager + config: + execute: + - unit: PlotTransientLightcurves + config: + pdf_path: candidates_linfit.pdf + save_png: false + include_cutouts: true + tabulator: + - unit: ZTFT2Tabulator + diff --git a/examples/infant_test.yml b/examples/infant_test.yml new file mode 100644 index 00000000..22c7f08e --- /dev/null +++ b/examples/infant_test.yml @@ -0,0 +1,378 @@ +name: infantmay10 +parameters: +- name: channelname + value: may10 +- name: date + value: "2024-05-10" +- name: deltat + value: 1 +- name: mindet + value: 1 +- name: maxdet + value: 8 + +mongo: + prefix: infantEval + reset: false + +channel: +- name: may10 + access: [ZTF, ZTF_PUB, ZTF_PRIV] + policy: [] + +task: + + +- title: token + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3PlainUnitExecutor + config: + target: + unit: T3ZTFArchiveTokenGenerator + config: + date_str: "{{ job.parameters.date }}" + delta_t: "{{ job.parameters.deltat }}" + debug: true + resource_name: ztf_stream_token + candidate: + ndethist: + $gte: "{{ job.parameters.mindet }}" + $lte: "{{ job.parameters.maxdet }}" + drb: + $gt: 0.995 + magpsf: + $gt: 18. + rb: + $gt: 0.5 +# ssdistnr: +# $lt: 0 + isdiffpos: + $in: + - "t" + - "1" + + +- title: NearbyInfantReact + unit: AlertConsumer + config: + iter_max: 100000 + supplier: + unit: ZiAlertSupplier + config: + deserialize: null + loader: + unit: ZTFArchiveAlertLoader + config: + resource_name: ztf_stream_token + + shaper: ZiDataPointShaper + directives: + - channel: "{{ job.parameters.channelname }}" + filter: + config: + gaia_excessnoise_sig_max: 999 + gaia_plx_signif: 3 + gaia_pm_signif: 3 + gaia_rs: 20 + gaia_veto_gmag_max: 20 + gaia_veto_gmag_min: 9 + min_ndet: 1 + min_tspan: -99 + max_tspan: 100 + min_archive_tspan: -99 + max_archive_tspan: 10000 + min_drb: 0.995 + min_gal_lat: 14 + min_rb: 0.3 + min_sso_dist: 20 + ps1_confusion_rad: 0 # Turns off PS1 confusion check. Maybe redundnant with drb? + ps1_confusion_sg_tol: 0.1 + ps1_sgveto_rad: 1 + ps1_sgveto_th: 0.8 + max_fwhm: 5.5 + max_elong: 2 + max_magdiff: 1 + max_nbad: 2 + on_stock_match: bypass + unit: DecentFilter + ingest: + mux: + combine: + - state_t2: + - unit: T2MatchBTS + - unit: T2LightCurveSummary + - config: + max_age: 10. + maglim_maxago: 10. + min_redshift: 0.0004 + min_magpull: 2 + max_absmag: -12 + lc_filters: + - attribute: sharpnr + operator: ">=" + value: -10.15 + - attribute: magfromlim + operator: ">" + value: 0 + - attribute: chipsf + operator: "<" + value: 4 + - attribute: sumrat + operator: ">" + value: 0.9 + det_filterids: + - 1 + - 2 + t2_dependency: + - config: &catalog_match_config + catalogs: + GLADEv23: + keys_to_append: + - z + - dist + - dist_err + - flag1 + - flag2 + - flag3 + rs_arcsec: 10 + use: extcats + NEDz_extcats: + keys_to_append: + - ObjType + - Velocity + - z + rs_arcsec: 30.0 + use: extcats + NEDz: + keys_to_append: + - ObjType + - Velocity + - z + rs_arcsec: 10.0 + use: catsHTM + NEDLVS: + keys_to_append: + - objname + - objtype + - dec + - z_unc + - z_tech + - z_qual + - z_qual_flag + - z + rs_arcsec: 10.0 + use: extcats + SDSS_spec: + keys_to_append: + - z + - bptclass + - subclass + rs_arcsec: 10.0 + use: extcats + milliquas: + use: extcats + rs_arcsec: 3 + keys_to_append: + - broad_type + - name + - redshift + - qso_prob + SDSSDR10: + use: catsHTM + rs_arcsec: 3 + keys_to_append: + - type + - flags + link_override: + filter: PPSFilter + select: first + sort: jd + unit: T2CatalogMatch + unit: T2InfantCatalogEval + - unit: T2DigestRedshifts + config: &digest_config + max_redshift_category: 7 + t2_dependency: + - config: *catalog_match_config + link_override: + filter: PPSFilter + select: first + sort: jd + unit: T2CatalogMatch + - unit: T2DemoLightcurveFitter + config: + max_redshift_category: 7 + tabulator: + - unit: ZTFT2Tabulator + t2_dependency: + - config: *catalog_match_config + link_override: + filter: PPSFilter + select: first + sort: jd + unit: T2CatalogMatch + unit: ZiT1Combiner + insert: + point_t2: + - config: *catalog_match_config + ingest: + filter: PPSFilter + select: first + sort: jd + unit: T2CatalogMatch + unit: ZiMongoMuxer + +- title: Run T2s + unit: T2Worker + config: + send_beacon: false + raise_exc: true + + + +- title: React + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3ReviewUnitExecutor + config: + supply: + unit: T3DefaultBufferSupplier + config: + select: + unit: T3FilteringStockSelector + config: + channel: "{{ job.parameters.channelname }}" + t2_filter: + unit: T2InfantCatalogEval + match: + action: true + + load: + unit: T3SimpleDataLoader + config: + directives: + - STOCK + - T1 + - T2DOC + - DATAPOINT + channel: "{{ job.parameters.channelname }}" + stage: + unit: T3SimpleStager + config: + execute: + - unit: TransientTablePublisher + config: + include_stock: true + include_channels: true + local_path: ./ + table_schema: + T2InfantCatalogEval: + 'ndet': + - detections + 'age': + - age + 'last_UL': + - last_UL + 'peak_mag': + - peak_mag + 'latest_mag': + - latest_mag + 'rb': + - rb + 'drb': + - drb + 'absmag': + - absmag + 'action': + - action + 'infantNEDz': + - NEDz_extcats_z + 'infantNEDdist': + - NEDz_extcats_dist2transient + 'infantNEDkpc': + - NEDz_extcats_kpcdist + 'infantGladez': + - GLADEv23_z + 'infantGladedist': + - GLADEv23_dist2transient + 'infantGladekpc': + - GLADEv23_kpcdist + 'infantNedLvsz': + - NEDLVS_z + 'infantNedLvsdist': + - NEDLVS_dist2transient + 'infantNedLvskpc': + - NEDLVS_kpcdist + transient_table_schema: + T2CatalogMatch: + 'Glade z': + - GLADEv23 + - z + 'NED z': + - NEDz_extcats + - z + 'NED offset': + - NEDz_extcats + - dist2transient + 'NEDLVS z': + - NEDLVS + - z + 'NEDLV offset': + - NEDLVS + - dist2transient + +- title: React + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3ReviewUnitExecutor + config: + supply: + unit: T3DefaultBufferSupplier + config: + select: + unit: T3FilteringStockSelector + config: + channel: "{{ job.parameters.channelname }}" + t2_filter: + unit: T2InfantCatalogEval + match: + action: true + + load: + unit: T3SimpleDataLoader + config: + directives: + - STOCK + - T1 + - T2DOC + - DATAPOINT + channel: "{{ job.parameters.channelname }}" + complement: + - unit: ZTFCutoutImages + config: + eligible: last + - unit: TNSNames + config: + include_report: true + stage: + unit: T3SimpleStager + config: + execute: + - unit: PlotTransientLightcurves + config: + pdf_path: candidates.pdf + save_png: true + include_cutouts: true + slack_channel: "#ztf_auto" + slack_token: + label: "slack/ztf_ia/jno" + tabulator: + - unit: ZTFT2Tabulator + diff --git a/examples/stellar_outburst.yml b/examples/stellar_outburst.yml new file mode 100644 index 00000000..a2a4c016 --- /dev/null +++ b/examples/stellar_outburst.yml @@ -0,0 +1,231 @@ +name: stellarapr17 +parameters: +- name: channelname + value: apr17 +- name: date + value: "2024-04-17" + +mongo: + prefix: evalOutburst + reset: false + +channel: +- name: apr17 + access: [ZTF, ZTF_PUB, ZTF_PRIV] + policy: [] + +task: + + +- title: token + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3PlainUnitExecutor + config: + target: + unit: T3ZTFArchiveTokenGenerator + config: + date_str: "{{ job.parameters.date }}" + delta_t: 1. + resource_name: ztf_stream_token + candidate: + ndethist: + $gte: 3. +# $lte: 10. + drb: + $gt: 0.995 + magpsf: + $lt: 18.5 + rb: + $gt: 0.3 + isdiffpos: + $in: + - "t" + - "1" + + +- title: filterStellarOUtburst + unit: AlertConsumer + config: + iter_max: 100000 + supplier: + unit: ZiAlertSupplier + config: + deserialize: null + loader: + unit: ZTFArchiveAlertLoader + config: + resource_name: ztf_stream_token + + shaper: ZiDataPointShaper + directives: + - channel: "{{ job.parameters.channelname }}" + filter: + config: + min_ndet: 3 + max_ndet: 10 + min_tspan: -99 + max_tspan: 100 + min_archive_tspan: -99 + max_archive_tspan: 10000 + max_mag: 18.5 + peak_time_limit: 5. + min_peak_diff: 1. + min_drb: 0.995 + min_rb: 0.3 + require_ps_star: true + require_gaia_star: true + on_stock_match: bypass + unit: StellarFilter + ingest: + mux: + combine: + - state_t2: + - unit: T2LineFit + config: + tabulator: + - unit: ZTFT2Tabulator + unit: ZiT1Combiner + insert: + point_t2: + - config: + catalogs: + NEDz: + keys_to_append: + - ObjType + - Velocity + - z + rs_arcsec: 10.0 + use: catsHTM + NEDLVS: + keys_to_append: + - objname + - objtype + - dec + - z_unc + - z_tech + - z_qual + - z_qual_flag + - z + rs_arcsec: 10.0 + use: extcats + SDSS_spec: + keys_to_append: + - z + - bptclass + - subclass + rs_arcsec: 3 + use: extcats + SDSSDR10: + use: catsHTM + rs_arcsec: 3 + keys_to_append: + - type + - flags + ingest: + filter: PPSFilter + select: first + sort: jd + unit: T2CatalogMatch + unit: ZiMongoMuxer + +- title: RunT2s + unit: T2Worker + config: + send_beacon: false + raise_exc: true + + + +- title: PrintTable + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3ReviewUnitExecutor + config: + supply: + unit: T3DefaultBufferSupplier + config: + select: + unit: T3StockSelector + config: + channel: "{{ job.parameters.channelname }}" + load: + unit: T3SimpleDataLoader + config: + directives: + - STOCK + - T1 + - T2DOC + - DATAPOINT + channel: "{{ job.parameters.channelname }}" + stage: + unit: T3SimpleStager + config: + execute: + - unit: TransientTablePublisher + config: + include_stock: true + include_channels: true + local_path: ./ + table_schema: + T2InfantCatalogEval: + 'ndet': + - detections + transient_table_schema: + T2CatalogMatch: + 'NED z': + - NEDz_extcats + - z + 'NEDLVS z': + - NEDLVS + - z + 'SDSS spec z': + - SDSS_spec + - z + 'SDSS spec class': + - SDSS_spec + - bptclass + 'SDSS DR10 class': + - SDSSDR10 + - type + + +- title: PlotLightcurves + unit: T3Processor + config: + raise_exc: true + execute: + - unit: T3ReviewUnitExecutor + config: + supply: + unit: T3DefaultBufferSupplier + config: + select: + unit: T3StockSelector + config: + channel: "{{ job.parameters.channelname }}" + load: + unit: T3SimpleDataLoader + config: + directives: + - STOCK + - T1 + - T2DOC + - DATAPOINT + channel: "{{ job.parameters.channelname }}" + stage: + unit: T3SimpleStager + config: + execute: + - unit: PlotTransientLightcurves + config: + pdf_path: candidates.pdf + save_png: false + include_cutouts: true + tabulator: + - unit: ZTFT2Tabulator + diff --git a/poetry.lock b/poetry.lock index e76d7937..912b73bd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6218,6 +6218,17 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +[[package]] +name = "types-pillow" +version = "10.2.0.20240511" +description = "Typing stubs for Pillow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-Pillow-10.2.0.20240511.tar.gz", hash = "sha256:b2fcc27b8e15ae3741941e43b4f39eba6fce6bcb152af90bbb07b387d2585783"}, + {file = "types_Pillow-10.2.0.20240511-py3-none-any.whl", hash = "sha256:ef87a19ea0a02a89c784cbc1b99dfff6c00dd0d5796a8ac868cf7ec69c5f88ff"}, +] + [[package]] name = "types-python-dateutil" version = "2.9.0.20240316" @@ -6837,6 +6848,24 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +[[package]] +name = "ztfquery" +version = "1.27.1" +description = "Python package to access ZTF data" +optional = true +python-versions = ">=3.8" +files = [ + {file = "ztfquery-1.27.1-py3-none-any.whl", hash = "sha256:a717ff9bed70d31f3cee7b80baffd9945d6c8bbc669b736b2410251d92c9e6b2"}, + {file = "ztfquery-1.27.1.tar.gz", hash = "sha256:19801e66e8b33b67feb6bf05886c104c9b96fd5c03fe23e8b1d559f50e67cf0f"}, +] + +[package.dependencies] +astropy = ">=5.2" +matplotlib = ">=3.7" +numpy = ">=1.24" +pandas = ">=1.7" +requests = ">=2.28" + [extras] elasticc = ["ampel-lsst", "astro-parsnip", "timeout-decorator", "xgboost"] extcats = ["extcats"] @@ -6848,9 +6877,9 @@ slack = ["slack-sdk"] sncosmo = ["iminuit", "sfdmap2", "sncosmo"] snpy = ["snpy"] voevent = ["voevent-parse"] -ztf = ["ampel-ztf"] +ztf = ["ampel-ztf", "ztfquery"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "23ea648d706d1fc1240e76bb735dd020623b45a292071eacfd4f26a23d5ae803" +content-hash = "5cca9f6404161d00f85c4cf0e284c8adb445f8ed7d9217b6d6e4133e7138a620" diff --git a/pyproject.toml b/pyproject.toml index e5f9d73a..baab9d75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ requests = "^2.26.0" astropy = ">=5.0" # pymage never made it to pypi; distribute our own package pymage = {version = "^0.5", optional = true, source = "desy-gitlab"} -pandas = "^2.0.0" +pandas = ">=2.0" seaborn = "^0.12.0" adjustText = "^1.0.0" extcats = {version = "^2.4.2", optional = true, source = "pypi"} @@ -65,6 +65,7 @@ scikit-learn = "^1.1.3" healpy = {version = "^1.16.2", optional = true} light-curve = {version = "^0.7.3"} ampel-lsst = {version = ">=0.8.6,<0.9", optional = true} +ztfquery = {version = "^1.26.1", optional = true} ligo-gracedb = {version = "^2.12.0", optional = true} astro-datalab = {version = "^2", optional = true} # mainline snoopy can't be built as PEP 517 package; use our own distribution @@ -77,6 +78,7 @@ pytest-cov = "^5.0.0" pytest-mock = "^3.12.0" types-requests = "^2.25.9" types-pytz = "^2022.1.2" +types-pillow = "^10.2.0.20240213" # prevent poetry 1.3 from removing setuptools setuptools = "*" @@ -91,7 +93,7 @@ sncosmo = ["sncosmo", "iminuit", "sfdmap2"] snpy = ["snpy"] notebook = ["jupyter"] voevent = ["voevent-parse"] -ztf = ["ampel-ztf"] +ztf = ["ampel-ztf", "ztfquery"] [tool.poetry.group.dev.dependencies] ruff = "^0.1.13" diff --git a/scripts/generate_unit_inventory.py b/scripts/generate_unit_inventory.py index 3aefbb08..f9cf79b2 100755 --- a/scripts/generate_unit_inventory.py +++ b/scripts/generate_unit_inventory.py @@ -88,9 +88,9 @@ def open_target_file(): and unit.__doc__.strip() and not any(unit.__doc__ == base.__doc__ for base in unit.__bases__) ): - # extract the first sentence, unwrap, and add a period + # extract the first sentence (ends with punctuation followed by something other than a lowercase letter), unwrap, and add a period doc = re.split( - r"([\.!\?:]\s+|\n\n)", + r"([\.!\?:]\s+?([^a-z]|\n)|\n\n)", inspect.getdoc(unit).strip(), maxsplit=1, flags=re.MULTILINE,