In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../src')))

import resspect
from resspect.tom_client import TomClient
from oracle_resspect_classifier.elasticc2_oracle_feature_extractor import ELAsTiCC2_ORACLEFeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


Using cuda device


In [None]:
import itertools

# code to grab objects from TOM for testing - not the most optimal way to do this but a good check for whether the feature extractor is doing its job
username = 'arjun15'
passwordfile = '../../oracle/passwordfile'

detected_in_last_days = 1
mjd_now = 60800
num_objects = 5


tom = TomClient(url="https://desc-tom-2.lbl.gov", username=username, passwordfile=passwordfile)

dic = {
    'detected_in_last_days': detected_in_last_days,
    'mjd_now': mjd_now
}

res = tom.post('elasticc2/gethottransients', json=dic)
data = res.json() if res.status_code == 200 else {}
print('=> Fetched hot transients')

ids = [obj['objectid'] for obj in data['diaobject']]
ids = ids[:num_objects] if len(ids) > num_objects else ids

# using these object ids, load in static and time series data for ORACLE
static = tom.post('db/runsqlquery/',
                json={'query': '''SELECT diaobject_id, ra, decl, mwebv, mwebv_err, z_final, z_final_err, hostgal_zphot, hostgal_zphot_err,
                hostgal_zspec, hostgal_zspec_err, hostgal_ra, hostgal_dec, hostgal_snsep, hostgal_ellipticity, hostgal_mag_u,
                hostgal_mag_g, hostgal_mag_r, hostgal_mag_i, hostgal_mag_z, hostgal_mag_y FROM elasticc2_ppdbdiaobject WHERE diaobject_id IN (%s) ORDER BY diaobject_id;''' % (', '.join(str(id) for id in ids)),
                    'subdict': {}})
static_data = static.json() if static.status_code == 200 else {}
print('=> Loaded static data...')

ts = tom.post('db/runsqlquery/',
                json={'query': 'SELECT diaobject_id, midpointtai, filtername, psflux, psfluxerr FROM elasticc2_ppdbdiaforcedsource WHERE diaobject_id IN (%s) ORDER BY diaobject_id;' % (', '.join(str(id) for id in ids)),
                    'subdict': {}})
ts_data = ts.json() if ts.status_code == 200 else {}
print('=> Loaded time-series data...') 

assert ts_data['status'] == 'ok', 'Failed to retrieve data from TOM!'

# for each object, sort all observations by MJD
ts_data['rows'].sort(key=lambda obs: obs['diaobject_id'])
grouped_ts_data = {snid: list(obj) for snid, obj in itertools.groupby(ts_data['rows'], key=lambda obs: obs['diaobject_id'])}

for observation in grouped_ts_data.values():
    observation.sort(key=lambda obs: obs['midpointtai'])

print(ts_data)
print(static_data)

In [2]:
import polars as pl

nersc_parquet_files = '/global/cfs/cdirs/desc-td/ELASTICC2_TRAIN02_parquet'
parquet_example = os.path.join(nersc_parquet_files, 'SNIa-SALT3.parquet')
parquet = pl.read_parquet(parquet_example)

In [3]:
parquet[0:5]

