In [1]:
import time
import pandas as pd
import numpy as np

from astropy.io import ascii
from astropy.coordinates import SkyCoord
from SPLASH.pipeline import Splash_Pipeline

# Load in data

In [2]:
# Load in 100 bts host grizy observations
bts_hosts = ascii.read('demo_bts_hosts.ecsv')
bts_df = bts_hosts.to_pandas()
bts_df.head()

Unnamed: 0,objID_3pi,raStack_3pi,decStack_3pi,primaryDetection_3pi,gKronMag_3pi,rKronMag_3pi,iKronMag_3pi,zKronMag_3pi,yKronMag_3pi,gKronMagErr_3pi,rKronMagErr_3pi,iKronMagErr_3pi,zKronMagErr_3pi,yKronMagErr_3pi,ps_score_3pi,SN_ra,SN_dec,sn_class,sn_redshift,ZTFID
0,126681757766820931,175.776621,15.567018,1,16.0256,15.2881,14.9533,14.7114,14.5005,0.002978,0.001497,0.001272,0.001197,0.002235,0.11103,175.776542,15.567139,,,
1,148231465149294617,146.51494,33.528361,1,18.2387,17.9746,17.875,17.8532,17.7821,0.004157,0.004871,0.004281,0.014407,0.019814,0.021125,146.514792,33.52825,SN II,0.038,ZTF18aacemcn
2,153731349546771375,134.954717,38.10897,1,18.7478,19.0585,18.6466,18.8301,18.6221,0.006392,0.009475,0.004919,0.017235,0.038667,0.017411,134.954667,38.109056,SN II,0.07247,ZTF18aacnlxz
3,169101206490177628,120.649031,50.922461,1,16.5325,15.6522,15.2551,15.0064,14.7582,0.002335,0.002128,0.000867,0.001623,0.003512,0.075905,120.648958,50.922528,SN Ia,0.05295,ZTF18aadlaxo
4,159661026075384557,102.607536,43.053266,1,,14.2309,13.166,12.9816,12.8047,,0.000836,0.000408,0.001003,0.001687,0.154649,102.607542,43.053222,SN IIn,0.01885,ZTF18aadmssd


In [3]:
def get_angular_separation(ra1, dec1, ra2, dec2, unit1='deg', unit2='deg'):
    """Function to calculate angular separation in arcseconds"""
    coord1 = SkyCoord(ra1, dec1, unit=unit1)
    coord2 = SkyCoord(ra2, dec2, unit=unit2)
    return coord1.separation(coord2).arcsec

def ab_mag_to_flux(AB_mag: np.ndarray) -> np.ndarray:
    """Convert AB magnitude to flux in mJy"""
    return 10**((AB_mag - 8.9) / -2.5) * 1000

ab_magerr_to_ferr = lambda sigma_m, f: np.abs(f * np.log(10) * (sigma_m / 2.5))  # transformation on the error of a magnitude turned into flux

# Add angular separation and grib data
bts_df['angular_separation_arcsec'] = get_angular_separation(bts_df['raStack_3pi'], bts_df['decStack_3pi'], bts_df['SN_ra'], bts_df['SN_dec'], unit1='deg', unit2='deg')
grizy = bts_df[['gKronMag_3pi', 'rKronMag_3pi', 'iKronMag_3pi', 'zKronMag_3pi', 'yKronMag_3pi']].to_numpy().astype(float)
grizy_err = bts_df[['gKronMagErr_3pi', 'rKronMagErr_3pi', 'iKronMagErr_3pi', 'zKronMagErr_3pi', 'yKronMagErr_3pi']].to_numpy().astype(float)
angular_seps = bts_df['angular_separation_arcsec'].to_numpy().astype(float)

# Convert the grizy data to mJy
grizy = ab_mag_to_flux(grizy)
grizy_err = ab_magerr_to_ferr(grizy_err, grizy)

grizy, angular_seps

