In [None]:
import numpy as np
from dataframe import Data
from SNANA_FITS_to_pd import read_fits
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib 
from io_utils import *
from random import random
%load_ext memory_profiler

In [None]:
dataset_val = 0

In [None]:
if dataset_val == 0:
    dataset =  "PLAsTiCC"
    data_ob = load_PLAsTiCC_data()
    object_ids = data_ob.get_all_object_ids()
    kilonova_ids=data_ob.get_ids_of_event_type(64)
    color_band_dict = {0:'C4',1:'C2', 2:'C3', 3:'C1', 4:'k', 5:'C5'}

In [None]:
if dataset_val == 1:    
    dataset = "ZTF"
    data_ob = load_ztf_data()
    object_ids = data_ob.get_all_object_ids()
    kilonova_ids = object_ids
    color_band_dict = {b'g ':'C2', b'r ':'C3'}

In [None]:
from LightCurve import LightCurve

In [None]:
def get_PCs(num_components, all_bands= False):
    
    if all_bands == True:
        PC_dict = np.load("principal_components/PC_all_bands_diff_mid_pt_dict.npy")
        PC_dict = PC_dict.item()
        PC_out = {}
        #num_components = int(num_components)
        #print(PC_dict['u'])
        PC_out[0] = PC_dict['u'][0:3]
        PC_out[1] = PC_dict['r'][0:3]
        PC_out[2] = PC_dict['i'][0:3]
        PC_out[3] = PC_dict['g'][0:3]
        PC_out[4] = PC_dict['z'][0:3]
        PC_out[5] = PC_dict['Y'][0:3]
        
    else:
        PC_out={}
        PCs = np.load("principal_components/PCs.npy")
        for band in data_ob.band_map.keys():
            PC_out[band] = PCs[0:num_components]
            
    return PC_out

In [None]:
def calc_prediction(coeff, PCs):
    predict_comb = np.zeros_like(PCs.shape[1])
    for a,b in zip(PCs,coeff): predict_comb=np.add(predict_comb,b*a)
    return predict_comb

In [None]:
def calc_loss(coeff, PCs, light_curve_seg):
    index = light_curve_seg!=0
    y_pred = calc_prediction(coeff,PCs)
    diff = light_curve_seg - y_pred
    neg_index = y_pred<0
    diff = diff[(index)|(neg_index)]
    
    #diff = diff[index]
    
    #error = np.sum(np.square(diff,diff))  np.sum(np.square(regularization_vals-coeff/np.sum(coeff)))*1000
    error = np.sum(np.square(diff,diff))
    return error

In [None]:
def get_mid_pt(event_df, bands, current_date=None, color_band_dict=None):
    
    if current_date is not None:
        date_difference = event_df[data_ob.time_col_name] - current_date
        past_index = (date_difference>=-50) & (date_difference<=0)
        event_df = event_df[past_index]
        #print(event_df)
        band_mid_points = []
        for i,band in enumerate(bands):
            #print(band)
            band_index = event_df[data_ob.band_col_name] == band
            band_df = event_df[band_index]
            #print(band_df)
            if(len(band_df)>0):
                max_index = np.argmax(band_df[data_ob.flux_col_name])
                band_mid_points.append(band_df[data_ob.time_col_name][max_index])
        if len(band_mid_points)>0:
            return np.median(np.array(band_mid_points))
        else:
            return None
    else:
        lc = LightCurve(event_df, time_col_name=data_ob.time_col_name, brightness_col_name=data_ob.flux_col_name, brightness_err_col_name=data_ob.flux_err_col_name,band_col_name=data_ob.band_col_name, band_map=data_ob.band_map)
        priority_regions = lc.find_region_priority()
        priority_region1 = priority_regions[0]
        return np.median(priority_region1)

In [None]:
print(len(object_ids))

def get_binned_time(df):
    return df[data_ob.time_col_name]-df[data_ob.time_col_name]%2

In [None]:
from scipy.optimize import minimize

