In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import fftpack,signal
import matplotlib
from astropy.visualization import (MinMaxInterval, SqrtStretch, ImageNormalize, ZScaleInterval)
import scipy
import json
import subprocess
import os

In [None]:
class crires_simu:
    
    #------------------
    
    def __init__ (self, path_etc_local, path_etc_input, path_input_modi, \
                  path_etc_output, aperture_size, date, order_number=[], order_window=[], \
                  target = None, model_name=None, planet_p=None):
        
        if target not in ['star', 'planet']:
            print ('Target not valid, please choose one from "star" or "planet" ')    
        
        self.target = target
        self.path_etc_local=path_etc_local
        self.path_input = path_etc_input
        self.path_input_modi = path_input_modi
        self.path_output = path_etc_output
        self.model_name = model_name
        self.date=date
        
        self.order_num = order_number
        self.order_wave = order_window
        
        self.planet_p = planet_p
        
        if aperture_size%2 == 0:
            print ("Warning: It's better to use an odd number, the final size is %s"%(aperture_size-1))
            
        self.aperture_size = aperture_size
        self.aperture_list = np.arange(1, self.aperture_size+1, 2).tolist() #1, 3, 5 ...35
    
    #------------------
    
    def output_json (self):
        
        with open(self.path_input, 'r') as f:
            data = json.load(f)
            # Modify the contents of the JSON file
            for i in self.aperture_list:
            
                data['seeingiqao']['aperturepix']= float(i)
                
                # Update the modified JSON file to disk
                with open(self.path_input_modi, 'w') as f:
                    json.dump(data, f)
                
                subprocess.call(['mkdir', self.path_output+'ascii_%s/'%i])
        
                if self.target == 'star':

                    subprocess.call(['python', self.path_etc_local, 'CRIRES', self.path_input_modi, \
                         '-o', self.path_output+'ascii_%s/'%i+'output_%s_%s_AO.json'%(self.date, i)])
        
                elif self.target == 'planet':
                    subprocess.call(['python', self.path_etc_local, 'CRIRES', self.path_input_modi, \
                         '-u' , self.model_name, '-o', self.path_output+'ascii_%s/'%i+'output_p_%s_%s_AO.json'%(self.date, i)])
        
        print ('the ETC calculation is done, the output file is saved at %s'%self.path_output)
    
    #------------------
    
    def json2ascii (self, path_json2ascii):
        
        if self.target == "star":      
            
            for i in self.aperture_list:
                subprocess.call(['cp', '-r', path_json2ascii, self.path_output+'ascii_%s/'%i])
                subprocess.call(['python', self.path_output+'ascii_%s/'%i+'etc_json2ascii.py',\
                                 self.path_output+'ascii_%s/'%i+'output_%s_%s_AO.json'%(self.date, i)])
        
        if self.target == 'planet':            
            for i in self.aperture_list:
                subprocess.call(['cp', '-r', path_json2ascii, self.path_output+'ascii_%s/'%i])
                subprocess.call(['python', self.path_output+'ascii_%s/'%i+'etc_json2ascii.py',\
                                 self.path_output+'ascii_%s/'%i+'output_p_%s_%s_AO.json'%(self.date, i)])
    #------------------
    
    def ascii2txt (self, output, order_number=None):   
        
        if order_number == None:
            order_number=self.order_num
        
        for i in self.aperture_list:

            for j in [1,2,3]:
                for x in order_number:
                    for m in output:
                        try:
                            if self.target == 'star':
                                file=self.path_output+'ascii_%s/'%i+\
                                'output_%s_%s_AO.json_order:%s_det:%s_%s.ascii'%(self.date, i, x, j, m)
                            elif self.target =='planet':
                                file=self.path_output+'ascii_%s/'%i+\
                                'output_p_%s_%s_AO.json_order:%s_det:%s_%s.ascii'%(self.date, i, x, j, m)
                            dat=[]
                            with open(file, 'r') as f:
                                lines = f.readlines()
                                dat += [line.split() for line in lines[2:]]


                            with open(self.path_output+'ascii_%s/'%i+'converted_data_%s_%s_%s.txt'%(x,m,j), 'w') as f:
                                # Write the header line
                                if m == 'psf':
                                    f.write("angle (arcsec)\tPSF (norm)\n")
                                elif m == 'snr_snr':
                                    f.write('wavelength (m)\tsnr\n')
                                elif m == 'sed_target':
                                    f.write('wavelength (m)\tflux (Jsm-2m)\n')
                                else:
                                    f.write("wavelength (m)\tcounts\n")
                                # Write the data
                                for line in dat:
                                    f.write("\t".join(line) + "\n")


                        except FileNotFoundError:
                            print('File Not Found: %s-%s-%s-%s'%(i,x,j,m))
                            continue
                            
                    print('txt_done_%s_%s_%s'%(i,x,j))
    
    #------------------        
            
    def signal(self, detectors, contribution=None, nor=None, focal_plane=None):
        
        #load parameters------
        self.focal_plane = np.zeros(focal_plane)
        self.detectors = detectors
        
        order_number = self.order_num
        order_window = self.order_wave 
        
        if contribution not in ['sky', 'target', 'tot']:
            print('Stupid boy, you need to choose one from [sky, target, tot]')
            
        
        inter = np.arange(0, int(np.ceil(self.aperture_size/2)),1)
        name = self.aperture_list
        
        
        #build the dict for all aperture sizes
        dat_tot = {'aper': inter.tolist(), 'name': name}
        
        for n in inter.tolist():
            i = 2*n+1
            dat_tot['aper'][n] = {'wavelength(nm)':[],'counts':[]}
            
            for nu in range(1, detectors+1):
                for order in order_number:
                    data_sig=pd.read_table(self.path_output+'ascii_%s/'%i+'converted_data_%s_signals_obs%s_%s.txt'%(order,contribution, nu))
                    dat_tot['aper'][n]['wavelength(nm)'].append(data_sig['wavelength (m)'].values*1e9)
                    dat_tot['aper'][n]['counts'].append(data_sig['counts'].values)
            
        #set up focal plane
        wave_count_0=pd.DataFrame(data=dat_tot['aper'][0])
                
        #make subtraction for each two aperture sizes and write the diffrenece into the dataframe wave_count_0
        
        for i in inter[1:]:
            
            wave_count_i=pd.DataFrame(data=dat_tot['aper'][i])
            wave_count_i_1=pd.DataFrame(data=dat_tot['aper'][i-1])
            if np.where((wave_count_i['wavelength(nm)']-wave_count_i_1['wavelength(nm)']).all==0, True, False) == False:
                dat_diff=(wave_count_i['counts']-wave_count_i_1['counts'])/2
            else:
                print('Wrong result! The wavelength range is not matched')
            wave_count_0['dat_diff_%s'%i]=dat_diff
            
        #organize wave_count_0 for ploting: rename, explode, and sort the values by ascending wavelength
        
        wave_count_0=wave_count_0.rename(columns={'counts':'dat_diff_0'})
        name_list=wave_count_0.columns.to_list()
        wave_count_ex=wave_count_0.explode(name_list).reset_index()
        wave_count_sort=wave_count_ex.sort_values(by='wavelength(nm)').reset_index()
        
        return (wave_count_sort)
            
    #------------------
    
    def plot_signal(self, wave_count_sort, max_percentile, save_path, interv, minimum=None, Nor=None, plot_combination=None):
        
        #set up fontsize
        font = {'size': 4}
        plt.rcParams.update({'font.size': font['size']})
        
        #plot data with imshow
        plt.ion() #enable the interactive mode
        row = len(self.order_num)
        fig, axs=plt.subplots(row, 1, dpi=500, sharex=True, sharey=True)
        focal_plane = self.focal_plane
        
        for n_order in range(row):

            one_order=wave_count_sort.loc[(wave_count_sort['wavelength(nm)']<=self.order_wave[n_order][1])&(wave_count_sort['wavelength(nm)']>=self.order_wave[n_order][0])]
            
            central_pix=int(np.ceil(focal_plane.shape[0]/2))
            
            pixel_posi=len(self.aperture_list)
            for x in range(0,pixel_posi):
                loc_up=central_pix+x
                loc_be=central_pix-x  
                self.focal_plane[loc_up,0:len(one_order)]=one_order['dat_diff_%s'%x]
                self.focal_plane[loc_be,0:len(one_order)]=one_order['dat_diff_%s'%x]
            
            ff=axs[n_order]

            ceil= np.percentile(focal_plane, max_percentile)
            if minimum == None:
                minimum=np.min(focal_plane)
                
            nor = ImageNormalize(focal_plane, interval= interv, vmax=ceil, vmin=minimum)
            
            max_order = np.max(self.order_num)

    
            ff.set_title('order%s: %s nm - %s nm'%(max_order-n_order,self.order_wave[n_order][0],self.order_wave[n_order][1]),fontsize=5)
            if Nor == None or nor == True:
                ff.imshow(focal_plane,norm = nor ,cmap='Greys_r')
                fig.colorbar(ff.imshow(focal_plane, norm=nor, cmap='Greys_r'), ax=axs[n_order], shrink=0.5, aspect=60, \
                         location='bottom',pad=0.2)
            elif Nor == False:
                ff.imshow(focal_plane, vmin=minimum, vmax=ceil, cmap='Greys_r')
                fig.colorbar(ff.imshow(focal_plane, vmin=minimum, vmax=ceil, cmap='Greys_r'), ax=axs[n_order], shrink=0.3, aspect=50, \
                         location='bottom',pad=0.2)
            else:
                print('Stupid sweet, you just passed a wrong value to Nor. It should be on from True, Flase or None')
                
            ff.set_aspect(8)
            ff.axis('off')
            if "plot_combination" == True:
                ff.hlines(y=1024+self.planet_p, xmin=0, xmax=self.focal_plane.shape[1], ls='--',colors='gray', linewidth=0.8)

        plt.subplots_adjust(hspace=1)
        plt.ylim(loc_be-1,loc_up+1)
        plt.suptitle('focal plane %s \n$e^-$/pix/exposure'%str(focal_plane.shape),x=0.5,fontsize=9)

        plt.savefig(save_path)
        plt.close() 
        
    #------------------
    
    def combine(self, data_star, data_p, \
                plot_combination=None, d=None, sky=True, noise=True, data_sky=None,\
                ron=None, dark=None, focal_plane=None,\
                 plot_save_path=None):
        
        if noise not in [True, False]:
            
            print('Stupid, you just passed a nonsense value to noise. It should be one from [None, True, False]')
            
        if sky not in [True, False]:
            
            print('Stupid, you just passed a nonsense value to sky contribution. It should be one from [None, True, False]')
            
        if self.planet_p == None:
            self.planet_p = d 
            
        elif d!=0 and self.planet_p != None:
            print ('Stupid! you already choose one position in the initial function. I will not change the value!')
            d = self.planet_p
            
        if focal_plane != None:
            self.focal_plane = np.zeros(focal_plane)
        
            
        row=len(self.order_num)
        pixel_posi=len(self.aperture_list)
        
        signal_tot = np.zeros((row, self.focal_plane.shape[0], self.focal_plane.shape[1]))
        for n_order in range(row):
            focal_plane = self.focal_plane
            cons_plane = self.focal_plane
            cons_nnoise_plane = self.focal_plane
            
            order_w=self.order_wave

            one_order_s=data_star.loc[(data_star['wavelength(nm)']<=order_w[n_order][1])&(data_star['wavelength(nm)']>=order_w[n_order][0])]
            one_order_p=data_p.loc[(data_p['wavelength(nm)']<=order_w[n_order][1])&(data_p['wavelength(nm)']>=order_w[n_order][0])]
            if sky == True: 
                one_order_k=data_sky.loc[(data_sky['wavelength(nm)']<=order_w[n_order][1])&(data_sky['wavelength(nm)']>=order_w[n_order][0])]
            
            central_pix=int(np.ceil(self.focal_plane.shape[0]/2))
            
            for x in range(0, pixel_posi):

                #add stellar signal
                loc_up=central_pix+x
                loc_be=central_pix-x

                focal_plane[loc_up,0:len(one_order_s)]=one_order_s['dat_diff_%s'%x]
                focal_plane[loc_be,0:len(one_order_s)]=one_order_s['dat_diff_%s'%x]
                #print(focal_plane[loc_be,1000:1010])
                
                if sky == True:
                    #add sky signal 
                    focal_plane[loc_up,0:len(one_order_k)]+=(one_order_k['dat_diff_%s'%x])
                    focal_plane[loc_be,0:len(one_order_k)]+=(one_order_k['dat_diff_%s'%x])
                    #print(focal_plane[loc_be,1000:1010])
                    
                if noise == True:
                    #add noise
                    noi=ron+dark
                    focal_plane[loc_up,0:len(one_order_k)]+=noi
                    focal_plane[loc_be,0:len(one_order_k)]+=noi
                    

                #add planet signal
                if d+x <= pixel_posi:
                    loc_up_p=central_pix+d+x
                    loc_be_p=central_pix+d-x
                    focal_plane[loc_up_p,0:len(one_order_p)]+=one_order_p['dat_diff_%s'%x]
                    focal_plane[loc_be_p,0:len(one_order_p)]+=one_order_p['dat_diff_%s'%x]
                    #print(focal_plane[loc_be_p,1000:1010])
                #print('--------')
                
            signal_tot[n_order]=focal_plane
            
        if plot_combination == True:

            central_pix = focal_plane.shape[0]/2
            width = len(self.aperture_list)

            fig, axs=plt.subplots(7,1,dpi=200, sharex=True, sharey=True)

            for n_order in range(0, row):

                ceil=np.percentile(focal_plane, 99.)    
                nor = ImageNormalize(focal_plane, interval= ZScaleInterval(contrast=0.1), vmax=ceil,vmin=np.min(focal_plane))

                ff=axs[n_order]
                ff.imshow(signal_tot[n_order],norm=nor ,cmap='Greys_r')
                ff.set_title('order%s: %s nm - %s nm'%(29-n_order,order_w[n_order][0],order_w[n_order][1]),fontsize=5)
                fig.colorbar(ff.imshow(signal_tot[n_order], norm=nor, cmap='Greys_r'), ax=axs[n_order], shrink=0.3, aspect=50, \
                             location='bottom',pad=0.2)

                ff.hlines(y=central_pix+d,xmin=0,xmax=focal_plane.shape[1],ls='--',colors='gray', linewidth=0.8)
                ff.axis('off')
                ff.set_aspect(8)
                ff.set_ylim(central_pix-width,central_pix+width)

            fig.subplots_adjust(hspace=1.3)
            fig.suptitle('focal plane (%s*%s)\n$e^-$/pix/exposure'%(focal_plane.shape[0], focal_plane.shape[1]),\
                         x=0.5,fontsize=9)    
                
            
        return (signal_tot,fig)
    
