In [None]:
import numpy as np
import os, sys
import math
from scipy import ndimage, misc
from astropy.coordinates import Angle
from astropy.io import fits
import matplotlib.pyplot as plt
from scipy import linalg
import csv
import pickle
from matplotlib.colors import LogNorm
from matplotlib import colors
import configparser
import json
%matplotlib inline
power = 3
#power_dir = "output_neid_order_trace"
power_dir = "output_paras_order_trace"
#power_dir = "output_poly3_test"

In [None]:
from order_trace import OrderTraceAlg
TEST_DIR = '/Users/cwang/documents/KPF/KPF-Pipeline/AlgorithmDev/test_data/'

In [None]:
# input: spectral fits is from dropbox KPF-Pipeline-TestData/order_trace_test
spectral_fits=  TEST_DIR + 'order_trace_test/DATA/paras.flatA.fits'

# output
clusters_collection = TEST_DIR + 'order_trace_test/'+power_dir+'/clusters_all_y_collection.pkl'
cluster_xy_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_xy.fits'
cluster_clean_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_clean.fits'
cluster_info_clean_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_clean.fits'
cluster_after_removal_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_after_removal.fits'
cluster_info_after_removal_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_after_removal.fits'
cluster_border_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_border.fits'
cluster_info_border_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_border.fits'
cluster_merge_fitting = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_merge_fitting.fits'
cluster_info_merge_fitting = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_merging_fitting.fits'
result_csv = TEST_DIR + 'order_trace_test/'+power_dir+'/result_cluster/result_cluster'
result_poly_width_csv =  TEST_DIR + 'order_trace_test/'+power_dir+'/output/paras_result_poly_2sigma_gaussian_pixel_'+str(power)+'.csv'

In [None]:
# input: spectral fits is from dropbox KPF-Pipeline-TestData/NEIData/FLAT
spectral_fits= TEST_DIR + 'order_trace_test/DATA/stacked_2fiber_flat.fits'

# output
clusters_collection = TEST_DIR + 'order_trace_test/'+power_dir+'/clusters_all_y_collection.pkl'
remove_vertical_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/data_remove_vertical.fits'
cluster_xy_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_xy.fits'
cluster_clean_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_clean.fits'
cluster_info_clean_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_clean.fits'
cluster_after_removal_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_after_removal.fits'
cluster_info_after_removal_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_after_removal.fits'
cluster_border_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_border.fits'
cluster_info_border_fits = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_border.fits'
cluster_merge_fitting = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_merge_fitting.fits'
cluster_info_merge_fitting = TEST_DIR + 'order_trace_test/'+power_dir+'/cluster_info_merging_fitting.fits'
result_csv = TEST_DIR + 'order_trace_test/'+power_dir+'/result_cluster/result_cluster'
result_poly_width_csv = TEST_DIR + 'order_trace_test/'+power_dir+'/output/neid_result_poly_2sigma_gaussian_pixel_0221_'+str(power)+'.csv'

In [None]:
def plot_imshow(img):
    if (np.amax(img) == 1) and (np.amin(img) == 0):
        print('is bw image')
        im = plt.imshow(img * -1, cmap='gray')
    else:    
        im = plt.imshow(img, cmap='gray', norm=LogNorm())
    return im    

In [None]:
# plot image from fits by setting area xmin, xmax, ymin, ymax
def plot_img(img, ymin, ymax, p_w=20, p_h=20, xmin=None, xmax=None, title="", aspect=None):
    #if is_bw is True:
    #    img = convert_to_bw(img)
    plt.figure(figsize=(p_w, p_h), frameon=False)
    plt.subplot(1, 1, 1)
    if xmin is None:
        xmin = 0
    if xmax is None:
        h, w = np.shape(img)
        xmax = w-1
    s_img = img[:, :]
    im = plot_imshow(s_img)

    #im = plt.imshow(s_img, cmap='gray')
    plt.ylim(ymin, ymax)
    plt.xlim(xmin, xmax)
    plt.title(title)
    if aspect is not None:
        plt.axes().set_aspect(aspect)
    plt.show()

In [None]:
def make_fits(data, output_fits):
    hdu = fits.PrimaryHDU(data)
    hdu.writeto(output_fits, overwrite=True)

