### Anomaly Detection with ORACLE

In [1]:
# !pip install numpy polars tqdm astropy tf_keras tensorflow 
# uncomment line above if necessary
import os
import pickle
import numpy as np
import polars as pl
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy import units as u
from astropy.coordinates import SkyCoord
from collections import OrderedDict
from taxonomy import get_classification_labels, get_astrophysical_class, plot_colored_tree
import tf_keras as keras
import tensorflow as tf
import itertools
os.environ["TF_USE_LEGACY_KERAS"] = "1"




**Reading in Parquet files using Polars**

In [2]:
static_feature_list = ['MWEBV', 'MWEBV_ERR', 'REDSHIFT_HELIO', 'REDSHIFT_HELIO_ERR', 'HOSTGAL_PHOTOZ', 'HOSTGAL_PHOTOZ_ERR', 'HOSTGAL_SPECZ', 'HOSTGAL_SPECZ_ERR', 'HOSTGAL_RA', 'HOSTGAL_DEC', 'HOSTGAL_SNSEP', 'HOSTGAL_ELLIPTICITY', 'HOSTGAL_MAG_u', 'HOSTGAL_MAG_g', 'HOSTGAL_MAG_r', 'HOSTGAL_MAG_i', 'HOSTGAL_MAG_z', 'HOSTGAL_MAG_Y', 'MW_plane_flag', 'ELAIS_S1_flag', 'XMM-LSS_flag', 'Extended_Chandra_Deep_Field-South_flag', 'COSMOS_flag']

In [3]:
train_path = "./../../train_parquet-001.parquet"
test_path = "./../../test_parquet-002.parquet"
train_parquet = pl.read_parquet(train_path)
train_parquet.head()

