# Filter Plots 

In [None]:
import numpy as np
from dataframe import Data
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib 
from io_utils import *
from random import random
from filter_transients import *
from matplotlib.patches import Rectangle

## Filter Region

In [None]:
from LightCurve import LightCurve

In [None]:
dataset =  "PLAsTiCC"
data_ob = load_PLAsTiCC_data()
object_ids = data_ob.get_all_object_ids()
kilonova_ids=data_ob.get_ids_of_event_type(64)
color_band_dict = {0:'C4',1:'C2', 2:'C3', 3:'C1', 4:'k', 5:'C5'}
bands = [0,1,2,3,4,5]
#transient_filter_load_saved(data_ob=data_ob)

In [None]:
object_id=4220
lc = LightCurve(data_ob, object_id)
color_band_dict = {0:'C4',1:'C4', 2:'C4', 3:'C4', 4:'C4', 5:'C4'}
fig = lc.plot_max_flux_regions(color_band_dict=color_band_dict, plot_points=True, priority=1, mark_maximum=False)
ax = fig.gca()
ax.annotate("Event ID: "+str(object_id), xy=(.14,.81),xycoords='figure fraction', fontsize=15)
ax.annotate("Type: "+ data_ob.get_object_type_for_PLAsTiCC(object_id=object_id), xy=(.14,.76),xycoords='figure fraction',fontsize=15)
pr = lc.find_region_priority()
median = np.median(pr[0])
ax.axvline(median,ymin=0, ymax=1, label = "median day", color = 'royalblue')
ax.axvline(median-10,ymin =0, ymax=1, color = 'red')
ax.axvline(median+10,ymin =0, ymax=1, color = 'red')
ax.axhline(100, xmin=0, xmax=1, color = 'red')
ax.axhline(-100, xmin=0, xmax=1, color = 'red')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
rect1 = Rectangle((median-10,ylim[0]),height = ylim[1]-ylim[0], width =20, color='lightgrey', alpha =.5,label="excluded region")
rect2 = Rectangle((xlim[0], -100), height = 200, width = xlim[1]-xlim[0], color = 'lightgrey', alpha = .5)
ax.add_patch(rect1)
#ax.axvline([0,np.amax(lc.df[lc.band_col_name])])
ax.add_patch(rect2)
ax.legend()
#fig.savefig("important_plots/exclude4")

print(data_ob.get_object_type_for_PLAsTiCC(object_id=4220))
print(object_id)

## Penalty distribution

In [None]:
dataset_val = 1

In [None]:
if dataset_val == 0:
    dataset =  "PLAsTiCC"
    data_ob = load_PLAsTiCC_data()
    object_ids = data_ob.get_all_object_ids()
    kilonova_ids=data_ob.get_ids_of_event_type(64)
    color_band_dict = {0:'C4',1:'C2', 2:'C3', 3:'C1', 4:'k', 5:'C5'}
    bands = [0,1,2,3,4,5]

In [None]:
if dataset_val == 1:    
    dataset = "ZTF"
    data_ob = load_ztf_train_data()
    data_ob.df_data.sort(['SNID','MJD'])
    data_ob.df_metadata.sort(['SNID'])
    #Error with load_ztf_mixed()
    object_ids = data_ob.get_all_object_ids()
    #kilonova_ids = object_ids
    bands = ['g', 'r']
    color_band_dict = {'g':'C2', 'r':'C3'}

In [None]:
data_ob.df_data

In [None]:
transient_events_penalty = []
periodic_events_penalty = []
kn_events_penalty = []
nonkn_events_penalty = []

kn_type_nums = [150,151]

for object_id in object_ids:
    object_df = data_ob.df_data[data_ob.df_data[data_ob.object_id_col_name] == object_id]
    #print(object_df)
    penalty = calc_priodic_penalty(data_ob, object_df)
    if data_ob.is_transient(object_id):
        transient_events_penalty.append(penalty)
    else:
        periodic_events_penalty.append(penalty)
        
    
    if data_ob.get_object_type_number(object_id) in kn_type_nums:
        kn_events_penalty.append(penalty)
    else:
        nonkn_events_penalty.append(penalty)
    
    

