In [1]:
# !pip install --user numpy polars tqdm astropy tf_keras tensorflow dataframe_image==0.2.5
# uncomment line above if necessary
import os
import pickle
import numpy as np
import polars as pl
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy import units as u
from astropy.coordinates import SkyCoord
from collections import OrderedDict
from taxonomy import get_classification_labels, get_astrophysical_class, plot_colored_tree
# import tf_keras as keras
from tensorflow import keras
import tensorflow as tf
import itertools
os.environ["TF_USE_LEGACY_KERAS"] = "1"

2025-04-08 12:03:13.336467: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class LSST_Source:

    # List of time series features actually stored in the instance of the class.
    time_series_features = ['MJD', 'BAND', 'PHOTFLAG', 'FLUXCAL', 'FLUXCALERR']

    # List of other features actually stored in the instance of the class.
    other_features = ['RA', 'DEC', 'MWEBV', 'MWEBV_ERR', 'REDSHIFT_HELIO', 'REDSHIFT_HELIO_ERR', 'VPEC', 'VPEC_ERR', 'HOSTGAL_PHOTOZ', 'HOSTGAL_PHOTOZ_ERR', 'HOSTGAL_SPECZ', 'HOSTGAL_SPECZ_ERR', 'HOSTGAL_RA', 'HOSTGAL_DEC', 'HOSTGAL_SNSEP', 'HOSTGAL_DDLR', 'HOSTGAL_LOGMASS', 'HOSTGAL_LOGMASS_ERR', 'HOSTGAL_LOGSFR', 'HOSTGAL_LOGSFR_ERR', 'HOSTGAL_LOGsSFR', 'HOSTGAL_LOGsSFR_ERR', 'HOSTGAL_COLOR', 'HOSTGAL_COLOR_ERR', 'HOSTGAL_ELLIPTICITY', 'HOSTGAL_MAG_u', 'HOSTGAL_MAG_g', 'HOSTGAL_MAG_r', 'HOSTGAL_MAG_i', 'HOSTGAL_MAG_z', 'HOSTGAL_MAG_Y', 'HOSTGAL_MAGERR_u', 'HOSTGAL_MAGERR_g', 'HOSTGAL_MAGERR_r', 'HOSTGAL_MAGERR_i', 'HOSTGAL_MAGERR_z', 'HOSTGAL_MAGERR_Y']

    # Additional features computed based on time_series_features and other_features mentioned in SNANA fits.
    custom_engineered_features = ['MW_plane_flag', 'ELAIS_S1_flag', 'XMM-LSS_flag', 'Extended_Chandra_Deep_Field-South_flag', 'COSMOS_flag']

    # Get the mean wavelengths for each filter and then convert to micro meters
    pb_wavelengths = {
        'u': (320 + 400) / (2 * 1000),
        'g': (400 + 552) / (2 * 1000),
        'r': (552 + 691) / (2 * 1000),
        'i': (691 + 818) / (2 * 1000),
        'z': (818 + 922) / (2 * 1000),
        'Y': (950 + 1080) / (2 * 1000),
    }

    # Pass band to color dict
    colors = OrderedDict({
        'u': 'blue',
        'g': 'green',
        'r': 'red',
        'i': 'teal',
        'z': 'orange',
        'Y': 'purple',
    })

    # 6 broadband filters used in LSST.
    LSST_bands = list(colors.keys())

    # Coordinates for LSST's 4 selected deep drilling fields. (Reference: https://www.lsst.org/scientists/survey-design/ddf)
    LSST_DDF = {
        'ELAIS_S1': SkyCoord(l=311.30 * u.deg, b=-72.90 * u.deg, frame='galactic'),
        'XMM-LSS': SkyCoord(l=171.20 * u.deg, b=-58.77 * u.deg, frame='galactic'),
        'Extended_Chandra_Deep_Field-South': SkyCoord(l=224.07 * u.deg, b=-54.47 * u.deg, frame='galactic'),
        'COSMOS': SkyCoord(l=236.83 * u.deg, b=42.09 * u.deg, frame='galactic'),
    }


    # Threshold values

    # MW_plane_flag is set to one if |self.b| <= b_threshold. Indicative of weather the object is in the galactic plane.
    b_threshold = 15 

    # Flux scaling value
    flux_scaling_const = 1000

    # Radius of the deep drilling field for LSST, in degrees.
    ddf_separation_radius_threshold = 3.5 / 2


    def __init__(self, parquet_row) -> None:
        """Create an LSST_Source object to store both photometric and host galaxy data from the Elasticc simulations.

        Args:
            parquet_row (_type_): A row from the polars data frame that was generated from the Elasticc FITS files using fits_to_parquet.py
            class_label (str): The Elasticc class label for this LSST_Source object.
        """

        # Set all the class attributes
        setattr(self, 'ELASTICC_class', parquet_row['ELASTICC_class'].to_numpy()[0])
        setattr(self, 'SNID', parquet_row['SNID'].to_numpy()[0])
        setattr(self, 'astrophysical_class', get_astrophysical_class(self.ELASTICC_class))

        for key in parquet_row.columns:
            if key in self.other_features:
                setattr(self, key, parquet_row[key].to_numpy()[0])
            elif key in self.time_series_features:
                setattr(self, key, parquet_row[key][0].to_numpy())

        # Run processing code on the light curves
        self.process_lightcurve()

        # Computer additional features
        self.compute_custom_features()

    
    def process_lightcurve(self) -> None:
        """Process the flux information with phot flags. Processing is done using the following steps:
        1. Remove saturations.
        Finally, all the time series data is modified to conform to the steps mentioned above.
        """

        # Remove saturations from the light curves
        saturation_mask =  (self.PHOTFLAG & 1024) == 0 

        # Alter time series data to remove saturations
        for time_series_feature in self.time_series_features:
            setattr(self, time_series_feature, getattr(self, time_series_feature)[saturation_mask])
        
    def compute_custom_features(self) -> None:

        source_coord = SkyCoord(ra = self.RA * u.deg, dec=self.DEC * u.deg)

        # Check if the object is close to the galactic plane of the milky way
        if abs(source_coord.galactic.b.degree) < self.b_threshold: 
            self.MW_plane_flag = 1
        else:
            self.MW_plane_flag = 0
        
        # Check if the object is in one of 4 LSST DDF's and set flags appropriately
        for key in self.LSST_DDF:

            # Separation from field center
            separation = source_coord.separation(self.LSST_DDF[key]).degree

            if separation < self.ddf_separation_radius_threshold:
                setattr(self, f'{key}_flag', 1)
            else:
                setattr(self, f'{key}_flag', 0)

        pass


    def plot_flux_curve(self) -> None:
        """Plot the SNANA calibrated flux vs time plot for all the data in the processed time series. All detections are marked with a star while non detections are marked with dots. Observations are color codded by their passband. This function is fundamentally a visualization tool and is not intended for making plots for papers.
        """

        # Colorize the data
        c = [self.colors[band] for band in self.BAND]
        patches = [mpatches.Patch(color=self.colors[band], label=band, linewidth=1) for band in self.colors]
        fmts = np.where((self.PHOTFLAG & 4096) != 0, '*', '.')

        # Plot flux time series
        for i in range(len(self.MJD)):
            plt.errorbar(x=self.MJD[i], y=self.FLUXCAL[i], yerr=self.FLUXCALERR[i], color=c[i], fmt=fmts[i], markersize = '10')

        # Labels
        plt.title(f"SNID: {self.SNID} | CLASS: {self.ELASTICC_class}")
        plt.xlabel('Time (MJD)')
        plt.ylabel('Calibrated Flux')
        plt.legend(handles=patches)

        plt.show()

    def get_classification_labels(self):
        """Get the classification labels (hierarchical) for this LSST Source object in the Taxonomy tree.

        Returns:
            (tree_nodes, numerical_labels): A tuple containing two list like objects. The first object contains the ordering of the nodes. The second list contains the labels themselves (0 when the object does not belong to the class and 1 when it does). The labels in the second object correspond to the nodes in the first object.
        """
        return get_classification_labels(self.astrophysical_class)
    
    def plot_classification_tree(self):
        """Plot the classification tree (based on our taxonomy) for this LSST Source object.
        """

        node, labels = self.get_classification_labels()
        plot_colored_tree(labels)


    def get_event_table(self):

        # Dataframe for time series data
        table = Table()

        # Find time since last observation
        time_since_first_obs = self.MJD - self.MJD[0]
        table['scaled_time_since_first_obs'] = time_since_first_obs / 100

        # 1 if it was a detection, zero otherwise
        table['detection_flag'] = np.where((self.PHOTFLAG & 4096 != 0), 1, 0)

        # Transform flux cal and flux cal err to more manageable values (more consistent order of magnitude)
        table['scaled_FLUXCAL'] = self.FLUXCAL / self.flux_scaling_const
        table['scaled_FLUXCALERR'] = self.FLUXCALERR / self.flux_scaling_const

        # One hot encoding for the pass band
        table['band_label'] = [self.pb_wavelengths[pb] for pb in self.BAND]

        # Consistency check
        assert len(table) == len(self.MJD), "Length of time series tensor does not match the number of mjd values."

        # Array for static features
        feature_static = OrderedDict()
        for other_feature in self.other_features:
            feature_static[other_feature] = getattr(self, other_feature)

        for feature in self.custom_engineered_features:
            feature_static[feature] = getattr(self, feature)

        # Array for computed static features
        table.meta = feature_static
        return table

    def __str__(self) -> str:

        to_return = str(vars(self))
        return to_return

