In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py
import tqdm

import gp_aug
import regressor
import utils

import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.size'] = 15

In [2]:
train_metadata = pd.read_csv('../data/plasticc/plasticc_train_metadata.csv.gz',
                            usecols=['object_id', 'true_target', 'true_peakmjd'])
train_metadata = train_metadata[train_metadata['true_target'].isin([42, 62])].drop('true_target', axis=1)
train_metadata.shape

(1677, 2)

In [3]:
test_metadata = pd.read_csv('../data/plasticc/plasticc_test_metadata.csv.gz',
                            usecols=['object_id', 'true_target', 'true_peakmjd'])
test_metadata = test_metadata[test_metadata['true_target'].isin([42, 62])].drop('true_target', axis=1)
test_metadata.shape

(1175244, 2)

In [4]:
metadata = pd.concat([train_metadata, test_metadata], axis=0)
print(metadata.shape)
metadata.head(7)

(1176921, 2)


Unnamed: 0,object_id,true_peakmjd
2,730,60444.379
7,1632,59602.09
11,2103,60220.684
12,2300,59582.93
18,3285,60403.363
21,3910,60545.609
23,4132,59613.43


In [5]:
objects_set = set(metadata.object_id.values)
len(objects_set)

1176921

In [6]:
passband2name = {0: 'u', 1: 'g', 2: 'r', 3: 'i', 4: 'z', 5: 'y'}
passband2lam  = {0: np.log10(3751.36), 1: np.log10(4741.64), 2: np.log10(6173.23), 
                 3: np.log10(7501.62), 4: np.log10(8679.19), 5: np.log10(9711.53)}
passband2color = {
    0: 'blue', 2: 'green', 4: 'purple',
    1: 'orange', 3: 'red', 5: 'brown'
}

In [7]:
def get_object(data, object_id):
    anobject = data[data.object_id == object_id]
    return anobject

def get_passband(anobject, passband):
    light_curve = anobject[anobject.passband == passband]
    return light_curve

def compile_obj(t, flux, flux_err, passband):
    obj = pd.DataFrame()
    obj['mjd']      = t
    obj['flux']     = flux
    obj['flux_err'] = flux_err
    obj['passband'] = passband
    return obj

def is_good(anobject, peak_mjd):
    if peak_mjd < anobject['mjd'].min() or peak_mjd > anobject['mjd'].max():
        return False
    
    # remove all objects with negative flux values
    if anobject['flux'].values.min() < 0:
        return False
    
    # keep only objects with at least 5 observations in at least 3 passbands
    count = 0
    for passband in range(6):
        if len(get_passband(anobject, passband)) < 5:
            count += 1
    if count > 3:
        return False
        
    # keep only objects without large breaks in observations
    anobject = anobject.sort_values('mjd')
    mjd = anobject['mjd'].values
    if np.diff(mjd, 1).max() > 50:
        return False
    
    return True

def plot_light_curves(anobject, title=""):
    anobject = anobject.sort_values('mjd')
    plt.figure(figsize=(9, 4))
    for passband in range(6):
        light_curve = get_passband(anobject, passband)
        plt.plot(light_curve['mjd'].values, light_curve['flux'].values, linewidth=0.5)
        plt.scatter(light_curve['mjd'].values, light_curve['flux'].values, label=passband2name[passband], linewidth=1)
    plt.xlabel('Modified Julian Date', size=14)
    plt.xticks(size=14)
    plt.ylabel('Flux', size=14)
    plt.yticks(size=14)
    plt.legend(loc='best',fontsize=14)
    plt.title(title, size=14)

In [8]:
def residuals_histogram(all_objects):
    plt.figure(figsize=(10, 7))
    plt.hist(all_objects['true_peakmjd'].values - all_objects['pred_peakmjd'].values, bins=50)
    plt.xlabel('mjd residuals', fontsize=15)
    plt.show()
    
    