SNID,IAUC,FAKE,RA,DEC,PIXSIZE,NXPIX,NYPIX,SNTYPE,NOBS,PTROBS_MIN,PTROBS_MAX,MWEBV,MWEBV_ERR,REDSHIFT_HELIO,REDSHIFT_HELIO_ERR,REDSHIFT_FINAL,REDSHIFT_FINAL_ERR,VPEC,VPEC_ERR,HOSTGAL_NMATCH,HOSTGAL_NMATCH2,HOSTGAL_OBJID,HOSTGAL_FLAG,HOSTGAL_PHOTOZ,HOSTGAL_PHOTOZ_ERR,HOSTGAL_SPECZ,HOSTGAL_SPECZ_ERR,HOSTGAL_RA,HOSTGAL_DEC,HOSTGAL_SNSEP,HOSTGAL_DDLR,HOSTGAL_CONFUSION,HOSTGAL_LOGMASS,HOSTGAL_LOGMASS_ERR,HOSTGAL_LOGSFR,HOSTGAL_LOGSFR_ERR,…,PSF_SIG2,PSF_RATIO,SKY_SIG,SKY_SIG_T,RDNOISE,ZEROPT,ZEROPT_ERR,GAIN,XPIX,YPIX,SIM_FLUXCAL_HOSTERR,SIM_MAGOBS,ELASTICC_class,AGN_PARAM(M_BH),AGN_PARAM(Mi),AGN_PARAM(edd_ratio),AGN_PARAM(edd_ratio2),AGN_PARAM(t_transition),AGN_PARAM(cl_flag),SIM_TEMPLATEMAG_u,SIM_TEMPLATEMAG_g,SIM_TEMPLATEMAG_r,SIM_TEMPLATEMAG_i,SIM_TEMPLATEMAG_z,SIM_TEMPLATEMAG_Y,SIM_HOSTLIB(g_obs),SIM_HOSTLIB(r_obs),SIM_HOSTLIB(i_obs),SIM_HOSTLIB(LOGMASS_TRUE),SIM_HOSTLIB(LOG_SFR),SIM_SALT2x0,SIM_SALT2x1,SIM_SALT2c,SIM_SALT2mB,SIM_SALT2alpha,SIM_SALT2beta,SIM_SALT2gammaDM
i64,binary,i64,f64,f64,f64,i64,i64,i64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
73817174,"b""NULL\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20""",2,194.194337,-16.671913,0.2,-9,-9,146,175,35222,35396,0.045439,0.002272,0.179155,0.1824,0.180447,0.1824,0.0,300.0,1,1,8062583700,0,0.179155,0.1824,-9.0,-9.0,194.193886,-16.671998,1.584567,2.227055,-99.0,10.4626,-9999.0,-9999.0,-9999.0,…,"[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[31.76, 33.43, … 51.349998]","[0.0, 0.0, … 0.0]","[0.25, 0.25, … 0.25]","[31.59, 31.33, … 29.959999]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[-9.0, -9.0, … -9.0]","[-9.0, -9.0, … -9.0]","[0.0, 0.0, … 0.0]","[99.0, 99.0, … 24.027132]","""CART""",,,,,,,,,,,,,,,,,,,,,,,,
10165293,"b""NULL\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20""",2,88.425799,-29.108888,0.2,-9,-9,146,66,17198,17263,0.026972,0.001349,0.057581,0.05264,0.057864,0.05264,0.0,300.0,1,1,10437527333,0,0.057581,0.05264,-9.0,-9.0,88.425761,-29.108741,0.541073,0.776435,-99.0,8.0083,-9999.0,-9999.0,-9999.0,…,"[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[41.720001, 40.919998, … 31.98]","[0.0, 0.0, … 0.0]","[0.25, 0.25, … 0.25]","[31.01, 30.120001, … 31.309999]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[-9.0, -9.0, … -9.0]","[-9.0, -9.0, … -9.0]","[0.0, 0.0, … 0.0]","[99.0, 99.0, … 26.480591]","""CART""",,,,,,,,,,,,,,,,,,,,,,,,
37997692,"b""NULL\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20""",2,211.202277,-5.134412,0.2,-9,-9,146,23,30944,30966,0.023278,0.001164,0.219521,0.21094,0.220621,0.21094,0.0,300.0,1,1,9562542552,0,0.219521,0.21094,-9.0,-9.0,211.202325,-5.134305,0.424066,0.495451,-99.0,9.2915,-9999.0,-9999.0,-9999.0,…,"[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[53.279999, 20.73, … 27.219999]","[0.0, 0.0, … 0.0]","[0.25, 0.25, … 0.25]","[30.120001, 31.610001, … 31.540001]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[-9.0, -9.0, … -9.0]","[-9.0, -9.0, … -9.0]","[0.0, 0.0, … 0.0]","[99.0, 99.0, … 27.329807]","""CART""",,,,,,,,,,,,,,,,,,,,,,,,
9722918,"b""NULL\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20""",2,155.447505,10.601137,0.2,-9,-9,146,161,3374,3534,0.028591,0.00143,0.325141,0.33938,0.326663,0.33938,0.0,300.0,1,1,9562560222,0,0.325141,0.33938,-9.0,-9.0,155.447689,10.601505,1.474684,2.029634,-99.0,9.6228,-9999.0,-9999.0,-9999.0,…,"[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[51.740002, 73.540001, … 64.440002]","[0.0, 0.0, … 0.0]","[0.25, 0.25, … 0.25]","[31.299999, 30.950001, … 30.99]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[-9.0, -9.0, … -9.0]","[-9.0, -9.0, … -9.0]","[0.0, 0.0, … 0.0]","[99.0, 99.0, … 28.176512]","""CART""",,,,,,,,,,,,,,,,,,,,,,,,
12318140,"b""NULL\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20""",2,220.748793,-42.183707,0.2,-9,-9,146,100,6493,6592,0.095204,0.00476,0.198641,0.21107,0.199419,0.21107,0.0,300.0,1,1,7875083702,0,0.198641,0.21107,-9.0,-9.0,220.748799,-42.183895,0.679666,0.952445,-99.0,8.9935,-9999.0,-9999.0,-9999.0,…,"[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[46.209999, 46.169998, … 52.509998]","[0.0, 0.0, … 0.0]","[0.25, 0.25, … 0.25]","[30.129999, 30.139999, … 29.84]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[-9.0, -9.0, … -9.0]","[-9.0, -9.0, … -9.0]","[0.0, 0.0, … 0.0]","[25.284834, 25.281155, … 27.903942]","""CART""",,,,,,,,,,,,,,,,,,,,,,,,


In [4]:
def get_class_info(parquet):
    assert parquet.shape[0] > 1
    classes = []
    snids = []
    for row in range(parquet.shape[0]):
        classes.append(parquet[row]["ELASTICC_class"].to_numpy()[0])
        snids.append(parquet[row]["SNID"].to_numpy()[0])

    return classes, snids

classes, snids = get_class_info(train_parquet.head())
print(list(zip(classes, snids)))

[('CART', 73817174), ('CART', 10165293), ('CART', 37997692), ('CART', 9722918), ('CART', 12318140)]


The following `LSST_Source()` class creates an object for each individual lightcurve that computes and stores the flux values (and other time-series data), as well as the static features associated with the lightcurve.

The `LSSTSourceDataSet()` class then stores the collection of `LSST_Source` objects from all the lightcurves in the Polars DataFrame.

