In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../src')))

import resspect
from resspect.tom_client import TomClient
from oracle_resspect_classifier.elasticc2_oracle_feature_extractor import ELAsTiCC2_ORACLEFeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


Using cuda device


In [None]:
import itertools

# code to grab objects from TOM for testing - not the most optimal way to do this but a good check for whether the feature extractor is doing its job
username = 'arjun15'
passwordfile = '../../oracle/passwordfile'

detected_in_last_days = 1
mjd_now = 60800
num_objects = 5


tom = TomClient(url="https://desc-tom-2.lbl.gov", username=username, passwordfile=passwordfile)

dic = {
    'detected_in_last_days': detected_in_last_days,
    'mjd_now': mjd_now
}

res = tom.post('elasticc2/gethottransients', json=dic)
data = res.json() if res.status_code == 200 else {}
print('=> Fetched hot transients')

ids = [obj['objectid'] for obj in data['diaobject']]
ids = ids[:num_objects] if len(ids) > num_objects else ids

# using these object ids, load in static and time series data for ORACLE
static = tom.post('db/runsqlquery/',
                json={'query': '''SELECT diaobject_id, ra, decl, mwebv, mwebv_err, z_final, z_final_err, hostgal_zphot, hostgal_zphot_err,
                hostgal_zspec, hostgal_zspec_err, hostgal_ra, hostgal_dec, hostgal_snsep, hostgal_ellipticity, hostgal_mag_u,
                hostgal_mag_g, hostgal_mag_r, hostgal_mag_i, hostgal_mag_z, hostgal_mag_y FROM elasticc2_ppdbdiaobject WHERE diaobject_id IN (%s) ORDER BY diaobject_id;''' % (', '.join(str(id) for id in ids)),
                    'subdict': {}})
static_data = static.json() if static.status_code == 200 else {}
print('=> Loaded static data...')

ts = tom.post('db/runsqlquery/',
                json={'query': 'SELECT diaobject_id, midpointtai, filtername, psflux, psfluxerr FROM elasticc2_ppdbdiaforcedsource WHERE diaobject_id IN (%s) ORDER BY diaobject_id;' % (', '.join(str(id) for id in ids)),
                    'subdict': {}})
ts_data = ts.json() if ts.status_code == 200 else {}
print('=> Loaded time-series data...') 

assert ts_data['status'] == 'ok', 'Failed to retrieve data from TOM!'

# for each object, sort all observations by MJD
ts_data['rows'].sort(key=lambda obs: obs['diaobject_id'])
grouped_ts_data = {snid: list(obj) for snid, obj in itertools.groupby(ts_data['rows'], key=lambda obs: obs['diaobject_id'])}

for observation in grouped_ts_data.values():
    observation.sort(key=lambda obs: obs['midpointtai'])

print(ts_data)
print(static_data)

In [2]:
import polars as pl

nersc_parquet_files = '/global/cfs/cdirs/desc-td/ELASTICC2_TRAIN02_parquet'
parquet_example = os.path.join(nersc_parquet_files, 'SNIa-SALT3.parquet')
parquet = pl.read_parquet(parquet_example)

In [3]:
parquet.columns

