In [0]:
CORRECTION_MODEL = 'fluxcorr'
# CORRECTION_MODEL = 'contnorm'
CORRECTION_PER_ARM = True
CORRECTION_PER_EXP = False

# Stellar grid paths
GRID_PATH = {
    'phoenix': '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes',
    'grid7': '/datascope/subaru/data/pfsspec/models/stellar/grid/roman/grid7',
    'gridie': '/datascope/subaru/data/pfsspec/models/stellar/grid/roman/gridie',
}

# Broadband filter to normalize spectra to
FILTER_PATH = '/datascope/subaru/data/pfsspec/subaru/hsc/filters/fHSC-g.txt'

# Arms to simulate and fit
ARMS = [ 'b', 'mr' ]

# Grids used for simulation
SIM_GRID = { arm: 'phoenix' for arm in ARMS }

# Grids used for fitting
if CORRECTION_MODEL == 'fluxcorr':
    FIT_GRID = { 
        'b': 'phoenix',
        'mr': 'phoenix',
    }
elif CORRECTION_MODEL == 'contnorm':
    FIT_GRID = { 
        'b': 'gridie',
        'mr': 'grid7',
    }

# Instrument configuration
DETECTOR_PATH = '/datascope/subaru/data/pfsspec/subaru/pfs/arms/{}.json'
DETECTORMAP_PATH = '/datascope/subaru/data/pfsspec/drp_pfs_data/detectorMap/detectorMap-sim-{}1.fits'
PSF_PATH = '/datascope/subaru/data/pfsspec/subaru/pfs/psf/import/{}.2'
SKY_PATH = '/datascope/subaru/data/pfsspec/subaru/pfs/noise/import/sky.see/{}/sky.h5'
MOON_PATH = '/datascope/subaru/data/pfsspec/subaru/pfs/noise/import/moon/{}/moon.h5'

In [0]:
import os
import sys
import logging
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

import cProfile, pstats


In [0]:
if 'debugpy' not in globals():
    import debugpy
    debugpy.listen(("localhost", 5683))

In [0]:
%load_ext autoreload

In [0]:
%autoreload 2

In [0]:
from pfs.ga.pfsspec.stellar.grid import ModelGrid

from pfs.ga.pfsspec.core import Filter
from pfs.ga.pfsspec.sim.obsmod.pipelines import StellarModelPipeline
from pfs.ga.pfsspec.core import Physics
from pfs.ga.pfsspec.core.obsmod.psf import GaussPsf, PcaPsf
from pfs.ga.pfsspec.sim.obsmod.detectors import PfsDetector
from pfs.ga.pfsspec.sim.obsmod.detectormaps import PfsDetectorMap
from pfs.ga.pfsspec.sim.obsmod.background import Sky
from pfs.ga.pfsspec.sim.obsmod.background import Moon
from pfs.ga.pfsspec.core.obsmod.snr import QuantileSnr

from pfs.ga.pfsspec.stellar import StellarSpectrum
from pfs.ga.pfsspec.sim.obsmod.observations import PfsObservation
from pfs.ga.pfsspec.sim.obsmod.noise import NormalNoise
from pfs.ga.pfsspec.sim.obsmod.calibration import FluxCalibrationBias
from pfs.ga.pfsspec.core.obsmod.resampling import FluxConservingResampler, Interp1dResampler

from pfs.ga.pfsspec.stellar.tempfit import TempFit, TempFitTrace

from pfs.ga.pfsspec.core.sampling import Parameter, ParameterSampler
from pfs.ga.pfsspec.sim.stellar import ModelGridSampler

from pfs.ga.pfsspec.stellar.continuum.models import PiecewiseChebyshev
from pfs.ga.pfsspec.stellar.continuum.finders import SigmaClipping

from pfs.ga.pfsspec.core.util import SmartParallel

### Load the model grids

In [0]:
grid = {}
for k in GRID_PATH:
    fn = os.path.join(GRID_PATH[k], 'spectra.h5')
    grid[k] = ModelGrid.from_file(fn, preload_arrays=False)
    print(k, fn)

In [0]:
for k in grid:
    print(k, grid[k].wave.min(), grid[k].wave.max(), grid[k].wave.shape, grid[k].get_shape())

### Load the detector config and configure the observation objects

In [0]:
detector = {}

for arm in ARMS:
    detector[arm] = PfsDetector()
    detector[arm].load_json(DETECTOR_PATH.format(arm))
    
    # Use a realistic detector map, might cause Nan wavelengths outside of the coverage
    # detector[arm].map = PfsDetectorMap()
    # detector[arm].map.load(DETECTORMAP_PATH.format(arm[0]))

    # print(arm, detector[arm].map.default_fiberid)
    # print(arm, detector[arm].map.get_wave()[0].shape, detector[arm].map.get_wave()[0][[0, -1]], detector[arm].wave)

In [0]:
gauss_psf = {}
pca_psf = {}
template_psf = {}