In [5]:
class LSST_Source:

    # List of time series features actually stored in the instance of the class.
    time_series_features = ['MJD', 'BAND', 'PHOTFLAG', 'FLUXCAL', 'FLUXCALERR']

    # List of other features actually stored in the instance of the class.
    other_features = ['RA', 'DEC', 'MWEBV', 'MWEBV_ERR', 'REDSHIFT_HELIO', 'REDSHIFT_HELIO_ERR', 'VPEC', 'VPEC_ERR', 'HOSTGAL_PHOTOZ', 'HOSTGAL_PHOTOZ_ERR', 'HOSTGAL_SPECZ', 'HOSTGAL_SPECZ_ERR', 'HOSTGAL_RA', 'HOSTGAL_DEC', 'HOSTGAL_SNSEP', 'HOSTGAL_DDLR', 'HOSTGAL_LOGMASS', 'HOSTGAL_LOGMASS_ERR', 'HOSTGAL_LOGSFR', 'HOSTGAL_LOGSFR_ERR', 'HOSTGAL_LOGsSFR', 'HOSTGAL_LOGsSFR_ERR', 'HOSTGAL_COLOR', 'HOSTGAL_COLOR_ERR', 'HOSTGAL_ELLIPTICITY', 'HOSTGAL_MAG_u', 'HOSTGAL_MAG_g', 'HOSTGAL_MAG_r', 'HOSTGAL_MAG_i', 'HOSTGAL_MAG_z', 'HOSTGAL_MAG_Y', 'HOSTGAL_MAGERR_u', 'HOSTGAL_MAGERR_g', 'HOSTGAL_MAGERR_r', 'HOSTGAL_MAGERR_i', 'HOSTGAL_MAGERR_z', 'HOSTGAL_MAGERR_Y']

    # Additional features computed based on time_series_features and other_features mentioned in SNANA fits.
    custom_engineered_features = ['MW_plane_flag', 'ELAIS_S1_flag', 'XMM-LSS_flag', 'Extended_Chandra_Deep_Field-South_flag', 'COSMOS_flag']

    # Get the mean wavelengths for each filter and then convert to micro meters
    pb_wavelengths = {
        'u': (320 + 400) / (2 * 1000),
        'g': (400 + 552) / (2 * 1000),
        'r': (552 + 691) / (2 * 1000),
        'i': (691 + 818) / (2 * 1000),
        'z': (818 + 922) / (2 * 1000),
        'Y': (950 + 1080) / (2 * 1000),
    }

    # Pass band to color dict
    colors = OrderedDict({
        'u': 'blue',
        'g': 'green',
        'r': 'red',
        'i': 'teal',
        'z': 'orange',
        'Y': 'purple',
    })

    # 6 broadband filters used in LSST.
    LSST_bands = list(colors.keys())

    # Coordinates for LSST's 4 selected deep drilling fields. (Reference: https://www.lsst.org/scientists/survey-design/ddf)
    LSST_DDF = {
        'ELAIS_S1': SkyCoord(l=311.30 * u.deg, b=-72.90 * u.deg, frame='galactic'),
        'XMM-LSS': SkyCoord(l=171.20 * u.deg, b=-58.77 * u.deg, frame='galactic'),
        'Extended_Chandra_Deep_Field-South': SkyCoord(l=224.07 * u.deg, b=-54.47 * u.deg, frame='galactic'),
        'COSMOS': SkyCoord(l=236.83 * u.deg, b=42.09 * u.deg, frame='galactic'),
    }


    # Threshold values

    # MW_plane_flag is set to one if |self.b| <= b_threshold. Indicative of weather the object is in the galactic plane.
    b_threshold = 15 

    # Flux scaling value
    flux_scaling_const = 1000

    # Radius of the deep drilling field for LSST, in degrees.
    ddf_separation_radius_threshold = 3.5 / 2


    def __init__(self, parquet_row) -> None:
        """Create an LSST_Source object to store both photometric and host galaxy data from the Elasticc simulations.

        Args:
            parquet_row (_type_): A row from the polars data frame that was generated from the Elasticc FITS files using fits_to_parquet.py
            class_label (str): The Elasticc class label for this LSST_Source object.
        """

        # Set all the class attributes
        setattr(self, 'ELASTICC_class', parquet_row['ELASTICC_class'].to_numpy()[0])
        setattr(self, 'SNID', parquet_row['SNID'].to_numpy()[0])
        setattr(self, 'astrophysical_class', get_astrophysical_class(self.ELASTICC_class))

        for key in parquet_row.columns:
            if key in self.other_features:
                setattr(self, key, parquet_row[key].to_numpy()[0])
            elif key in self.time_series_features:
                setattr(self, key, parquet_row[key][0].to_numpy())

        # Run processing code on the light curves
        self.process_lightcurve()

        # Computer additional features
        self.compute_custom_features()

    
    def process_lightcurve(self) -> None:
        """Process the flux information with phot flags. Processing is done using the following steps:
        1. Remove saturations.
        Finally, all the time series data is modified to conform to the steps mentioned above.
        """

        # Remove saturations from the light curves
        saturation_mask =  (self.PHOTFLAG & 1024) == 0 

        # Alter time series data to remove saturations
        for time_series_feature in self.time_series_features:
            setattr(self, time_series_feature, getattr(self, time_series_feature)[saturation_mask])
        
    def compute_custom_features(self) -> None:

        source_coord = SkyCoord(ra = self.RA * u.deg, dec=self.DEC * u.deg)

        # Check if the object is close to the galactic plane of the milky way
        if abs(source_coord.galactic.b.degree) < self.b_threshold: 
            self.MW_plane_flag = 1
        else:
            self.MW_plane_flag = 0
        
        # Check if the object is in one of 4 LSST DDF's and set flags appropriately
        for key in self.LSST_DDF:

            # Separation from field center
            separation = source_coord.separation(self.LSST_DDF[key]).degree

            if separation < self.ddf_separation_radius_threshold:
                setattr(self, f'{key}_flag', 1)
            else:
                setattr(self, f'{key}_flag', 0)

        pass


    def plot_flux_curve(self) -> None:
        """Plot the SNANA calibrated flux vs time plot for all the data in the processed time series. All detections are marked with a star while non detections are marked with dots. Observations are color codded by their passband. This function is fundamentally a visualization tool and is not intended for making plots for papers.
        """

        # Colorize the data
        c = [self.colors[band] for band in self.BAND]
        patches = [mpatches.Patch(color=self.colors[band], label=band, linewidth=1) for band in self.colors]
        fmts = np.where((self.PHOTFLAG & 4096) != 0, '*', '.')

        # Plot flux time series
        for i in range(len(self.MJD)):
            plt.errorbar(x=self.MJD[i], y=self.FLUXCAL[i], yerr=self.FLUXCALERR[i], color=c[i], fmt=fmts[i], markersize = '10')

        # Labels
        plt.title(f"SNID: {self.SNID} | CLASS: {self.ELASTICC_class}")
        plt.xlabel('Time (MJD)')
        plt.ylabel('Calibrated Flux')
        plt.legend(handles=patches)

        plt.show()

    def get_classification_labels(self):
        """Get the classification labels (hierarchical) for this LSST Source object in the Taxonomy tree.

        Returns:
            (tree_nodes, numerical_labels): A tuple containing two list like objects. The first object contains the ordering of the nodes. The second list contains the labels themselves (0 when the object does not belong to the class and 1 when it does). The labels in the second object correspond to the nodes in the first object.
        """
        return get_classification_labels(self.astrophysical_class)
    
    def plot_classification_tree(self):
        """Plot the classification tree (based on our taxonomy) for this LSST Source object.
        """

        node, labels = self.get_classification_labels()
        plot_colored_tree(labels)


    def get_event_table(self):

        # Dataframe for time series data
        table = Table()

        # Find time since last observation
        time_since_first_obs = self.MJD - self.MJD[0]
        table['scaled_time_since_first_obs'] = time_since_first_obs / 100

        # 1 if it was a detection, zero otherwise
        table['detection_flag'] = np.where((self.PHOTFLAG & 4096 != 0), 1, 0)

        # Transform flux cal and flux cal err to more manageable values (more consistent order of magnitude)
        table['scaled_FLUXCAL'] = self.FLUXCAL / self.flux_scaling_const
        table['scaled_FLUXCALERR'] = self.FLUXCALERR / self.flux_scaling_const

        # One hot encoding for the pass band
        table['band_label'] = [self.pb_wavelengths[pb] for pb in self.BAND]

        # Consistency check
        assert len(table) == len(self.MJD), "Length of time series tensor does not match the number of mjd values."

        # Array for static features
        feature_static = OrderedDict()
        for other_feature in self.other_features:
            feature_static[other_feature] = getattr(self, other_feature)

        for feature in self.custom_engineered_features:
            feature_static[feature] = getattr(self, feature)

        # Array for computed static features
        table.meta = feature_static
        return table

    def __str__(self) -> str:

        to_return = str(vars(self))
        return to_return

