## Imports from standard packages

In [4]:
import numpy as np
from scipy.interpolate import interp1d
from astropy.io import fits

## Local imports

In [9]:
from extract.overlap import TrpzOverlap

# Simulations inputs

In [105]:
scale_flux = 1e-5

tilt = True

# Main kwargs for simulation
overlap_kwargs = {"n_os": 15,
                  "thresh": 1e-8}
# Convolution kwargs
c_kwargs = {"thresh": 1e-6}

# Output file
output_file = f"Simulations/phoenix_teff_02300_scale_{scale_flux:.1e}.fits"

### Use interpolated PHOENIX spectrum

In [106]:
path = "/Users/antoinedb/Documents/Doctorat/SOSS/"
model_file = "Z-0.0-lte02300-0.00-0.0.PHOENIX-ACES-AGSS-COND-2011-n_os-15.npz"
spec_file = np.load(path+model_file)

In [107]:
wv, flux = spec_file["wave"], spec_file["flux"]
# Multiplication by a fudge factor to get
# a realistic number of counts on the detector
flux *= scale_flux

flux_interp = interp1d(wv, flux, kind="cubic", bounds_error=False, fill_value=0.)

# Simulations

In [108]:
# Read relevant files
wv_1 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/wavelengths_m1.fits")[0].data
wv_2 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/wavelengths_m2.fits")[0].data
P1 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/spat_profile_m1.fits")[0].data.squeeze()
P2 = fits.open("../jwst-mtl/SOSS/extract/Ref_files/spat_profile_m2.fits")[0].data.squeeze()

# Convert to float (fits precision is 1e-8)
wv_1 = wv_1.astype(float)
wv_2 = wv_2.astype(float)
P1 = P1.astype(float)
P2 = P2.astype(float)

if not tilt:
    # Remove the tilt from wv maps
    wv_1 = np.tile(wv_1[50,:], (256, 1))
    wv_2 = np.tile(wv_2[50,:], (256, 1))
    

#### Initiate a simulation ####

simu = TrpzOverlap([P1,P2], [wv_1,wv_2], c_kwargs=c_kwargs, **overlap_kwargs)

### Inject spectrum

# Generate flux to inject
flux = flux_interp(simu.lam_grid)

# Init outputs
out_ord = [{} for i in range(simu.n_ord)]
out_full = {}

# Inject order 1 and 2 separately (we don't want any contamination here)
for i_ord in range(simu.n_ord):
    out_ord[i_ord]["data"] = simu.rebuild(flux, orders=[i_ord])

# Inject both orders (full)
out_full["data"] = simu.rebuild(flux)

## Add noise

In [109]:
from simulation_utils import add_noise

for out_dict in out_ord:
    out_dict["noisy"] = add_noise(out_dict["data"])

out_full["noisy"] = add_noise(out_full["data"])

# Save full simu and for each orders separately

In [110]:
hdr = fits.Header()
for key in overlap_kwargs:
    hdr[key.upper()] = overlap_kwargs[key]
    
for key in c_kwargs:
    hkey = "C_" + key.upper()
    hdr[hkey] = c_kwargs[key]

hdr["TILTED"] = tilt

# Save headers
primary_hdu = fits.PrimaryHDU(header=hdr)
hdul = fits.HDUList([primary_hdu])

# Save flux
col_list = []
col_list.append(fits.Column(name="lam_grid", array=simu.lam_grid, format="D"))
col_list.append(fits.Column(name='f_lam', array=flux, format='D'))
table_hdu = fits.BinTableHDU.from_columns(col_list, name='FLUX')
hdul.append(table_hdu)

for i_ord, out in enumerate(out_ord):
    name = f"FLUX_C{simu.orders[i_ord]}"
    x = simu.lam_grid_c(i_ord)
    y = simu.c_list[i_ord].dot(flux)
    col_list = []
    col_list.append(fits.Column(name='lam_grid', array=x, format='D'))
    col_list.append(fits.Column(name='f_lam', array=y, format='D'))
    table_hdu = fits.BinTableHDU.from_columns(col_list, name=name)
    hdul.append(table_hdu)

# Save detector simu
hdul.append(fits.ImageHDU(out_full["data"], name="FULL"))
hdul.append(fits.ImageHDU(out_full["noisy"], name="FULL NOISY"))
for i_ord, out in enumerate(out_ord):
    name = f"ORD {simu.orders[i_ord]}"
    hdul.append(fits.ImageHDU(out["data"], name=name))
    hdul.append(fits.ImageHDU(out["noisy"], name=name + " NOISY"))
    
# Write to file
hdul.writeto(output_file)