In [None]:
import os
import numpy as np
import pandas as pd
import xarray as xr
from scipy import interpolate
import matplotlib.pyplot as plt
from exotic_ld import StellarLimbDarkening
from scipy.ndimage import gaussian_filter1d

from exotic_jedi.stage_3.leastsqs_light_curve_fits import fit_white_light_curve, \
    fit_spec_light_curve
from exotic_jedi.stage_3.map_light_curve_fits import min_white_light_curve, \
    min_spec_light_curve
from exotic_jedi.stage_3.mcmc_light_curve_fits import MCMCWhiteLightCurve, \
    MCMCSpecLightCurve


# Setup paths.
ld_data_path = '/path/to/data/limb_darkening'
stellar_specs = ['stellar-spec-W39-G395H-NRS1-Grant.nc',  # NRS1 stellar spec.
                 'stellar-spec-W39-G395H-NRS2-Grant.nc']  # NRS2 stellar spec.
reduction_dir = '/path/to/reduction/jwst/nirspec/g395h/wasp39/visit_01366'
version_dir = os.path.join(reduction_dir, 'reduction_v1')
stage_2_dir = os.path.join(version_dir, 'stage_2')
stage_3_dir = os.path.join(version_dir, 'stage_3')
bin_scheme = os.path.join(stage_3_dir, 'w39b_ers_g395h_10pix_bins_v3.csv')
for _dir in [stage_3_dir, ]:
    if not os.path.exists(_dir):
        os.mkdir(_dir)

# Config.
start_cut = 15  # Starting integration for lcs.
wlc_wv_bins = [(2.779, 3.717),  # NRS1 binning [um].
               (3.824, 5.173)]  # NRS2 binning [um].
spec_wv_bins = pd.read_csv(bin_scheme).values[:, 0:3]  # Spectroscopic binning [um.]
tilt_start = 269  # Start of tilt integration.
tilt_end = 272  # End of tilt integration.
outlier_threshold = 4.  # Outlier def for light curve residuals.
err_inflation = 1.48  # Red noise multiplier.
fitting_method = 'least-sqs'  # 'least-sqs', 'map', 'mcmc'.
n_cpus = 8  # if fitting_method = 'mcmc'.
fitting_version = 'v2'
check_cols = False
check_bad_cols = False
draw_fits = False
draw_results = True
save_transmission_spectrum = True

# Bad columns mask.
bad_cols = [
    np.array([789, 804, 805, 976, 1063, 1064, 1074, 1214, 1215,
              1250, 1308, 1395, 1416, 1442, 1446, 1447], dtype=int),
    np.array([62, 64, 164, 618, 772, 1092, 1096, 1517, 1918,
              1969, 1970, 1971, 1088], dtype=int)]
if check_cols:
    for s_idx, ss in enumerate(stellar_specs):
        spec_path = os.path.join(stage_2_dir, ss)
        ds_ = xr.open_dataset(spec_path)
        for col_idx, pixel_lc in enumerate(ds_['flux'].values.T):
            if check_bad_cols:
                if col_idx not in bad_cols[s_idx]:
                    continue
            print('Pixel light curve NRS{} at column idx={}.'
                  .format(s_idx + 1, col_idx))
            fig, ax1 = plt.subplots(1, 1, figsize=(11, 6))
            ax1.scatter(ds_['time_flux'].values, pixel_lc, s=15, c='#000000')
            ax1.set_xlabel('Time / BJD')
            ax1.set_ylabel('Relative flux')
            plt.show()