In [6]:
class LSSTSourceDataSet():


    def __init__(self, path):
        """
        Arguments:
            path (string): Parquet file.
            transform (callable, optional): Optional transform to be applied on a sample.
        """

        print(f'Loading parquet dataset: {path}', flush=True)

        self.path = path
        self.parquet = pl.read_parquet(path)
        self.num_sample = self.parquet.shape[0]

        print(f"Number of sources: {self.num_sample}")

    def get_len(self):

        return self.num_sample

    def get_item(self, idx):
        
        row = self.parquet[idx]
        source = LSST_Source(row)
        table = source.get_event_table()

        astrophysical_class = get_astrophysical_class(source.ELASTICC_class)
        _, class_labels = get_classification_labels(astrophysical_class)
        class_labels = np.array(class_labels)

        return source, class_labels
    
    def get_dimensions(self):

        idx = 0
        source, class_labels = self.get_item(idx)
        table = source.get_event_table()

        ts_np = table.to_pandas().to_numpy()
        static_np = np.array(list(table.meta.values()))

        dims = {
            'ts': ts_np.shape[1],
            'static': static_np.shape[0],
            'labels': class_labels.shape[0]
        }

        return dims
    
    def get_labels(self):

        ELASTICC_labels = self.parquet['ELASTICC_class']
        astrophysical_labels = []

        for idx in range(self.num_sample):

            elasticc_class = ELASTICC_labels[idx]
            astrophysical_class = get_astrophysical_class(elasticc_class)
            astrophysical_labels.append(astrophysical_class)
        
        return astrophysical_labels