In [None]:
# make image data in 2D based on selected clusters
def make_2D_data(index, x, y, nx, ny, selected_clusters=None):
    imm = np.zeros((ny, nx), dtype=np.uint8)
    if selected_clusters is None:
        ymin = 0
        ymax = ny-1
    else:
        sel = np.where(np.isin(index, selected_clusters))[0]
        ymin = np.amin(y[sel])
        ymax = np.amax(y[sel])
                       
    for cy in range(ny):
        if cy < ymin: 
            continue
        elif cy > ymax:
            break;
        #print(cy,' ', end='')    
        y_cond = np.where(y==cy)[0]
        if selected_clusters is None:
            nz_idx_at_cy = y_cond[np.where(index[y_cond] != 0)[0]]
        else:
            nz_idx_at_cy = y_cond[np.where(np.isin(index[y_cond], selected_clusters))[0]]
        imm[cy, x[nz_idx_at_cy]] = 1
    print()    

    return imm

In [None]:
# make fits on 2D of all clusters
def make_cluster_fits(index, x, y, nx, ny, fits_path=None):
    imm = make_2D_data(index, x, y,  nx, ny)
    if fits_path is not None:
        make_fits(imm, fits_path)
    ind_max = np.amax(index)
    print('there are '+str(ind_max)+' clusters in total in fits, '+fits_path)
    return imm

In [None]:
# make fits on cluster info (index, x, y)
def make_cluster_info_fits(index, x, y, cluster_info_filepath):
    cluster_data = np.zeros((3, index.size))
    cluster_data[0, :] = index
    cluster_data[1, :] = x
    cluster_data[2, :] = y
    make_fits(cluster_data, cluster_info_filepath)

