In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'
os.environ['CRDS_PATH'] = '/path/to/crds_cache/jwst_ops'

from jwst import datamodels
from jwst.pipeline import calwebb_detector1
from crds import get_default_context
print('Using crds context={}'.format(get_default_context()))

from exotic_jedi.stage_1 import CustomBiasStep, DestripingGroupsStep


# Data and version config.
data_chunk_names = ['jw01366003001_04101_00001-seg001_nrs1',
                    'jw01366003001_04101_00001-seg002_nrs1',
                    'jw01366003001_04101_00001-seg003_nrs1']
data_dir = '/path/to/data/jwst/nirspec/g395h/wasp39/visit_01366'
reduction_dir = '/path/to/reduction/jwst/nirspec/g395h/wasp39/visit_01366'
version_dir = os.path.join(reduction_dir, 'reduction_v1')
stage_1_dir = os.path.join(version_dir, 'stage_1')
for _dir in [version_dir, stage_1_dir]:
    if not os.path.exists(_dir):
        os.mkdir(_dir)

# Instantiate STScI steps for NIRSpec stage 1.
stsci_group_scale = calwebb_detector1.group_scale_step.GroupScaleStep()
stsci_dq_init = calwebb_detector1.dq_init_step.DQInitStep()
stsci_saturation = calwebb_detector1.saturation_step.SaturationStep()
stsci_superbias = calwebb_detector1.superbias_step.SuperBiasStep()
stsci_refpix = calwebb_detector1.refpix_step.RefPixStep()
stsci_linearity = calwebb_detector1.linearity_step.LinearityStep()
stsci_dark_current = calwebb_detector1.dark_current_step.DarkCurrentStep()
stsci_jump = calwebb_detector1.jump_step.JumpStep()
stsci_ramp_fit = calwebb_detector1.ramp_fit_step.RampFitStep()
stsci_gain_scale = calwebb_detector1.gain_scale_step.GainScaleStep()

# Instantiate Custom steps for NIRSpec stage 1.
custom_bias = CustomBiasStep()
custom_destriping_groups = DestripingGroupsStep()

# Iterate data chunks.
for data_chunk_name in data_chunk_names:
    print('\n========= Working on {} =========\n'.format(data_chunk_name))

    # Read in chunk.
    raw_data_chunk = os.path.join(
        data_dir, 'raw', '{}_uncal.fits'.format(data_chunk_name))
    dm_raw = datamodels.RampModel(raw_data_chunk)

    # Stage 1 reduction.
    proc = stsci_group_scale.call(dm_raw)
    proc = stsci_dq_init.call(proc)
    proc = stsci_saturation.call(proc, n_pix_grow_sat=1)
    # proc = stsci_superbias.call(proc)
    proc = custom_bias.call(proc)
    proc = stsci_refpix.call(proc, odd_even_columns=True)
    proc = stsci_linearity.call(proc)
    proc = stsci_dark_current.call(proc)
    proc = stsci_jump.call(
        proc, rejection_threshold=10.,
        flag_4_neighbors=True, min_jump_to_flag_neighbors=10.,
        skip=False)
    proc = custom_destriping_groups.call(
        proc,
        start_trace_col=606, end_trace_col=2042,
        poly_order=2, n_sigma_trace_mask=15.,
        dq_bits=[0, 1, 2, 10, 11, 13, 19],
        keep_mean_bkd_level=False,
        draw_mask=False)
    _, proc = stsci_ramp_fit.call(proc)
    stage_1_output = stsci_gain_scale.call(proc)
    stage_1_output.save(path=os.path.join(
        stage_1_dir, '{}_stage_1.fits'.format(data_chunk_name)))