In [None]:
def get_time_segment(event_df, start_date, end_date, current_date=None):

    start_index = event_df[data_ob.time_col_name] >= start_date
    end_index = event_df[data_ob.time_col_name] <= end_date 
    if current_date is None:
        return event_df[start_index&end_index]
    else:
        past_index = event_df[data_ob.time_col_name] <= current_date 
        return event_df[start_index&end_index&past_index]

In [None]:
def optimize_coeff(band_df, mid_point_date, current_date, PCs, no_of_predicted_days=51, time_step = 2):
    
    if(len(band_df)>0): 

        start_date=mid_point_date-(no_of_predicted_days-1)*time_step/2 
        end_date= mid_point_date+(no_of_predicted_days-1)*time_step/2 
        start_index = band_df[data_ob.time_col_name] >= start_date
        end_index = band_df[data_ob.time_col_name] <= end_date 
        
        past_index = band_df[data_ob.time_col_name] <= current_date 
        fit_df = band_df[start_index&end_index&past_index]
        
        if len(fit_df)>0:
            
            binned_dates = get_binned_time(fit_df)
            b2 = (binned_dates-mid_point_date+no_of_predicted_days-1)/2
            b2 = b2.astype(int)
            light_curve_seg = np.zeros((no_of_predicted_days))
            light_curve_seg[b2[:]] = fit_df[data_ob.flux_col_name]
            #initial_guess = np.amax(fit_df[data_ob.flux_col_name])*np.array([.93,.03 ,.025])
            initial_guess = [.93,.03 ,.025]
            result = minimize(calc_loss, initial_guess, args=(PCs, light_curve_seg))
            
            return result.x
        
        
            predicted_lt_curve = calc_prediction(result.x,PCs)
            x_data = np.arange(start_date,end_date+2,2)
            
            return x_data, predicted_lt_curve, result.x
        
    return []
            

In [None]:
def predict_lc_coeff(event_df, PC_dict, current_date= None, no_of_predicted_days = 51, time_step=2, bands=None, make_plot=False):
    coeff_all_band = {}
    
    if bands is None:
        bands = data_ob.band_map.keys()
    mid_point_date = get_mid_pt(event_df, bands, current_date, color_band_dict)
    
    #print(bands)
    #print(mid_point_date)
    if mid_point_date is not None:
        prediction_start_date=mid_point_date-(no_of_predicted_days-1)*time_step/2 
        prediction_end_date= mid_point_date+(no_of_predicted_days-1)*time_step/2 

        event_df = get_time_segment(event_df, prediction_start_date, prediction_end_date, current_date)
        for band in bands:
            band_index = event_df[data_ob.band_col_name] == band
            band_df = event_df[band_index]
            #print(band_df)
            PCs = PC_dict[band]
            if len(band_df)>0:
                
                binned_dates = get_binned_time(band_df)
                b2 = (binned_dates-mid_point_date+no_of_predicted_days-1)/2
                b2 = b2.astype(int)
                light_curve_seg = np.zeros((no_of_predicted_days))
                light_curve_seg[b2[:]] = band_df[data_ob.flux_col_name]
                initial_guess = np.amax(band_df[data_ob.flux_col_name])*np.array([.93,.03 ,.025])
                result = minimize(calc_loss, initial_guess, args=(PCs, light_curve_seg))
                coeff_all_band[band] = result.x
                
            else:
                coeff_all_band[band] = [0, 0, 0]
                
    return coeff_all_band
            