SNID,MJD,BAND,PHOTFLAG,PHOTPROB,FLUXCAL,FLUXCALERR,PSF_SIG1,SKY_SIG,RDNOISE,ZEROPT,ZEROPT_ERR,GAIN,SIM_MAGOBS,RA,DEC,SNTYPE,NOBS,PTROBS_MIN,PTROBS_MAX,MWEBV,MWEBV_ERR,REDSHIFT_HELIO,REDSHIFT_HELIO_ERR,REDSHIFT_FINAL,REDSHIFT_FINAL_ERR,VPEC,VPEC_ERR,HOSTGAL_NMATCH,HOSTGAL_NMATCH2,HOSTGAL_OBJID,HOSTGAL_FLAG,HOSTGAL_PHOTOZ,HOSTGAL_PHOTOZ_ERR,HOSTGAL_SPECZ,HOSTGAL_SPECZ_ERR,HOSTGAL_RA,…,SIM_HOSTLIB(LOG_SFR),SIM_DLMU,SIM_LENSDMU,SIM_RA,SIM_DEC,SIM_MWEBV,SIM_PEAKMJD,SIM_MJD_EXPLODE,SIM_MAGSMEAR_COH,SIM_AV,SIM_RV,SIM_SALT2x0,SIM_SALT2x1,SIM_SALT2c,SIM_SALT2mB,SIM_SALT2alpha,SIM_SALT2beta,SIM_SALT2gammaDM,SIM_PEAKMAG_u,SIM_PEAKMAG_g,SIM_PEAKMAG_r,SIM_PEAKMAG_i,SIM_PEAKMAG_z,SIM_PEAKMAG_Y,SIM_EXPOSURE_u,SIM_EXPOSURE_g,SIM_EXPOSURE_r,SIM_EXPOSURE_i,SIM_EXPOSURE_z,SIM_EXPOSURE_Y,SIM_GALFRAC_u,SIM_GALFRAC_g,SIM_GALFRAC_r,SIM_GALFRAC_i,SIM_GALFRAC_z,SIM_GALFRAC_Y,SIM_SUBSAMPLE_INDEX
i64,list[f64],list[str],list[i32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],f64,f64,i32,i32,i32,i32,f32,f32,f32,f32,f32,f32,f32,f32,i16,i16,i64,i16,f32,f32,f32,f32,f64,…,f32,f32,f32,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i16
73,"[60423.4061, 60423.4145, … 60989.0133]","[""Y"", ""Y"", … ""z""]","[0, 0, … 0]","[-9.0, -9.0, … -9.0]","[2.781537, -42.948391, … -5.324702]","[41.806305, 39.684013, … 17.21748]","[2.69, 2.41, … 1.86]","[46.860001, 50.029999, … 62.650002]","[0.25, 0.25, … 0.25]","[30.09, 30.1, … 30.99]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[99.0, 99.0, … 29.884895]",304.243393,-4.291256,110,228,482437,482664,0.130539,0.006527,0.655564,0.65532,0.654122,0.65532,0.0,300.0,1,1,10500430702,0,0.655564,0.65532,-9.0,-9.0,304.243466,…,0.5699,42.48978,0.002786,304.243378,-4.291256,0.126506,60523.589844,-9.0,-0.073794,-9.0,-9.0,1.1e-05,-0.55056,-0.060676,23.013763,0.14,3.1,0.0,28.512768,24.174513,22.983221,22.947306,22.942657,23.064501,1.0,1.0,1.0,1.0,1.0,1.0,66.754341,3.128764,2.976415,5.285044,7.844293,10.192559,-9
352,"[60484.3786, 60484.3831, … 61039.0972]","[""Y"", ""Y"", … ""z""]","[0, 0, … 0]","[-9.0, -9.0, … -9.0]","[70.780701, -60.739841, … 1.772054]","[47.873493, 39.694359, … 12.591577]","[2.94, 2.48, … 1.47]","[48.77, 48.439999, … 54.07]","[0.25, 0.25, … 0.25]","[30.09, 30.110001, … 31.0]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[99.0, 99.0, … 31.830404]",12.713786,-39.216479,110,230,401267,401496,0.008067,0.000403,0.434414,0.42573,0.433314,0.42573,0.0,300.0,1,1,10562646832,0,0.434414,0.42573,-9.0,-9.0,12.713772,…,-2.0552,41.190102,-0.003643,12.713786,-39.21648,0.007851,60588.613281,-9.0,-0.108876,-9.0,-9.0,1.5e-05,-2.219902,0.181567,22.698746,0.14,3.1,0.0,26.361338,23.16703,22.45286,22.381245,22.364374,22.731253,1.0,1.0,1.0,1.0,1.0,1.0,6.782129,1.757226,3.554987,5.989055,7.657897,12.969928,-9
406,"[60299.138, 60303.1139, … 60717.048]","[""Y"", ""Y"", … ""z""]","[0, 0, … 0]","[-9.0, -9.0, … -9.0]","[104.413422, 77.31572, … -10.37617]","[21.652668, 23.016081, … 13.606386]","[1.44, 1.47, … 1.51]","[43.52, 46.77, … 59.810001]","[0.25, 0.25, … 0.25]","[30.120001, 30.139999, … 31.01]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[22.978334, 22.877457, … 28.925873]",40.319748,-33.143045,110,114,123550,123663,0.019233,0.000962,0.54048,0.53686,0.539642,0.53686,0.0,300.0,1,1,6125373823,0,0.54048,0.53686,-9.0,-9.0,40.319858,…,0.5091,42.473682,-0.01015,40.319748,-33.143044,0.019505,60307.359375,-9.0,-0.002364,-9.0,-9.0,1.2e-05,1.100645,-0.014268,22.910362,0.14,3.1,0.0,28.332739,23.913773,22.686335,22.658911,22.644289,22.828428,1.0,1.0,1.0,1.0,1.0,1.0,180.994919,5.168973,4.178105,6.507644,8.459147,11.505424,-9
695,"[60721.3707, 60769.3947, … 61252.1304]","[""Y"", ""Y"", … ""z""]","[0, 0, … 0]","[-9.0, -9.0, … -9.0]","[-22.408533, 7.243584, … -31.153187]","[36.396923, 23.026184, … 21.761795]","[2.35, 1.5, … 2.21]","[47.07, 42.310001, … 64.75]","[0.25, 0.25, … 0.25]","[30.129999, 30.1, … 30.969999]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[99.0, 99.0, … 30.792587]",228.555405,-28.468939,10,101,251751,251851,0.202867,0.010143,0.299424,0.001,0.300201,0.001,0.0,300.0,1,1,10437584688,0,0.210101,0.22199,0.299424,0.001,228.555363,…,0.2999,40.959812,0.001004,228.555405,-28.468939,0.219303,60852.234375,-9.0,-0.089124,-9.0,-9.0,2.3e-05,-1.414116,0.141963,22.232874,0.14,3.1,0.0,26.395741,23.286243,22.595175,22.406416,22.33827,22.800871,1.0,1.0,1.0,1.0,1.0,1.0,40.124321,4.346725,5.491041,6.867548,8.962564,17.037214,-9
871,"[60948.3459, 60951.3069, … 61106.1365]","[""Y"", ""Y"", … ""z""]","[0, 0, … 0]","[-9.0, -9.0, … -9.0]","[41.522675, 0.242583, … 26.881418]","[18.41955, 21.900927, … 21.948093]","[1.39, 1.59, … 2.5]","[40.700001, 42.630001, … 58.279999]","[0.25, 0.25, … 0.25]","[30.120001, 30.120001, … 30.940001]","[0.005, 0.005, … 0.005]","[1.0, 1.0, … 1.0]","[99.0, 99.0, … 24.059935]",75.94169,-49.172649,110,146,468016,468161,0.009739,0.000487,0.784721,0.78404,0.784873,0.78404,0.0,300.0,1,1,8876168928,0,0.784721,0.78404,-9.0,-9.0,75.941553,…,0.5071,43.506626,-0.030031,75.941689,-49.172649,0.009449,61079.53125,-9.0,-0.052141,-9.0,-9.0,6e-06,0.625115,-0.089637,23.776234,0.14,3.1,0.0,28.718786,25.386219,23.753199,23.154127,23.349894,23.381153,1.0,1.0,1.0,1.0,1.0,1.0,155.968948,7.543713,2.460058,2.868721,4.2371,4.980576,-9