#-----------------------------------------    

def load_json(path, permit):
    with open(path, '%s'%permit) as f:
        data=json.load(f)
        
    return (data)              

In [None]:

class planet_simu:
    
    def __init__(self, data_planet=None, data_star=None, disper_planet=None, disper_star=None):
        
        self.flux_planet = data_planet
        self.flux_star = data_star
        self.disper_planet = disper_planet
        self.disper_star = disper_star
    
    def color_index(self, band1_w, band2_w):
        
        data_1=flux_planet[self.flux_planet.wave>band1_w[0]&self.flux_planet.wave<band1_w[1]]
        data_2=flux_planet[self.flux_planet.wave>band2_w[0]&self.flux_planet.wave<band2_w[1]]
        
        int_flux_1= np.trapz(data_1.flux, data_1.wave)
        int_flux_2= np.trapz(data_2.flux, data_2.wave)
        
        c_index=-2.5*np.log10(int_flux_1/int_flux_2)
        
        return (c_index)
    
    def p_2_s (self,  order_num, plot=True, planet_posi= None):
        
        font = {'size': 4}
        plt.rcParams.update({'font.size': font['size']})

        p_2_s_intr=self.disper_planet['dat_diff_0']/self.disper_star['dat_diff_0']
        
        if planet_posi != None:
            
            p_2_s_planet = self.disper_planet['dat_diff_0']/self.disper_star['dat_diff_%s'%planet_posi]
        else:
            p_2_s_planet = []
        
        
        if plot==True:

            fig,axs=plt.subplots(len(order_num),3)
            plt.MaxNLocator(nbins=3)

            for i in range(len(order_num)):
                for j in range(3):
                    ax=axs[i][j]
                    ind=3*i+j
                    wave_range = self.disper_planet.loc[self.disper_planet['index']==ind, 'wavelength(nm)']
                    ax.plot(wave_range, p_2_s_intr[2046*ind:2046*(ind+1)], linewidth=0.6, alpha=0.7,\
                            label='intrinsic flux ratio')                     

                    if ind + 3 >= len(order_num)*3:
                        ax.set_xlabel('wavelength (nm)')
                    if j%3 ==0: 
                        ax.set_ylabel('$f_p/f_s')
                    ax.set_title('order:%s, detector:%s'%(order_num[i], j+1))
                    ax.set_yscale('log')
                    plt.MaxNLocator(nbins=3)

            fig.suptitle('planet-to-star flux ratio')
            fig.subplots_adjust(hspace=1.3,wspace=0.3)
            
            
            
            if planet_posi != None:
                
                fig1,axs1=plt.subplots(len(order_num),3)

                for i in range(len(order_num)):
                    for j in range(3):
                        ax1=axs1[i][j]
                        ind=3*i+j
                        wave_range = self.disper_planet.loc[self.disper_planet['index']==ind, 'wavelength(nm)']
                        ax1.plot(wave_range, p_2_s_planet[2046*ind:2046*(ind+1)], linewidth=0.6, alpha=0.7,\
                                label='intrinsic flux ratio')                     

                        if ind + 3 >= len(order_num)*3:
                            ax1.set_xlabel('wavelength (nm)')
                        if j%3 ==0: 
                            ax.set_ylabel('$f_p/f_s')
                        ax1.set_title('order:%s, detector:%s'%(order_num[i], j+1))
                        ax1.set_yscale('log')
                        plt.MaxNLocator(nbins=3)
                        
                fig1.suptitle('planet-to-star flux ratio at planet position')
                fig1.subplots_adjust(hspace=1.3, wspace=0.3)
               
                
            
                return(p_2_s_intr, p_2_s_planet, fig, fig1)
            
            else:
        
                return (p_2_s_intr, p_2_s_planet, fig)
        
        else:
            
            return (p_2_s_intr, p_2_s_planet)
    

