## Imports from standard packages

In [2]:
import numpy as np
from scipy.interpolate import interp1d
from astropy.io import fits

## Local imports

In [3]:
from extract.overlap import TrpzOverlap
from extract.throughput import ThroughputSOSS
from extract.convolution import WebbKer

# Simulations inputs

In [4]:
WebbKer.file_frame = "spectral_kernel_matrix_os_{}_width_{}pixels_cut.fits"
WebbKer.path = ''

In [5]:
scale_flux = 1e2

tilt = True

# Main kwargs for simulation
overlap_kwargs = {"n_os": 10,
                  "thresh": 1e-8}
# Convolution kwargs
# c_kwargs = {"thresh": 1e-6}
c_kwargs={'n_out':[5*10, 8*10], 'length':21*10+1}

# Output file
output_file = f"../Simulations/phoenix_teff_09000_scale_{scale_flux:.1e}_vsini_5_cutker.fits"

### Use interpolated PHOENIX spectrum

In [6]:
from simulation_utils import load_simu

In [7]:
# path = "/Users/antoinedb/Documents/Doctorat/SOSS/"
# model_file = "Z-0.0-lte09000-4.00-0.0.PHOENIX-ACES-AGSS-COND-2011-n_os-15.npz"
# spec_file = np.load(path+model_file)

spec_file = load_simu(f"../Simulations/phoenix_teff_09000_scale_{scale_flux:.1e}_vsini_5.fits")

In [8]:
spec_file