In [7]:
train = LSSTSourceDataSet("./../../train_parquet-001.parquet")
test = LSSTSourceDataSet("./../../test_parquet-002.parquet")

Loading parquet dataset: ./../../train_parquet-001.parquet
Number of sources: 1081614
Loading parquet dataset: ./../../test_parquet-002.parquet
Number of sources: 463528


In [8]:
# printing out important lightcurve features/information from the LSST_Source objects
source, class_labels = train.get_item(10)
table = source.get_event_table()
print(table)
print(table.meta)

scaled_time_since_first_obs detection_flag ... scaled_FLUXCALERR band_label
--------------------------- -------------- ... ----------------- ----------
                        0.0              0 ...       0.023611953      1.015
      0.0002350000000296859              0 ...       0.032928236      1.015
        0.20086999999999533              0 ...      0.0039335443     0.6215
         0.2107680000000255              0 ...        0.00878829     0.7545
        0.21090800000005402              0 ...      0.0066291466     0.6215
         0.2208310000000347              0 ...      0.0087686395     0.7545
        0.22089600000006612              0 ...       0.010617377     0.7545
         0.2806570000000647              0 ...      0.0052202134     0.7545
         0.2807670000000508              0 ...       0.009135532       0.87
        0.31042600000000675              0 ...       0.012291644       0.87
                        ...            ... ...               ...        ...
         0.9