for arm in ARMS:
    gauss_psf[arm] = GaussPsf()
    gauss_psf[arm].load(os.path.join(PSF_PATH.format(arm), 'gauss.h5'))

    # print(f'mean pixel size for arm {arm}', np.diff(detector[arm].get_wave()[0]).mean())
    # print(f'mean sigma and FWHM for arm {arm}', gauss_psf[arm].sigma.mean(), 2.355 * gauss_psf[arm].sigma.mean())

    pca_psf[arm] = PcaPsf()
    pca_psf[arm].load(os.path.join(PSF_PATH.format(arm), 'pca.h5'))

    # Precompute the PSF for fitting
    s = gauss_psf[arm].get_optimal_size(grid[FIT_GRID[arm]].wave)
    print(f'optimal kernel size for arm {arm}:', s)
    template_psf[arm] = PcaPsf.from_psf(gauss_psf[arm], grid[FIT_GRID[arm]].wave, size=s, truncate=5)

### Load sky, moon and conversion function tables

In [0]:
sky = {}
moon = {}

for arm in ARMS:
    detector_wave, _, detector_mask = detector[arm].get_wave()
    detector_s = gauss_psf[arm].get_optimal_size(detector_wave[detector_mask])
    print(f'Optimal size of PSF kernel for arm {arm}', detector_s)
    detector[arm].psf = PcaPsf.from_psf(gauss_psf[arm], detector_wave[detector_mask], size=detector_s, truncate=5)

    sky[arm] = Sky()
    sky[arm].load(SKY_PATH.format(arm), format='h5')

    moon[arm] = Moon()
    moon[arm].load(MOON_PATH.format(arm), format='h5')

In [0]:
obs = {}

for arm in ARMS:
    obs[arm] = PfsObservation()
    obs[arm].detector = detector[arm]
    obs[arm].sky = sky[arm]
    obs[arm].moon = moon[arm]
    obs[arm].noise_model = NormalNoise()

### Create observation simulation pipeline

In [0]:
# Broadband filter used for normalization
mag_filt = Filter()
mag_filt.read(FILTER_PATH)

In [0]:
def create_pipeline(arm, grid, calib_bias=False):
    """
    Configure the observation simulation pipeline
    """

    pp = StellarModelPipeline()
    pp.model_res = grid.resolution or 150000
    pp.mag_filter = mag_filt
    pp.observation = obs[arm]
    pp.snr = QuantileSnr(binning=1.0)
    pp.resampler = Interp1dResampler()
    pp.noise_level = 1.0
    pp.noise_freeze = False
    if calib_bias:
        bias = FluxCalibrationBias(reuse_bias=False)
        bias.amplitude = 0.02
        pp.calibration = bias

    return pp

In [0]:
pipeline = {}
for arm in ARMS:
    pipeline[arm] = create_pipeline(arm, grid=grid[SIM_GRID[arm]])

### Simulate a spectrum

In [0]:
obs_params = {
    'seeing': 0.5,
    'exp_time': 15 * 60,
    'exp_count': 1,
    'target_zenith_angle': 60,
    'target_field_angle': 0.6,
    'moon_zenith_angle': 45,
    'moon_target_angle': 60,
    'moon_phase': 0.0,
    'sky_residual': 0.0,
    'mag': 21.5,
}

model_params = {
    'M_H': -1.5,
    'T_eff': 4000,
    'log_g': 2.5,
    'a_M': 0.0,
}

In [0]:
# Run the simulation
spec = {}
for arm in ARMS:
    spec[arm] = grid[SIM_GRID[arm]].interpolate_model(**model_params)
    pipeline[arm].run(spec[arm], **obs_params)

In [0]:
f, ax = plt.subplots(1, 1, dpi=120)

for arm in ARMS:
    plt.plot(spec[arm].wave, spec[arm].flux, lw=0.1, label=arm)
    plt.plot(spec[arm].wave, spec[arm].flux_err, lw=0.1, label=arm)

In [0]:
spec['mr'].flux, spec['mr'].flux_err

In [0]:
# Multiple exposures for the same object
spectra = { arm: [] for arm in ARMS }
exp_count = 2
for arm in ARMS:
    for i in range(exp_count):
        s = spec[arm].copy()
        s.apply_noise(NormalNoise())
        spectra[arm].append(s)

In [0]:
spectra['b'], spectra['mr']

In [0]:
f, ax = plt.subplots(1, 1, dpi=120)

for arm in ARMS:
    for s in spectra[arm]:
        plt.plot(s.wave, s.flux, lw=0.1, label=arm)

# Run the model fitting

In [0]:
from pfs.ga.pfsspec.core import Physics
from pfs.ga.pfsspec.core.obsmod.resampling import FluxConservingResampler, Interp1dResampler
from pfs.ga.pfsspec.stellar.tempfit import ModelGridTempFit, ModelGridTempFitTrace
from pfs.ga.pfsspec.stellar.tempfit import FluxCorr, ContNorm
from pfs.ga.pfsspec.stellar.fluxcorr import PolynomialFluxCorrection
from pfs.ga.pfsspec.stellar.continuum.models import PiecewiseChebyshev, Spline
from pfs.ga.pfsspec.core.sampling import NormalDistribution

