In [None]:
import numpy as np
import os, sys
import math
from scipy import ndimage, misc
from astropy.coordinates import Angle
from astropy.io import fits
import matplotlib.pyplot as plt
from scipy import linalg
import csv
import pickle
from matplotlib.colors import LogNorm
from matplotlib import colors
%matplotlib inline
power = 3
power_dir = "output_neid_test_cluster_form2"
#power_dir = "output_poly3_test"

In [None]:
from AlgorithmDev import OrderTrace

In [None]:
# input: spectral fits is from dropbox KPF-Pipeline-TestData/order_trace_test
spectral_fits='../test_data/order_trace_test/DATA/paras.flatA.fits'

# output
clusters_collection = '../test_data/order_trace_test/'+power_dir+'/clusters_all_y_collection.pkl'
cluster_xy_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_xy.fits'
cluster_clean_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_clean.fits'
cluster_info_clean_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_info_clean.fits'
cluster_after_removal_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_after_removal.fits'
cluster_info_after_removal_fits =  '../test_data/order_trace_test/'+power_dir+'/cluster_info_after_removal.fits'
cluster_border_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_border.fits'
cluster_info_border_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_info_border.fits'
cluster_merge_fitting = '../test_data/order_trace_test/'+power_dir+'/cluster_merge_fitting.fits'
cluster_info_merge_fitting = '../test_data/order_trace_test/'+power_dir+'/cluster_info_merging_fitting.fits'
result_csv = '../test_data/order_trace_test/'+power_dir+'/result_cluster/result_cluster'
result_poly_width_csv = '../test_data/order_trace_test/'+power_dir+'/output/paras_result_poly_2sigma_gaussian_pixel_'+str(power)+'.csv'

In [None]:
# input: spectral fits is from dropbox KPF-Pipeline-TestData/NEIData/FLAT
spectral_fits='../test_data/order_trace_test/DATA/stacked_2fiber_flat.fits'

# output
clusters_collection = '../test_data/order_trace_test/'+power_dir+'/clusters_all_y_collection.pkl'
remove_vertical_fits = '../test_data/order_trace_test/'+power_dir+'/data_remove_vertical.fits'
cluster_xy_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_xy.fits'
cluster_clean_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_clean.fits'
cluster_info_clean_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_info_clean.fits'
cluster_after_removal_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_after_removal.fits'
cluster_info_after_removal_fits =  '../test_data/order_trace_test/'+power_dir+'/cluster_info_after_removal.fits'
cluster_border_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_border.fits'
cluster_info_border_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_info_border.fits'
cluster_merge_fitting = '../test_data/order_trace_test/'+power_dir+'/cluster_merge_fitting.fits'
cluster_info_merge_fitting = '../test_data/order_trace_test/'+power_dir+'/cluster_info_merging_fitting.fits'
result_csv = '../test_data/order_trace_test/'+power_dir+'/result_cluster/result_cluster'
result_poly_width_csv = '../test_data/order_trace_test/'+power_dir+'/output/neid_result_poly_2sigma_gaussian_pixel_0221_'+str(power)+'.csv'

In [None]:
# plot image from fits by setting area xmin, xmax, ymin, ymax
def plot_img(img, ymin, ymax, p_w=20, p_h=20, xmin=None, xmax=None, title="", aspect=None):
    plt.figure(figsize=(p_w, p_h), frameon=False)
    plt.subplot(1, 1, 1)
    if xmin is None:
        xmin = 0
    if xmax is None:
        h, w = np.shape(img)
        xmax = w-1
    s_img = img[:, :]
    im = plt.imshow(s_img, cmap='gray', norm=LogNorm())
    plt.ylim(ymin, ymax)
    plt.xlim(xmin, xmax)
    plt.title(title)
    if aspect is not None:
        plt.axes().set_aspect(aspect)
    plt.show()

In [None]:
# plot image from fits by setting area xmin, xmax, ymin, ymax on top of another image
# the top image could be image of set of clusters by using make_2D_data_range 
# background image in grey, and top image in red with 0.5 alpha in default. aspect is settable
def plot_img_on_original(img, o_img, ymin, ymax, xmin=None, xmax=None, p_w=20, p_h=20, title="", alpha = 0.5, aspect=None):
    plt.figure(figsize=(p_w, p_h), frameon=False)
    plt.subplot(1, 1, 1)
    if xmin is None:
        xmin = 0
    if xmax is None:
        h, w = np.shape(img)
        xmax = w-1

    cmap = colors.ListedColormap(['white', 'red'])    
    im_b = plt.imshow(o_img, cmap='gray', norm=LogNorm())
    im_t = plt.imshow(img, cmap=cmap, alpha=alpha) 
    plt.ylim(ymin, ymax)
    plt.xlim(xmin, xmax)
    plt.title(title)
    if aspect is not None:
        plt.axes().set_aspect(aspect)
    plt.show()

In [None]:
def make_fits(data, output_fits):
    hdu = fits.PrimaryHDU(data)
    hdu.writeto(output_fits, overwrite=True)