In [None]:
data=load_json('./modi_input_0607_p_AO_2.json','r')
out_put=[]
for i in data['output']:
    if i == 'psf':
        out_put.append('%s'%i)
    else:   
        for j in data['output'][i]:
            out_put.append('%s_%s'%(i,j))
out_put=[string.lower() for string in out_put]        

In [None]:
simu_star=crires_simu(path_etc_local='./etc_cli.py', path_etc_input='./input_0607_AO.json', path_input_modi='./modi_input_0607_AO_2.json',\
             path_etc_output='./0408/star/', target='star', aperture_size=35, date='0408',\
        order_window=[[1921.318, 1961.128],[1989.978, 2031.165],[2063.711, 2106.392],[2143.087, 2187.386], [2228.786, 2274.835],\
        [2321.596, 2369.534], [2422.415, 2472.388]],\
        order_number=[29, 28, 27, 26, 25, 24, 23])

simu_star.output_json()

simu_star.json2ascii(path_json2ascii='./etc_json2ascii.py')

simu_star.ascii2txt(output=out_put)

data_log=simu_star.signal(detectors=3, contribution='target', focal_plane=(2048, 2048*3))

data_log.to_csv('./0408/star/stellar_simu.csv')

In [None]:
simu_star.plot_signal(wave_count_sort=data_log, max_percentile=99, save_path='./0408/star/star_simu.png', interv=ZScaleInterval())