In [4]:
# example_input = parquet[0:5]
example_input = parquet[0]
additional_features = ELAsTiCC2_ORACLEFeatureExtractor._get_static_features()

def get_phot_from_parquet(parquet_rows):
    class_to_sncode = {
        'SNIa': 10, 'SNIb/c': 25, 'SNII': 37, 'SNIax': 12, 'SN91bg': 11, 'KN': 50, 'M-dwarf Flare': 82, 'Dwarf Novae': 84, 'uLens': 88, 
        'SLSN': 40, 'TDE': 42, 'ILOT': 45, 'CART': 46, 'PISN': 59, 'Cepheid': 90, 'RR Lyrae': 80, 'Delta Scuti': 91, 'EB': 83, 'AGN': 60,
        'SNII': 32, 'SNII': 31, 'SNII': 35, 'SNII': 36, 'SNIb/c': 21, 'SNIb/c': 20, 'SLSN': 72, 'SNIb/c': 27, 'SNIb/c': 26
    }
    
    data = []
    for idx, obj in enumerate(parquet_rows.iter_rows(named=True)):        
        phot_d = {}
        phot_d['objectid'] = int(obj['SNID'])
        phot_d['sncode'] = obj['SNTYPE']
        # phot_d['sncode'] = class_to_sncode[obj['ELASTICC_class']]
        phot_d['redshift'] = obj['REDSHIFT_FINAL']
        phot_d['RA'] = obj['RA']
        phot_d['DEC'] = obj['DEC']
        
        phot_d['photometry'] = {}
        phot_d['photometry']['BAND'] = obj['BAND']
        phot_d['photometry']['MJD'] = obj['MJD']
        phot_d['photometry']['FLUXCAL'] = obj['FLUXCAL']
        phot_d['photometry']['FLUXCALERR'] = obj['FLUXCALERR']
        phot_d['photometry']['PHOTFLAG'] = obj['PHOTFLAG']
        
        phot_d['additional_info'] = {}
        
        for feature in additional_features:
            phot_d[feature] = obj[feature]
        
        data.append(phot_d)
        
    return data

