In [None]:
import os
import numpy as np
import xarray as xr
from functools import partial
import matplotlib.pyplot as plt

os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'
os.environ['CRDS_PATH'] = '/path/to/crds_cache/jwst_ops'

from jwst import datamodels
from jwst.pipeline import calwebb_spec2
import jwst.assign_wcs.nirspec
from crds import get_default_context
print('Using crds context={}'.format(get_default_context()))

from ahoy.stage_2 import GainStep, ReadNoiseStep, FlatFieldStep, \
    WavelengthMapStep, IntegrationTimesStep, StitchChunksStep, \
    InspectDQFlagsStep, CleanOutliersStep, DestripingRateimagesStep, \
    Extract1DBoxStep, Extract1DOptimalStep, AlignSpectraStep


# Data and version config.
data_name = 'jw01366003001_04101_00001-nrs2'
data_chunk_names = ['jw01366003001_04101_00001-seg001_nrs2',
                    'jw01366003001_04101_00001-seg002_nrs2',
                    'jw01366003001_04101_00001-seg003_nrs2']
reduction_dir = '/path/to/reduction/jwst/nirspec/g395h/wasp39/visit_01366'
version_dir = os.path.join(reduction_dir, 'reduction_v1')
stage_1_dir = os.path.join(version_dir, 'stage_1')
stage_2_dir = os.path.join(version_dir, 'stage_2')
for _dir in [stage_2_dir, ]:
    if not os.path.exists(_dir):
        os.mkdir(_dir)

# Instantiate STScI steps for NIRSpec stage 2.
stsci_assign_wcs = calwebb_spec2.assign_wcs_step.AssignWcsStep()
stsci_extract_2d = calwebb_spec2.extract_2d_step.Extract2dStep()
stsci_srctype = calwebb_spec2.srctype_step.SourceTypeStep()
stsci_wavecorr = calwebb_spec2.wavecorr_step.WavecorrStep()

# Mod extract_2d trimming by updating wcs_step slit info.
jwst.assign_wcs.nirspec.nrs_wcs_set_input = partial(
    jwst.assign_wcs.nirspec.nrs_wcs_set_input,
    wavelength_range=[2.3e-06, 5.3e-06], slit_y_low=-1, slit_y_high=50)

# Instantiate Custom steps for NIRSpec stage 2.
custom_gain = GainStep()
custom_readnoise = ReadNoiseStep()
custom_flat = FlatFieldStep()
custom_wavelength_map = WavelengthMapStep()
custom_integration_times = IntegrationTimesStep()
stitch_chunks = StitchChunksStep()
inspect_dq_flags = InspectDQFlagsStep()
clean_outliers = CleanOutliersStep()
destripe_rateimages = DestripingRateimagesStep()
extract_1d_box = Extract1DBoxStep()
extract_1d_optimal = Extract1DOptimalStep()
align_spectra = AlignSpectraStep()

# Read in any chunk.
stage_1_any_data_chunk = os.path.join(
    stage_1_dir, '{}_stage_1.fits'.format(data_chunk_names[0]))
dm_any_stage_1 = datamodels.CubeModel(stage_1_any_data_chunk)

# Stage 2 reduction, part 1: auxiliary data.
proc = stsci_assign_wcs.call(dm_any_stage_1)
proc = stsci_extract_2d.call(proc)
proc = stsci_srctype.call(proc)
proc = stsci_wavecorr.call(proc)
gain = custom_gain.call(
    proc, data_base_name=data_name, data_chunk_name=data_chunk_names[0],
    stage_1_dir=stage_1_dir, stage_2_dir=stage_2_dir,
    trim_col_start=5, trim_col_end=-5, median_value=True)
readnoise = custom_readnoise.call(
    proc, gain_value=gain, data_base_name=data_name,
    data_chunk_name=data_chunk_names[0],
    stage_1_dir=stage_1_dir, stage_2_dir=stage_2_dir,
    trim_col_start=5, trim_col_end=-5, median_value=True)
flat = custom_flat.call(
    proc, data_base_name=data_name, data_chunk_name=data_chunk_names[0],
    stage_2_dir=stage_2_dir, trim_col_start=5, trim_col_end=-5,
    apply=False, skip=True)
wavelength_map = custom_wavelength_map.call(
    proc, data_base_name=data_name, stage_2_dir=stage_2_dir,
    trim_col_start=5, trim_col_end=-5)