{'grid': array([0.55028933, 0.5503206 , 0.55035187, ..., 2.99896924, 2.99903714,
        2.99910505]),
 'f_k': array([4.72029320e+17, 4.71761943e+17, 4.71523937e+17, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00]),
 'grid_c1': array([0.83615988, 0.83619105, 0.83622223, ..., 2.83632606, 2.83639357,
        2.83646109]),
 'f_c1': array([1.27513409e+17, 1.27493226e+17, 1.27473109e+17, ...,
        2.24926315e+15, 2.24904694e+15, 2.24883060e+15]),
 'grid_c2': array([0.5529112 , 0.55294192, 0.55297263, ..., 1.41168005, 1.4117111 ,
        1.41176256]),
 'f_c2': array([4.64368361e+17, 4.64324848e+17, 4.64286324e+17, ...,
        2.77874001e+16, 2.77850474e+16, 2.77812351e+16]),
 'data': array([[ 14.10512256,  -3.95967257,  10.52447484, ..., -14.13736113,
         -20.94488784,  15.26676589],
        [-17.11189848, -14.48966492,  -2.6145901 , ...,  -1.57250597,
           6.15776183,  -0.287649  ],
        [-23.88106705, -26.1715474 ,   0.77425164, ..., -12.09461067,
         -1

In [9]:
# wv, flux = spec_file["wave"], spec_file["flux"]
wv, flux = spec_file["grid"], spec_file["f_k"]
# # Multiplication by a fudge factor to get
# # a realistic number of counts on the detector
# flux *= scale_flux

flux_interp = interp1d(wv, flux, kind="cubic", bounds_error=False, fill_value=0.)

# Simulations

In [10]:
from extract.convolution import fwhm2sigma, gaussians

class GaussKer:

    def __init__(self, grid, res, bounds_error=False,
                 fill_value="extrapolate", **kwargs):
        """
        Parameters
        ----------
        grid : 1d array
            Grid used to define the kernels
        fwhm: ...
            ...
        bounds_error, fill_value and kwargs:
            `interp1d` kwargs used to get FWHM as a function of the grid.
        """
        fwhm = grid / res
        
        # What we really want is sigma, not FWHM
        sig = fwhm2sigma(fwhm)

        # Now put sigma as a function of the grid
        sig = interp1d(grid, sig, bounds_error=bounds_error,
                       fill_value=fill_value, **kwargs)

        self.fct_sig = sig

    def __call__(self, x, x0):
        """
        Parameters
        ----------
        x: 1d array
            position where the kernel is evaluated
        x0: 1d array (same shape as x)
            position of the kernel center for each x.

        Returns
        -------
        Value of the gaussian kernel for each sets of (x, x0)
        """

        # Get the sigma of each gaussians
        sig = self.fct_sig(x0)

        return gaussians(x, x0, sig)

In [11]:
# List of orders to consider in the extraction
order_list = [1, 2]

path = "../jwst-mtl/SOSS/extract/Ref_files/"

#### Wavelength solution ####
wave_maps = []
wave_maps.append(fits.getdata(path + "wavelengths_m1.fits"))
wave_maps.append(fits.getdata(path + "wavelengths_m2.fits"))

if not tilt:
    # Remove the tilt from wv maps
    wave_maps[0] = np.tile(wave_maps[0][50,:], (256, 1))
    wave_maps[1] = np.tile(wave_maps[1][50,:], (256, 1))

#### Spatial profiles ####
spat_pros = []
spat_pros.append(fits.getdata(path + "spat_profile_m1.fits").squeeze())
spat_pros.append(fits.getdata(path + "spat_profile_m2.fits").squeeze())

# Convert data from fits files to float (fits precision is 1e-8)
wave_maps = [wv.astype('float64') for wv in wave_maps]
spat_pros = [p_ord.astype('float64') for p_ord in spat_pros]

#### Throughputs ####
thrpt_list = [ThroughputSOSS(order) for order in order_list]

#### Convolution kernels ####
ker_list = [WebbKer(wv_map) for wv_map in wave_maps]
# ker_list = [GaussKer(wv, res) for res in [2000, 900]]

# Put all inputs from reference files in a list
ref_files_args = [spat_pros, wave_maps, thrpt_list, ker_list]

In [12]:
# # Read relevant files
# wv_1 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/wavelengths_m1.fits")[0].data
# wv_2 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/wavelengths_m2.fits")[0].data
# P1 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/spat_profile_m1.fits")[0].data.squeeze()
# P2 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/spat_profile_m2.fits")[0].data.squeeze()

# # Convert to float (fits precision is 1e-8)
# wv_1 = wv_1.astype(float)
# wv_2 = wv_2.astype(float)
# P1 = P1.astype(float)
# P2 = P2.astype(float)

# if not tilt:
#     # Remove the tilt from wv maps
#     wv_1 = np.tile(wv_1[50,:], (256, 1))
#     wv_2 = np.tile(wv_2[50,:], (256, 1))
    

#### Initiate a simulation ####
simu = TrpzOverlap(*ref_files_args, c_kwargs=c_kwargs, **overlap_kwargs)
# simu = TrpzOverlap([P1,P2], [wv_1,wv_2], c_kwargs=c_kwargs, **overlap_kwargs)

### Inject spectrum

# Generate flux to inject
flux = flux_interp(simu.lam_grid)

# Init outputs
out_ord = [{} for i in range(simu.n_ord)]
out_full = {}

# Inject order 1 and 2 separately (we don't want any contamination here)
for i_ord in range(simu.n_ord):
    out_ord[i_ord]["data"] = simu.rebuild(flux, orders=[i_ord])

# Inject both orders (full)
out_full["data"] = simu.rebuild(flux)

## Add noise

In [13]:
from simulation_utils import add_noise

for out_dict in out_ord:
    out_dict["noisy"] = add_noise(out_dict["data"])

out_full["noisy"] = add_noise(out_full["data"])

# Save full simu and for each orders separately

In [14]:
hdr = fits.Header()
for key in overlap_kwargs:
    hdr[key.upper()] = overlap_kwargs[key]
    
for key in c_kwargs:
    hkey = "C_" + key.upper()
    hdr[hkey] = str(c_kwargs[key])

hdr["TILTED"] = tilt

# Save headers
primary_hdu = fits.PrimaryHDU(header=hdr)
hdul = fits.HDUList([primary_hdu])

# Save flux
col_list = []
col_list.append(fits.Column(name="lam_grid", array=simu.lam_grid, format="D"))
col_list.append(fits.Column(name='f_lam', array=flux, format='D'))
table_hdu = fits.BinTableHDU.from_columns(col_list, name='FLUX')
hdul.append(table_hdu)

for i_ord, out in enumerate(out_ord):
    name = f"FLUX_C{simu.orders[i_ord]}"
    x = simu.lam_grid_c(i_ord)
    y = simu.c_list[i_ord].dot(flux)
    col_list = []
    col_list.append(fits.Column(name='lam_grid', array=x, format='D'))
    col_list.append(fits.Column(name='f_lam', array=y, format='D'))
    table_hdu = fits.BinTableHDU.from_columns(col_list, name=name)
    hdul.append(table_hdu)

# Save detector simu
hdul.append(fits.ImageHDU(out_full["data"], name="FULL"))
hdul.append(fits.ImageHDU(out_full["noisy"], name="FULL NOISY"))
for i_ord, out in enumerate(out_ord):
    name = f"ORD {simu.orders[i_ord]}"
    hdul.append(fits.ImageHDU(out["data"], name=name))
    hdul.append(fits.ImageHDU(out["noisy"], name=name + " NOISY"))
    
# Write to file
hdul.writeto(output_file)