In [3]:
class LSSTSourceDataSet():


    def __init__(self, path):
        """
        Arguments:
            path (string): Parquet file.
            transform (callable, optional): Optional transform to be applied on a sample.
        """

        print(f'Loading parquet dataset: {path}', flush=True)

        self.path = path
        self.parquet = pl.read_parquet(path)
        self.num_sample = self.parquet.shape[0]

        print(f"Number of sources: {self.num_sample}")

    def get_len(self):

        return self.num_sample

    def get_item(self, idx):
        
        row = self.parquet[idx]
        source = LSST_Source(row)
        table = source.get_event_table()

        astrophysical_class = get_astrophysical_class(source.ELASTICC_class)
        _, class_labels = get_classification_labels(astrophysical_class)
        class_labels = np.array(class_labels)
        snid = source.SNID

        return source, class_labels, snid

    def get_item_from_snid(self, snid):
        row = self.parquet.filter(pl.col('SNID') == snid)
        # row = self.pandas.loc[self.pandas['SNID'] == snid]
        # print(row)
        source = LSST_Source(row)
        table = source.get_event_table()

        astrophysical_class = get_astrophysical_class(source.ELASTICC_class)
        _, class_labels = get_classification_labels(astrophysical_class)
        class_labels = np.array(class_labels)
        snid = source.SNID

        return source, class_labels, snid
    
    def get_dimensions(self):

        idx = 0
        source, class_labels = self.get_item(idx)
        table = source.get_event_table()

        ts_np = table.to_pandas().to_numpy()
        static_np = np.array(list(table.meta.values()))

        dims = {
            'ts': ts_np.shape[1],
            'static': static_np.shape[0],
            'labels': class_labels.shape[0]
        }

        return dims
    
    def get_labels(self):

        ELASTICC_labels = self.parquet['ELASTICC_class']
        astrophysical_labels = []

        for idx in range(self.num_sample):

            elasticc_class = ELASTICC_labels[idx]
            astrophysical_class = get_astrophysical_class(elasticc_class)
            astrophysical_labels.append(astrophysical_class)
        
        return astrophysical_labels