Preparing the NumPy arrays with the stored lightcurve features (and then saving them in Pickle files so that we don't need to reprocess them every time we want to run the model, because the processing is super time-consuming).

In [None]:
def save(save_path , obj):
    with open(save_path, 'wb') as f:
        pickle.dump(obj, f)

X_ts = []
X_static = []
Y = []
astrophysical_classes = []
elasticc_classes = []
lengths = []

train = LSSTSourceDataSet("./../../train_parquet-001.parquet")

for i in tqdm(range(train.get_len())):

    source, labels = train.get_item(i)
    table = source.get_event_table()

    meta_data = table.meta
    ts_data = pd.DataFrame(np.array(table)) # astropy table to pandas dataframe

    # Append data for ML
    X_ts.append(ts_data)
    X_static.append(meta_data)
    Y.append(labels)

    # Append other useful data
    astrophysical_classes.append(source.astrophysical_class)
    elasticc_classes.append(source.ELASTICC_class)
    lengths.append(ts_data.shape[0])

print("\nDumping to pickle...")

# Make a directory and save the data
os.makedirs(f"{"pickles"}", exist_ok=True)       

save(f"{"pickles"}/y.pkl", Y)
save(f"{"pickles"}/x_ts.pkl", X_ts)
save(f"{"pickles"}/x_static.pkl", X_static)
save(f"{"pickles"}/a_labels.pkl", astrophysical_classes)
save(f"{"pickles"}/e_label.pkl", elasticc_classes)
save(f"{"pickles"}/lengths.pkl", lengths)

In [9]:
!ls pickles

a_labels.pkl
e_label.pkl
lengths.pkl
x_static.pkl
x_ts.pkl
y.pkl


Loading the saved pickle files and unpacking them into arrays

In [10]:
def load(file_name):
    with open(file_name, 'rb') as f:
        return pickle.load(f)

In [11]:
from taxonomy import get_taxonomy_tree
class_count = 10
ts_length = 500
ts_flag_value = 0.

model = keras.models.load_model(f"models/lsst_alpha_0.5/best_model.h5", compile=False)

tree = get_taxonomy_tree()
X_ts = load(f"{"pickles"}/x_ts.pkl")
X_static = load(f"{"pickles"}/x_static.pkl")
Y = load(f"{"pickles"}/y.pkl")
astrophysical_classes = load(f"{"pickles"}/a_labels.pkl")




**Running the model**

Selecting the correct static and time-series features from the dataset and running the model with these feature values as input

In [12]:
from interpret_results import get_conditional_probabilites
from taxonomy import get_taxonomy_tree

static_features = ['MWEBV', 'MWEBV_ERR', 'REDSHIFT_HELIO', 'REDSHIFT_HELIO_ERR', 'HOSTGAL_PHOTOZ', 'HOSTGAL_PHOTOZ_ERR', 'HOSTGAL_SPECZ', 'HOSTGAL_SPECZ_ERR', 'HOSTGAL_RA', 'HOSTGAL_DEC', 'HOSTGAL_SNSEP', 'HOSTGAL_ELLIPTICITY', 'HOSTGAL_MAG_u', 'HOSTGAL_MAG_g', 'HOSTGAL_MAG_r', 'HOSTGAL_MAG_i', 'HOSTGAL_MAG_z', 'HOSTGAL_MAG_Y', 'MW_plane_flag', 'ELAIS_S1_flag', 'XMM-LSS_flag', 'Extended_Chandra_Deep_Field-South_flag', 'COSMOS_flag']
tree = get_taxonomy_tree()

class_probs = []

for j, c in enumerate(np.unique(astrophysical_classes)):

    idx = list(np.where(np.array(astrophysical_classes) == c)[0])[:class_count]
    X_ts_class = [X_ts[i] for i in idx]
    X_static_class = [X_static[i] for i in idx]
    Y_class = [Y[i] for i in idx]

    for i in range(5):

        table = X_ts_class[i]
        static = X_static_class[i]
        target = Y_class[i]

        tables = []
        statics = []
        statics_info = []
        for k in range(1, table.to_numpy().shape[0] + 1):
            tables.append(table.to_numpy()[:k, :])
            statics_info.append(static)
            # statics.append(list(static.values()))
            statics.append([static[feature] for feature in static_features])
        
        tables = tf.keras.utils.pad_sequences(tables, maxlen=ts_length,  dtype='float32', padding='post', value=ts_flag_value)
        
        statics = np.squeeze(statics)

        true_class_idx = np.argmax(target[-19:])

        tables = np.asarray(tables).astype(np.float32)
        statics = np.asarray(statics).astype(np.float32)
        
        logits = model.predict([tables, statics], verbose=0)
        _, pseudo_conditional_probabilities = get_conditional_probabilites(logits, tree)
        class_probs.append(pseudo_conditional_probabilities)

  pseudo_probabilities[:, mask] = np.exp(y_pred[:, mask]) / np.sum(np.exp(y_pred[:, mask]), axis=1, keepdims=True)
  pseudo_probabilities[:, mask] = np.exp(y_pred[:, mask]) / np.sum(np.exp(y_pred[:, mask]), axis=1, keepdims=True)


In [13]:
# saving model outputs so that they can be used again without having to re-run the model every time
!touch logits.npy

def save_logits():
    with open('logits.npy', 'wb') as f:
        np.save(f, logits)

def load_logits():
    logits = np.load('logits.npy')

# The function calls can be uncommented as needed and then the cell can be run to either save the current predictions or load the previous ones
# save_logits()
# load_logits()

**Visualizing model output**

In [14]:
import networkx as nx
from taxonomy import source_node_label

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10)

level_order_nodes = list(nx.bfs_tree(tree, source=source_node_label).nodes())
class_probs_df = [pd.DataFrame(probs, columns=level_order_nodes) for probs in class_probs]
class_probs_df[2]