integration_times, int_time_s = custom_integration_times.call(
    data_chunk_names, data_base_name=data_name,
    stage_1_dir=stage_1_dir, stage_2_dir=stage_2_dir)

# Stage reduction, part 2: combining.
proc = stitch_chunks.call(
    data_chunk_names, stage_1_dir=stage_1_dir,
    trim_col_start=5, trim_col_end=-5)

# Stage reduction, part 3: cleaning.
inspect_dq_flags.call(proc, draw_dq_flags=False, skip=True)
proc, _ = clean_outliers.call(
    proc, window_width=100, poly_order=4, outlier_threshold=4.0,
    draw_cleaning_grid=False, draw_cleaning_col=False)
proc = destripe_rateimages.call(
    proc,
    start_trace_col=5, end_trace_col=2042,
    poly_order=2, n_sigma_trace_mask=15.,
    draw_psf_fits=False, draw_trace_position=False, draw_mask=False)
proc, P = clean_outliers.call(
    proc, window_width=100, poly_order=4, outlier_threshold=4.0,
    draw_cleaning_grid=False, draw_cleaning_col=False)

# Stage reduction, part 4: extraction.
proc.data = proc.data * int_time_s * gain  # Covert data from DN/s to e-s.
proc.err = np.sqrt(proc.data + readnoise**2)  # Recompute uncertainties.
proc.err[~np.isfinite(proc.err)] = 0.  # Technically should include bkg counts.
wv, spec, spec_unc = extract_1d_box.call(
    proc, wavelength_map,
    start_trace_col=5, end_trace_col=2042,
    poly_order=2, n_sigma_trace_mask=6.,
    draw_spectra=False)
# wv, spec, spec_unc = extract_1d_optimal.call(
#     proc, wavelength_map, P, readnoise,
#     start_trace_col=5, end_trace_col=2042, poly_order=2,
#     median_spatial_profile=False,
#     spatial_profile_windows=[0, 270, 465],
#     draw_spectra=False)
spec, spec_unc, x_shifts, y_shifts = align_spectra.call(
    proc, spec, spec_unc,
    align_spectra=True,
    draw_cross_correlation_fits=False, draw_trace_positions=False)

lc = np.sum(spec[:, 1:-1], axis=1)
norm = np.median(lc[15:135])
lc_pp = np.sqrt(np.sum(np.square(spec_unc[:, 1:-1]), axis=1))
lc /= norm
lc_pp /= norm
print('Photon precision={}'.format(np.median(lc_pp) * 1e6))
print('Pre-transit wlc precision={}'.format(np.std(lc[15:135]) * 1e6))
print('Post-transit wlc precision={}'.format(np.std(lc[360:465]) * 1e6))
# fig, ax1 = plt.subplots(1, 1, figsize=(10, 5))
# ax1.scatter(integration_times[:len(lc)], lc, c='#bc5090', s=5, alpha=0.95)
# ax1.set_ylabel('Relative flux')
# ax1.set_xlabel('Time / BJD')
# plt.tight_layout()
# plt.show()

# Make standardised ERS xarray.
quality_flag = np.ones(np.shape(spec), dtype=bool)
x_shift = x_shifts - np.median(x_shifts)
y_shift = y_shifts - np.median(y_shifts)
ds = xr.Dataset(
    data_vars=dict(
        flux=(["time_flux", "wavelength"], spec, {'units': 'electron'}),
        flux_error=(["time_flux", "wavelength"], spec_unc, {'units': 'electron'}),
        quality_flag=(["time_flux", "wavelength"], quality_flag, {'units': ''}),
        x_shift=(["time_flux"], x_shift, {'units': ''}),
        y_shift=(["time_flux"], y_shift, {'units': ''})
    ),
    coords=dict(
        wavelength=(["wavelength"], wv, {'units': 'micron'}),
        time_flux=(["time_flux"], integration_times, {'units': 'bjd'}),
    ),
    attrs=dict(author="D Grant",
               contact="email",
               code="code",
               notes="notes",
               normalised="No",
               doi="none",
    )
)

res_path = os.path.join(
    stage_2_dir, "stellar-spec-W39-G395H-NRS2-Grant-v2.nc")
ds.to_netcdf(res_path)