In [4]:
!ls -alh pickles

total 5.9G
drwxrwx---  4 arjun15 arjun15  512 Mar 25 11:26 .
drwx------ 19 arjun15 arjun15 4.0K Apr  8 11:57 ..
-rw-rw----  1 arjun15 arjun15 2.1M Mar 25 11:26 a_labels.pkl
drwxrwx--- 14 arjun15 arjun15 4.0K Feb 16 14:57 augmented
-rw-rw----  1 arjun15 arjun15  14M Mar 25 11:26 e_label.pkl
drwxrwx---  2 arjun15 arjun15  512 Feb 14 22:16 .ipynb_checkpoints
-rw-rw----  1 arjun15 arjun15 2.3M Mar 25 11:26 lengths.pkl
-rw-rw----  1 arjun15 arjun15 1.8K Mar 20 09:41 phase.pkl
-rw-rw----  1 arjun15 arjun15 308M Mar 25 11:28 x_static.pkl
-rw-rw----  1 arjun15 arjun15 5.3G Mar 25 12:03 x_ts.pkl
-rw-rw----  1 arjun15 arjun15 246M Mar 25 11:28 y.pkl


In [5]:
def load(file_name):
    with open(file_name, 'rb') as f:
        return pickle.load(f)

In [None]:
from taxonomy import get_taxonomy_tree
class_count = 10
ts_length = 500
ts_flag_value = 0