In [None]:
plot_distribution(transient_events_penalty, periodic_events_penalty, label1="Transient", label2 ="Periodic")
plt.show()

In [None]:
plot_distribution(kn_events_penalty, nonkn_events_penalty, label1="Kilonovae", label2 ="Non-kilonovae", cut=8)
plt.show()

In [None]:
transient_events_penalty = []
periodic_events_penalty = []
kn_events_penalty = []
nonkn_events_penalty = []

kn_type_nums = [64]

for object_id in object_ids:
    object_df = data_ob.df_data[data_ob.df_data[data_ob.object_id_col_name] == object_id]
    max_time = np.amax(object_df[data_ob.time_col_name])
    min_time = max_time-365
    time_index = object_df[data_ob.time_col_name]>=min_time
    object_df = object_df[time_index]
    #print(object_df)
    penalty = calc_priodic_penalty(data_ob, object_df)
    if data_ob.is_transient(object_id)==1:
        transient_events_penalty.append(penalty)
    elif data_ob.is_transient(object_id)==0:
        periodic_events_penalty.append(penalty)
        
    
    if data_ob.get_object_type_number(object_id) in kn_type_nums:
        kn_events_penalty.append(penalty)
    else:
        nonkn_events_penalty.append(penalty)
    

In [None]:
plot_distribution(transient_events_penalty, periodic_events_penalty, label1="Transient", label2 ="Periodic")
plt.show()

In [None]:
plot_distribution(kn_events_penalty, nonkn_events_penalty, label1="Kilonovae", label2 ="Non-kilonovae", cut=8)
plt.show()

# Light curves:

In [None]:
for ob_id in object_ids:
    pc = Predict_lc

In [None]:
from Predict_lc import PredictLightCurve

In [None]:
pc = Predict_lc()

In [None]:
decouple_prediction_bands = True
decouple_pc_bands = True
mark_maximum = False
use_filter = False
min_flux_threshold = 20
num_pc_components = 3
use_random_current_date = False
dataset_val =0

arr103 = []
for object_id in object_ids:
    #object_id = 136110
    event_df = data_ob.get_data_of_event(object_id)
    #print(event_df)
    pc = PredictLightCurve(data_ob, object_id=object_id)
    #pc.lc.plot_light_curve(color_band_dict=color_band_dict)
    current_date = None
    if use_random_current_date:
        #median_date = np.median(pc.lc.dates_of_maximum)
        #current_date = median_date+random()*50-25
        current_min = np.amin(lc.df[data_ob.time_col_name])
        current_max = np.amax(lc.df[data_ob.time_col_name])
        current_date = int(random()*(current_max-current_min)+current_min)
    #print(current_date)
    coeff_dict, num_pts_dict = pc.predict_lc_coeff(current_date=current_date,
                                                   num_pc_components=num_pc_components, 
                                                   decouple_pc_bands=decouple_pc_bands, 
                                                   decouple_prediction_bands=decouple_prediction_bands, 
                                                   min_flux_threshold=min_flux_threshold, 
                                                   bands=bands)
    #print(coeff_dict)
    object_type_num = data_ob.get_object_type_number(object_id)

    if dataset_val == 0:
        object_type = data_ob.get_object_type_for_PLAsTiCC(object_id)
    #if object_type_num == 101: 
    #    arr103.append(object_id)
    #print(object_type_num)

    fig = pc.plot_predicted_bands(all_band_coeff_dict=coeff_dict, 
                                      color_band_dict=color_band_dict, 
                                      mark_maximum=mark_maximum, 
                                      axes_lims = True)
    fig.gca().annotate("Type: "+ object_type, xy=(.09,.86),xycoords='figure fraction',fontsize=15)
    plt.show()
    plt.close('all')