In [None]:
def plot_predicted_bands(all_band_coeff_dict, PC_dict, current_date=None, bands=None, num_buffer_days = None):
    if bands is None: 
        bands = data_ob.band_map.keys()
    mid_point_date = get_mid_pt(event_df, bands, current_date, color_band_dict)
    
    lc = LightCurve(event_df, time_col_name=data_ob.time_col_name, brightness_col_name=data_ob.flux_col_name, brightness_err_col_name=data_ob.flux_err_col_name,band_col_name=data_ob.band_col_name, band_map=data_ob.band_map)
    fig = lc.plot_light_curve(color_band_dict=color_band_dict, alpha=0.3, mark_maximum = False, mark_label= False, plot_points = True)
    
    if mid_point_date is not None:
        
        for band, coeff in all_band_coeff_dict.items():
            
            if current_date is None:
                end_date = mid_point_date +50
            else:
                end_date = current_date

            fig = lc.plot_light_curve(color_band_dict, fig = fig, start_date= mid_point_date -50, end_date=end_date, band = band, alpha=1, mark_maximum = False, plot_points = True)

            if len(coeff)!=0:
                predicted_lc= calc_prediction(coeff,PC_dict[band])
                #plt.plot(x_data, predicted_lc, color = color_band_dict[band])
                time_data= np.arange(0,102,2) + mid_point_date - 50
            else: 
                predicted_lc=[]
                time_data=[]
                
            

            plt.plot(time_data, predicted_lc, color = color_band_dict[band])

            
        
            
        if num_buffer_days is not None: 
            plt.xlim([mid_point_date-50-num_buffer_days, mid_point_date+50+num_buffer_days])
            
            
        _, _, ymin, ymax = plt.axis()    
        plt.plot([mid_point_date,mid_point_date],[ymin/2,ymax/2],color = "slateblue", ls="dashed", label="median of max dates")
    
    if current_date is not None:
        _, _, ymin, ymax = plt.axis() 
        plt.plot([current_date,current_date],[ymin/2,ymax/2],color= "darkorange", ls="dashed", label="current date")
        
    plt.xlabel("mjd", fontsize=20)
    plt.ylabel("flux", fontsize=20)
    

In [None]:
import time

In [None]:
print(len(object_ids))

In [None]:
def add_to_coeff_arr(coeff_arr, coeff_dict):
    all_coeff = np.array([])
    for key, value in coeff_dict.items():
        if all_coeff.size==0:
            all_coeff = np.asarray(value)
        else:
            all_coeff= np.concatenate((all_coeff, value), axis = 0)
    #print(all_coeff)
        
    if all_coeff.size!=0:
        coeff_arr.append(all_coeff)
        event_type.append(data_ob.get_object_type_number(object_id))
    return coeff_arr

In [None]:
def plot_coeff_dict(fig, coeff_dict, object_id):
    ax = fig.gca()
    for key, value in coeff_dict.items():
        if data_ob.get_object_type_number(object_id) == 64:
            color = "black"
        else: 
            color = "yellow"
        ax.scatter(value[0], value[1], value[2], color = color,alpha=.4)
        
    return fig
        

In [None]:
PC_dict = get_PCs(3,all_bands=False)
start = time.time()
coeff_arr = []
event_type = []
#object_ids = data_ob.get_ids_of_event_type(6)
fig = plt.figure()
ax = fig.gca(projection='3d')
for object_id in object_ids:
    
    event_df = data_ob.get_data_of_event(object_id)
    mid_point_date = get_mid_pt(event_df, data_ob.band_map.keys())
    #print(mid_point_date)
    current_date = mid_point_date+ random()*88-44
    #print(current_date)
    coeff_dict = predict_lc_coeff(event_df, PC_dict, current_date=current_date)
    
    fig = plot_coeffs(fig, coeff_dict, object_id)
    #print(coeff_dict)
    coeff_arr = add_to_coeff_arr(coeff_arr, coeff_dict)
    #print(coeff_dict)
    #fig = plot_predicted_bands(coeff_dict, PC_dict, num_buffer_days=25, current_date=current_date)
    #ax=plt.gca()
    #plt.text(.01,.94,data_ob.object_id_col_name+": "+str(object_id),fontsize=15, transform=ax.transAxes)
    #if dataset_val == 0:
        #print(data_ob.get_object_type_for_PLAsTiCC(object_id))
    #    plt.text(.01,.88,"Type: "+data_ob.get_object_type_for_PLAsTiCC(object_id),fontsize=15, transform=ax.transAxes)
    #plt.legend(loc="upper right")
    
    #coeff_dict = predict_alert_light_curves(current_date,event_df, PC_dict)
    #print(data_ob.get_object_type(object_id))
    #plt.savefig("kilonova_curves/fit_lc_with_PCs/target_"+str(data_ob.get_object_type_number(object_id))+"/object_"+str(object_id)+"highest_priority_random_curr_date")
    #plt.show()
    #plt.close('all')

