In [1]:
%matplotlib inline
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

N_PASSBANDS = 6

In [2]:
passband2lam  = {0: np.log10(3751.36), 1: np.log10(4741.64), 2: np.log10(6173.23), 
                 3: np.log10(7501.62), 4: np.log10(8679.19), 5: np.log10(9711.53)}

In [3]:
def get_object(data, object_id):
    anobject = data[data.object_id == object_id]
    return anobject

In [4]:
def get_passband(anobject, passband):
    light_curve = anobject[anobject.passband == passband]
    return light_curve

In [5]:
def is_good(anobject):
    # remove all objects with negative flux values
    if anobject['flux'].values.min() < 0:
        return False
    
    # keep only objects with at least 10 observations in at least 3 passbands
    count = 0
    for passband in range(N_PASSBANDS):
        if len(get_passband(anobject, passband)) < 10:
            count += 1
    if count > 3:
        return False
        
    # keep only objects without large breaks in observations
    anobject = anobject.sort_values('mjd')
    mjd = anobject['mjd'].values
    if np.diff(mjd, 1).max() > 50:
        return False
    
    return True

In [12]:
meta_file = '../data/plasticc/PLAsTiCC-2018/training_set_metadata.csv'
metadata = pd.read_csv(meta_file)

file = "../data/plasticc/PLAsTiCC-2018/training_set.csv"
data = pd.read_csv(file)

data = data[data.detected == 1]
object_ids = np.unique(data.object_id)
print(object_ids.shape)
    
data["log_lam"] = data.apply(lambda x: passband2lam[x.passband], axis=1)
    
good_objects_df = pd.DataFrame(columns=data.columns)

good_object_ids = []
for i in tqdm_notebook(object_ids):
    anobject = get_object(data, i)

    if is_good(anobject):
        good_object_ids.append(i)
        good_objects_df = pd.concat([good_objects_df, anobject])
    
print(len(good_object_ids))
good_objects_df["class"] = good_objects_df.apply(lambda x: 
       1 if int(metadata[metadata.object_id == x.object_id].target.to_numpy()[0]) 
                                             in (90, 67, 52) else 0, axis=1)

good_objects_df.to_csv('../data/plasticc/good_objects.csv')

del data
del object_ids
del good_object_ids
del good_objects_df

(7848,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7848.0), HTML(value='')))


516


In [6]:
from tqdm import tqdm_notebook

meta_file_test = '../data/plasticc/plasticc_test_metadata.csv.gz'
metadata = pd.read_csv(meta_file_test)

for batch_number in range(1, 12):
    file = "../data/plasticc/PLAsTiCC-2018/test_set_batch{}.csv".format(batch_number)
    data = pd.read_csv(file)

    data = data[data.detected == 1]
    object_ids = np.unique(data.object_id)
    print(batch_number, object_ids.shape)
    
    data["log_lam"] = data.apply(lambda x: passband2lam[x.passband], axis=1)
    
    good_objects_df = pd.DataFrame(columns=data.columns)

    good_object_ids = []
    for i in tqdm_notebook(object_ids):
        anobject = get_object(data, i)
    
        if is_good(anobject):
            good_object_ids.append(i)
            good_objects_df = pd.concat([good_objects_df, anobject])
    
    print(batch_number, len(good_object_ids))
    good_objects_df["class"] = good_objects_df.apply(lambda x: 
            1 if int(metadata[metadata.object_id == x.object_id].true_target.to_numpy()[0]) 
                                                 in (90, 67, 52) else 0, axis=1)
    
    
    good_objects_df.to_csv('../data/plasticc/good_objects.csv', mode='a', header=False)  
        
    del data
    del object_ids
    del good_object_ids
    del good_objects_df

1 (32926,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=32926.0), HTML(value='')))


1 1854
2 (345997,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345997.0), HTML(value='')))


2 19
3 (345997,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345997.0), HTML(value='')))


3 7
4 (345997,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345997.0), HTML(value='')))


4 20
5 (345997,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345997.0), HTML(value='')))


5 10
6 (345996,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345996.0), HTML(value='')))


6 18
7 (345996,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345996.0), HTML(value='')))


7 18
8 (345996,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345996.0), HTML(value='')))


8 9
9 (345996,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345996.0), HTML(value='')))


9 13
10 (345996,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345996.0), HTML(value='')))


10 9
11 (345996,)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=345996.0), HTML(value='')))


11 15


In [8]:
good_objects = pd.read_csv('../data/plasticc/good_objects.csv', index_col=0)
print(np.unique(good_objects.object_id).shape)
good_objects.sample(10)

(1992,)


Unnamed: 0,object_id,mjd,passband,flux,flux_err,detected,log_lam,class
537281,16917,60593.1287,1,46.91198,1.370294,1,3.675929,0
6629072,209931,59870.0459,4,107.189636,2.627374,1,3.938479,1
6947512,219870,60434.0115,4,32.411114,1.982084,1,3.938479,0
7719805,244335,60250.1708,2,12.785584,2.124709,1,3.790512,1
3505470,112444,59866.0438,4,36.240757,2.020582,1,3.938479,0
4976018,158318,60588.1666,0,19.051743,1.75884,1,3.574189,0
7429819,235172,60640.0656,3,586.982239,3.227352,1,3.875155,1
1348807,43398,60168.3157,4,188.441711,2.334626,1,3.938479,1
7961196,251965,59857.0931,2,146.236984,1.928173,1,3.790512,1
1342813,43208,59945.1032,4,32.0905,2.006571,1,3.938479,0