model = keras.models.load_model(f"models/lsst_alpha_0.5/best_model.h5", compile=False)

tree = get_taxonomy_tree()
X_ts = load("pickles/x_ts.pkl")
X_static = load("pickles/x_static.pkl")
Y = load("pickles/y.pkl")
astrophysical_classes = load("pickles/a_labels.pkl")

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_confusion_matrix(y_true, y_pred, labels, title=None, img_file=None):
    
    n_class = len(labels)
    font = {'size'   : 25}
    plt.rc('font', **font)
    
    cm = np.round(confusion_matrix(y_true, y_pred, labels=labels, normalize='true'),2)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues)
    disp.im_.colorbar.remove()
    
    fig = disp.figure_
    if n_class > 10:
        plt.xticks(rotation=90)
        plt.yticks(rotation=45)
    
    fig.set_figwidth(18)
    fig.set_figheight(18)
    
    for label in disp.text_.ravel():
        if n_class > 10:
            label.set_fontsize(12)
        elif n_class <= 10 and n_class > 3:
            disp.ax_.tick_params(axis='both', labelsize=40)
            label.set_fontsize('xx-large')
        else:
            disp.ax_.tick_params(axis='both', labelsize=40)
            label.set_fontsize('xx-large')
    
    if title:
        disp.ax_.set_xlabel("Predicted Label", fontsize=60)
        disp.ax_.set_ylabel("True Label", fontsize=60)
        disp.ax_.set_title(title, fontsize=45, wrap=True)
    
    plt.tight_layout()

    if img_file:
        plt.savefig(img_file)

    return cm

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_confusion_matrix_with_peaks(y_true, y_pred, phase, day, labels, title=None, img_file=None):
    n_class = len(labels)
    font = {'size'   : 35}
    plt.rc('font', **font)
    
    cm = np.round(confusion_matrix(y_true, y_pred, labels=labels, normalize='true'),2)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues)
    disp.im_.colorbar.remove()
    
    fig = disp.figure_
    if n_class > 10:
        plt.xticks(rotation=90)
        plt.yticks(rotation=45)
    
    fig.set_figwidth(18)
    fig.set_figheight(18)
    
    for label in disp.text_.ravel():
        if n_class > 10:
            label.set_fontsize(12)
        elif n_class <= 10 and n_class > 3:
            disp.ax_.tick_params(axis='both', labelsize=40)
            label.set_fontsize('xx-large')
        else:
            disp.ax_.tick_params(axis='both', labelsize=40)
            label.set_fontsize('xx-large')
    
    if title:
        disp.ax_.set_xlabel("Predicted Label", fontsize=60)
        disp.ax_.set_ylabel("True Label", fontsize=60)
        disp.ax_.set_title(title, fontsize=45, wrap=True)

    handles = [plt.Line2D([0], [0], color='white', marker='o', linestyle='', label=f'{np.unique(astrophysical_classes)[i]} = {phase[day, i]}') for i in range(19)]
    plt.legend(handles=handles, title='Days Since Peak MJD', loc='upper left', bbox_to_anchor=(1.05, 1), fontsize='xx-small')

    plt.tight_layout()

    if img_file:
        plt.savefig(img_file)

    return cm