for object_id in kilonova_ids:
    
    event_df = data_ob.get_data_of_event(object_id)
    mid_point_date = get_mid_pt(event_df, data_ob.band_map.keys())
    #print(mid_point_date)
    current_date = mid_point_date+ random()*88-44
    #print(current_date)
    coeff_dict = predict_lc_coeff(event_df, PC_dict, current_date=current_date)
    
    fig = plot_coeffs(fig, coeff_dict, object_id)
    #print(coeff_dict)
    coeff_arr = add_to_coeff_arr(coeff_arr, coeff_dict)

plt.show()   
#end = time.time()
#print(end - start)

In [None]:
np.save("coeff_arr", coeff_arr)

In [None]:
for element in coeff_arr:
    for k in range(6):
        coeffs = element[k:k+3]
        

In [None]:
def plot_coeff_arr(fig, coeff_arr, o object_id):
    ax = fig.gca()
    for key, value in coeff_dict.items():
        if data_ob.get_object_type_number(object_id) == 64:
            color = "black"
        else: 
            color = "yellow"
        ax.scatter(value[0], value[1], value[2], color = color,alpha=.4)
        
    return fig

In [None]:
PC_dict = get_PCs(3,all_bands=False)
start = time.time()
coeff_arr = []
event_type = []
#object_ids = data_ob.get_ids_of_event_type(6)
fig = plt.figure()
ax = fig.gca(projection='3d')
for object_id in kilonova_ids:
    
    event_df = data_ob.get_data_of_event(object_id)
    mid_point_date = get_mid_pt(event_df, data_ob.band_map.keys())
    #print(mid_point_date)
    current_date = mid_point_date+ random()*88-44
    #print(current_date)
    coeff_dict = predict_lc_coeff(event_df, PC_dict, current_date=current_date)
    
    fig = plot_coeffs(fig, coeff_dict, object_id)
    #print(coeff_dict)
    coeff_arr = add_to_coeff_arr(coeff_arr, coeff_dict)

plt.show()   

In [None]:
coeff_arr = np.asarray(coeff_arr)

In [None]:
coeff_arr[2]

In [None]:
from sklearn.decomposition import PCA

In [None]:
pca = PCA(n_components=2)
pca.fit(coeff_arr.T)
PC_coeff = pca.components_

In [None]:
PC_coeff = PC_coeff.T

In [None]:
PC_coeff.shape

In [None]:
np.save("PC_")

In [None]:
from mpl_toolkits import mplot3d

In [None]:
fig = plt.figure()
#ax = plt.axes(projection='3d')
for object_type in np.unique(data_ob.df_metadata[data_ob.target_col_name]):
    index_current_type = np.where(np.asarray(event_type)==object_type)
    current_type_all_coeffs = PC_coeff[index_current_type]
    if object_type==64:
        color = "black"
    else:
        color = "yellow"
    for current_type_coeffs in current_type_all_coeffs:
        if object_type!=64:
            plt.scatter(current_type_coeffs[0],current_type_coeffs[1], color= color,alpha=.7)
        
index_current_type = np.where(np.asarray(event_type)==64)
current_type_all_coeffs = PC_coeff[index_current_type]
for current_type_coeffs in current_type_all_coeffs:
    #plt.scatter(current_type_coeffs[0],current_type_coeffs[1], color= color,alpha=.2)
    plt.scatter(current_type_coeffs[0],current_type_coeffs[1], color="black",alpha=.7)
    
plt.xlim((-0.005,.005))
plt.ylim(-.005,0.015)
plt.show()

In [None]:
index_current_type = np.where(np.asarray(event_type)==64)
current_type_all_coeffs = PC_coeff[index_current_type]
print(current_type_all_coeffs)
for current_type_coeffs in current_type_all_coeffs:
    #plt.scatter(current_type_coeffs[0],current_type_coeffs[1], color= color,alpha=.2)
    plt.scatter(current_type_coeffs[0],current_type_coeffs[1], color="black",alpha=.5)
plt.show()

In [None]:
print(index_current_type)

In [None]:
print(np.where(np.asarray(event_type)==64))

In [None]:
event_type

In [None]:
len(event_type)