if __name__ == "__main__":

    mid_wvs = []
    wvs_widths = []
    tds = []
    tds_errs = []
    res_ps = []
    us = []
    all_popt = []
    photon_nf = []

    wlc_raw_fluxes = []
    wlc_raw_flux_errors = []
    wlc_light_curve_models = []
    wlc_corrected_fluxes = []
    wlc_corrected_flux_errors = []
    wlc_systematic_models = []
    wlc_residuals = []

    raw_fluxes = []
    raw_flux_errors = []
    light_curve_models = []
    corrected_fluxes = []
    corrected_flux_errors = []
    systematic_models = []
    residuals = []

    for s_idx, ss in enumerate(stellar_specs):

        spec_path = os.path.join(stage_2_dir, ss)
        print("Stellar spectra ={}".format(spec_path))
        ds_ = xr.open_dataset(spec_path)

        # Get integration times.
        pre_tilt_times = ds_['time_flux'].values[start_cut:tilt_start]
        pst_tilt_times = ds_['time_flux'].values[tilt_end:]
        tilt_idx = pre_tilt_times.shape[0]

        # Get wavelength bin in terms of pixel indexes.
        wlc_wv_bin_start = wlc_wv_bins[s_idx][0]
        wlc_wv_bin_end = wlc_wv_bins[s_idx][1]
        wvs = ds_['wavelength'].values
        wlc_wv_bin_idxs = np.where(np.logical_and(
            wvs >= wlc_wv_bin_start, wvs < wlc_wv_bin_end))[0]
        print('\nWhite light curve wavelength bins {} --{} um'.format(
            wlc_wv_bin_start, wlc_wv_bin_end))
        print('Pixel bin={}pxs w/ cols {} -- {}'.format(
            wlc_wv_bin_idxs.shape[0], wlc_wv_bin_idxs[0], wlc_wv_bin_idxs[-1]))

        # Make bad column mask.
        bad_cols_win = bad_cols[s_idx]
        bad_cols_win_shift = bad_cols_win - wlc_wv_bin_idxs[0]
        bad_cols_mask = np.ones(wlc_wv_bin_idxs.shape[0], dtype=bool)
        bad_cols_mask[bad_cols_win_shift] = False

        # Fluxes.
        pre_tilt_f = ds_['flux'].values[
            start_cut:tilt_start, wlc_wv_bin_idxs[0]:wlc_wv_bin_idxs[-1] + 1]
        pst_tilt_f = ds_['flux'].values[
            tilt_end:, wlc_wv_bin_idxs[0]:wlc_wv_bin_idxs[-1] + 1]
        pre_tilt_wlc = np.sum(pre_tilt_f[:, bad_cols_mask], axis=1)
        pst_tilt_wlc = np.sum(pst_tilt_f[:, bad_cols_mask], axis=1)

        # Flux uncertainties.
        pre_tilt_f_err = ds_['flux_error'].values[
            start_cut:tilt_start, wlc_wv_bin_idxs[0]:wlc_wv_bin_idxs[-1] + 1]
        pst_tilt_f_err = ds_['flux_error'].values[
            tilt_end:, wlc_wv_bin_idxs[0]:wlc_wv_bin_idxs[-1] + 1]
        pre_tilt_wlc_err = np.sqrt(np.sum(np.square(
            pre_tilt_f_err[:, bad_cols_mask]), axis=1))
        pst_tilt_wlc_err = np.sqrt(np.sum(np.square(
            pst_tilt_f_err[:, bad_cols_mask]), axis=1))

        # Normalise light curve pre- and post-tilt.
        pre_tilt_norm = np.median(pre_tilt_wlc[:135])
        pre_tilt_wlc /= pre_tilt_norm
        pre_tilt_wlc_err /= pre_tilt_norm
        pst_tilt_norm = np.median(pst_tilt_wlc[-115:])
        pst_tilt_wlc /= pst_tilt_norm
        pst_tilt_wlc_err /= pst_tilt_norm

        # Trace positions.
        pre_tilt_x_shits = ds_['x_shift'].values[start_cut:tilt_start]
        pst_tilt_x_shits = ds_['x_shift'].values[tilt_end:]
        pre_tilt_y_shits = ds_['y_shift'].values[start_cut:tilt_start]
        pst_tilt_y_shits = ds_['y_shift'].values[tilt_end:]

        # Standardise de-trending parameters.
        pre_tilt_x_shits -= np.median(pre_tilt_x_shits)
        pre_tilt_x_shits /= np.std(pre_tilt_x_shits)
        pre_tilt_y_shits -= np.median(pre_tilt_y_shits)
        pre_tilt_y_shits /= np.std(pre_tilt_y_shits)
        pst_tilt_x_shits -= np.median(pst_tilt_x_shits)
        pst_tilt_x_shits /= np.std(pst_tilt_x_shits)
        pst_tilt_y_shits -= np.median(pst_tilt_y_shits)
        pst_tilt_y_shits /= np.std(pst_tilt_y_shits)

        # Combine data pre- and post-tilt.
        times = np.concatenate([pre_tilt_times, pst_tilt_times])
        wlc = np.concatenate([pre_tilt_wlc, pst_tilt_wlc])
        wlc_err = np.concatenate([pre_tilt_wlc_err, pst_tilt_wlc_err])

        # Custom throughput.
        sld = StellarLimbDarkening(
            M_H=0.0, Teff=5512., logg=4.47, ld_model='3D',
            ld_data_path=ld_data_path)
        ss_wavelengths = sld._stellar_wavelengths * 1e-4
        ss_intensity = np.sum(sld._stellar_fluxes, axis=0)
        interp_function = interpolate.interp1d(
            ss_wavelengths, ss_intensity, kind="linear")
        ss_intensity = interp_function(ds_['wavelength'].values)
        custom_throughput = gaussian_filter1d(
            np.median(ds_['flux'].values, axis=0) / ss_intensity, 100)

        # Theoretical limb-darkening parameters.
        u1, u2, u3, u4 = sld.compute_4_parameter_non_linear_ld_coeffs(
            wavelength_range=(np.min(ds_['wavelength'].values * 1.e4),
                              np.max(ds_['wavelength'].values * 1.e4)),
            mode='custom',
            custom_wavelengths=ds_['wavelength'].values * 1.e4,
            custom_throughput=custom_throughput)

        if fitting_method == 'least-sqs':
            # Fit white light curve transit: least-squares lm optimisation.
            t0_wlc_fit, a_wlc_fit, inc_wlc_fit, fit_dict = fit_white_light_curve(
                times, wlc, wlc_err * err_inflation, pre_tilt_x_shits,
                pre_tilt_y_shits, pst_tilt_x_shits, pst_tilt_y_shits,
                u1, u2, u3, u4, tilt_idx=tilt_idx, draw_fits=draw_fits)

        elif fitting_method == 'map':
            # Fit white light curve transit: map optimisation.
            t0_wlc_fit, a_wlc_fit, inc_wlc_fit, fit_dict = min_white_light_curve(
                times, wlc, wlc_err * err_inflation, pre_tilt_x_shits,
                pre_tilt_y_shits, pst_tilt_x_shits, pst_tilt_y_shits,
                u1, u2, u3, u4, tilt_idx=tilt_idx, draw_fits=draw_fits)

        elif fitting_method == 'mcmc':
            # Fit white light curve transit: bayesian sampling.
            mcmc_wlc = MCMCWhiteLightCurve(
                times, wlc, wlc_err * err_inflation, pre_tilt_x_shits,
                pre_tilt_y_shits, pst_tilt_x_shits, pst_tilt_y_shits,
                tilt_idx, u1, u2, u3, u4)
            mcmc_wlc.n_cpus = n_cpus
            t0_wlc_fit, a_wlc_fit, inc_wlc_fit, fit_dict = \
                mcmc_wlc.sample_white_light_curve(draw_fits=draw_fits)

        else:
            raise ValueError('Fitting method not recognised.')

        # For white light curve xarray.
        wlc_raw_fluxes.append(wlc)
        wlc_raw_flux_errors.append(wlc_err)
        wlc_light_curve_models.append(fit_dict['light_curve_model'])
        wlc_corrected_fluxes.append(fit_dict['corrected_flux'])
        wlc_corrected_flux_errors.append(fit_dict['corrected_flux_error'])
        wlc_systematic_models.append(fit_dict['systematic_model'])
        wlc_residuals.append(fit_dict['residual'])

        # Iterate spectral bins.
        for lc_wv_bin_start, lc_wv_bin_end in zip(spec_wv_bins[:, 0],
                                                  spec_wv_bins[:, 1]):

            # Get wavelength bin in terms of pixel indexes.
            lc_wv_bin_idxs = np.where(np.logical_and(
                wvs >= lc_wv_bin_start, wvs < lc_wv_bin_end))[0]
            if lc_wv_bin_idxs.shape[0] == 0:
                print('No pixels on NRS{} in wavelength range {} -- {}.'
                      .format(s_idx + 1, lc_wv_bin_start, lc_wv_bin_end))
                continue

            # Set wavelength bin.
            mid_wv = np.median(wvs[lc_wv_bin_idxs[0]:lc_wv_bin_idxs[-1] + 1])
            wv_width = lc_wv_bin_end - lc_wv_bin_start
            wv_left_ang = (mid_wv - wv_width/2) * 1.e4
            wv_right_ang = (mid_wv + wv_width/2) * 1.e4
            print('\nWavelength bin={} um'.format(np.round(mid_wv, 4)))
            print('Pixel bin={}pxs w/ cols {} -- {}'.format(
                lc_wv_bin_idxs.shape[0], lc_wv_bin_idxs[0], lc_wv_bin_idxs[-1]))

            # temp: investigate specific fits.
            # if abs(mid_wv - 3.921) > 0.005:
            #     continue

            # Make bad column mask.
            bad_cols_win = bad_cols[s_idx][
                (lc_wv_bin_idxs[0] <= bad_cols[s_idx])
                & (bad_cols[s_idx] <= lc_wv_bin_idxs[-1])]
            bad_cols_win_shift = bad_cols_win - lc_wv_bin_idxs[0]
            bad_cols_mask = np.ones(lc_wv_bin_idxs.shape[0], dtype=bool)
            bad_cols_mask[bad_cols_win_shift] = False
            if np.sum(~bad_cols_mask) == lc_wv_bin_idxs.shape[0]:
                print('No good columns in wavelength bin, skipping.')
                continue

            # Fluxes.
            pre_tilt_f = ds_['flux'].values[
                start_cut:tilt_start, lc_wv_bin_idxs[0]:lc_wv_bin_idxs[-1] + 1]
            pst_tilt_f = ds_['flux'].values[
                tilt_end:, lc_wv_bin_idxs[0]:lc_wv_bin_idxs[-1] + 1]
            pre_tilt_lc = np.sum(pre_tilt_f[:, bad_cols_mask], axis=1)
            pst_tilt_lc = np.sum(pst_tilt_f[:, bad_cols_mask], axis=1)

            # Flux uncertainties.
            pre_tilt_f_err = ds_['flux_error'].values[
                start_cut:tilt_start, lc_wv_bin_idxs[0]:lc_wv_bin_idxs[-1] + 1]
            pst_tilt_f_err = ds_['flux_error'].values[
                tilt_end:, lc_wv_bin_idxs[0]:lc_wv_bin_idxs[-1] + 1]
            pre_tilt_lc_err = np.sqrt(np.sum(np.square(
                pre_tilt_f_err[:, bad_cols_mask]), axis=1))
            pst_tilt_lc_err = np.sqrt(np.sum(np.square(
                pst_tilt_f_err[:, bad_cols_mask]), axis=1))

            # Normalise light curve pre- and post-tilt.
            pre_tilt_norm = np.median(pre_tilt_lc[:135])
            pre_tilt_lc /= pre_tilt_norm
            pre_tilt_lc_err /= pre_tilt_norm
            pst_tilt_norm = np.median(pst_tilt_lc[-115:])
            pst_tilt_lc /= pst_tilt_norm
            pst_tilt_lc_err /= pst_tilt_norm

            # Combine data pre- and post-tilt.
            lc = np.concatenate([pre_tilt_lc, pst_tilt_lc])
            lc_err = np.concatenate([pre_tilt_lc_err, pst_tilt_lc_err])

            # Theoretical limb-darkening parameters.
            u1, u2, u3, u4 = sld.compute_4_parameter_non_linear_ld_coeffs(
                wavelength_range=(wv_left_ang, wv_right_ang),
                mode='custom',
                custom_wavelengths=ds_['wavelength'].values * 1.e4,
                custom_throughput=custom_throughput)

            if fitting_method == 'least-sqs':
                # Fit spectroscopic light curve transit: least-squares lm optimisation.
                transit_depth, transit_depth_err, res_p, popt, fit_dict = \
                    fit_spec_light_curve(
                        times, lc, lc_err * err_inflation, pre_tilt_x_shits,
                        pre_tilt_y_shits, pst_tilt_x_shits, pst_tilt_y_shits,
                        u1, u2, u3, u4, t0_wlc_fit, a_wlc_fit, inc_wlc_fit,
                        tilt_idx=tilt_idx, outlier_threshold=outlier_threshold,
                        draw_fits=draw_fits)

            elif fitting_method == 'map':
                # Fit spectroscopic light curve transit: map optimisation.
                transit_depth, transit_depth_err, res_p, popt, fit_dict = \
                    min_spec_light_curve(
                        times, lc, lc_err * err_inflation, pre_tilt_x_shits,
                        pre_tilt_y_shits, pst_tilt_x_shits, pst_tilt_y_shits,
                        u1, u2, u3, u4, t0_wlc_fit, a_wlc_fit, inc_wlc_fit,
                        tilt_idx=tilt_idx, outlier_threshold=outlier_threshold,
                        draw_fits=draw_fits)

            elif fitting_method == 'mcmc':
                # Fit spectroscopic light curve transit: bayesian sampling.
                mcmc_lc = MCMCSpecLightCurve(
                    times, lc, lc_err * err_inflation, pre_tilt_x_shits,
                    pre_tilt_y_shits, pst_tilt_x_shits, pst_tilt_y_shits,
                    tilt_idx, u1, u2, u3, u4, t0_wlc_fit, a_wlc_fit, inc_wlc_fit)
                mcmc_lc.n_cpus = n_cpus
                transit_depth, transit_depth_err, res_p, popt, fit_dict = \
                    mcmc_lc.sample_spec_light_curve(draw_fits=draw_fits)

            else:
                raise ValueError('Fitting method not recognised.')

            mid_wvs.append(mid_wv)
            wvs_widths.append(wv_width)
            tds.append(transit_depth)
            tds_errs.append(transit_depth_err)
            us.append([u1, u2, u3, u4])
            res_ps.append(res_p)
            all_popt.append(popt)

            # Compute photon limit for bin.
            no_mask_flux = np.sum(ds_['flux'].values[
                :, lc_wv_bin_idxs[0]:lc_wv_bin_idxs[-1] + 1], axis=1)
            no_mask_wlc_err_ = np.sqrt(np.sum(np.square(
                ds_['flux_error'].values[:, lc_wv_bin_idxs[0]:
                                            lc_wv_bin_idxs[-1] + 1]), axis=1)) \
                / no_mask_flux
            photon_nf.append(np.median(no_mask_wlc_err_))

            # For spec light curve xarray.
            raw_fluxes.append(lc)
            raw_flux_errors.append(lc_err)
            light_curve_models.append(fit_dict['light_curve_model'])
            corrected_fluxes.append(fit_dict['corrected_flux'])
            corrected_flux_errors.append(fit_dict['corrected_flux_error'])
            systematic_models.append(fit_dict['systematic_model'])
            residuals.append(fit_dict['residual'])

    mid_wvs = np.array(mid_wvs)
    wvs_widths = np.array(wvs_widths)
    tds = np.array(tds)
    tds_errs = np.array(tds_errs)
    us = np.array(us)
    res_ps = np.array(res_ps)
    all_popt = np.array(all_popt)
    photon_nf = np.array(photon_nf)

    wlc_raw_fluxes = np.array(wlc_raw_fluxes)
    wlc_raw_flux_errors = np.array(wlc_raw_flux_errors)
    wlc_light_curve_models = np.array(wlc_light_curve_models)
    wlc_corrected_fluxes = np.array(wlc_corrected_fluxes)
    wlc_corrected_flux_errors = np.array(wlc_corrected_flux_errors)
    wlc_systematic_models = np.array(wlc_systematic_models)
    wlc_residuals = np.array(wlc_residuals)

    raw_fluxes = np.array(raw_fluxes)
    raw_flux_errors = np.array(raw_flux_errors)
    light_curve_models = np.array(light_curve_models)
    corrected_fluxes = np.array(corrected_fluxes)
    corrected_flux_errors = np.array(corrected_flux_errors)
    systematic_models = np.array(systematic_models)
    residuals = np.array(residuals)

    if draw_results:
        # Transmission spectrum.
        fig, ax1 = plt.subplots(1, 1, figsize=(11, 6))
        ax1.errorbar(mid_wvs, tds * 1.e6, xerr=wvs_widths / 2, yerr=tds_errs * 1.e6,
                     fmt='.', color='#bc5090')
        ax1.set_xlabel('Wavelength / um')
        ax1.set_ylabel('Transit depth / ppm')
        ax1.set_ylim(19900, 23600)
        plt.tight_layout()
        plt.show()

        # Noise limit comparison.
        print('Median spec lc precision={} xphoton-noise'.format(
            np.median(res_ps / (photon_nf * 1.e6))))
        fig, ax1 = plt.subplots(1, 1, figsize=(11, 6))
        ax1.scatter(mid_wvs, res_ps, s=20, color='#58508d')
        ax1.plot(mid_wvs, photon_nf * 1.e6,
                 linestyle='--', color='#000000', alpha=0.4)
        ax1.plot(mid_wvs, photon_nf * 1.e6 * 2,
                 linestyle='--', color='#000000', alpha=0.4)
        ax1.plot(mid_wvs, photon_nf * 1.e6 * 3,
                 linestyle='--', color='#000000', alpha=0.4)
        ax1.text(mid_wvs[-1] + 0.04, photon_nf[-1] * 1.e6, 'photon noise')
        ax1.text(mid_wvs[-1] + 0.04, photon_nf[-1] * 2.e6, 'x2 photon noise')
        ax1.text(mid_wvs[-1] + 0.04, photon_nf[-1] * 3.e6, 'x3 photon noise')
        ax1.set_xlim(2.6, 5.6)
        ax1.set_xlabel('Wavelength / um')
        ax1.set_ylabel('Spectroscopic light curve residuals $\sigma$ / ppm')
        plt.tight_layout()
        plt.show()

        # Limb-darkening parameters
        fig, ax1 = plt.subplots(1, 1, figsize=(11, 6))
        ax1.plot(mid_wvs, us[:, 0])
        ax1.plot(mid_wvs, us[:, 1])
        ax1.plot(mid_wvs, us[:, 2])
        ax1.plot(mid_wvs, us[:, 3])
        ax1.set_xlabel('Wavelength / um')
        ax1.set_ylabel('Limb-darkening coeffs')
        plt.tight_layout()
        plt.show()

        # All fitted parameters quick look.
        for po in range(all_popt.shape[1]):
            print('popt idx={}'.format(po))
            plt.plot(all_popt[:, po])
            plt.show()

    if save_transmission_spectrum:
        # Make standardised ERS xarray for fitted white light curves.
        wlc_quality_flag = np.ones(wlc_raw_fluxes.shape).astype(bool)
        quality_flag = np.ones(raw_fluxes.shape).astype(bool)
        shifts_x = np.concatenate([pre_tilt_x_shits, pst_tilt_x_shits])
        shifts_y = np.concatenate([pre_tilt_x_shits, pst_tilt_x_shits])

        ds = xr.Dataset(
            coords=dict(
                central_wavelength=(["central_wavelength"], np.mean(np.array(wlc_wv_bins), axis=1), {'units': 'micron'}),
                start_wavelength=(["start_wavelength"], np.array(wlc_wv_bins)[:, 0], {'units': 'micron'}),
                end_wavelength=(["end_wavelength"], np.array(wlc_wv_bins)[:, 1], {'units': 'micron'}),
                time_flux=(["time_flux"], times, {'units': 'BJD TDB'}),
            ),
            data_vars=dict(
                raw_flux=(["central_wavelength", "time_flux"], wlc_raw_fluxes, {'units': ''}),
                raw_flux_error=(["central_wavelength", "time_flux"], wlc_raw_flux_errors, {'units': ''}),
                light_curve_model=(["central_wavelength", "time_flux"], wlc_light_curve_models, {'units': ''}),
                corrected_flux=(["central_wavelength", "time_flux"], wlc_corrected_fluxes, {'units': ''}),
                corrected_flux_error=(["central_wavelength", "time_flux"], wlc_corrected_flux_errors, {'units': ''}),
                systematic_model=(["central_wavelength", "time_flux"], wlc_systematic_models, {'units': ''}),
                residuals=(["central_wavelength", "time_flux"], wlc_residuals, {'units': ''}),
                quality_flag=(["central_wavelength", "time_flux"], wlc_quality_flag, {'units': ''}),
                shift_x=(["time_flux"], shifts_x, {'units': ''}),
                shift_y=(["time_flux"], shifts_y, {'units': ''}),
            ),
            attrs=dict(author="D Grant",
                       contact="email",
                       code="code",
                       data_origin="path",
                       notes="v2",
            )
        )

        res_path = os.path.join(
            stage_3_dir, "fitted-white-light-curve-W39-G395H-10pix-custom-Grant-{}-{}.nc"
            .format(fitting_method, fitting_version))
        ds.to_netcdf(res_path)

        # Make standardised ERS xarray for fitted light curves.
        quality_flag = np.ones(raw_fluxes.shape).astype(bool)
        shifts_x = np.concatenate([pre_tilt_x_shits, pst_tilt_x_shits])
        shifts_y = np.concatenate([pre_tilt_x_shits, pst_tilt_x_shits])

        ds = xr.Dataset(
            coords=dict(
                central_wavelength=(["central_wavelength"], spec_wv_bins[:, 2], {'units': 'micron'}),
                start_wavelength=(["start_wavelength"], spec_wv_bins[:, 0], {'units': 'micron'}),
                end_wavelength=(["end_wavelength"], spec_wv_bins[:, 1], {'units': 'micron'}),
                time_flux=(["time_flux"], times, {'units': 'BJD TDB'}),
            ),
            data_vars=dict(
                raw_flux=(["central_wavelength", "time_flux"], raw_fluxes, {'units': ''}),
                raw_flux_error=(["central_wavelength", "time_flux"], raw_flux_errors, {'units': ''}),
                light_curve_model=(["central_wavelength", "time_flux"], light_curve_models, {'units': ''}),
                corrected_flux=(["central_wavelength", "time_flux"], corrected_fluxes, {'units': ''}),
                corrected_flux_error=(["central_wavelength", "time_flux"], corrected_flux_errors, {'units': ''}),
                systematic_model=(["central_wavelength", "time_flux"], systematic_models, {'units': ''}),
                residuals=(["central_wavelength", "time_flux"], residuals, {'units': ''}),
                quality_flag=(["central_wavelength", "time_flux"], quality_flag, {'units': ''}),
                shift_x=(["time_flux"], shifts_x, {'units': ''}),
                shift_y=(["time_flux"], shifts_y, {'units': ''}),
            ),
            attrs=dict(author="D Grant",
                       contact="email",
                       code="code",
                       data_origin="path",
                       notes="v2",
            )
        )

        res_path = os.path.join(
            stage_3_dir, "fitted-binned-light-curve-W39-G395H-10pix-custom-Grant-{}-{}.nc"
            .format(fitting_method, fitting_version))
        ds.to_netcdf(res_path)

        # Make standardised ERS xarray for transmission spectrum.
        ds = xr.Dataset(
            data_vars=dict(
                transit_depth=(["central_wavelength"], tds, {'units': '(R_p/R_s)**2'}),
                transit_depth_error=(["central_wavelength"], tds_errs, {'units': '(R_p/R_s)**2'}),
            ),
            coords=dict(
                central_wavelength=(["central_wavelength"], mid_wvs, {'units': 'micron'}),
                bin_half_width=(["bin_half_width"], wvs_widths / 2., {'units': 'micron'})
            ),
            attrs=dict(author="D Grant",
                       contact="email",
                       code="code",
                       doi="none",
            )
        )

        res_path = os.path.join(
            stage_3_dir, "transit-spectrum-W39-G395H-10pix-custom-Grant-{}-{}.nc"
            .format(fitting_method, fitting_version))
        ds.to_netcdf(res_path)
