In [None]:
import numpy as np
from astropy.table import Table
import matplotlib.pyplot as plt

In [None]:
def extract_kilonova_ids(file_path="/media/biswajit/drive/PLAsTiCC_data/training_set_metadata.csv"):
    
    df_meta_data = Table.read(file_path,delimiter=",")
    
    kilonova_index= df_meta_data['target']==64
    return np.array(df_meta_data[kilonova_index]['object_id'])

In [None]:
kilonova_ids=extract_kilonova_ids()
print(kilonova_ids)

In [None]:
def load_data(file_path="/media/biswajit/drive/PLAsTiCC_data/training_set.csv"):
    table = Table.read(file_path)
    return table

In [None]:
df = load_data()
print(df)

In [None]:
pass_bands=np.unique(df['passband'])
pass_band_dict = {0:'C1' , 1:'C2' , 2:'C3' , 3:'C4' , 4:'k' , 5:'C5'}

In [None]:
from statistics import median


def get_max_flux_dates(df, band_num=None):
    
    '''
    retrurns max flux dates and points
    '''
    
    pass_bands_nos=np.unique(df['passband'])
    
    if band_num == None:
        
        max_flux_dates = []
        max_flux_points = {}
        for band in pass_bands_nos:
            ind = df['passband'] == band
            current_band_data = df[ind]
            current_max_index = np.argmax(current_band_data['flux'])
            current_max_date = current_band_data['mjd'][current_max_index]
            
            max_flux_dates.append(current_max_date)
            max_flux_points[band]= [current_max_date,current_band_data['flux'][current_max_index]]
            
        
    
    else:
        ind = df['passband'] == band_num
        current_band_data = df[ind]
        current_max_index = np.argmax(current_band_data['flux'])
        current_max_date = current_band_data['mjd'][current_max_index]

        max_flux_dates.append(current_max_date)
        max_flux_points[band_num]= [current_max_date,current_band_data['flux'][current_max_index]]
        
    return max_flux_points,max_flux_dates

        

In [None]:
def plot_light_curve(df,band_num = None, start_date=None, end_date=None, max_flux_points=None,_pbnames = ['u','g','r','i','z','y'], pass_band_dict = {0:'C1' , 1:'C2' , 2:'C3' , 3:'C4' , 4:'k' , 5:'C5'}): 
    
    fig = plt.figure(figsize=(15,15))
    ax = fig.add_subplot(1,1,1)
    pass_band_nos=np.unique(df['passband'])
    
    if start_date ==None:
        start_date = amin(df['mjd'])
    if end_date == None:
        end_date = amax(df['mjd'])
    
    if band_num!=None:
        
        if (sum(pass_band_nos==band_num)==1):
            
            
            
            band_index=df['passband']==band_num
            start_index=df['mjd']>=start_date
            end_index=df['mjd']<=end_date
            index = band_index*start_index*end_index
            
            if sum(index) <= 0:
                print("the band requested has no data points in the given date range")
            
            df_plot_data = df[index]
            
            ax.errorbar(df_plot_data['mjd'],df_plot_data['flux'],df_plot_data['flux_err'], color=pass_band_dict[band],label = pbname)
            ax.plot([start_date,end_date],[0,0],label='y=0')
                
            if max_flux_points !=None:
                if band_num in max_flux_points.keys():
                    print("could not find the band number "+str(band_num)+" in max_flux_points ")
                    
                else:
                    
                    ax.plot(max_flux_points[band_num][0],max_flux_points[band_num][1],color=pass_band_dict[band_num],marker='o',markersize=10)
                

            ax.legend()

            
        else:
            print("the band requested is not present")
            
        
    else:
            

        data_points_found = 1
        for band in pass_band_nos:
            
            pbname = _pbnames[band]

            band_index=df['passband']==band
            start_index=df['mjd']>=start_date
            end_index=df['mjd']<=end_date

            index = band_index*start_index*end_index
            
            #print(sum(index))

            if sum(index) > 0:

                data_points_found=1

                df_plot_data = df[index]
                ax.errorbar(df_plot_data['mjd'],df_plot_data['flux'],df_plot_data['flux_err'], color=pass_band_dict[band],label = pbname)
                 
            if max_flux_points !=None:
                if max_flux_points !=None:
                    if band_num in max_flux_points.keys():
                        print("could not find the band number "+str(band_num)+" in max_flux_points")
                    
                    else:

                        ax.plot(max_flux_points[band][0],max_flux_points[band][1],color=pass_band_dict[band],marker='o',markersize=10)
                    
                    
                    
                    

        if data_points_found == 0:
            print("There are no data points in the given date range")
            

        min_date = np.amin(df['mjd'])
        max_date = np.amax(df['mjd'])

        ax.plot([start_date,end_date],[0,0],label='y=0')
        ax.set_xlim([start_date,end_date])
        ax.legend()
        
    ax.remove()
    ax.set_xlabel("mjd",fontsize=20)
    ax.set_ylabel("flux",fontsize=20)
    #fig.close()

            
    return ax
            
    