Unnamed: 0,Alert,Transient,Variable,SN,Fast,Long,Periodic,AGN,SNIa,SNIb/c,SNIax,SNI91bg,SNII,KN,Dwarf Novae,uLens,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN,Cepheid,RR Lyrae,Delta Scuti,EB
0,1.0,9.492863e-01,0.050714,5.255018e-01,3.044619e-05,4.237541e-01,1.125154e-10,0.050714,9.741241e-02,8.249030e-02,1.817628e-02,2.177706e-03,3.252451e-01,3.044619e-05,1.663330e-13,4.467344e-18,1.054422e-19,2.076728e-01,1.222073e-01,1.427024e-02,1.674449e-02,6.285933e-02,5.152515e-20,9.212954e-12,4.277795e-16,1.033020e-10
1,1.0,8.737988e-01,0.126201,4.229231e-01,9.195883e-05,4.507838e-01,2.965710e-10,0.126201,1.268132e-01,6.351855e-02,1.653022e-02,2.396070e-03,2.136651e-01,9.195883e-05,1.756972e-12,7.314679e-17,2.194750e-18,2.522672e-01,9.913383e-02,1.562168e-02,1.441621e-02,6.934489e-02,3.457890e-19,2.837364e-11,1.619008e-15,2.681958e-10
2,1.0,8.446445e-01,0.155355,3.855445e-01,7.660373e-05,4.590234e-01,3.763419e-10,0.155355,1.094077e-01,5.906171e-02,1.598611e-02,2.482642e-03,1.986064e-01,7.660373e-05,2.276183e-12,9.874146e-17,3.789239e-18,2.869453e-01,9.479002e-02,1.172109e-02,1.370367e-02,5.186338e-02,7.171663e-19,4.114870e-11,2.292783e-15,3.351909e-10
3,1.0,8.881329e-01,0.111867,3.660498e-01,6.981644e-05,5.220132e-01,3.592396e-10,0.111867,1.031108e-01,5.997543e-02,1.568683e-02,2.140960e-03,1.851357e-01,6.981644e-05,2.672783e-12,8.447400e-17,3.632511e-18,3.195297e-01,9.921215e-02,1.307193e-02,1.559653e-02,7.460287e-02,1.337679e-18,5.530968e-11,3.265118e-15,3.039266e-10
4,1.0,8.558778e-01,0.144122,3.612638e-01,5.637744e-05,4.945576e-01,4.420610e-10,0.144122,1.116091e-01,6.209595e-02,1.557552e-02,1.777979e-03,1.702053e-01,5.637744e-05,2.208926e-12,7.493315e-17,3.421810e-18,3.060817e-01,8.849110e-02,1.087719e-02,1.255346e-02,7.655422e-02,1.490485e-18,5.708938e-11,3.900701e-15,3.849677e-10
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
305,1.0,3.965065e-07,1.000000,9.207915e-09,4.768777e-15,3.872986e-07,1.523322e-13,1.000000,4.817261e-10,3.721579e-12,4.273221e-11,1.528671e-12,8.678207e-09,4.768777e-15,4.749560e-29,7.646734e-34,2.278529e-34,5.305718e-10,3.742687e-07,1.249557e-08,3.548121e-12,1.955965e-13,3.042238e-31,3.790043e-14,2.962446e-21,1.144317e-13
306,1.0,4.633435e-07,1.000000,1.023802e-08,5.826421e-15,4.531055e-07,1.769744e-13,1.000000,5.358534e-10,4.136640e-12,5.156810e-11,1.782748e-12,9.644682e-09,5.826421e-15,6.289811e-29,1.143747e-33,3.858881e-34,6.185162e-10,4.382083e-07,1.427469e-08,3.715898e-12,2.460546e-13,3.695735e-31,4.305775e-14,3.625926e-21,1.339166e-13
307,1.0,4.347304e-07,1.000000,1.152920e-08,7.437763e-15,4.232012e-07,1.736830e-13,1.000000,5.009803e-10,4.956675e-12,6.014322e-11,1.192176e-12,1.096192e-08,7.437763e-15,7.839098e-29,1.278614e-33,5.468802e-34,4.693218e-10,4.046905e-07,1.803646e-08,4.644508e-12,2.875832e-13,4.100376e-31,3.429794e-14,3.441453e-21,1.393851e-13
308,1.0,5.949485e-07,0.999999,4.886248e-09,2.086188e-15,5.900622e-07,1.302466e-13,0.999999,3.122700e-10,1.610330e-12,2.982843e-11,2.191967e-12,4.540347e-09,2.086188e-15,1.908911e-29,9.127715e-34,2.871385e-34,7.003444e-10,5.714445e-07,1.791442e-08,2.667729e-12,2.589461e-13,1.412286e-31,5.160482e-14,2.420184e-21,7.864182e-14


In [15]:
class_probs_df[2].head(2)

Unnamed: 0,Alert,Transient,Variable,SN,Fast,Long,Periodic,AGN,SNIa,SNIb/c,SNIax,SNI91bg,SNII,KN,Dwarf Novae,uLens,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN,Cepheid,RR Lyrae,Delta Scuti,EB
0,1.0,0.949286,0.050714,0.525502,3e-05,0.423754,1.125154e-10,0.050714,0.097412,0.08249,0.018176,0.002178,0.325245,3e-05,1.66333e-13,4.467344e-18,1.0544219999999999e-19,0.207673,0.122207,0.01427,0.016744,0.062859,5.1525149999999995e-20,9.212954e-12,4.277795e-16,1.03302e-10
1,1.0,0.873799,0.126201,0.422923,9.2e-05,0.450784,2.96571e-10,0.126201,0.126813,0.063519,0.01653,0.002396,0.213665,9.2e-05,1.756972e-12,7.314679e-17,2.1947500000000002e-18,0.252267,0.099134,0.015622,0.014416,0.069345,3.4578899999999995e-19,2.837364e-11,1.619008e-15,2.681958e-10