In [0]:
if CORRECTION_MODEL == 'fluxcorr':
    correction_model = FluxCorr()
    correction_model.use_flux_corr = True
    correction_model.flux_corr_type = PolynomialFluxCorrection
    correction_model.flux_corr_degree = 10
    correction_model.flux_corr_per_arm = CORRECTION_PER_ARM
    correction_model.flux_corr_per_exp = CORRECTION_PER_EXP
elif CORRECTION_MODEL == 'contnorm':
    correction_model = ContNorm()
    correction_model.use_cont_norm = True
    correction_model.cont_model_type = Spline
    correction_model.cont_per_arm = CORRECTION_PER_ARM
    correction_model.cont_per_exp = CORRECTION_PER_EXP

# Set up tracing to get some performance statistics
trace = ModelGridTempFitTrace()

tempfit = ModelGridTempFit(correction_model=correction_model, trace=trace)
tempfit.template_grids = { arm: grid[FIT_GRID[arm]] for arm in ARMS }
tempfit.cache_templates = True

tempfit.template_resampler = Interp1dResampler()

rv_0 = Physics.z_to_vel(0)
tempfit.rv_0 = rv_0
tempfit.rv_bounds = [rv_0 - 100.0, rv_0 + 100.0]
tempfit.rv_prior = NormalDistribution(rv_0, 50)
tempfit.rv_step = 5.0

tempfit.params_0 = { p: model_params[p] for p in [ 'M_H', 'T_eff', 'log_g' ] }

if CORRECTION_MODEL == 'fluxcorr':
    tempfit.params_fixed = { 'a_M': 0.0 }
elif CORRECTION_MODEL == 'contnorm':
    tempfit.params_fixed = { 'a_M': -0.5 }

tempfit.params_priors = {
    'M_H': NormalDistribution(model_params['M_H'], 0.5),
    'T_eff': NormalDistribution(model_params['T_eff'], 50, model_params['T_eff'] - 50, model_params['T_eff'] + 50),
    'log_g': NormalDistribution(model_params['log_g'], 0.5)
}
tempfit.params_steps = {
    'M_H': 0.01,
    'T_eff': 1,
    'log_g': 0.01
}

# tempfit.mcmc_burnin = 500
# tempfit.mcmc_samples = 1000
# tempfit.mcmc_walkers = 3

In [0]:
# Normalize flux of spectra and templates to about unity
tempfit.init_correction_models(spectra, tempfit.rv_bounds, force=True)
tempfit.spec_norm, tempfit.temp_norm = tempfit.get_normalization(spectra)
print(tempfit.spec_norm, tempfit.temp_norm)

In [0]:
tempfit.reset()
trace.reset()

profiler = cProfile.Profile()
profiler.enable()

rv = np.linspace(*tempfit.rv_bounds, 100)
log_L = tempfit.calculate_log_L(spectra, None, rv)

profiler.disable()

In [0]:
trace.counters

In [0]:
if True:
    stats = pstats.Stats(profiler).sort_stats('cumtime')
    stats.print_stats(30)

In [0]:
f, ax = plt.subplots(1, 1, dpi=120)

ax.plot(rv, log_L)
ax.grid()

In [0]:
tempfit.reset()
trace.reset()

profiler = cProfile.Profile()
profiler.enable()

res = tempfit.fit_rv(spectra)
# res = tempfit.fit_rv(spectra, calculate_error=False, calculate_cov=False)

profiler.disable()

In [0]:
res.params_fit, res.params_err

In [0]:
res.cov

In [0]:
res.a_fit

In [0]:
res.rv_fit, res.rv_err

In [0]:
# Evaluate the best fit model
templates, _ = tempfit.get_templates(spectra, res.params_fit)
corrections = tempfit.eval_correction(spectra, templates, res.rv_fit, a=res.a_fit)
models = tempfit.eval_model(spectra, templates, res.rv_fit, a=res.a_fit)

In [0]:
exp_count = max(len(spectra[arm]) for arm in ARMS)

for ie in range(exp_count):
    f, ax = plt.subplots(1, 1, dpi=120)

    for arm in ARMS:    
        mask = models[arm][ie].mask_as_bool()
        ax.plot(spectra[arm][ie].wave[mask], corrections[arm][ie][mask], lw=0.3, label=arm)

In [0]:
exp_count = max(len(spectra[arm]) for arm in ARMS)

for ie in range(exp_count):
    f, ax = plt.subplots(1, 1, dpi=120)

    for arm in ARMS:    
        ax.plot(spectra[arm][ie].wave, spectra[arm][ie].flux, lw=0.1, label=arm)
        ax.plot(models[arm][ie].wave, models[arm][ie].flux, lw=0.1, c='k', label=arm)

In [0]:
exp_count = max(len(spectra[arm]) for arm in ARMS)

for ie in range(exp_count):
    f, ax = plt.subplots(1, 1, dpi=120)

    for arm in ARMS:    
        ax.plot(spectra[arm][ie].wave, spectra[arm][ie].flux - models[arm][ie].flux, lw=0.1, label=arm)
        ax.axhline(0, c='k', lw=0.1)

    ax.grid()