(array([[1.41175716e+00, 2.78458193e+00, 3.79035598e+00, 4.73630869e+00,
         5.75174998e+00],
        [1.83873863e-01, 2.34509262e-01, 2.57039578e-01, 2.62252719e-01,
         2.80001269e-01],
        [1.15048245e-01, 8.64171621e-02, 1.26287394e-01, 1.06649789e-01,
         1.29169506e-01],
        [8.85115610e-01, 1.99122346e+00, 2.87051619e+00, 3.60944144e+00,
         4.53649044e+00],
        [           nan, 7.37292812e+00, 1.96607463e+01, 2.33002188e+01,
         2.74233180e+01],
        [4.48993444e-01, 5.91725111e-01, 6.50309353e-01, 6.77204757e-01,
         7.15418226e-01],
        [2.98483274e-01, 6.53792599e-01, 1.09415790e+00, 1.42902557e+00,
         1.69855651e+00],
        [6.01893981e-01, 1.23720025e+00, 1.92415476e+00, 2.25029782e+00,
         2.59680895e+00],
        [1.73988282e+00,            nan,            nan,            nan,
                    nan],
        [3.22760218e+00, 5.96815371e+00, 8.96107269e+00, 1.00110585e+01,
         1.09496440e+01],
        [4

In [4]:
nan_mask = ~(np.sum(np.isnan(grizy), axis=1)==0)

# Use the SPLASH pipeline

In [5]:
# Load pipeline object
pipeline = Splash_Pipeline(pipeline_version='weighted_full_band',   # the default version of the pipeline
                           pre_transformed=False,                   # whether the given data is pre-logged and nnormalized
                           within_4sigma=True,                      # whether we only want to classify objects with properties within 4-sigma of the training set
                           nan_thresh_ratio=1.0)                    # to keep this notebook concise, we are allowing the pipeline to imput any num of nans

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [6]:
# Predict the classes. n_resamples is the number of boostraps for getting the median predicted host properties.
start_time = time.time()
classes = pipeline.predict_classes(grizy, angular_seps, grizy_err, n_resamples=50)
duration = time.time() - start_time
print(f'{grizy.shape[0]} classifications produced in {duration} seconds (~{duration / grizy.shape[0]} seconds per classificaiton).')

100 classifications produced in 0.11946892738342285 seconds (~0.0011946892738342286 seconds per classificaiton).


In [7]:
print(f'Number of each class:\n{pd.Series(classes).value_counts()}\n\nThe class labels are 0=Ia 1=Ib/c 2=SLSN 3=IIn 4=II (P/L) -1=Outside train properties 4 sigma.')

Number of each class:
0    100
Name: count, dtype: int64

The class labels are 0=Ia 1=Ib/c 2=SLSN 3=IIn 4=II (P/L) -1=Outside train properties 4 sigma.


In [8]:
# We can also take a look at the predicted host properties of each supernova
props, props_err = pipeline.predict_host_properties(grizy, grizy_err, n_resamples=50, return_normalized=False)  # return_normalized will return the properties normalized with the train mean and std
props, props_err  # in order (log(mass), log(sfr), redshift)

(array([[ 1.10310050e+01, -3.75957005e-01,  9.51032211e-02],
        [ 9.56965151e+00, -7.51535242e-01,  1.00051838e-01],
        [ 8.95170806e+00, -8.72420166e-01,  8.89827766e-02],
        [ 1.11151520e+01, -5.22053711e-01,  1.16926910e-01],
        [ 1.12975540e+01, -3.57286946e-01,  3.38397102e-02],
        [ 1.02009132e+01,  7.46003631e-02,  1.18718800e-01],
        [ 1.08498859e+01, -1.34875948e+00,  1.52605821e-01],
        [ 1.10526220e+01, -1.17702138e+00,  1.45487447e-01],
        [ 1.09303317e+01, -3.37367030e-01,  8.58731277e-02],
        [ 1.11726915e+01, -5.51517356e-01,  7.42392363e-02],
        [ 1.09473116e+01, -1.03067116e+00,  1.44434641e-01],
        [ 1.10929133e+01, -3.47182047e-01,  7.25027968e-02],
        [ 1.08686682e+01, -1.25816889e+00,  1.61185909e-01],
        [ 1.02663476e+01,  5.42050084e-02,  1.69353483e-01],
        [ 1.11624087e+01, -2.82333086e-01,  4.44655782e-02],
        [ 1.10840349e+01, -9.81644131e-01,  1.44244492e-01],
        [ 1.11190457e+01

# OVR Classification

In [9]:
ovr_class_probs = pipeline.predict_probs(grizy, angular_seps, grizy_err, n_resamples=50, ovr=True)
ovr_class_probs

{'Ia': array([0.758191  , 0.76777103, 0.70059497, 0.74500489, 0.59772951,
        0.644007  , 0.75493848, 0.65377798, 0.75587078, 0.73481009,
        0.74691726, 0.75306668, 0.74404515, 0.77135018, 0.64689373,
        0.73530142, 0.74163545, 0.84916874, 0.76095538, 0.78949962,
        0.76822055, 0.73184991, 0.78442419, 0.75540493, 0.78356971,
        0.66003255, 0.758191  , 0.68719983, 0.7595759 , 0.77489034,
        0.75400378, 0.60257279, 0.74500489, 0.81697577, 0.66454669,
        0.71571327, 0.7084723 , 0.76003633, 0.77045905, 0.75726472,
        0.81960167, 0.73481009, 0.62883065, 0.79776273, 0.74548388,
        0.81202061, 0.8429259 , 0.66398404, 0.79407463, 0.69634427,
        0.76911778, 0.77664578, 0.79653887, 0.77045905, 0.74882013,
        0.73774931, 0.71313968, 0.76049616, 0.75165654, 0.74929436,
        0.69259715, 0.69474152, 0.80538613, 0.78949962, 0.68284505,
        0.7655143 , 0.83716345, 0.77533012, 0.70481139, 0.75353553,
        0.76911778, 0.78228341, 0.73036199