In [0]:
GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes_full'

In [0]:
%matplotlib inline

In [0]:
import sys

# Allow load project as module
sys.path.insert(0, '../../../..')

In [0]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py

In [0]:
import matplotlib     
matplotlib.rcParams.update({'font.size': 14})

# Load stellar model grid

In [0]:
from pfsspec.core.grid import ArrayGrid
from pfsspec.stellar.grid import ModelGrid
from pfsspec.stellar.grid.bosz import Bosz
from pfsspec.stellar.grid.phoenix import Phoenix

In [0]:
fn = os.path.join(GRID_PATH, 'spectra.h5')
grid = ModelGrid(Phoenix(), ArrayGrid)
grid.preload_arrays = False
grid.load(fn, format='h5')

In [0]:
grid.wave

In [0]:
for k in grid.grid.axes.keys():
    print(k, grid.grid.axes[k].values)

In [0]:
grid.grid.values.keys()

## Count valid spectra

In [0]:
for k in grid.grid.value_indexes.keys():
    print(grid.grid.value_indexes[k].shape, np.sum(grid.grid.value_indexes[k]))

# Plot grid coverage

In [0]:
grid.grid.value_indexes['flux'].size, grid.grid.value_indexes['flux'].sum()

In [0]:
for i, k, a in grid.enumerate_axes():
    print(k, a.values.shape)

In [0]:
axes = { k: a for i, k, a in grid.enumerate_axes() }

In [0]:
if 'Fe_H' in axes:
    m_h = 'Fe_H'
elif 'M_H' in axes:
    m_h = 'M_H'
else:
    raise NotImplementedError()

In [0]:
g = np.meshgrid(np.arange(axes[m_h].values.shape[0]), np.arange(axes['T_eff'].values.shape[0]), indexing='ij')
g[0].shape, g[1].shape

In [0]:
for i in range(axes['log_g'].values.shape[0]):
    f, ax = plt.subplots(1, 1, figsize=(4, 4), squeeze=True)
    #idx = grid.grid.value_indexes['flux'][:, :, i, 0, 0]
    idx = grid.grid.value_indexes['flux'][:, :, i, 2]
    #ax.plot(idx[g[0]].flatten())
    ax.scatter(axes[m_h].values[g[0].flatten()], axes['T_eff'].values[g[1].flatten()], c=idx[g[0].flatten(), g[1].flatten()])
    ax.set_xlabel(r'[{}]'.format(m_h))
    ax.set_ylabel(r'$T_\mathrm{eff}$')
    f.suptitle('log g = {}'.format(axes['log_g'].values[i]))

# Plot some nice spectra

In [0]:
grid.grid.get_nearest_index(M_H=-1.5, T_eff=5000, log_g=1.5, a_M=0.0, C_M=0.0)

In [0]:
model = grid.get_nearest_model(M_H=-1.5, T_eff=5000, log_g=1.5, a_M=0.0, C_M=0.0)
#model = grid.get_nearest_model(Fe_H=-1.5, T_eff=5000, log_g=1.5, a_M=0.0)

In [0]:
plt.plot(model.wave, model.flux)
plt.plot(model.wave, model.cont)

In [0]:
plt.plot(0.5 * (model.wave[1:] + model.wave[:-1]), model.wave[1:] - model.wave[:-1])