In [None]:
from taxonomy import source_node_label, get_taxonomy_tree
from interpret_results import save_leaf_cf_and_rocs

def determine_anomalies(class_probs, tree, purity_threshold=0.7):
    level_order_nodes = list(nx.bfs_tree(tree, source=source_node_label).nodes())
    leaf_nodes = level_order_nodes[-19:]
    
    indiv_probs = [np.squeeze(class_probs[index][0]) for index in range(len(class_probs))]
    true_classes = [class_probs[i][1] for i in range(len(class_probs))]

    true_class_ids = [np.argmax(true_classes[index][-19:]) for index in range(len(true_classes))]
    true_class_names = []
    for class_idx in tqdm(true_class_ids, desc='Calculating Labels: '):
        true_class_names.append(leaf_nodes[class_idx])

    pd.set_option('display.max_columns', None)
    pd.set_option('display.max_rows', None)
    
    class_df = pd.DataFrame(true_class_names, columns=['True Class:'])
    class_probs_df = pd.DataFrame(indiv_probs)
    df = pd.concat([class_df, class_probs_df], axis=1, ignore_index=True)
    columns = level_order_nodes.copy()
    columns.insert(0, 'True Class:')
    df.columns = columns
    display(df)

    results = df
    anomaly_labels = results[['KN', 'uLens', 'SLSN', 'PISN', 'TDE', 'CART', 'ILOT']]
    non_anomaly_labels = results[['Cepheid', 'RR Lyrae', 'Delta Scuti', 'EB', 'SNIa', 'SNIax', 'SNIb/c', 'SNI91bg', 'SNII', 'M-dwarf Flare', 'Dwarf Novae', 'AGN']]
    anomaly_preds = [np.sum(lightcurve[1][::]) for lightcurve in anomaly_labels.iterrows()]
    non_anomaly_preds = [np.sum(lightcurve[1][::]) for lightcurve in non_anomaly_labels.iterrows()]
    anomaly_detections = list(enumerate(zip(anomaly_preds, non_anomaly_preds)))

    leaf_labels = np.array(level_order_nodes)[-19:]
    pred_labels = [leaf_labels[i] for i in np.argmax(np.array(indiv_probs)[:, -19:], axis=1)]

    anomaly = pd.DataFrame({'Anomaly Probability': [anomaly_detections[row][1][0] for row in range(len(anomaly_detections))]})
    not_anomaly = pd.DataFrame({'Not Anomaly Probability': [anomaly_detections[row][1][1] for row in range(len(anomaly_detections))]})
    
    anomaly.reset_index(drop=True, inplace=True)
    not_anomaly.reset_index(drop=True, inplace=True)
    preds_df = pd.concat([class_df, anomaly, not_anomaly.set_axis(anomaly.index)], axis=1)
    display(preds_df)

    labels = [class_probs[i][1][-19:] for i in range(len(class_probs))]
    anomaly_map = [
        0.0, # AGN - not anomaly
        0.0, # SNIa - not anomaly
        0.0, # SNIb/c - not anomaly
        0.0, # SNIax - not anomaly
        1.0, # SNI91bg - anomaly
        0.0, # SNII - not anomaly
        1.0, # KN - anomaly
        0.0, # Dwarf Novae - not anomaly
        1.0, # uLens - anomaly
        0.0, # M-dwarf Flare - not anomaly
        1.0, # SLSN - anomaly
        1.0, # TDE - anomaly
        1.0, # ILOT - anomaly
        1.0, # CART - anomaly
        1.0, # PISN - anomaly
        0.0, # Cepheid - not anomaly
        0.0, # RR Lyrae - not anomaly
        0.0, # Delta Scuti - not anomaly
        0.0  # EB - not anomaly
    ]
    
    y_true = [[anomaly_map[np.argmax(label)], 1.0 - anomaly_map[np.argmax(label)]] for label in labels]
    anomaly = np.array([anomaly_detections[row][1][0] for row in range(len(anomaly_detections))])
    non_anomaly = np.array([anomaly_detections[row][1][1] for row in range(len(anomaly_detections))])
    y_pred = np.stack((anomaly, non_anomaly), axis=1)
    print(y_true)

    true = [np.argmax(y_true[i]) for i in range(len(y_true))]
    pred = [0 if element > purity_threshold else 1 for element in np.transpose(y_pred)[0]]

    return true, pred