In [None]:
# show cluster width and height histogram information on selected clusters or all clusters
def show_cluster_info(index, x, y, nx, ny, selected_clusters=None):
    widths = list()
    heights = list()
    good_widths = list()
    
    widths.append(0)
    heights.append(0)
    
    #print('selected_clusters: ', selected_clusters)
    #import pdb;pdb.set_trace() 
    
    for sc in selected_clusters:
        idx = np.where(index==sc)[0]
        sel_x = x[idx]
        sel_y = y[idx]
        x1 = np.amin(sel_x)
        x2 = np.amax(sel_x)
        y1 = np.amin(sel_y)
        y2 = np.amax(sel_y)
        w = x2 - x1 + 1
        h = y2 - y1 + 1
        if sc%100 == 0:
            print('cluster: ', sc, ' total pixels: ', np.size(sel_x), ' x1, x2, y1, y2: ',
                 x1, x2, y1, y2, ' w, h: ', w, h)   
        widths.append(w)
        heights.append(h)
        if w >= nx/4:
            good_widths.append(sc)
        
    #import pdb;pdb.set_trace() 
    h_stats = order_t.find_cluster_stats_from_histogram(np.array(heights), 10)
    w_stats = order_t.find_cluster_stats_from_histogram(np.array(widths), 10)

    for h_s in h_stats:
        print('height stats: ', h_s)
    print('\n')    
    for w_s in w_stats:    
        print('widthw stats: ', w_s)
    #print('good_widths: ', good_widths)
    
    return h_stats, w_stats

In [None]:
# show cluster size information for each cluster from selected cluster
def show_cluster_size(index, x, y, nx, ny, selected_clusters=None):
    if selected_clusters is None:
        selected_cluseters = np.arange(1, np.amax(index)+1, dtype=int)
    for i in selected_clusters:
        cluster_idx = np.where(index == i)[0]
        x1 = np.amin(x[cluster_idx])
        x2 = np.amax(x[cluster_idx])
        y1 = np.amin(y[cluster_idx])
        y2 = np.amax(y[cluster_idx])
        #import pdb;pdb.set_trace()
        
        print('cluster ', i, ' total: ', np.size(cluster_idx), " width: ", (x2-x1+1), " height: ", (y2-y1+1))
           

In [None]:
# make image data in 2D based on selected clusters
def make_2D_data(index, x, y, nx, ny, selected_clusters=None):
    imm = np.zeros((ny, nx), dtype=np.uint8)
    if selected_clusters is None:
        ymin = 0
        ymax = ny-1
    else:
        sel = np.where(np.isin(index, selected_clusters))[0]
        ymin = np.amin(y[sel])
        ymax = np.amax(y[sel])
                       
    for cy in range(ny):
        if cy < ymin: 
            continue
        elif cy > ymax:
            break;
        #print(cy,' ', end='')    
        y_cond = np.where(y==cy)[0]
        if selected_clusters is None:
            nz_idx_at_cy = y_cond[np.where(index[y_cond] != 0)[0]]
        else:
            nz_idx_at_cy = y_cond[np.where(np.isin(index[y_cond], selected_clusters))[0]]
        imm[cy, x[nz_idx_at_cy]] = 1
    print()    

    return imm

In [None]:
# make image data in 2D based on selected clusters and return image and range
def make_2D_data_range(index, x, y, nx, ny, selected_clusters=None):
    imm = np.zeros((ny, nx), dtype=np.uint8)
    xmin = nx
    xmax = 0

    if selected_clusters is None:
        ymin = 0
        ymax = ny-1
    else:
        sel = np.where(np.isin(index, selected_clusters))[0]
        ymin = np.amin(y[sel])
        ymax = np.amax(y[sel])
        xmin = np.amin(x[sel])
        xmax = np.amax(x[sel])
                       
    print('y: ', ymin, ymax)    
    for cy in range(ny):
        if cy < ymin: 
            continue
        elif cy > ymax:
            break;

        y_cond = np.where(y==cy)[0]
        if selected_clusters is None:
            nz_idx_at_cy = y_cond[np.where(index[y_cond] != 0)[0]]
        else:
            nz_idx_at_cy = y_cond[np.where(np.isin(index[y_cond], selected_clusters))[0]]
       
        imm[cy, x[nz_idx_at_cy]] = 1
    print()    

    return imm, xmin, xmax, ymin, ymax

In [None]:
# make fits on 2D of all clusters
def make_cluster_fits(index, x, y, nx, ny, fits_path):
    imm = make_2D_data(index, x, y,  nx, ny)
    make_fits(imm, fits_path)
    ind_max = np.amax(index)
    print('there are '+str(ind_max)+' clusters in total in fits, '+fits_path)
    return imm

In [None]:
# make fits on cluster info (index, x, y)
def make_cluster_info_fits(index, x, y, cluster_info_filepath):
    cluster_data = np.zeros((3, index.size))
    cluster_data[0, :] = index
    cluster_data[1, :] = x
    cluster_data[2, :] = y
    make_fits(cluster_data, cluster_info_filepath)

In [None]:
# opt filter by given y data of a column 
def opt_filter(y_data, par=20, weight=None):      
    n = y_data.size
    #import pdb; pdb.set_trace()
    # 1D array
    if y_data.ndim != 1:
        print('opt_filter handles one dimensional y data only')
        return y_data

    if par < 0:
        return y_data

    wgt = np.reshape(weight, (1, -1)) if weight is not None else np.ones((1, n), dtype=np.float64)[0]

    r = y_data*wgt
    
    # resolve banded matrix by combining a, b, c, abc*f = r
    a = np.ones((1, n), dtype=np.float64)[0] * (-abs(par))
    b = np.hstack([[wgt[0]+abs(par)], wgt[1:n-1]+2.0*abs(par), [wgt[n-1]+abs(par)]])
    c = a.copy()
    a[0] = c[-1] = 0
    B = np.zeros((n, n))
    for i in range(n):
        B[i, i] = b[i]
        if i != n-1:
            B[i+1, i] = c[i]
            B[i, i+1] = a[i+1]
            
    BINV = linalg.inv(B)
    f = linalg.solve_banded((1, 1), np.vstack([a, b, c]), r)
    return f, B, BINV