In [None]:
simu_planet=crires_simu(path_etc_local='./etc_cli.py', path_etc_input='./input_0607_p_AO.json', path_input_modi='./modi_input_0607_p_AO_2.json',\
             path_etc_output='./0408/planet/', target='planet', aperture_size=35, date='0408',\
        order_window=[[1921.318, 1961.128],[1989.978, 2031.165],[2063.711, 2106.392],[2143.087, 2187.386], [2228.786, 2274.835],\
        [2321.596, 2369.534], [2422.415, 2472.388]],\
        order_number=[29, 28, 27, 26, 25, 24, 23],\
        model_name='./models_1683734698/bt-settl_m/lte013-3.5-0.0.BT-Settl.7.dat.txt.dat')

simu_planet.output_json()

simu_planet.json2ascii(path_json2ascii='./etc_json2ascii.py')

simu_planet.ascii2txt(output=out_put)

data_log_p=simu_planet.signal(detectors=3, focal_plane=(2048, 2048*3), contribution='target')

data_log_p.to_csv('./0408/planet/planet_simu.csv')

In [None]:
simu_planet.plot_signal(data_log_p, max_percentile=99, save_path='./0408/planet/planet_simu.png', interv=ZScaleInterval())

In [None]:
data_log_k=simu_planet.signal(detectors=3, focal_plane=(2048, 2048*3), contribution='sky')

