In [1]:
from ampel_notebook_utils import api_get_lightcurve, api_get_lightcurves
import numpy as np
import matplotlib.pyplot as plt
from ampel.log.AmpelLogger import AmpelLogger
from ampel.abstract.AbsLightCurveT2Unit import AbsLightCurveT2Unit
from typing import Optional
from ampel.view.LightCurve import LightCurve
from ampel.types import UBson
from ampel.struct.UnitResult import UnitResult

archivetoken = '' 

In [2]:
class T2OddSNFinder(AbsLightCurveT2Unit):
    """
    Find SN that do not follow expected behaviour
    Built for Ia only (for now)
    
    Parameters:
     *use_filters* lists the filter ids, as encoded in the datapoint "fid" field
    (e.g. present in ZTF alerts).
    
    Optionally plot debug lightcurves
     *plot*
     *plot_dir* 
    
    """

    # These parameters can be provided by the user when specifying the channel
    # If no (default) value exists this has to be provided
    use_filters: list[int] = [1, 2, 3]
        
    plot: bool = False
    plot_dir: Optional[str]

    def process(self, light_curve: LightCurve) -> UBson | UnitResult:
        # Extract residuals in required bands, remove those where isdiffposs != True
        data = np.asarray([[float(p[0]), float(p[1]), float(p[2]), float(p[4])] for p in light_curve.get_ntuples(['jd', 'magpsf', 'sigmapsf', 'isdiffpos', 'fid'])
                           if ((int(p[-1]) in self.use_filters)&(p[-2]=='t'))])
        
        #If the cuts result in no data being left, return None for all values
        if len(data) == 0:
            return {'first_det_date':None, 'peak_date':None, 'peak_mag':None, 'last_det_date_pred':None,
                    'last_det_date':None, 'maxgap_size':None, 'maxgap_end':None, 'ndet':0}
        
        # Find peak
        peak_date = data[np.where(data[:,1]==data[:,1].min())[0][0], 0]
        peak_mag = data[np.where(data[:,1]==data[:,1].min())[0][0], 1]
        peak_fid = data[np.where(data[:,1]==data[:,1].min())[0][0], -1]
        
        # Estimate when object will fade below the detection treshold
        # Current estimation uses a broken line approximation based on SN 2011fe
        # --> Try for each band, take max of found times
        vis_times = []
        for i in self.use_filters:
            banddata = data[np.where(data[:,-1]==i)]
            if len(banddata) == 0:
                continue
            band_peak = banddata[np.where(banddata[:,1]==banddata[:,1].min())[0][0], 1]
            vis_times.append(self.vis_time(band_peak, i))
        last_det_pred = peak_date + max(vis_times)
        
        #Find the biggest gap between 2 detections
        maxgap_size = max(data[:-1,0]-data[1:,0])
        maxgap_end = data[np.where(data[:-1,0]-data[1:,0] == maxgap_size)[0][0]+1,0]
                         
        # Construct output dictionary
        t2_output = {'first_det_date':min(data[:,0]), 'peak_date':peak_date, 'peak_mag':peak_mag,
                     'peak_fid':peak_fid, 'last_det_date_pred':last_det_pred, 'last_det_date':max(data[:,0]),
                     'delta_last_det':max(data[:,0])-last_det_pred, 'maxgap_size':maxgap_size,
                     'maxgap_end': maxgap_end, 'ndet':len(data)}
        
        # Optionally make debug plot 
        if self.plot and self.plot_dir is not None:
            plt.figure()
            for filt in self.use_filters:
                if filt == 1:
                    color = 'g'
                    band = 'g'
                elif filt == 2:
                    color = 'r'
                    band = 'r'
                else:
                    color = 'k'
                    band = 'i'
                plt.errorbar([p[0] for p in data if p[-1]==filt], [p[1] for p in data if p[-1]==filt],
                             yerr=[p[2] for p in data if p[-1]==filt], color=color, ls='none', marker='.',
                             label=band)
            plt.xlabel('JD')
            plt.ylabel('mag')
            plt.gca().invert_yaxis()
            plt.legend()
            plt.savefig( os.path.join(self.plot_dir, 'plot_lc.png') )
            plt.close()

        return t2_output
    
    #3 short functions to aproximate the visible time after peak per band & a function to coordinate which to use
    #These are functions based on SN 2011fe
    #Give in peak value, return visible time
    def vis_time(self, mag, band):
        if band == 1:
            return self.gfunc(mag)
        elif band == 2:
            return self.rfunc(mag)
        elif band == 3:
            return self.ifunc(mag)
        else:# band not recognised
            return 0

    def gfunc(sefl, mag):
        if mag > 17.8:
            return -10*mag+210
        else:
            return -58*mag+1070

    def rfunc(self, mag):
        if mag > 17.8:
            return -20*mag+410
        else:
            return -40*mag+770

    def ifunc(self, mag):
        if mag > 17.8:
            return -18*mag+386
        else:
            return -35*mag+690

In [3]:
#Quick test
snname = "ZTF20abzetdf" #test object, sibling

lightCurve = api_get_lightcurve(snname, archivetoken)

t2instance = T2OddSNFinder(logger=AmpelLogger.get_logger())
t2instance.process(lightCurve)

{'first_det_date': 2459104.9511921,
 'peak_date': 2459114.0022454,
 'peak_mag': 18.262699127197266,
 'peak_fid': 2.0,
 'last_det_date_pred': 2459158.748262856,
 'last_det_date': 2459903.8366551,
 'delta_last_det': 745.0883922437206,
 'maxgap_size': 700.9756481000222,
 'maxgap_end': 2459166.9252778,
 'ndet': 40}