In [1]:
%matplotlib inline
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

N_PASSBANDS = 6

In [2]:
passband2lam  = {0: np.log10(3751.36), 1: np.log10(4741.64), 2: np.log10(6173.23), 
                 3: np.log10(7501.62), 4: np.log10(8679.19), 5: np.log10(9711.53)}

In [3]:
def get_object(data, object_id):
    anobject = data[data.object_id == object_id]
    return anobject

In [4]:
def get_passband(anobject, passband):
    light_curve = anobject[anobject.passband == passband]
    return light_curve

In [5]:
def is_good(anobject):
    # remove all objects with negative flux values
    if anobject['flux'].values.min() < 0:
        return False
    
    # keep only objects with at least 3 observations in at least 3 passbands
    count = 0
    for passband in range(N_PASSBANDS):
        if len(get_passband(anobject, passband)) < 3:
            count += 1
    if count > 3:
        return False
        
    # keep only objects without large breaks in observations
    anobject = anobject.sort_values('mjd')
    mjd = anobject['mjd'].values
    if np.diff(mjd, 1).max() > 50:
        return False
    
    return True

In [6]:
from tqdm import tqdm_notebook

In [7]:
meta_file = '../data/plasticc/PLAsTiCC-2018/plasticc_train_metadata.csv.gz'
metadata = pd.read_csv(meta_file)

file = "../data/plasticc/PLAsTiCC-2018/plasticc_train_lightcurves.csv.gz"
data = pd.read_csv(file)

data = data[data.detected_bool == 1]
object_ids = np.unique(data.object_id)
print(object_ids.shape)
    
data["log_lam"] = data.apply(lambda x: passband2lam[x.passband], axis=1)
    
good_objects_df = pd.DataFrame(columns=data.columns)

good_object_ids = []
for i in tqdm_notebook(object_ids):
    anobject = get_object(data, i)

    if is_good(anobject):
        good_object_ids.append(i)
        good_objects_df = pd.concat([good_objects_df, anobject])
    
print(len(good_object_ids))
good_objects_df["class"] = good_objects_df.apply(lambda x: 
       1 if int(metadata[metadata.object_id == x.object_id].target.to_numpy()[0]) 
                                             in (90, 67, 52) else 0, axis=1)

good_objects_df.to_csv('../data/plasticc/good_objects.csv')

del data
del object_ids
del good_object_ids
del good_objects_df

(7848,)


HBox(children=(FloatProgress(value=0.0, max=7848.0), HTML(value='')))


1957


In [None]:
from tqdm import tqdm_notebook

meta_file_test = '../data/plasticc/PLAsTiCC-2018/plasticc_test_metadata.csv.gz'
metadata = pd.read_csv(meta_file_test)

for batch_number in ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11']:
    file = "../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_{}.csv.gz".format(batch_number)
    print(file)
    data = pd.read_csv(file)

    data = data[data.detected_bool == 1]
    object_ids = np.unique(data.object_id)
    print(batch_number, object_ids.shape)
    
    data["log_lam"] = data.apply(lambda x: passband2lam[x.passband], axis=1)
    
    good_objects_df = pd.DataFrame(columns=data.columns)

    good_object_ids = []
    for i in tqdm_notebook(object_ids):
        anobject = get_object(data, i)
    
        if is_good(anobject):
            good_object_ids.append(i)
            good_objects_df = pd.concat([good_objects_df, anobject])
    
    print(batch_number, len(good_object_ids))
    good_objects_df["class"] = good_objects_df.apply(lambda x: 
            1 if int(metadata[metadata.object_id == x.object_id].true_target.to_numpy()[0]) 
                                                 in (90, 67, 52) else 0, axis=1)
    
    
    good_objects_df.to_csv('../data/plasticc/good_objects.csv', mode='a', header=False)  
        
    del data
    del object_ids
    del good_object_ids
    del good_objects_df

../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_01.csv.gz
01 (32926,)


HBox(children=(FloatProgress(value=0.0, max=32926.0), HTML(value='')))


01 11886
../data/plasticc/PLAsTiCC-2018/plasticc_test_lightcurves_02.csv.gz
02 (345997,)


HBox(children=(FloatProgress(value=0.0, max=345997.0), HTML(value='')))


02 18669


In [None]:
good_objects = pd.read_csv('../data/plasticc/good_objects.csv', index_col=0)
print(np.unique(good_objects.object_id).shape)
good_objects.sample(10)

In [11]:
good_objects.sample(10)

Unnamed: 0,object_id,mjd,passband,flux,flux_err,detected_bool,log_lam,class
1568900,14416509,60400.0535,1,33.934093,5.306179,1,3.675929,1
8547057,270325,59679.0225,3,8.042412,1.939021,1,3.875155,0
7377496,233300,60403.0334,3,32.112743,2.306313,1,3.875155,1
24881055,99173179,60493.0856,2,51.211929,3.750323,1,3.790512,0
5912236,54672603,60029.2406,3,124.341766,3.898028,1,3.875155,0
16535645,70754240,60040.9894,4,110.979774,14.936048,1,3.938479,0
15183157,5443126,59788.4338,4,619.810852,16.085686,1,3.938479,0
34082543,10979587,59824.1221,3,107.776382,4.972433,1,3.875155,1
29843408,74676692,60591.078,3,28.29151,4.219119,1,3.875155,0
22095237,20454996,60593.2945,2,59.950615,2.13694,1,3.790512,0