data_log_k.to_csv('./0408/sky_noi/sky_simu.csv')

simu_planet.plot_signal(data_log_k, max_percentile=99.99, save_path='./0408/sky_noi/sky_simu.png', interv=ZScaleInterval(), Nor=None)

In [None]:
simu_planet=crires_simu(path_etc_local='./etc_cli.py', path_etc_input='./input_0607_p_AO.json', path_input_modi='./modi_input_0607_p_AO_2.json',\
             path_etc_output='./0408/planet/', target='planet', aperture_size=35, date='0408',\
        order_window=[[1921.318, 1961.128],[1989.978, 2031.165],[2063.711, 2106.392],[2143.087, 2187.386], [2228.786, 2274.835],\
        [2321.596, 2369.534], [2422.415, 2472.388]],\
        order_number=[29, 28, 27, 26, 25, 24, 23],\
        model_name='./models_1683734698/bt-settl_m/lte013-3.5-0.0.BT-Settl.7.dat.txt.dat')

#noise parameter
read=6 #e-/pix
dark=0.003#e-/pix/s
dit=120
ndit=18
nspec=1
nspat=119

npix=nspec*nspat
ndark=dit*dark #e-/pix

#angular separation (pix)
po=int(np.ceil(0.319/0.059))


tot_data_com, fig=simu_planet.combine(data_star=data_log, data_sky=data_log_k, data_p=data_log_p, ron=read**2, dark=ndark, d=po, \
                                focal_plane=(2048,2048*3), plot_combination=True)
fig.savefig('./0408/combine.png')

In [None]:
np.save('./0408/combine_focal_plane.npy', tot_data_com)

In [None]:
data_log_p = pd.read_csv('./0408/planet/planet_simu.csv')
data_log_s = pd.read_csv('./0408/star/stellar_simu.csv')
po=int(np.ceil(0.319/0.059))

p=planet_simu(disper_planet=data_log_p, disper_star=data_log_s)
p2s_in, p2s_pl, plot_in, plot_pl=p.p_2_s(order_num=[29, 28, 27, 26, 25, 24, 23], planet_posi=po)