In [0]:
#GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/bosz/bosz_50000_GK/fit'

GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes'
FIT_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/fit'

In [0]:
%matplotlib inline

In [0]:
import sys

# Allow load project as module
sys.path.insert(0, '../../../..')

In [0]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py

In [0]:
import matplotlib     
matplotlib.rcParams.update({'font.size': 14})

# Load stellar model grid

In [0]:
from pfs.ga.pfsspec.core.grid import ArrayGrid
from pfs.ga.pfsspec.stellar.grid import ModelGrid
from pfs.ga.pfsspec.stellar.grid.bosz import Bosz
from pfs.ga.pfsspec.stellar.grid.phoenix import Phoenix

In [0]:
fn = os.path.join(GRID_PATH, 'spectra.h5')
#grid = ModelGrid(Bosz(normalized=False), ArrayGrid)
grid = ModelGrid(Phoenix(normalized=False), ArrayGrid)
grid.preload_arrays = False
grid.load(fn, format='h5')

In [0]:
fn = os.path.join(FIT_GRID_PATH, 'spectra.h5')
#fit_grid = ModelGrid(Bosz(normalized=False), ArrayGrid)
fit_grid = ModelGrid(Phoenix(normalized=True), ArrayGrid)
fit_grid.preload_arrays = False
fit_grid.load(fn, format='h5')

In [0]:
grid.wave, fit_grid.wave

In [0]:
for i, k, ax in grid.enumerate_axes():
    print(k, ax.values)

In [0]:
for i, k, ax in fit_grid.enumerate_axes():
    print(k, ax.values)

In [0]:
grid.grid.values.keys(), fit_grid.grid.values.keys()

## Count valid spectra

In [0]:
for k in grid.grid.value_indexes.keys():
    print(k, grid.grid.value_indexes[k].shape, np.sum(grid.grid.value_indexes[k]))

In [0]:
for k in fit_grid.grid.value_indexes.keys():
    print(k, fit_grid.grid.value_indexes[k].shape, np.sum(fit_grid.grid.value_indexes[k]))

# Plot

In [0]:
for _, k, ax in fit_grid.enumerate_axes():
    print(k, ax.values)

In [0]:
# M_H = -0.5
# T_eff = 4500
# log_g = 1
# a_M = 0
# C_M = 0

M_H = 0
T_eff = 4500
log_g = 1
a_M = 0
C_M = 0

In [0]:
idx = grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
idx

In [0]:
fit_idx = fit_grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
fit_idx

In [0]:
model = grid.get_model_at(idx)

fit_params = fit_grid.get_continuum_parameters_at(fit_idx)
fit_wave, fit_cont = fit_grid.continuum_model.eval(fit_params)

In [0]:
plt.plot(model.wave, model.flux)
plt.plot(fit_wave, fit_cont)

In [0]:
plt.plot(model.wave, model.flux)
plt.plot(fit_wave, fit_cont)
plt.xlim(3000, 4000)

In [0]:
plt.plot(model.wave, model.flux)
plt.plot(fit_wave, fit_cont)
plt.xlim(3600, 3800)

In [0]:
s = np.s_[:, :, :, idx[3]]
#s = np.s_[:, :, :, idx[3], idx[4]]

In [0]:
def load_params(name):
    fit_params = grid.grid.get_value(name)
    masks = grid.grid.value_indexes[name]
        
    return fit_params, masks

In [0]:
def plot_params(fit_params, idx=2, param_idx=0):
    pp = fit_params.shape[idx]
    rr = int(np.ceil(pp / 4 + 0.5))
    f, axs = plt.subplots(rr, 4, figsize=(16, 4 * rr))
    for p in range(pp):
        i = p // 4
        j = p % 4
        
        s = (fit_params.ndim - 1) * [slice(None)]
        s[idx] = p
        s = tuple(s)
        vmin, vmax = fit_params[s][..., param_idx].min(), fit_params[s][..., param_idx].max()
        l = axs[i, j].imshow(fit_params[s][..., param_idx], aspect='auto', vmin=vmin, vmax=vmax)
        f.colorbar(l, ax=axs[i, j])
        axs[i, j].set_xlabel('param: {} | slice: {}'.format(param_idx, p))
        
        axs[i, j].set_xlim(-0.5, 10.5)

In [0]:
pi=0
fit_params, masks = load_params('blended_1')
fit_params[~masks] = np.nan
for pi in range(fit_params.shape[-1]):
    plot_params(fit_params[s], param_idx=pi)