In [None]:
def plot_max_flux_region(df,max_flux_dates,total_days_range=100, max_flux_points=None, priority =None,pass_band_dict = {0:'C1' , 1:'C2' , 2:'C3' , 3:'C4' , 4:'k' , 5:'C5'}):
    
    if priority !=None:
        if priority<=0:
            raise valueError("Error in priority value, priority number must be greater than 1")

        
    fig = plt.figure(figsize=(16,16))
    
    for i,ranges in enumerate(max_flux_dates):
            
        mid_pt = median(ranges)
        #print(mid_pt)
        start_date = mid_pt - total_days_range/2
        end_date = mid_pt + total_days_range/2    

        if priority==None:
            ax = plot_light_curve(df,start_date=start_date,end_date=end_date, max_flux_points=max_flux_points)

            ax.figure = fig
            fig.axes.append(ax)
            fig.add_axes(ax)

            for j in range(i):
                fig.axes[j].change_geometry(i+1, 1, j+1)

            dummy = fig.add_subplot(i+1,1,i+1)
            ax.set_position(dummy.get_position())
            dummy.remove()
            
            
        else:
            if (i<priority)|(len(ranges)==len(max_flux_dates[i-1])):
                ax = plot_light_curve(df,start_date=start_date,end_date=end_date, max_flux_points=max_flux_points)

                ax.figure = fig
                fig.axes.append(ax)
                fig.add_axes(ax)

                for j in range(i):
                    fig.axes[j].change_geometry(i+1, 1, j+1)

                dummy = fig.add_subplot(i+1,1,i+1)
                ax.set_position(dummy.get_position())
                dummy.remove()

                
                
                #print(ranges)
                
                    
            else:
                break
                    
    return fig

In [None]:
import copy
def find_region_priority(max_flux_dates,total_days_range=100):
    #print(max_flux_dates_copy)
    max_flux_dates_copy= copy.copy(max_flux_dates)
    max_flux_dates_copy.sort()
    probable_regions=[[]]
    
    for date in max_flux_dates_copy:
        
        if len(probable_regions[0]) == 0:
            probable_regions[0].append(date)
            
        else:
            region_flag=0
            for region in probable_regions:

                modified_region = copy.copy(region)
                modified_region.append(date)
                
                new_median = median(modified_region)
                #print(region)

                for region_date in region:
                    
                    if ((date-region_date)<=14)|((date-new_median)<=total_days_range/2):
                        #print(1)
                        region.append(date)
                        region_flag =1
                        break

            if (region_flag!=1):
                probable_regions.append([date])                
    
    def myfunc(e):
        return len(e)
                                                 
    probable_regions.sort(reverse=True,key=myfunc)                                

    return probable_regions
    

In [None]:
for i,obj_id in enumerate(kilonova_ids):
    index = df['object_id'] == obj_id
    #print(obj_id)
    #plot_data(df[index])
    max_flux_points,max_flux_dates = get_max_flux_dates(df[index])
    probable_regions=find_region_priority(max_flux_dates=max_flux_dates,total_days_range=100)
    #print(probable_regions)
    #print(max_flux_points)
    fig= plot_max_flux_region(df[index],probable_regions,max_flux_points=max_flux_points,priority=1)
    fig.savefig("./kilonova_curves/kilonova_segments/train"+str(obj_id))
    #plt.show()
    plt.close('all')