['SNID',
 'MJD',
 'BAND',
 'PHOTFLAG',
 'PHOTPROB',
 'FLUXCAL',
 'FLUXCALERR',
 'PSF_SIG1',
 'SKY_SIG',
 'RDNOISE',
 'ZEROPT',
 'ZEROPT_ERR',
 'GAIN',
 'SIM_MAGOBS',
 'RA',
 'DEC',
 'SNTYPE',
 'NOBS',
 'PTROBS_MIN',
 'PTROBS_MAX',
 'MWEBV',
 'MWEBV_ERR',
 'REDSHIFT_HELIO',
 'REDSHIFT_HELIO_ERR',
 'REDSHIFT_FINAL',
 'REDSHIFT_FINAL_ERR',
 'VPEC',
 'VPEC_ERR',
 'HOSTGAL_NMATCH',
 'HOSTGAL_NMATCH2',
 'HOSTGAL_OBJID',
 'HOSTGAL_FLAG',
 'HOSTGAL_PHOTOZ',
 'HOSTGAL_PHOTOZ_ERR',
 'HOSTGAL_SPECZ',
 'HOSTGAL_SPECZ_ERR',
 'HOSTGAL_RA',
 'HOSTGAL_DEC',
 'HOSTGAL_SNSEP',
 'HOSTGAL_DDLR',
 'HOSTGAL_CONFUSION',
 'HOSTGAL_LOGMASS',
 'HOSTGAL_LOGMASS_ERR',
 'HOSTGAL_LOGSFR',
 'HOSTGAL_LOGSFR_ERR',
 'HOSTGAL_LOGsSFR',
 'HOSTGAL_LOGsSFR_ERR',
 'HOSTGAL_COLOR',
 'HOSTGAL_COLOR_ERR',
 'HOSTGAL_ELLIPTICITY',
 'HOSTGAL_OBJID2',
 'HOSTGAL_SQRADIUS',
 'HOSTGAL_OBJID_UNIQUE',
 'HOSTGAL_ZPHOT_Q000',
 'HOSTGAL_ZPHOT_Q010',
 'HOSTGAL_ZPHOT_Q020',
 'HOSTGAL_ZPHOT_Q030',
 'HOSTGAL_ZPHOT_Q040',
 'HOSTGAL_ZPHOT_Q050'

In [4]:
example_input = parquet[0]
additional_features = ELAsTiCC2_ORACLEFeatureExtractor._get_static_features()

def get_phot_from_parquet(parquet_rows):
    class_to_sncode = {
        'SNIa': 10, 'SNIb/c': 25, 'SNII': 37, 'SNIax': 12, 'SN91bg': 11, 'KN': 50, 'M-dwarf Flare': 82, 'Dwarf Novae': 84, 'uLens': 88, 
        'SLSN': 40, 'TDE': 42, 'ILOT': 45, 'CART': 46, 'PISN': 59, 'Cepheid': 90, 'RR Lyrae': 80, 'Delta Scuti': 91, 'EB': 83, 'AGN': 60,
        'SNII': 32, 'SNII': 31, 'SNII': 35, 'SNII': 36, 'SNIb/c': 21, 'SNIb/c': 20, 'SLSN': 72, 'SNIb/c': 27, 'SNIb/c': 26
    }
    
    data = []
    for idx, obj in enumerate(parquet_rows.iter_rows(named=True)):        
        phot_d = {}
        phot_d['objectid'] = int(obj['SNID'])
        phot_d['sncode'] = obj['SNTYPE']
        # phot_d['sncode'] = class_to_sncode[obj['ELASTICC_class']]
        phot_d['redshift'] = obj['REDSHIFT_FINAL']
        phot_d['RA'] = obj['RA']
        phot_d['DEC'] = obj['DEC']
        
        phot_d['photometry'] = {}
        phot_d['photometry']['BAND'] = obj['BAND']
        phot_d['photometry']['MJD'] = obj['MJD']
        phot_d['photometry']['FLUXCAL'] = obj['FLUXCAL']
        phot_d['photometry']['FLUXCALERR'] = obj['FLUXCALERR']
        phot_d['photometry']['PHOTFLAG'] = obj['PHOTFLAG']
        
        phot_d['additional_info'] = {}
        
        for feature in additional_features:
            phot_d[feature] = obj[feature]
        
        data.append(phot_d)
        
    return data

data_dic = get_phot_from_parquet(example_input)

In [5]:
# testing out the feature extractor
from resspect.fit_lightcurves import fit, fit_TOM
feature_extraction_method = 'oracle_resspect_classifier.elasticc2_oracle_feature_extractor.ELAsTiCC2_ORACLEFeatureExtractor'
fit(
    data_dic,
    output_features_file = 'TOM_days_storage/TOM_training_features.csv',
    feature_extractor = feature_extraction_method,
    filters = 'LSST',
    additional_info = additional_features,
    # one_code = gentypes
)

INFO:root:Starting oracle_resspect_classifier.elasticc2_oracle_feature_extractor.ELAsTiCC2_ORACLEFeatureExtractor fit...


                                                 MJD  \
0  [60423.4061, 60423.4145, 60428.4206, 60448.362...   

                                             FLUXCAL  \
0  [2.7815372943878174, -42.94839096069336, -14.2...   

                                          FLUXCALERR  \
0  [41.806304931640625, 39.68401336669922, 36.072...   

                                                BAND  \
0  [Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, ...   

                                            PHOTFLAG          RA       DEC  \
0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  304.243393 -4.291256   

      MWEBV  MWEBV_ERR  REDSHIFT_FINAL  ...  HOSTGAL_RA  HOSTGAL_DEC  \
0  0.130539   0.006527        0.654122  ...  304.243466    -4.291261   

   HOSTGAL_SNSEP  HOSTGAL_ELLIPTICITY  HOSTGAL_MAG_u  HOSTGAL_MAG_g  \
0       0.262877               0.3411      23.957273      22.941786   

   HOSTGAL_MAG_r  HOSTGAL_MAG_i  HOSTGAL_MAG_z  HOSTGAL_MAG_Y  
0      21.804691      21.145388      20

ValueError: 16 columns passed, passed data had 30 columns

In [None]:
# viewing the output
import numpy as np

data = pl.read_csv('TOM_days_storage/TOM_training_features.csv', index_col=False)
data['orig_sample'] = 'train'
data['type'] = np.where(data['sncode'] == 10, 'Ia', 'other')
data.to_csv('TOM_days_storage/TOM_training_features.csv', index=False)

data