data_dic = get_phot_from_parquet(example_input)

In [5]:
# testing out the feature extractor
from resspect.fit_lightcurves import fit, fit_TOM
feature_extraction_method = 'oracle_resspect_classifier.elasticc2_oracle_feature_extractor.ELAsTiCC2_ORACLEFeatureExtractor'
fit(
    data_dic,
    output_features_file = '../intermediate_TOM_training_features.parquet',
    feature_extractor = feature_extraction_method,
    filters = 'LSST',
    additional_info = additional_features,
    # one_code = gentypes
)

INFO:root:Starting oracle_resspect_classifier.elasticc2_oracle_feature_extractor.ELAsTiCC2_ORACLEFeatureExtractor fit...
INFO:root:Features have been saved to: ../intermediate_TOM_training_features.parquet


In [6]:
# viewing the output
import numpy as np
import pandas as pd

pd.set_option('display.max_columns', 40)

data = pd.read_parquet('../intermediate_TOM_training_features.parquet')
# data = pd.read_csv('../TOM_training_features.csv', index_col=0)
print(len(data['MJD'][0]))
data['orig_sample'] = 'train'
data['type'] = np.where((data['sncode'] == 10) | (data['sncode'] == 110), 'Ia', 'non-Ia')
data.to_parquet('../final_TOM_training_features.parquet', index=False)

data

228


Unnamed: 0,diaobject_id,redshift,type,sncode,MJD,FLUXCAL,FLUXCALERR,BAND,PHOTFLAG,RA,DEC,MWEBV,MWEBV_ERR,REDSHIFT_FINAL,REDSHIFT_FINAL_ERR,HOSTGAL_PHOTOZ,HOSTGAL_PHOTOZ_ERR,HOSTGAL_SPECZ,HOSTGAL_SPECZ_ERR,HOSTGAL_RA,HOSTGAL_DEC,HOSTGAL_SNSEP,HOSTGAL_ELLIPTICITY,HOSTGAL_MAG_u,HOSTGAL_MAG_g,HOSTGAL_MAG_r,HOSTGAL_MAG_i,HOSTGAL_MAG_z,HOSTGAL_MAG_Y,orig_sample
0,73,0.654122,Ia,110,"[60423.4061, 60423.4145, 60428.4206, 60448.362...","[2.7815372943878174, -42.94839096069336, -14.2...","[41.806304931640625, 39.68401336669922, 36.072...","[Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",304.243393,-4.291256,0.130539,0.006527,0.654122,0.65532,0.655564,0.65532,-9.0,-9.0,304.243466,-4.291261,0.262877,0.3411,23.957273,22.941786,21.804691,21.145388,20.711977,20.549498,train


In [None]:
# basic classifier testing before trying the full active learning loop
from oracle_resspect_classifier.oracle_classifier import OracleClassifier

classifier_test = OracleClassifier(dir='../', weights_dir='/pscratch/sd/a/arjun15/')
classifier_test.predict(data)

In [None]:
# lightcurve plots 
from oracle.pretrained.ELAsTiCC import time_dependent_feature_list
import torch
import matplotlib.pyplot as plt

x_ts = pd.DataFrame(data[time_dependent_feature_list].iloc[0]).T

img = ELAsTiCC2_ORACLEFeatureExtractor._plot_sample_lc(x_ts)
if isinstance(img, torch.Tensor):
    img_np = img.permute(1, 2, 0).numpy().astype(int)
else:
    img_np = np.asarray(img)

plt.figure(figsize=(4, 4))
plt.imshow(img_np)
plt.axis('off')
plt.show()