In [None]:
from __future__ import print_function
from astropy.io import fits
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from numpy.polynomial.polynomial import polyval, polyder
import time
import csv
%matplotlib inline

In [None]:
from AlgorithmDev import PolygonClipping2
power = 3

In [None]:
def plot_two_fits_trace(spectral1, spectral2, total_rows, coeffs_rows, range_rows = None):
    plt.figure(figsize=(20,20))
    plt.subplot(1, 2, 1) 
    im1 = plt.imshow(spectral1['data'], cmap='gray', norm=LogNorm())

    total_col = np.shape(coeffs_rows)[1]

    for y in range(0, total_rows):
        if range_rows is not None:
            x_val = np.arange(range_rows[y, 0], range_rows[y, 1])
        else:
            x_val = np.arange(0, spectral['xdim'])
        y_val = np.polyval(coeffs_rows[y], x_val)
        plt.plot(x_val, y_val, 'r--')
    
    plt.ylim(0, spectral1['ydim'])
    plt.colorbar(im1, fraction=0.046, pad=0.04)   
    
    plt.subplot(1, 2, 2)
    im2 = plt.imshow(spectral2['data'], cmap='gray', norm=LogNorm())
    
    plt.ylim(0, spectral2['ydim'])
    plt.colorbar(im2, fraction=0.046, pad=0.04)  

In [None]:
def plot_output(out_data, total_rows):
    # show output
    plt.figure(figsize=(12,12))
    plt.subplot(1, 1, 1)
    plt.imshow(out_data, cmap='gray')
    plt.ylim(0, total_rows)
    #plt.colorbar(im, fraction=0.046, pad=0.04)

In [None]:
def load_spectral_sample(fits_file, order_trace_csv, flatlamp_file):
    poly_c = PolygonClipping2(fits_file, 6)
    spectral = poly_c.load_paras_spectral()
    print('data size: ', spectral['xdim'], spectral['ydim'])
    coeffs_rows, widths, xrange = poly_c.load_csv_file(order_trace_csv, spectral['xdim'], power)
    flatlamp_spectral = poly_c.load_paras_spectral(flatlamp_file)
   
    return {'spectral': spectral, 'flatlamp_spectral': flatlamp_spectral, 'coeffs': coeffs_rows, 
            'poly_handle': poly_c, 'widths': widths, 'xrange': xrange}

In [None]:
def analyze_spectral(spectral, f_spectral, coeffs_rows, widths, xrange, poly_c, row_idx=None, method = 'sum_fraction'):
    total_rows = len(coeffs_rows)
    in_data = np.array(spectral.get('data'), None)
    flat_data = np.array(f_spectral.get('data'), None)
    out_data = np.zeros((total_rows, spectral.get('xdim')))
    
    row_range = range(0, total_rows) if row_idx is None else range(row_idx, row_idx+1)
    for i in row_range:
        start = time.time()
        print(i, 'widths: ', widths[i], ' method: ', method)
        if 'sum_fraction' in method:
            result_data = poly_c.rectify_spectral_curve_by_fractional_sum(coeffs_rows[i], widths[i], xrange[i], in_data, flat_data, verbose=False)
        else:
            result_data = poly_c.rectify_spectral_curve_by_optimal2(coeffs_rows[i], widths[i], xrange[i], in_data, flat_data, verbose=False)
        end = time.time()
        print('extraction: ', i, ' time: ', (end-start), ' widths: ', widths[i])
        line_data = result_data.get('out_data')
        poly_c.fill_2D_to_2D(line_data, out_data, 0, i)
    
    print('done')
    return {'out_data': out_data, 'dim': {'height': total_rows, 'width': spectral['xdim']}}
    

In [None]:
def make_fits(data, output_fits, MJD = None):
    hdu = fits.PrimaryHDU(data)
    if MJD is not None:
        hdu.header['MJD-OBS'] = str(MJD)
    hdu.writeto(output_fits)