In [16]:
# this output helps visualize the categories that would be classified as anomalous activity
anomaly_labels = class_probs_df[2][['SNIb/c', 'SNIax', 'SNI91bg', 'SNII', 'KN', 'Dwarf Novae', 'uLens', 'M-dwarf Flare', 'SLSN', 'TDE', 'ILOT', 'CART', 'PISN']]
anomaly_labels

Unnamed: 0,SNIb/c,SNIax,SNI91bg,SNII,KN,Dwarf Novae,uLens,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN
0,8.249030e-02,1.817628e-02,2.177706e-03,3.252451e-01,3.044619e-05,1.663330e-13,4.467344e-18,1.054422e-19,2.076728e-01,1.222073e-01,1.427024e-02,1.674449e-02,6.285933e-02
1,6.351855e-02,1.653022e-02,2.396070e-03,2.136651e-01,9.195883e-05,1.756972e-12,7.314679e-17,2.194750e-18,2.522672e-01,9.913383e-02,1.562168e-02,1.441621e-02,6.934489e-02
2,5.906171e-02,1.598611e-02,2.482642e-03,1.986064e-01,7.660373e-05,2.276183e-12,9.874146e-17,3.789239e-18,2.869453e-01,9.479002e-02,1.172109e-02,1.370367e-02,5.186338e-02
3,5.997543e-02,1.568683e-02,2.140960e-03,1.851357e-01,6.981644e-05,2.672783e-12,8.447400e-17,3.632511e-18,3.195297e-01,9.921215e-02,1.307193e-02,1.559653e-02,7.460287e-02
4,6.209595e-02,1.557552e-02,1.777979e-03,1.702053e-01,5.637744e-05,2.208926e-12,7.493315e-17,3.421810e-18,3.060817e-01,8.849110e-02,1.087719e-02,1.255346e-02,7.655422e-02
...,...,...,...,...,...,...,...,...,...,...,...,...,...
305,3.721579e-12,4.273221e-11,1.528671e-12,8.678207e-09,4.768777e-15,4.749560e-29,7.646734e-34,2.278529e-34,5.305718e-10,3.742687e-07,1.249557e-08,3.548121e-12,1.955965e-13
306,4.136640e-12,5.156810e-11,1.782748e-12,9.644682e-09,5.826421e-15,6.289811e-29,1.143747e-33,3.858881e-34,6.185162e-10,4.382083e-07,1.427469e-08,3.715898e-12,2.460546e-13
307,4.956675e-12,6.014322e-11,1.192176e-12,1.096192e-08,7.437763e-15,7.839098e-29,1.278614e-33,5.468802e-34,4.693218e-10,4.046905e-07,1.803646e-08,4.644508e-12,2.875832e-13
308,1.610330e-12,2.982843e-11,2.191967e-12,4.540347e-09,2.086188e-15,1.908911e-29,9.127715e-34,2.871385e-34,7.003444e-10,5.714445e-07,1.791442e-08,2.667729e-12,2.589461e-13


**Anomaly Detection Calculations**

Calculating a binary anomaly vs. non-anomaly prediction from the model output. Note, in the output list, the first index of each tuple corresponds to the specific lightcurve in the set, while the second and third elements in the tuple represent the probability of it being an anomaly vs. not being an anomaly).

In [18]:
results = class_probs_df[2]
anomaly_labels = results[['SNIb/c', 'SNIax', 'SNI91bg', 'KN', 'Dwarf Novae', 'uLens', 'M-dwarf Flare', 'SLSN', 'TDE', 'ILOT', 'CART', 'PISN']]
non_anomaly_labels = results[['SNIa', 'SNII', 'Cepheid', 'RR Lyrae', 'Delta Scuti', 'EB', 'AGN']]
anomaly_preds = [np.sum(lightcurve[1][::]) for lightcurve in anomaly_labels.iterrows()]
non_anomaly_preds = [np.sum(lightcurve[1][::]) for lightcurve in non_anomaly_labels.iterrows()]
anomaly_detections = list(enumerate(zip(anomaly_preds, non_anomaly_preds)))
anomaly_detections[:10]

[(0, (0.52662885, 0.47337115)),
 (1, (0.53332067, 0.4666795)),
 (2, (0.53663045, 0.4633696)),
 (3, (0.5998863, 0.4001137)),
 (4, (0.5740635, 0.42593664)),
 (5, (0.5758239, 0.42417595)),
 (6, (0.5489708, 0.45102912)),
 (7, (0.3848641, 0.6151359)),
 (8, (0.32920232, 0.67079765)),
 (9, (0.33745205, 0.662548))]