In [None]:
def augment_ts_length_to_days_since_trigger(X_ts, X_static, Y, a_classes, days):

    # Augment the length of the ts data
    X_ts = get_ts_upto_days_since_trigger(X_ts, days=days)

    # Squeeze data into homogeneously shaped numpy arrays
    X_ts = np.squeeze(X_ts)
    X_static = np.squeeze(X_static)
    Y = np.squeeze(Y).astype(np.float32)
    astrophysical_classes = np.squeeze(a_classes)

    return X_ts, X_static, Y, astrophysical_classes

def get_ts_upto_days_since_trigger(X_ts, days, add_padding=True):

    augmented_list = []

    # Loop through all the data
    for ind in tqdm(range(len(X_ts)), desc ="TS Augmentation: "):

        times = X_ts[ind]['scaled_time_since_first_obs'].to_numpy()

        # Get the idx of the first detection
        first_detection_idx = np.where(X_ts[ind]['detection_flag'].to_numpy() == 1)[0][0]
        first_detection_t = times[first_detection_idx]

        if len(np.where((times - first_detection_t) * 100 <= days)[0]) == 0:
            augmented_list.append(np.zeros_like(X_ts[ind].to_numpy()))
        else:
            # Get the index of the the last observation between the mjd(first detection) and  mjd(first detection)
            last_observation_idx = np.where((times - first_detection_t) * 100 <= days)[0][-1]
            
            # Slice the data appropriately, Keep the first new_length number of observations and all columns
            augmented_list.append(X_ts[ind].to_numpy()[:(last_observation_idx + 1), :])

    # Optionally - Pad for TF masking layer
    if add_padding:
        augmented_list = pad_sequences(augmented_list, maxlen=ts_length,  dtype='float32', padding='post', value=ts_flag_value)

    return augmented_list


In [None]:
from dataloader import augment_ts_length_to_days_since_trigger, get_ts_upto_days_since_trigger, get_augmented_data
from interpret_results import save_all_cf_and_rocs, save_leaf_cf_and_rocs, get_conditional_probabilites
import networkx as nx
import numpy as np
import pandas as pd
import gc

days = 2 ** np.array(range(11))
default_batch_size = 1024

def save(save_path , obj):
    with open(save_path, 'wb') as f:
        pickle.dump(obj, f)

def run_day_wise_anomaly_detection_analysis(model, tree, model_dir, X_ts, X_static, Y, astrophysical_classes, purity_threshold=0.7, path="plots/daywise/"):
    all_predictions = []
    all_trues = []

    precision_values = []

    for i, d in enumerate(days):
        print(f'Running inference for trigger + {d} days...')

        gc.collect()

        # print('Augmenting Data...')
        # x1, x2, y_true, _ = augment_ts_length_to_days_since_trigger(X_ts, X_static, Y, astrophysical_classes, d)
        # save(f"pickles/augmented/day{d}/y.pkl", y_true)

        # x1 = load(f"pickles/augmented/day{d}/x1.pkl")[:2]
        # x2 = load(f"pickles/augmented/day{d}/x2.pkl")[:2]
        # y_true = load(f"pickles/augmented/day{d}/y.pkl")
        # print('Loaded!')

        if not os.path.exists(f'pickles/augmented/day{d}/pred.pkl'):
            print('Passing through model...')
            y_pred = model.predict([x1, x2], batch_size=default_batch_size)
            cm = save_leaf_cf_and_rocs(y_true, y_pred, tree, path, "test")
            save(f"pickles/augmented/day{d}/pred.pkl", y_pred)
            print('Saved to disk!')
        else:
            print(f'Loading y_pred values from disk for day {d}')
            y_pred = load(f'pickles/augmented/day{d}/pred.pkl')
            y_true = load(f'pickles/augmented/day{d}/y.pkl')
            cm = save_leaf_cf_and_rocs(y_true, y_pred, tree, path, "test")
            print('Loaded!')

        _, pseudo_conditional_probabilities = get_conditional_probabilites(y_pred, tree)
        probs_with_labels = list(zip(pseudo_conditional_probabilities, y_true))

        print('Determining anomalies...')
        true, pred = determine_anomalies(probs_with_labels, tree, purity_threshold)

        print(f'For trigger + {d} days, these are the statistics:')

        plot_title = f"Trigger + {d} days, {purity_threshold * 100}% confidence threshold"
        
        # Print all the stats and make plots...
        labels = ['Anomaly', 'Not Anomaly']
        for index, (label, prediction) in tqdm(enumerate(zip(true, pred))):
            true[index] = labels[label]
            pred[index] = labels[prediction]

        current_cm = plot_confusion_matrix(true, pred, labels, title=plot_title, img_file=os.path.join(path, f'AD_{d}.png'))
        # current_cm = plot_confusion_matrix_with_peaks(true, pred, phase, i, labels, title=plot_title, img_file=os.path.join(path, f'AD_{d}.png'))

        # calculates the precision at the specific day and adds to data set
        precision_values.append(current_cm[0][0] / (current_cm[0][0] + current_cm[1][0]))
        
    return precision_values