In [None]:
def spectral_update(spectrall, correct_file):
    in_data = np.array(spectral.get('data'), None)
    correct_data = poly_c.load_correct_data(correct_file)
    new_data = poly_c.correct_data_by_sub(correct_data, in_data)
    spectrall.update({'data': new_data})
    """
    plt.figure(figsize=(16, 16))
    plt.subplot(1, 1, 1)
    im = plt.imshow(in_data, cmap='gray', norm=LogNorm())
    plt.ylim(0, spectrall['ydim'])
    plt.title("a0018.fits")
    plt.colorbar(im, fraction=0.046, pad=0.04) 
    
    plt.figure(figsize=(16, 16))
    plt.subplot(1, 1, 1)
    im = plt.imshow(correct_data, cmap='gray', norm=LogNorm())
    plt.ylim(0, spectrall['ydim'])
    plt.title("bleeding_cure_14feb2015_1800.fits")
    plt.colorbar(im, fraction=0.046, pad=0.04)  

    plt.figure(figsize=(16, 16))
    plt.subplot(1, 1, 1)
    im = plt.imshow(spectrall.get('data'), cmap='gray', norm=LogNorm())
    plt.ylim(0, spectrall['ydim']) 
    plt.title("after subtraction")
    plt.colorbar(im, fraction=0.046, pad=0.04)   
    """

In [None]:
def extract_optimal_trace(in_data, start_idx, total_order=None):
    width, height = np.shape(in_data)
    start_idx = max(start_idx, 0)
    if total_order is not None:
        end_idx = min(start_idx+total_order, height)
    else:
        end_idx = height
 
    return in_data[start_idx:end_idx, :]

## 1. define and load files: spectrum file, flat file, cure file, coeffs/width file

In [None]:
# input fits_base: fits_base, flat_base are from dropbox: KPF-Pipeline-TestData/polygon_clipping_test/paras_data
# csv_base is from order_trace_width_test_neid result
# for paras data
#fits_base = '../test_data/paras_data/a00'
#flats_base = '../test_data/paras_data/paras.flat'
#csv_base = '../test_data/paras_data/order_trace_'
#csv_base = '../test_data/order_trace_test/for_optimal_extraction/'

# for NEID data
fits_base = '../test_data_02242020/NEIDdata/TAUCETI_20191217/L0/neidTemp_2D20191217T'
flats_base = '../test_data_02242020/NEIDdata/FLAT/stacked_2fiber_flat'
#csv_base = '../test_data/paras_data/order_trace_'
csv_base = '../test_data_02242020/order_trace_test/for_optimal_extraction/'

output_base = '../test_data_02242020/order_trace_test/for_optimal_extraction/output/'

#fits_list = ['18', '19']
fits_list = ['023129', '023815','024240','024704','030057','030724','031210','031636']
fiber_list = ['A']
f = 0
c = 0
#fiber_name = fiber_list[c]  # use a0018.fits and paras.flatA.fits and csv created by order_trace_test
fiber_name = ''
fits_name = fits_list[f]
fits_file = fits_base+fits_name+'.fits'

#csv_file = '../test_data/paras_data/order_trace_'+fiber_name+'.csv'
#csv_file = csv_base + 'neid_result_poly_2sigma_gaussian_pixel_0221_3.csv'
csv_file = csv_base + 'neid_poly_2sigma_gaussian_pixel_0303_3.csv'
#csv_file = csv_base + 'paras_result_poly_2sigma_gaussian_pixel_3.csv'
flatlamp_file = flats_base+fiber_name+'.fits'

method = 'sum_fraction'
#method = 'optimal'
output_cure_fits = output_base + fiber_name + '_' + fits_name + '_cure_optimal.fits' #18_A_cure_optimal
output_original_fits = output_base + fiber_name + 'NEID_' + fits_name + '_original_' + method +'.fits' # optimal or fractional sum
cure_fits = "../test_data/order_trace_test/for_optimal_extraction/14feb2015/bleeding_cure_14feb2015_1800.fits"
output_fits = output_original_fits
print('fits_file:', fits_file, '\ncsv_file', csv_file, '\noutput file: ', output_fits)


In [None]:
sample_info = load_spectral_sample(fits_file, csv_file, flatlamp_file)

