# Preprocess Dataset

In this notebook we preprocess the the ligh curves such as described in [Link to paper in arXiv](https://arxiv.org/abs/2107.07531).

#### Index<a name="index"></a>
1. [Import Packages](#imports)
2. [Load the Original Dataset](#loadData)
3. [Preprocess Light Curves](#preprocess)
4. [Save Processed PlasticcData](#saveData)
5. [Light Curve Comparison](#comparison)

## 1. Import Packages<a name="imports"></a>

In [None]:
!pip install ../snmachine/

In [None]:
import collections
import os
import pickle
import sys
import time

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
from snmachine import sndata, analysis
from utils.plasticc_pipeline import get_directories, load_dataset

In [None]:
%config Completer.use_jedi = False  # enable autocomplete

#### Aestetic settings

In [None]:
%matplotlib inline

sns.set(font_scale=1.3, style="ticks")

## 2. Load Original Dataset<a name="loadData"></a>

First, **write** the path to the dataset folder `folder_path`.

In [None]:
# os_name = 'baseline_v2_0_paper'
# os_name = 'noroll_v2_0_paper'
os_name = 'presto_v2_0_paper'

folder_path = f'/folder/path'

Then, **write** in `data_file_name` the name of the file where your dataset is saved.

In this notebook we use the dataset previously created.

In [None]:
# extra_name_to_save = 'ddf'
extra_name_to_save = 'wfd'

# name = 'train'
name = 'test'

# file_id = '000'
file_id = '006' # until 012

data_file_name = f'{name}_{extra_name_to_save}_{file_id}.pckl'

Load the dataset. It takes 15s and 4min for DDF, respectivelly train and test. It takes ~14-25min for 1/13 WFD.

In [None]:
data_path = os.path.join(folder_path, data_file_name)
ini_time = time.time()
dataset = load_dataset(data_path)
print(time.time() - ini_time)

In [None]:
dataset.get_max_length()

In [None]:
dataset.remove_gaps(max_gap_length=50)

In [None]:
dataset.get_max_length()

In [None]:
print(collections.Counter(dataset.metadata['target']), len(dataset.metadata))

Save the data of one event to later compare on the Section [Light curve transformation](#transformation). **Choose** the event by modifying `obj_show`.

In [None]:
dataset.object_names

In [None]:
print(f'The longest original light curve has {dataset.get_max_length():.2f} days.') 

In [None]:
# DDF
#obj_show = '416872' # base train
#obj_show = '529406' # foot8 train
#obj_show = '681000' # base test
#obj_show = '209066' # foot8 test

# WFD
#obj_show = '122420516' # base train
#obj_show = '2912729' # base test 000
#obj_show = '60282775' # base test 001
#obj_show = '34495546' # base test 002
###obj_show = '64767080' # base test 003
#obj_show = '11702392' # base test 004
#obs_before = dataset.data[obj_show].to_pandas()

## 3. Preprocess Light Curves<a name="preprocess"></a>

In [None]:
is_run_everything = 0

In [None]:
if is_run_everything:
    batch_ids = ['000', '001', '002', '003', '004', '005', '006', 
                 '007', '008', '009', '010', '011', '012']
    
    extra_name_to_save = 'wfd'
    
    name = 'test'
    
    max_distance = 50
    max_gap_length = 50
    
    lc_length_s = []
    
    for i, batch_id in enumerate(batch_ids):
        print(f'Batch {batch_id}')
    
        # Load data
        data_file_name = f'{name}_{extra_name_to_save}_{batch_id}.pckl'
        data_path = os.path.join(folder_path, data_file_name)
        ini_time = time.time()
        dataset = load_dataset(data_path)
        print(time.time() - ini_time)
        print('max', dataset.get_max_length())
        
        # Select window
        if is_only_roll:
            ini_time = time.time()
            dataset.select_window(window=[60768, None], verbose=True)
            time_taken = time.time() - ini_time
            print(time_taken)
            print('max', dataset.get_max_length())
        
        # Select transient
        ini_time = time.time() # other
        dataset.select_transients(max_distance=max_distance, verbose=True)
        time_taken = time.time() - ini_time
        print(time_taken)
        
        # Remove gaps
        ini_time = time.time() 
        dataset.remove_gaps(max_gap_length*2, verbose=True)
        dataset.remove_gaps(max_gap_length*2, verbose=True)
        dataset.remove_gaps(max_gap_length, verbose=True)
        dataset.remove_gaps(max_gap_length, verbose=True)
        dataset.remove_gaps(max_gap_length, verbose=True)
        time_taken = time.time() - ini_time
        print(time_taken)
        
        # Keep only events with at least one detection
        ini_time = time.time()
        good_objs = []
        for obj in dataset.object_names:
            obj_data = dataset.data[obj]
            if np.sum(obj_data['detected']) > 0:
                good_objs.append(obj)
        time_taken = time.time() - ini_time
        print(time_taken)        
        
        # Keep only events detected at least 2 days ; I shouls have added this before
        ini_time = time.time()
        good_objs = []
        for obj in dataset.object_names:
            obj_data = dataset.data[obj]
            if np.max(obj_data['mjd'])-np.min(obj_data['mjd']) > 0.5:
                good_objs.append(obj)
        time_taken = time.time() - ini_time
        print(time_taken)
        
        if len(dataset.object_names) != len(good_objs):
            print('Something bad unless is 1.5 years datasets')
            ini_time = time.time()
            dataset.update_dataset(good_objs)
            dataset.update_dataset(list(dataset.metadata.index))
            time_taken = time.time() - ini_time
            print(time_taken)
        
        # Calculate LC length
        ini_time = time.time()
        lc_length_s.append(analysis.compute_lc_length(dataset))
        print(time.time() - ini_time)
        
        # Save file
        folder_path_to_save = folder_path[:-14]
        file_name = data_file_name[:-5]+'_gapless50_updated.pckl'
        if is_only_roll:
            file_name = data_file_name[:-5]+'_roll_gapless50_updated.pckl'
        
        ini_time = time.time()
        with open(os.path.join(folder_path_to_save, file_name), 'wb') as f:
            pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
        time_taken = time.time() - ini_time
        print(time_taken)
        
        print('')
        print('')

### 3.1. Only rolling part<a name="roll"></a> <font color=salmon>(Optional)</font>

We generated the events between days 60220 and 61325. Since the rolling cadence starts in year 1.5, I will cut all the light curves to be after 60220+548=60768 days.

In [None]:
is_only_roll = 0

In [None]:
if is_only_roll:
    ini_time = time.time()
    dataset.select_window(window=[60768, None], verbose=True)
    time_taken = time.time() - ini_time
    print(time_taken)
    dataset.get_max_length()

### 3.2. Select transient part<a name="trans"></a>

Select all observations between the detections or within 50 days before the first detection or after the last.

In [None]:
max_distance = 50

In [None]:
ini_time = time.time() # 7min
dataset.select_transients(max_distance=max_distance, verbose=True)
time_taken = time.time() - ini_time
print(time_taken)

In [None]:
354+254

### 3.3. Remove gaps<a name="removeGaps"></a>

**Write** the maximum duration of the gap to allowed in the light curves, `max_gap_length`.

In [None]:
max_gap_length = 50

To remove all the gaps longer than `max_gap_length`, the `remove_gaps` function must be called a few times; it only removes the first gap longer than `max_gap_length`.

To introduce uniformity in the dataset, the resulting light curves are translated so their first observation is at time zero.

This takes ~30min for the WFD data.

In [None]:
ini_time = time.time() 
dataset.remove_gaps(max_gap_length*2, verbose=True)
dataset.remove_gaps(max_gap_length*2, verbose=True)
dataset.remove_gaps(max_gap_length, verbose=True)
dataset.remove_gaps(max_gap_length, verbose=True)
dataset.remove_gaps(max_gap_length, verbose=True)
#dataset.remove_gaps(max_gap_length, verbose=True)
#dataset.remove_gaps(max_gap_length, verbose=True)
#dataset.remove_gaps(max_gap_length, verbose=True)
time_taken = time.time() - ini_time
print(time_taken)

In [None]:
# Keep only events with at least one detection
ini_time = time.time()
good_objs = []
for obj in dataset.object_names:
    obj_data = dataset.data[obj]
    if (np.sum(obj_data['detected']) > 0) and (np.max(obj_data['mjd'])-np.min(obj_data['mjd'])>.5):
        good_objs.append(obj)
time_taken = time.time() - ini_time
print(time_taken)

if len(dataset.object_names) != len(good_objs):
    print('Something bad')

In [None]:
if len(dataset.object_names) != len(good_objs):
    print('Something bad')
    ini_time = time.time()
    dataset.update_dataset(good_objs)
    dataset.update_dataset(list(dataset.metadata.index))
    time_taken = time.time() - ini_time
    print(time_taken)

## 4. Save Processed SnanaData<a name="saveData"></a>

Now, **chose** a path to save the SnanaData instance created (`folder_path_to_save`) and the name of the file (`file_name`). It takes ~6-8min to save 1/13 of test set WFD for baseline v2.0.

In [None]:
folder_path_to_save = folder_path[:-14]
file_name = data_file_name[:-5]+'_gapless50_updated.pckl'
if is_only_roll:
    file_name = data_file_name[:-5]+'_roll_gapless50_updated.pckl'
file_name

In [None]:
ini_time = time.time()
with open(os.path.join(folder_path_to_save, file_name), 'wb') as f:
    pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
time_taken = time.time() - ini_time
print(time_taken)

[Go back to top.](#index)

## 5. Light Curve Comparison<a name="comparison"></a>

Here we show the difference between one original light curve and the transformed one.

In [None]:
obs_after = dataset.data[obj_show]

In [None]:
sndata.plot_lc(obs_before)
plt.axvspan(xmin=729, xmax=849, color='gray', alpha=.3)
plt.title('Before')

In [None]:
sndata.plot_lc(obs_after, False)
plt.title('After')

[Go back to top.](#index)