def plot_light_curves_with_peak(anobject, true_peak_mjd=None, title="", n_obs=1000, save=None):
    model = gp_aug.GaussianProcessesAugmentation(passband2lam)
    model.fit(anobject['mjd'].values, anobject['flux'].values, 
              anobject['flux_err'].values, anobject['passband'].values)
    t_aug, flux_aug, flux_err_aug, passband_aug = model.augmentation(
        anobject['mjd'].min(), 
        anobject['mjd'].max(),
        n_obs=n_obs
    )
    anobject_aug = compile_obj(t_aug, flux_aug, flux_err_aug, passband_aug)
    curve = anobject_aug[['mjd', 'flux']].groupby('mjd', as_index=False).sum()
    pred_peak_mjd = curve['mjd'][curve['flux'].argmax()]

    plt.figure(figsize=(12, 7))
    for passband in range(6):
        light_curve = get_passband(anobject_aug, passband)
        plt.plot(light_curve['mjd'].values, light_curve['flux'].values, linewidth=1,
                 color=passband2color[passband])
        light_curve = get_passband(anobject, passband)
        plt.scatter(light_curve['mjd'].values, light_curve['flux'].values, 
                    label=passband2name[passband], color=passband2color[passband], linewidth=1)
    plt.plot(curve['mjd'].values, curve['flux'].values, label='Sum', linewidth=1, color='pink')

    plt.xlabel('Modified Julian Date', size=14)
    plt.xticks(size=14)
    plt.ylabel('Flux', size=14)
    plt.yticks(size=14)

    plt.axvline(pred_peak_mjd, label='Pred peak', color='gray')
    if true_peak_mjd is not None:
        plt.axvline(true_peak_mjd, label='True peak', color='black')
    plt.legend(loc='best', ncol=3, fontsize=14)
    plt.title(title, size=14)
    
    if save is not None:
        plt.savefig(save, format='pdf')
        
    plt.show()

In [27]:
def read_test_csv(filename, objects_set):
    test_idx = pd.read_csv(filename, usecols=['object_id']).object_id
    skiprows = test_idx.apply(lambda obj: obj not in objects_set).astype(bool)
    test_detected = pd.read_csv(filename, usecols=['detected_bool']).detected_bool
    skiprows |= ~test_detected.astype(bool)
    skiprows = skiprows.index[skiprows] + 1  # first row contains column names
    df = pd.read_csv(filename, skiprows=skiprows)
    return df


def GP_prepare_picture(anobject, n_obs=1000):
    model = gp_aug.GaussianProcessesAugmentation(passband2lam)
    model.fit(anobject['mjd'].values, anobject['flux'].values, 
              anobject['flux_err'].values, anobject['passband'].values)
    t_aug, flux_aug, flux_err_aug, passband_aug = model.augmentation(anobject['mjd'].min(), 
                                                                     anobject['mjd'].max(), n_obs=n_obs)
    anobject_aug = compile_obj(t_aug, flux_aug, flux_err_aug, passband_aug)
    anobject_aug = anobject_aug.drop('flux_err', axis=1).set_index(['mjd', 'passband']).unstack(level=1)
    return anobject_aug.values.T, np.array(anobject_aug.index)

In [32]:
pictures = []
ts = []
true_peaks = []

data = read_test_csv(f'../data/plasticc/plasticc_train_lightcurves.csv.gz', objects_set)
for object_id in tqdm.notebook.tqdm(data.object_id.unique()):
    anobject = get_object(data, object_id)
    true_peak = metadata[metadata.object_id == object_id]['true_peakmjd'].values[0]
    if not is_good(anobject, true_peak):
        continue
    picture, t = GP_prepare_picture(anobject)
    pictures.append(picture)
    ts.append(t)
    true_peaks.append(true_peak)

for number in range(1, 12):
    data = read_test_csv(f'../data/plasticc/plasticc_test_lightcurves_{number:02}.csv.gz', objects_set)
    for object_id in tqdm.notebook.tqdm(data.object_id.unique()):
        anobject = get_object(data, object_id)
        true_peak = metadata[metadata.object_id == object_id]['true_peakmjd'].values[0]
        if not is_good(anobject, true_peak):
            continue
        picture, t = GP_prepare_picture(anobject)
        pictures.append(picture)
        ts.append(t)
        true_peaks.append(true_peak)

HBox(children=(FloatProgress(value=0.0, max=1677.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=18180.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115377.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115523.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115573.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115670.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115838.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115480.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115629.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115609.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=116378.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115987.0), HTML(value='')))




In [37]:
!mkdir cnn_cls_data

In [38]:
np.save('cnn_cls_data/pictures.npy', np.array(pictures))
np.save('cnn_cls_data/timestamps.npy', np.array(ts))
np.save('cnn_cls_data/true_peaks.npy', np.array(true_peaks))

In [39]:
np.array(pictures).shape, np.array(ts).shape, np.array(true_peaks).shape

((11078, 6, 1000), (11078, 1000), (11078,))

In [44]:
((np.array(true_peaks) < np.array(ts)[:, -1]) & (np.array(true_peaks) > np.array(ts)[:, 0])).all()

True