spectral = sample_info.get('spectral')
flatlamp_spectral = sample_info.get('flatlamp_spectral')
coeffs_rows = sample_info.get('coeffs')
poly_c = sample_info.get('poly_handle')
range_rows = sample_info.get('xrange')

plot_two_fits_trace(spectral, flatlamp_spectral, np.shape(coeffs_rows)[0], coeffs_rows, range_rows)

## 1.1 (optional) update the data by bleeding file (for paras data)

In [None]:
output_fits = output_cure_fits
print('output_fits:', output_fits)
spectral_update(spectral, cure_fits)

## 2. optimal extraction analysis

In [None]:
widths = sample_info.get('widths')
xrange = sample_info.get('xrange')
#print(widths)
result_optimal = analyze_spectral(spectral, flatlamp_spectral, coeffs_rows, widths, xrange, poly_c, method=method)
plot_output(result_optimal.get('out_data'), result_optimal.get('dim').get('height'))

## 3. extract row from analysis result 

In [None]:
#output_partial_optimal_extraction_fits = output_base + fiber_name + '_' + fits_name + '_sum_fraction_from_2.fits'

in_data = result_optimal['out_data']
output_file = output_original_fits
out_data = extract_optimal_trace(in_data, 0)
#output_file = output_partial_optimal_extraction_fits
#out_data = extract_optimal_trace(in_data, 2)

flat_header = fits.open(flatlamp_file)
header = flat_header[0].header
header_keys = list(header.keys())
if 'OBSJD' in header_keys:
    mjd = header['OBSJD'] - 2400000.5
elif 'OBS MJS' in header_keys:
    mjd = header['OBS MJD']
else:
    mjd = 58831.009653
exptime = header['EXPTIME'] if 'EXPTIME' in header_keys else 600.0
hdu = fits.PrimaryHDU(out_data)
hdu.header['MJD-OBS'] = mjd
hdu.header['EXPTIME'] = exptime
hdu.writeto(output_file, overwrite=True)

## comparison between NEID L1 and the result from sum fraction or optimal extraction

In [None]:
neid_file = '../test_data_02242020/NEIDdata/TAUCETI_20191217/L1/neidL1_20191217T023129.fits'
#my_file = '../test_data/order_trace_test/for_optimal_extraction/output/_023129_original_optimal.fits'
my_file = '../test_data_02242020/order_trace_test/for_optimal_extraction/output/NEID_023129_original_sum_fraction.fits'
my_csv = '../test_data_02242020/order_trace_test/for_optimal_extraction/neid_poly_2sigma_gaussian_pixel_0303_3.csv'
neid_fits, neid_header = fits.getdata(neid_file, header=True)
my_fits, my_header = fits.getdata(my_file, header=True)

order_widths = list()
with open(my_csv) as order_csv:
    order_rows = csv.reader(order_csv)
    for row in order_rows:
        order_widths.append([row[power+1], row[power+2]]) 
                                    
d = 7 
neid_size = np.shape(neid_fits)
my_size = np.shape(my_fits)
last_order = min(neid_size[0], (my_size[0]//2) + d -1)
print('neid: ',np.shape(neid_fits))
print('my: ', np.shape(my_fits))
print('size_y: ', last_order)
s_y = 450
for i in range(d, last_order+1):
    
    neid_order = neid_fits[i, s_y:]
    my_order = my_fits[(i-d)*2, s_y:]

    plt.figure(figsize=(12,12))
    plt.plot(neid_order, 'b--', label='order of neid L1: '+str(i))
    plt.plot(my_order, 'r--', label = 'order of OrderTrace: ' + str((i-d)*2))
    #plt.xlim(0, min(neid_size[1], my_size[1]))
    plt.title('neid L1 vs. fractional sum from results of order_trace_width_test_neid '+\
              '  widths: ('+ str(order_widths[(i-d)*2][0]) + ',' + str(order_widths[(i-d)*2][1]) +')'+\
              '  x: ['+str(s_y)+','+str(neid_size[1])+']' )
    plt.legend(loc="upper right", prop={'size': 12})
    plt.show()