In [None]:
# plot polynomial fitting curve on top of given 2D image
# the cluster orders is settable by order_set
def plot_poly_trace(imm, total_order, coeffs_orders, max_x, max_y, size=20, order_set=None, \
                    title=None, background=False, widths=None, aspect=None, xmin=None, xmax=None, ymin=None, ymax=None):
    plt.figure(figsize=(size,size))
    plt.subplot(1, 1, 1)
    im = plot_imshow(imm)
    #im = plt.imshow(imm, cmap='gray', norm=LogNorm())
    
    if order_set is None:
        orders = list(range(1, total_order+1))
    else:
        orders = order_set
        
    x_dist = max_x//20    
                   
    for o_idx, order in enumerate(orders):
        if (background is not False):
            x_val = np.arange(0, max_x)
            # y value on x range
            y_val = np.polyval(coeffs_orders[order, 0:power+1], x_val)
            plt.plot(x_val, y_val, 'b--')
        #print("x range: ", coeffs_orders[order, power+1], coeffs_orders[order, power+2])    
        # x range
        x_val = np.arange(coeffs_orders[order, power+1], coeffs_orders[order, power+2]+1)
        # y value on x range
        y_val = np.polyval(coeffs_orders[order, 0:power+1], x_val)
        plt.plot(x_val, y_val, 'r--')
        
        if widths is not None:
            y_val_bottom = y_val-widths[o_idx][0]
            plt.plot(x_val, y_val_bottom, 'g--')
            y_val_top = y_val+widths[o_idx][1]
            plt.plot(x_val, y_val_top, 'g--')
        
        # show number of cluster
        s = ((order%15)+1)*x_dist
        if s >= x_val.size:
            dem = int((coeffs_orders[order, power+2] - coeffs_orders[order, power+1])//5)
            s = dem*((order%4)+1)
            #s = x_val.size//2
        plt.text(x_val[s], y_val[s], str(order), fontsize=12, color='b', fontweight='bold', horizontalalignment='center', verticalalignment='center')
        
    if title is not None:
        plt.title(title, fontsize=12)
    x1 = 0 if xmin is None else xmin
    x2 = max_x if xmax is None else xmax
    y1 = 0 if ymin is None else ymin
    y2 = max_y if ymax is None else ymax
    
    plt.ylim(y1, y2)
    plt.xlim(x1, x2)
    if aspect is not None:
        plt.axes().set_aspect(aspect)
    
    plt.show()
    #plt.colorbar(im, fraction=0.046, pad=0.04)

In [None]:
def to_str(afloat):
    new_str = f"{afloat:.4f}"
    return new_str

In [None]:
# json save and load
def save_obj(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

## Usage: Using OrderTarce to extract order trace from the given spectral fits (for NEID)

In [None]:
order_t = OrderTrace(spectral_fits)
spe_info = order_t.load_spectral()
ny = spe_info['ny']
v_reset = [[434, 451], [1930, 1945]]
r_reset = [[0, 1000], [ny-1000, ny]]
cluster_info = order_t.extract_order_trace(power, cols_to_reset=v_reset, rows_to_reset=r_reset, power_for_width = 2, show_time=True)
result_poly_width_csv  = TEST_DIR + 'order_trace_test/'+power_dir+'/output/neid_poly_3sigma_gaussian_pixel_'+str(power)+'_width_2.1.csv'
#result_poly_width_csv = ' TEST_DIR + 'order_trace_test/'+power_dir+'/result_poly_2sigma_gaussian_peak_'+str(power)+'.csv'
order_t.write_cluster_info_to_csv(cluster_info['widths'], cluster_info['coeffs'], power, result_poly_width_csv)

## Usage: Using OrderTarce to extract order trace from the given spectral fits (for NEID)

In [None]:
order_t = OrderTrace(spectral_fits)
spe_info = order_t.load_spectral()
ny = spe_info['ny']
cluster_info = order_t.extract_order_trace(power, power_for_width = 3, show_time=True)
result_poly_width_csv = TEST_DIR + 'order_trace_test/'+power_dir+'/output/paras_result_poly_3sigma_gaussian_pixel_'+str(power)+'_width_3.csv'
order_t.write_cluster_info_to_csv(cluster_info['widths'], cluster_info['coeffs'], power, result_poly_width_csv)

## Extracting order trace step by step
### excute the cells from step 1 to step 10 and get visual output for each step

## 1. load spectral file

In [None]:
fits_header = fits.open(spectral_fits)
config = configparser.ConfigParser()
config.read('order_trace/PARAS.cfg')
#config.read('order_trace/NEID.cfg')
order_t = OrderTraceAlg(fits_header[0], config['PARAM'] )
imm_spec, nx, ny = order_t.get_spectral_data()
spe_info = {'data': imm_spec, 'nx': nx, 'ny': ny}
print('row: ', spe_info['ny'], ' column: ', spe_info['nx'])
plot_img(imm_spec, 0, ny-1)
power = order_t.get_poly_degree()
print('power: ', power)
#plot_img(imm_spec, 2000, 2100, xmin=4500, xmax=4600)

## 2. find cluster pixels  and make fits

In [None]:
r_v = True if 'stacked_2fiber_flat' in spectral_fits else False

v_reset = [[434, 451], [1930, 1945]]
r_reset = [[0, 900], [ny-900, ny]]

cluster_xy = order_t.locate_clusters()

#order_t.make_fits(cluster_xy['im_map'], cluster_xy_fits)
yy = np.shape(cluster_xy['cluster_image'])[0]
plot_img(cluster_xy['cluster_image'], 0, yy-1)

## 3. form clusters, basic cleanning (based on size and total pixel), make fits

In [None]:
#cluster_info, dict
cluster_info = order_t.collect_clusters(cluster_xy['x'], cluster_xy['y'])

In [None]:
clusters_collection =  TEST_DIR + 'order_trace_test/'+power_dir+'/clusters_all_y_collection.pkl'
save_obj(cluster_info, clusters_collection)   # optional, save cluster info into .pkl file

In [None]:
# (optional) load cluster collection 
cluster_info = load_obj(clusters_collection)   # optional, load cluster info from .pkl file

In [None]:
# assign index value to cluster_info['index'], where cluster_info['index'] is the same size as cluster_xy['x']
cluster_info = order_t.remove_cluster_by_size(cluster_info, cluster_xy['x'], cluster_xy['y'])

In [None]:
# remove unassigned index
x, y, index_t = order_t.reorganize_index(cluster_info['index'], cluster_xy['x'], cluster_xy['y'])
nx = spe_info['nx']
ny = spe_info['ny']
imm = order_t.make_2d_data(index_t, x, y)   # show image  and make fits and info fits 
plot_img(imm, 0, np.shape(imm)[0]-1)

## save data from the result of 3.

In [None]:
imm=make_cluster_fits(index_t, x, y, nx, ny, cluster_clean_fits)
make_cluster_info_fits(index_t, x, y, cluster_info_clean_fits)

## reload clean fits and info fits of step 3. (optional)

In [None]:
# (optional) reload saved fits 
imm, hdr = fits.getdata(cluster_clean_fits, header=True)
ny, nx = np.shape(imm)
cluster_info, c_hdr = fits.getdata(cluster_info_clean_fits, header=True)
index_t = cluster_info[0].astype(int)
x = cluster_info[1].astype(int)
y = cluster_info[2].astype(int)

plot_img(imm, 0, np.shape(imm)[0]-1)
ind_max = np.amax(index_t)