In [1]:
%matplotlib inline
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

N_PASSBANDS = 6

In [2]:
passband2lam  = {0: np.log10(3751.36), 1: np.log10(4741.64), 2: np.log10(6173.23), 
                 3: np.log10(7501.62), 4: np.log10(8679.19), 5: np.log10(9711.53)}

In [3]:
def get_object(data, object_id):
    anobject = data[data.object_id == object_id]
    return anobject

In [4]:
def get_passband(anobject, passband):
    light_curve = anobject[anobject.passband == passband]
    return light_curve

In [5]:
def is_good(anobject):
    # remove all objects with negative flux values
    if anobject['flux'].values.min() < 0:
        return False
    
    # keep only objects with at least 3 observations in at least 3 passbands
    count = 0
    for passband in range(N_PASSBANDS):
        if len(get_passband(anobject, passband)) < 3:
            count += 1
    if count > 3:
        return False
        
    # keep only objects without large breaks in observations
    anobject = anobject.sort_values('mjd')
    mjd = anobject['mjd'].values
    if np.diff(mjd, 1).max() > 50:
        return False
    
    return True

In [6]:
from tqdm import tqdm_notebook

In [7]:
meta_file = '../data/plasticc/PLAsTiCC-2018/plasticc_train_metadata.csv.gz'
metadata = pd.read_csv(meta_file)

file = "../data/plasticc/PLAsTiCC-2018/plasticc_train_lightcurves.csv.gz"
data = pd.read_csv(file)

data = data[data.detected_bool == 1]
object_ids = np.unique(data.object_id)
print(object_ids.shape)
    
data["log_lam"] = data.apply(lambda x: passband2lam[x.passband], axis=1)
    
good_objects_df = pd.DataFrame(columns=data.columns)

good_object_ids = []
for i in tqdm_notebook(object_ids):
    anobject = get_object(data, i)

    if is_good(anobject):
        good_object_ids.append(i)
        good_objects_df = pd.concat([good_objects_df, anobject])
    
print(len(good_object_ids))
good_objects_df["class"] = good_objects_df.apply(lambda x: 
       1 if int(metadata[metadata.object_id == x.object_id].target.to_numpy()[0]) 
                                             in (90, 67, 52) else 0, axis=1)

good_objects_df["true_peakmjd"] = good_objects_df.apply(lambda x: 
       int(metadata[metadata.object_id == x.object_id].true_peakmjd.to_numpy()[0]), axis=1)
                                                        
good_objects_df["true_target"] = good_objects_df.apply(lambda x: 
       int(metadata[metadata.object_id == x.object_id].true_target.to_numpy()[0]), axis=1)

good_objects_df.to_csv('../data/plasticc/good_objects_peak.csv')

del data
del object_ids
del good_object_ids
del good_objects_df

(7848,)


HBox(children=(FloatProgress(value=0.0, max=7848.0), HTML(value='')))


1957


In [None]:
from tqdm import tqdm_notebook

meta_file_test = '../data/plasticc/PLAsTiCC-2018/plasticc_test_metadata.csv.gz'
metadata = pd.read_csv(meta_file_test)

for batch_number in ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11']:
    file = "../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_{}.csv.gz".format(batch_number)
    print(file)
    data = pd.read_csv(file)

    data = data[data.detected_bool == 1]
    object_ids = np.unique(data.object_id)
    print(batch_number, object_ids.shape)
    
    data["log_lam"] = data.apply(lambda x: passband2lam[x.passband], axis=1)
    
    good_objects_df = pd.DataFrame(columns=data.columns)

    good_object_ids = []
    for i in tqdm_notebook(object_ids):
        anobject = get_object(data, i)
    
        if is_good(anobject):
            good_object_ids.append(i)
            good_objects_df = pd.concat([good_objects_df, anobject])
    
    print(batch_number, len(good_object_ids))
    good_objects_df["class"] = good_objects_df.apply(lambda x: 
            1 if int(metadata[metadata.object_id == x.object_id].true_target.to_numpy()[0]) 
                                                 in (90, 67, 52) else 0, axis=1)
    
    good_objects_df["true_peakmjd"] = good_objects_df.apply(lambda x: 
           int(metadata[metadata.object_id == x.object_id].true_peakmjd.to_numpy()[0]), axis=1)
                                                        
    good_objects_df["true_target"] = good_objects_df.apply(lambda x: 
           int(metadata[metadata.object_id == x.object_id].true_target.to_numpy()[0]), axis=1)
    
    
    good_objects_df.to_csv('../data/plasticc/good_objects_peak.csv', mode='a', header=False)  
        
    del data
    del object_ids
    del good_object_ids
    del good_objects_df

../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_01.csv.gz
01 (32926,)


HBox(children=(FloatProgress(value=0.0, max=32926.0), HTML(value='')))


01 11886
../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_02.csv.gz
02 (345997,)


HBox(children=(FloatProgress(value=0.0, max=345997.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




02 18669
../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_03.csv.gz
03 (345997,)


HBox(children=(FloatProgress(value=0.0, max=345997.0), HTML(value='')))

In [None]:
good_objects = pd.read_csv('../data/plasticc/good_objects_peak.csv', index_col=0)
print(np.unique(good_objects.object_id).shape)
good_objects.sample(10)

In [12]:
good_objects.sample(10)

Unnamed: 0,object_id,mjd,passband,flux,flux_err,detected_bool,log_lam,class,true_peakmjd,true_target
1311643,42151,59770.3817,3,33.75238,1.093348,1,3.875155,1,59784,90
7229566,16072528,60141.4117,1,71.908791,2.538478,1,3.675929,0,60109,62
12470802,69561529,60556.1907,3,283.323975,4.345737,1,3.875155,0,60564,62
40100879,51701058,60019.4032,4,106.552505,16.839195,1,3.938479,1,60005,90
13264726,95761922,59869.2393,2,255.29953,4.166808,1,3.790512,1,59846,90
4951854,93336136,59724.4279,4,215.34581,11.226151,1,3.938479,0,59636,95
8025800,42291986,60633.3445,4,135.532608,25.922781,1,3.938479,1,60594,90
2244713,79537686,60502.3461,2,37.39922,2.990577,1,3.790512,0,60369,42
36085417,11565132,60371.0089,5,452.764557,45.20118,1,3.987288,1,60326,90
19827028,71727822,60545.261,3,73.606499,4.480795,1,3.875155,0,60512,42