def make_gif(files, gif_file=None):

    # Load the images
    images = []
    for filename in files:
        images.append(imageio.imread(filename))

    # Create the figure and axes
    fig, ax = plt.subplots(figsize=(18, 18))

    # Create the animation
    def animate(i):
        ax.clear()
        ax.axis('off')
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.imshow(images[i])

    fig.tight_layout()

    anim = animation.FuncAnimation(fig, animate, frames=len(images), interval=500)

    if gif_file:
        # Save the animation as a GIF
        anim.save(gif_file)

In [None]:
model = keras.models.load_model(f"models/lsst_alpha_0.5/best_model.h5", compile=False)
tree = get_taxonomy_tree()

precision = run_day_wise_anomaly_detection_analysis(model, tree, None, X_ts, X_static, Y, astrophysical_classes, purity_threshold=0.7, path='plots/daywise/testing/')
print(precision)

Running inference for trigger + 1 days...
Loading y_pred values from disk for day 1


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


               precision    recall  f1-score   support

          AGN       0.43      0.37      0.40     76258
         CART       0.00      0.00      0.00      8207
      Cepheid       0.00      0.00      0.00     13771
  Delta Scuti       0.00      0.00      0.00     20650
  Dwarf Novae       0.00      0.00      0.00      8025
           EB       0.88      0.00      0.00     66454
         ILOT       0.00      0.00      0.00      7461
           KN       0.01      1.00      0.01      4426
M-dwarf Flare       0.00      0.00      0.00      1859
         PISN       0.00      0.00      0.00     63586
     RR Lyrae       0.00      0.00      0.00     14033
         SLSN       0.02      0.03      0.03     66088
      SNI91bg       0.00      0.00      0.00     28637
         SNII       0.00      0.00      0.00    301544
         SNIa       0.00      0.00      0.00    120739
        SNIax       0.00      0.00      0.00     28030
       SNIb/c       0.00      0.00      0.00    168254
         

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Loaded!
Determining anomalies...


Calculating Labels: 100%|██████████| 1081614/1081614 [00:00<00:00, 4911051.20it/s]


In [10]:
import networkx as nx
from taxonomy import get_taxonomy_tree, source_node_label

tree = get_taxonomy_tree()

level_order_nodes = list(nx.bfs_tree(tree, source=source_node_label).nodes())
print(level_order_nodes[-19:])

['AGN', 'SNIa', 'SNIb/c', 'SNIax', 'SNI91bg', 'SNII', 'KN', 'Dwarf Novae', 'uLens', 'M-dwarf Flare', 'SLSN', 'TDE', 'ILOT', 'CART', 'PISN', 'Cepheid', 'RR Lyrae', 'Delta Scuti', 'EB']