In [None]:
def blockprint():
    sys.stdout = open(os.devnull, 'w')

def enable_print():
    sys.stdout = sys.__stdout__

In [None]:
# plot polynomial fitting curve on top of given 2D image
# the cluster orders is settable by order_set
def plot_poly_trace(imm, total_order, coeffs_orders, max_x, max_y, size=20, order_set=None, \
                    title=None, background=False, widths=None, aspect=None, xmin=None, xmax=None, ymin=None, ymax=None):
    plt.figure(figsize=(size,size))
    plt.subplot(1, 1, 1)
    im = plt.imshow(imm, cmap='gray', norm=LogNorm())
    
    if order_set is None:
        orders = list(range(1, total_order+1))
    else:
        orders = order_set
        
    x_dist = max_x//20    
                   
    for o_idx, order in enumerate(orders):
        if (background is not False):
            x_val = np.arange(0, max_x)
            # y value on x range
            y_val = np.polyval(coeffs_orders[order, 0:power+1], x_val)
            plt.plot(x_val, y_val, 'b--')
        #print("x range: ", coeffs_orders[order, power+1], coeffs_orders[order, power+2])    
        # x range
        x_val = np.arange(coeffs_orders[order, power+1], coeffs_orders[order, power+2]+1)
        # y value on x range
        y_val = np.polyval(coeffs_orders[order, 0:power+1], x_val)
        plt.plot(x_val, y_val, 'r--')
        
        if widths is not None:
            y_val_bottom = y_val-widths[o_idx][0]
            plt.plot(x_val, y_val_bottom, 'g--')
            y_val_top = y_val+widths[o_idx][1]
            plt.plot(x_val, y_val_top, 'g--')
        
        # show number of cluster
        s = ((order%15)+1)*x_dist
        if s >= x_val.size:
            dem = int((coeffs_orders[order, power+2] - coeffs_orders[order, power+1])//5)
            s = dem*((order%4)+1)
            #s = x_val.size//2
        plt.text(x_val[s], y_val[s], str(order), fontsize=12, color='b', fontweight='bold', horizontalalignment='center', verticalalignment='center')
        
    if title is not None:
        plt.title(title, fontsize=12)
    x1 = 0 if xmin is None else xmin
    x2 = max_x if xmax is None else xmax
    y1 = 0 if ymin is None else ymin
    y2 = max_y if ymax is None else ymax
    
    plt.ylim(y1, y2)
    plt.xlim(x1, x2)
    if aspect is not None:
        plt.axes().set_aspect(aspect)
    
    plt.show()
    #plt.colorbar(im, fraction=0.046, pad=0.04)

In [None]:
# plot y value from min_y to max_y at each x location from min_x to max_x
def plot_data_at_x(data, min_x, max_x, min_y, max_y, cluster_points, size=20, title=None, cluster_no=None, w1=None, w2=None, slope_set=None, slope=None, slope_double=None):
    plt.figure(figsize=(size,size))
    display_color = ['r', 'b', 'k', 'g', 'y', 'm', 'c']
    size = np.shape(data)
    
    # min_y to max_y set to be value on x axis. 
    plt_x_val = np.arange(0, size[0]) + min_y   
    
    #print(cluster_no)
    #import pdb;pdb.set_trace()
    for x in range(min_x, max_x+1):      
        d_idx = (x - min_x)%(len(display_color))
        plt_y_val = data[:, x]
        plt.plot(plt_x_val, plt_y_val, lineStyle='--', marker='o', color=display_color[d_idx])
        cluster_data = cluster_points[:, x]
        cluster_data = cluster_data[np.where(np.logical_and(cluster_data <= max_y, cluster_data >= min_y))[0]]
        
        # put dots on peak and the width around the peak of the cluster with cluster_no
        if cluster_data.size > 0:
            cluster_data_y = np.zeros(cluster_data.size)                    
            plt.scatter(cluster_data, cluster_data_y, c=display_color[d_idx][0])   
            if cluster_no != None and w1!= None and w2 != None:
                cluster_y = cluster_points[cluster_no][x]
                dots = np.zeros(1)
                plt.scatter(np.array([cluster_y]), dots, c='k', alpha=0.5)
                cluster_x1 = cluster_y - w1
                cluster_x2 = cluster_y + w2
                dots = np.zeros(2)
                plt.scatter(np.array([cluster_x1, cluster_x2]), dots, c='b', alpha=0.5)
                
            
        if slope is not None:
            slope_at_x = slope[:, x - min_x]
            plt.plot(plt_x_val, slope_at_x, 'bo')
        if slope_double is not None:
            slope_2nd_at_x = slope_double[:, x - min_x]
            plt.plot(plt_x_val, slope_2nd_at_x, 'g-')
        if slope_set is not None:
            for i in range(len(slope_set)):
                #import pdb;pdb.set_trace()
                y_val = None
                if 'coeffs' in slope_set[i]:
                    slope_coeffs = slope_set[i]['coeffs']
                    bound = slope_set[i]['bound']
                    x_set = np.arange(bound[0], bound[1]+1)
                    y_val = np.polyval(slope_coeffs, x_set)
                elif 'gaussian' in slope_set[i] and slope_set[i]['gaussian'] is not None:
                    x_set = slope_set[i]['x_set']
                    y_val = slope_set[i]['gaussian'](x_set)
                    
                if y_val is not None:    
                    plt.plot(x_set, y_val, 'k-')
                #plt.plot([x_set[0], x_set[0]], [0, y_val[0]], 'k--')
    
    if title is not None:
        plt.title(title, fontsize=14)
    plt.xlim(min_y, max_y)    
    plt.show()
                     

In [None]:
# show curve fitting on given x & y data by using polynomial
def curve_fitting_plot(x_data, y_data, pow = 3, size=12, extend=1500):
    plt.figure(figsize=(size,size))
    curve_coeffs = np.polyfit(x_data, y_data, pow)
    
    x_min = int(np.amin(x_data)) - extend
    x_max = int(np.amax(x_data)) + extend
    
    plt_x_val = np.arange(x_min, x_max)
    plt_y_val = np.polyval(curve_coeffs, plt_x_val)
    
    
    plt.plot(x_data, y_data, 'r+')
    plt.plot(plt_x_val, plt_y_val, c='b')
    
    #import pdb;pdb.set_trace()
    roots = np.roots(curve_coeffs)
    der_coeffs = np.polyder(curve_coeffs)
    roots_der = np.roots(der_coeffs)
    
    plt.title('order: '+str(pow)+ ' extreme: '+str(roots_der.real),
               fontsize=14)

In [None]:
def to_str(afloat):
    new_str = f"{afloat:.4f}"
    return new_str

In [None]:
# json save and load
def save_obj(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

In [None]:
# show the following set of two plots along the entire x axis range (~ every 2000 pixels for each set on NEID data):
#  1. plot the trace of selected clusters plus the the trace at top width and bottom width location (optional)
#  2. plot the image of selected clusters on top of the original image
def plot_selected_clusters_on_original_img(s_clusters, n_x, n_y, n_index, n_coeffs, show_width=False, alpha=0.5):
    print(s_clusters)
    
    sel_imm, xmin, xmax, ymin, ymax = make_2D_data_range(n_index, n_x, n_y, nx, ny, selected_clusters=s_clusters)
    imm_spec = spe_info['data']
    print('range: ', xmin, xmax, ymin, ymax)
    
    max_index = np.amax(n_index)
    a = 3         # aspect
    y_off = 20    # vertical offset added to the range
    step = (nx+5)//4   # split the display into parts
    
    if show_width is True:
        cluster_points = order_t.get_cluster_points(n_coeffs, power)
        all_widths = order_t.find_all_cluster_widths(n_index, n_x, n_y, \
                                    n_coeffs, cluster_points, power, cluster_set=s_clusters)
        c_widths = [[one_width['avg_pwidth'], one_width['avg_nwidth']] for one_width in all_widths]
        print('widths: ', c_widths)
    else:
        c_widths = None
        
    for x in range(0, nx, step):
        x1 = x
        x2 = min(x+step, xmax)
        
        # print trace on top of original image
        plot_poly_trace(imm_spec, max_index, n_coeffs, nx, ny, widths = c_widths, 
                        order_set=s_clusters, xmin=x1, xmax=x2, ymin=ymin-y_off, ymax=ymax+y_off, aspect=a)
        # plot image of selected clusters on top of orignal image
        plot_img_on_original(sel_imm, imm_spec, ymin-y_off, ymax+y_off, xmin=x1, xmax=x2, \
                             p_w=20, p_h=20, alpha=alpha, aspect=a, title="")

In [None]:
# menu to select cluster and plot the selected clusters
def show_one_cluster_per_plot_menu(index_t, x, y, nx, ny):
    select_idx = input("select a cluster to view: (ex: 10, from 1 to "+ str(ind_max)+ "):\n"+
                    "or a range of clusters to view: (ex: 6-10)\n" + 
                    "or a set of clusters to view: (ex: 1, 2, 3)\n")
    select_idx = select_idx.strip()
    nums_set = select_idx.split(",")
    num_set = list()

    if (len(nums_set) > 1):
        print(nums_set)
        for nstr in nums_set:
            if not nstr.isdigit():
                print(nstr, " is not a number")
            else:
                num_set.append(int(nstr))
    else:
        nums_set = select_idx.split("-")
        print(nums_set)
        if (len(nums_set) == 2):
            if nums_set[0].isdigit() and nums_set[1].isdigit():
                n1 = int(nums_set[0])
                n2 = int(nums_set[1])
                print(n1, n2)
                if n1 > n2:
                    n1, n2 = n2, n1
                num_set = [i for i in range(n1, n2+1)]
            else:
                print(select_idx, " is not valid number range")
        else:
            if not select_idx.isdigit() or int(select_idx) < 0 or int(select_idx) > ind_max:
                print(select_idx, " is not a valid number")
            else:
                num_set = [int(select_idx)]

    print(num_set)        
    num_selected = num_set[0]
    if len(num_set) > 0:
        num_set.sort()
        for num_idx in num_set:
            show_cluster_size(index_t, x, y, nx, ny, [num_idx])
            new_imm = make_2D_data(index_t, x, y, nx, ny, [num_idx]) 
            sel = np.where(np.isin(index_t, [num_idx]))[0]

            ymin = np.amin(y[sel])
            ymax = np.amax(y[sel])
            xmin = np.amin(x[sel])
            xmax = np.amax(x[sel])

            print('cluster: ', num_idx, 'x1, x2, y1, y2: ', xmin, xmax, ymin, ymax)    

            plot_img(new_imm, ymin, ymax, 20, 20, xmin, xmax)
            p_info, errors, area = order_t.extract_order_from_cluster(num_idx, index_t, x, y, power)
            print(p_info, errors)    

In [None]:
def show_clusters_per_plot_menu(index_t, x, y, nx, ny):
    while(True):
        select_idx = input("select clusters to view: all(A) \n" +
                            "exit(E) \n" + 
                            "<one cluster>(cluster number)\n"+
                            "<multiple cluster>(No. 1, No. 2...)\n" + 
                            "<cluster number range>(No. 1 - No. 2): ")
        select_idx = select_idx.strip()

        print(select_idx)

        if select_idx == "A" or select_idx == "a":
            plot_img(imm, 0, ny)
            break
        elif select_idx == "E" or select_idx == "e":
            break
        else:
            nums_set = select_idx.split(",")
            num_set = list()
            if (len(nums_set) > 1):
                for nstr in nums_set:
                    if not nstr.isdigit():
                        print(nstr, " is not a number")
                        num_set = list()
                        break
                    else:
                        num_set.append(int(nstr))
                if len(num_set) > 1:
                    print(num_set)
                    num_set.sort()
            else:
                nums_set = select_idx.split("-")
                if (len(nums_set) == 2):
                    if nums_set[0].isdigit() and nums_set[1].isdigit():
                        n1 = int(nums_set[0])
                        n2 = int(nums_set[1])
                        if n1 > n2:
                            n1, n2 = n2, n1
                        num_set = [i for i in range(n1, n2+1)]
                    else:
                        print(select_idx, " is not valid number range")
                else:
                    if not select_idx.isdigit() or int(select_idx) < 0 or int(select_idx) > ind_max:
                        print(select_idx, " is not a valid number")
                    else:
                        num_set = [int(select_idx)]

            if len(num_set) == 0:
                continue

            #h_stats, w_stats = show_cluster_info(index_t, x, y, nx, ny, num_set)    
            show_cluster_size(index_t, x, y, nx, ny, num_set)
            new_imm = make_2D_data(index_t, x, y, nx, ny, num_set) 
            sel = np.where(np.isin(index_t, num_set))[0]

            ymin = np.amin(y[sel])
            ymax = np.amax(y[sel])
            xmin = np.amin(x[sel])
            xmax = np.amax(x[sel])

            print('ymin: ', ymin, ' ymax: ', ymax, ' xmin: ', xmin, ' xmax: ', xmax)
            #import pdb;pdb.set_trace()

            # close up
            plot_img(new_imm, ymin, ymax, 20, 20, xmin, xmax)
            # full size
            plot_img(new_imm, 0, ny-1, 20, 20)
            break


## Usage: Using OrderTarce to extract order trace from the given spectral fits

In [None]:
order_t = OrderTrace(spectral_fits)
order_t.load_spectral()
cluster_info = order_t.extract_order_trace(power)
#result_poly_width_csv = '../test_data/order_trace_test/'+power_dir+'/result_poly_2sigma_gaussian_peak_'+str(power)+'.csv'
order_t.write_cluster_info_to_csv(cluster_info['widths'], cluster_info['coeffs'], power, result_poly_width_csv)

## Extracting order trace step by step
### excute the cells from step 1 to step 10 and get visual output for each step

## 1. load spectral file

In [None]:
order_t = OrderTrace(spectral_fits)
spe_info = order_t.load_spectral()
print('row: ', spe_info['ny'], ' column: ', spe_info['nx'])

imm_spec = spe_info['data']
nx = spe_info['nx']
ny = spe_info['ny']
plot_img(imm_spec, 0, ny-1)
#plot_img(imm_spec, 2000, 2100, xmin=4500, xmax=4600)

## 2. find cluster pixels  and make fits

In [None]:
r_v = True if 'stacked_2fiber_flat' in spectral_fits else False

cluster_xy = order_t.locate_clusters(remove_vertical = r_v)
order_t.make_fits(cluster_xy['im_map'], cluster_xy_fits)
yy = np.shape(cluster_xy['im_map'])[0]
plot_img(cluster_xy['im_map'], 0, yy-1)

## 3. form clusters, basic cleanning (based on size and total pixel), make fits

In [None]:
#cluster_info, dict
cluster_info = order_t.collect_clusters(cluster_xy['x'], cluster_xy['y'])

In [None]:
clusters_collection = '../test_data/order_trace_test/'+power_dir+'/clusters_all_y_collection.pkl'
save_obj(cluster_info, clusters_collection)   # optional, save cluster info into .pkl file

In [None]:
# (optional) load cluster collection 
cluster_info = load_obj(clusters_collection)   # optional, load cluster info from .pkl file

In [None]:
# assign index value to cluster_info['index'], where cluster_info['index'] is the same size as cluster_xy['x']
cluster_info = order_t.remove_cluster_noise(cluster_info, cluster_xy['x'], cluster_xy['y'])

In [None]:
# remove unassigned index
x, y, index_t = order_t.reorganize_index(cluster_info['index'], cluster_xy['x'], cluster_xy['y'])
nx = spe_info['nx']
ny = spe_info['ny']
imm = order_t.make_2D_data(index_t, x, y)   # show image  and make fits and info fits 
plot_img(imm, 0, np.shape(imm)[0]-1)

## save data from the result of 3.

In [None]:
order_t.make_fits(imm, cluster_clean_fits)

cluster_data = np.zeros((3, index_t.size))
cluster_data[0, :] = index_t
cluster_data[1, :] = x
cluster_data[2, :] = y
order_t.make_fits(cluster_data, cluster_info_clean_fits)

## reload clean fits and info fits of step 3. (optional)

In [None]:
cluster_clean_fits = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_clean.fits'
cluster_info_clean_fits = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_info_clean.fits'

In [None]:
# (optional) reload saved fits 
imm, hdr = fits.getdata(cluster_clean_fits, header=True)
ny, nx = np.shape(imm)
cluster_info, c_hdr = fits.getdata(cluster_info_clean_fits, header=True)
index_t = cluster_info[0].astype(int)
x = cluster_info[1].astype(int)
y = cluster_info[2].astype(int)

plot_img(imm, 0, np.shape(imm)[0]-1)
ind_max = np.amax(index_t)

# 4. Advanced cluster cleaning to remove noisy clusters

In [None]:
index_new, all_status = order_t.advanced_cluster_cleaning_handler(index_t, x, y, power)

# prepare for saving
x_p = x.copy()
y_p = y.copy()

## save advanced cleaning result of step 4 to fits and info fits

In [None]:
# save advanced cleaning result to fits and info fits
#cluster_after_removal_fits = '../test_data/order_trace_test/'+power_dir+'/cluster_after_removal5.fits'
#cluster_info_after_removal_fits =  '../test_data/order_trace_test/'+power_dir+'/cluster_info_after_removal5.fits'

new_imm_after_removal = make_2D_data(index_new, x_p, y_p, nx, ny)
order_t.make_fits(new_imm_after_removal, cluster_after_removal_fits)
new_x, new_y, new_index, convert_map = order_t.reorganize_index(index_new, x_p, y_p, True)
cluster_data = np.zeros((3, new_index.size))
cluster_data[0, :] = new_index
cluster_data[1, :] = new_x
cluster_data[2, :] = new_y
order_t.make_fits(cluster_data, cluster_info_after_removal_fits)

## (optional) load advanced cleaning results, result of step 4

In [None]:
cluster_after_removal_fits = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_after_removal5.fits'
cluster_info_after_removal_fits =  '../test_data_backup/order_trace_test/'+power_dir+'/cluster_info_after_removal5_r.fits'
new_imm_after_removal, hdr = fits.getdata(cluster_after_removal_fits, header=True)
#plot_img(new_imm_after_removal, 0, np.shape(new_imm_after_removal)[0]-1)

cluster_info_tmp, c_hdr = fits.getdata(cluster_info_after_removal_fits, header=True)
new_index = cluster_info_tmp[0].astype(int)
new_x = cluster_info_tmp[1].astype(int)
new_y = cluster_info_tmp[2].astype(int)


## show result and curve fitting

In [None]:
new_coeffs, errors = order_t.curve_fitting_on_all_clusters(new_index, new_x, new_y, power)
max_index = np.amax(new_index)

print(np.shape(new_coeffs))
print(max_index)
plot_poly_trace(new_imm_after_removal, max_index, new_coeffs, nx, ny)

## 5. clean the clusters along the top and bottom border (optional), new_x, new_y, new_index

In [None]:
index_r = new_index.copy()
x = new_x.copy()
y = new_y.copy()

index_b = order_t.clean_clusters_on_border(x, y, index_r, 0)
print(len(index_r))
print(len(index_b))

index_t = order_t.clean_clusters_on_border(x, y, index_b, ny-1)
new_x, new_y, new_index, convert_map = order_t.reorganize_index(index_t, x, y, True)
print(len(index_b))
print(len(new_index))

## saving result of step 5 to fits and info fits

In [None]:
cluster_data = np.zeros((3, new_index.size))
cluster_data[0, :] = new_index
cluster_data[1, :] = new_x
cluster_data[2, :] = new_y
order_t.make_fits(cluster_data, cluster_info_border_fits)

imm = order_t.make_2D_data(new_index, new_x, new_y)
order_t.make_fits(imm, cluster_border_fits)
ind_max = np.amax(new_index)
print('there are '+str(ind_max)+' clusters in total')

yy = np.shape(imm)[0]
plot_img(imm, 0, yy-1)

## (optional) load result of step 5 fits and info fits

In [None]:
imm, hdr = fits.getdata(cluster_border_fits, header=True)
plot_img(imm, 0, np.shape(imm)[0]-1)

cluster_info, c_hdr = fits.getdata(cluster_info_border_fits, header=True)
new_index = cluster_info[0].astype(int)
new_x = cluster_info[1].astype(int)
new_y = cluster_info[2].astype(int)

## prepare data for step 6

In [None]:
# prepare for step 6
new_coeffs, errors = order_t.curve_fitting_on_all_clusters(new_index, new_x, new_y, power)
max_index = np.amax(new_index)

### for debug usage: to find the curve by giving the range 

In [None]:
def find_target_curve(all_curves):
    sel_list = list()
    target_clusters =  [[   0.,  438., 7731., 7799.], [ 448., 1934., 7795., 7967.], [1941., 7174., 7963., 8091.]]

    for i in range(1, len(all_curves)):
        #print(all_curves[i])
        for c in target_clusters:
            if all_curves[i, 0] == c[0] and all_curves[i, 1] == c[1] and \
               all_curves[i, 2] == c[2] and all_curves[i, 3] == c[3]:
                sel_list.append(i)
                break
    return sel_list

## 6. Merging clusters, input parameters:  index_t, x, y, imm

In [None]:
def merge_clusters (index, x, y, power, times=None):
    new_index = index.copy()
    new_x = x.copy()
    new_y = y.copy()
    new_coeffs, errors = order_t.curve_fitting_on_all_clusters(new_index, new_x, new_y, power)
    total = 0
    
    while(True):
        all_location = new_coeffs[:, power+1:power+5]
        total += 1
        print('time: ', total)
        pr = False
                
        n_index, n_x, n_y, n_coeffs, merge_status = order_t.one_step_merge_cluster(new_coeffs, power, \
                                                                new_index, new_x, new_y, print_result=pr)

        log = merge_status['log']
        print("  "+log)
            
        new_index = n_index.copy()
        new_x = n_x.copy()
        new_y = n_y.copy()
        new_coeffs = n_coeffs.copy()
        
        if (times is not None) and (total == times):
            break
        
        if merge_status['status'] == 'nochange':
            break
            
    m_x, m_y, m_index = order_t.reorganize_index(new_index, new_x, new_y)
    m_coeffs, errors = order_t.curve_fitting_on_all_clusters(m_index, m_x, m_y, power)
    return m_x, m_y, m_index, m_coeffs


In [None]:
#m_x, m_y, m_index, m_coeffs = order_t.merge_clusters(new_index, new_x, new_y, power)
m_x, m_y, m_index, m_coeffs = merge_clusters(new_index, new_x, new_y, power)
new_x = m_x.copy()
new_y = m_y.copy()
new_index = m_index.copy()
new_coeffs = m_coeffs.copy()

max_index = np.amax(new_index)
plot_poly_trace(imm_spec, max_index, new_coeffs, nx, ny, title='after merging cluster', size=20)

## save merge results of step 6 to fits and info fits

In [None]:
# (optional) change filename for saving, optional
cluster_merge_fitting = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_merge_fitting_0303_sorted.fits'
cluster_info_merge_fitting = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_info_merging_fitting_0303_sorted.fits'

In [None]:
# store after merge fits and fits info
imm = make_cluster_fits(new_index, new_x, new_y, nx, ny, cluster_merge_fitting)
make_cluster_info_fits(new_index, new_x, new_y, cluster_info_merge_fitting)

In [None]:
cluster_coeffs = new_coeffs
cluster_points = order_t.get_cluster_points(new_coeffs, power)

## (optional) load merge result of step 6 

In [None]:
imm, hdr = fits.getdata(cluster_merge_fitting, header=True)
plot_img(imm, 0, np.shape(imm)[0]-1)

cluster_info, c_hdr = fits.getdata(cluster_info_merge_fitting, header=True)
new_index = cluster_info[0].astype(int)
new_x = cluster_info[1].astype(int)
new_y = cluster_info[2].astype(int)

In [None]:
#sort cluster based on y position (optional)
cluster_merge_fitting = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_merge_fitting_0303_sorted.fits'
cluster_info_merge_fitting = '../test_data_backup/order_trace_test/'+power_dir+'/cluster_info_merging_fitting_0303_sorted.fits'

sorted_index = order_t.sort_cluster_in_y(cluster_coeffs, power)
new_index_sort = np.zeros(np.size(new_index), dtype=int)
for i in range(1, len(sorted_index)):
    idx = np.where(new_index == sorted_index[i])[0]
    new_index_sort[idx] = i
new_index = new_index_sort.copy()

imm = make_cluster_fits(new_index_sort, new_x, new_y, nx, ny, cluster_merge_fitting)
make_cluster_info_fits(new_index_sort, new_x, new_y, cluster_info_merge_fitting)

## plot on result of step 6

In [None]:
new_coeffs, errors = order_t.curve_fitting_on_all_clusters(new_index, new_x, new_y, power)
max_index = np.amax(new_index)    
plot_poly_trace(imm, max_index, new_coeffs, nx, ny)
cluster_points = order_t.get_cluster_points(new_coeffs, power)
cluster_coeffs = new_coeffs

In [None]:
# plot on specific trace 
sel_idx = 1

plot_selected_clusters_on_original_img([sel_idx], new_x, new_y, new_index, new_coeffs, show_width=True, alpha=0.5)

## 7.  remove broken cluster which has big opening in the center (optional)

In [None]:
new_x1, new_y1, new_index1 = order_t.remove_broken_cluster(new_index, new_x, new_y, new_coeffs)
next_coeffs, errors = order_t.curve_fitting_on_all_clusters(new_index1, new_x1, new_y1, power)
imm =  make_2D_data(new_index1, new_x1, new_y1, nx, ny)
max_index = np.amax(new_index1)

plot_poly_trace(imm, max_index, next_coeffs, nx, ny, title='after removing broken cluster', size=20)

In [None]:
new_x = new_x1.copy()
new_y = new_y1.copy()
new_index = new_index1.copy()
cluster_points = order_t.get_cluster_points(next_coeffs, power)
cluster_coeffs = next_coeffs

## (optional)  get data from pre-stored fits

In [None]:
data, hdr = fits.getdata(spectral_fits, header=True)
yy = np.shape(data)[0]
plot_img(data, 0, yy-1)

In [None]:
data, hdr = fits.getdata(cluster_xy_fits, header=True)
yy = np.shape(data)[0]
plot_img(data, 0, yy-1)

In [None]:
data, hdr = fits.getdata(cluster_clean_fits, header=True)
yy = np.shape(data)[0]
plot_img(data, 0, yy-1)

In [None]:
imm_border, hdr = fits.getdata(cluster_border_fits, header=True)
yy = np.shape(imm_border)[0]
plot_img(imm_border, 0, yy-1)

## 8. (optional) fitting clusters on peaks

In [None]:
max_cluster_no = np.amax(new_index)
original_coeffs = cluster_coeffs.copy()
peak_info = order_t.curve_fitting_on_peaks(cluster_coeffs, power)
plot_poly_trace(imm, max_cluster_no, peak_info['coeffs'], nx, ny, title="fitting on cluster peaks", size=20)
cluster_points = peak_info['peak_pixels']
cluster_coeffs =  peak_info['coeffs']

In [None]:
print(peak_info['errors'], np.mean(peak_info['errors']))

## (optional) difference (RMS) between peak fitting and cluster pixel fitting

In [None]:
rms = order_t.rms_of_polys(original_coeffs, peak_info['coeffs'], power)
print(rms, np.mean(rms))    

## (optional) before step 9: select cluster no

In [None]:
max_cluster_no = np.amax(new_index)
cluster_no = input("select clusters no. " +"(1-" + str(max_cluster_no)+") \n" + 
                        "exit(E): ")
cluster_no = int(cluster_no.strip())

## 9. call API to find width

In [None]:
cluster_widths =  order_t.find_all_cluster_widths(new_index, new_x, new_y, cluster_coeffs,  cluster_points, power)

## (optional) find width of each order, same as step 9 with print and plot

In [None]:
cluster_widths = list()
for n in range(1, max_cluster_no+1):
#for n in range(cluster_no, cluster_no+1):
    print('cluster: ', n)
    ext_spectrum = order_t.get_spectrum_around_cluster(n, new_index, new_x, new_y, cluster_coeffs, power)
    if ext_spectrum is not None:
        x_s = int(nx//2)
        x_e = int(x_s+0)
        y_s =  ext_spectrum['min_y']
        y_e = ext_spectrum['max_y']
        
        cluster_width_info = order_t.width_of_cluster_by_gaussian(n, cluster_coeffs, cluster_points, power)
        cluster_widths.append(cluster_width_info)
    
        print('top width: ', cluster_width_info['avg_nwidth'], ' bottom width: ', cluster_width_info['avg_pwidth'])

        if n != cluster_no:
            continue
        c_widths = cluster_width_info['width_info_all_x']
        for width_at_x in c_widths:
            xs = width_at_x['x']
            
            #import pdb;pdb.set_trace()    
            info_at_x = width_at_x['width_info']
            slope_coeffs_bound = width_at_x['slope_coeffs']
            prev_width = info_at_x['width0']
            next_width = info_at_x['width1']
        
            plot_data_at_x(ext_spectrum['data'], xs, xs, y_s, y_e, cluster_points, cluster_no=cluster_no, w1=float(prev_width), \
                           w2=float(next_width),  title=' width ' + prev_width + ' and '+ next_width  + ' at '+str(xs), size=12)

## (optional) write into result from API result

In [None]:
for c_widths in cluster_widths:
    c_no = c_widths['cluster_no']
    width_all_x = c_widths['width_info_all_x']
    with open(result_csv+str(c_no)+'.csv', mode='w') as result_file:
        #import pdb;pdb.set_trace()
        fieldname=['no', 'x', 'y', 'data','p_mid','backgd0','n_mid','backgd1','p_slope_1','p_slope_2','n_slope_1','n_slope_2','width0', 'width1' ]
        result_writer = csv.DictWriter(result_file, fieldnames=fieldname)
        result_writer.writeheader()
        
        for cluster_x in width_all_x:
            info_x = cluster_x['width_info']
            info_x['no']=str(cluster_no)
            info_x['x']=str(cluster_x['x'])
            result_writer.writerow(info_x)
            i = 1
            for one_slope in cluster_x['slopes_next']:
                result_writer.writerow({'no': str(i),'y': str(one_slope[0]), 'n_slope_1': to_str(one_slope[1]), 
                                        'n_slope_2': to_str(one_slope[2]), 'data': to_str(one_slope[3])})
                i += 1
            i = 1   
            for one_slope in cluster_x['slopes_prev']:
                result_writer.writerow({'no': str(i), 'y': str(one_slope[0]), 'p_slope_1': to_str(one_slope[1]), 
                                    'p_slope_2': to_str(one_slope[2]), 'data': to_str(one_slope[3])})    
                i += 1
        previous_width = c_widths['avg_pwidth']
        next_width = c_widths['avg_nwidth']
        result_writer.writerow({'no': c_no, "width0": previous_width, 'width1': next_width})

## 10. write widths result to csv file

In [None]:
result_poly_width_csv = '../test_data_backup/order_trace_test/'+power_dir+'/output/neid_poly_2sigma_gaussian_pixel_0303_'+str(power)+'.csv'
print(np.shape(cluster_coeffs))
import pdb;pdb.set_trace()
print(np.shape(cluster_coeffs))
order_t.write_cluster_info_to_csv(cluster_widths, cluster_coeffs, power, result_poly_width_csv)

## (optional) test code on single column data and filtering

In [None]:
ori_spec = spe_info['data']
one_column_data = ori_spec[:, 2000]
ny, nx = np.shape(ori_spec)
x = np.arange(ny)
y = one_column_data
plt.figure(figsize=(12,12))
plt.plot(x, y)


In [None]:
ori_spec = spe_info['data']
one_column_data = ori_spec[:, 2000]
f, A, AINV = opt_filter(one_column_data, 20)
data= one_column_data - order_t.opt_filter(one_column_data, 20)
x = np.arange(ny)
plt.figure(figsize=(12,12))
plt.plot(x, data)

imm = np.zeros((ny, nx), dtype=np.uint8)
mm_pos = np.where(data>0, data, 0)
h = 0.5*np.sort(mm_pos)[mm_pos.size//2]
print('h: ', h)
imm[:, 2000][mm_pos>(h+1)] = 1
plt.figure(figsize=(12,12))
plt.plot(x, imm[:, 2000])

In [None]:
plt.figure(figsize=(12,12))
plt.plot(x, f)