In [0]:
GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/gk2025/gk2025_small_binned_compressed'

In [0]:
%matplotlib inline

In [0]:
import sys

# Allow load project as module
sys.path.insert(0, '../../../../python')

In [0]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py

In [0]:
import matplotlib     
matplotlib.rcParams.update({'font.size': 6})

# Load stellar model grid

In [0]:
from pfs.ga.pfsspec.stellar.grid import ModelGrid

In [0]:
fn = os.path.join(GRID_PATH, 'spectra.h5')
grid = ModelGrid.from_file(fn, preload_arrays=False)
grid.load(fn, format='h5')

In [0]:
grid.wave

In [0]:
for k in grid.grid.axes.keys():
    print(k, grid.grid.axes[k].values)

In [0]:
grid.grid.values.keys()

## Count valid spectra

In [0]:
for k in grid.grid.value_indexes.keys():
    print(k, grid.grid.value_indexes[k].shape, np.sum(grid.grid.value_indexes[k]))

# Plot grid coverage

In [0]:
grid.grid.value_indexes['flux'].size, grid.grid.value_indexes['flux'].sum()

In [0]:
for i, k, a in grid.enumerate_axes():
    print(k, a.values.shape)

In [0]:
axes = { k: a for i, k, a in grid.enumerate_axes() }

In [0]:
if 'Fe_H' in axes:
    m_h = 'Fe_H'
elif 'M_H' in axes:
    m_h = 'M_H'
else:
    raise NotImplementedError()

In [0]:
for i, k, ax in grid.enumerate_axes():
    print(k, ax.values.size, ax.values.min(), ax.values.max())

In [0]:
# Plot all models
# We plot log_g and T_eff for each M_H in a big plot and if there are
# other dimensions those are plotted in separate multiplots

# Full shape of the grid ordered by M_H, T_eff, log_g and other dimensions
axes = { k: a for i, k, a in grid.enumerate_axes() }
plot_axes = ['T_eff', 'log_g']
other_axes = [k for k in axes.keys() if k not in plot_axes and grid.get_axis(k).values.size > 1]
all_axes = plot_axes + other_axes
plot_shape = tuple(axes[k].values.size for k in all_axes)

print(plot_axes, other_axes, all_axes, plot_shape)

In [0]:
plot_axes, other_axes

In [0]:
plot_shape

In [0]:
for ix_page in np.ndindex(plot_shape[4:]):
    print(ix_page)

In [0]:
# Iterate over all possible indexes based on shape
g = np.meshgrid(np.arange(axes['T_eff'].values.shape[0]), np.arange(axes['log_g'].values.shape[0]), indexing='ij')
print(g[0].shape, g[1].shape)

step = 4

nrows = plot_shape[2] // step
ncols = plot_shape[3] // step if len(plot_shape) > 3 else 1
# ncols = nrows = 10
print(nrows, ncols)
print(ncols * 0.25, nrows * 0.25)

# Figure out the outer loop shape
plot_slice = []
for i, k, _ in grid.enumerate_axes():
    if k in plot_axes:              # subplot axes
        pass
    elif k in other_axes and other_axes.index(k) < 2:   # index of subplot
        pass
    elif k in other_axes:                           # plots
        plot_slice.append(i)

print(plot_axes, other_axes, plot_slice, [plot_shape[s] for s in plot_slice], plot_shape[4:])

In [0]:
for ix_page in np.ndindex(plot_shape[4:]):     # iterate over the pages
    
    f, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 1, nrows * 1), sharex=True, sharey=True, dpi=120, squeeze=False)
    f.subplots_adjust(wspace=0.05, hspace=0.05)

    for ix in np.ndindex((nrows, ncols)):
        ii = [ 0 for _ in grid.enumerate_axes() ]
        for i, k, _ in grid.enumerate_axes():
            if k in plot_axes:              # subplot axes
                ii[i] = slice(None)
            elif k in other_axes and other_axes.index(k) < 2:   # index of subplot
                ii[i] = ix[other_axes.index(k)] * step
            elif k in other_axes:                           # plots
                ii[i] = ix_page[other_axes.index(k) - 2] * step

        ii = tuple(ii)
        idx = grid.grid.value_indexes['flux'][ii]

        ax = axs[ix]
        ax.imshow(
            idx.T,
            origin='lower',
            aspect='auto',
            extent=(axes[plot_axes[0]].values[0],
                    axes[plot_axes[0]].values[-1],
                    axes[plot_axes[1]].values[-1],
                    axes[plot_axes[1]].values[0]),
            clim=(0, 1))

    # Remove xtick from all axes but the last row
    # be aware that sharex and sharey are true
    # for ax in axs[:-1, :].ravel():
    #     ax.set_xticks([])
        
    # for ax in axs[:, 1:].ravel():
    #     ax.set_yticks([])

    for ax in axs[-1, :]:
        ax.set_xlabel(plot_axes[0])

    for ax in axs[:, 0]:
        ax.set_ylabel(plot_axes[1])

    # for i, ax in enumerate(axs[0, :]):
    #     ax.text(0.5, 1.05, f'{other_axes[1]} = {grid.get_axis(other_axes[1]).values[i * step]}', transform=ax.transAxes, ha='center')

    for j, ax in enumerate(axs[:, -1]):
        ax.text(1.05, 0.5, f'{other_axes[0]} = {grid.get_axis(other_axes[0]).values[j * step]}', transform=ax.transAxes, rotation=-90, va='center')

    # f.suptitle(', '.join([ f'{k}={axes[k].values[ix_page[i]]}' for i, k in enumerate(other_axes) ]))

    break

print('done')

In [0]:
spec = grid.get_model(T_eff=3500, log_g=1.5, M_H=-4.5, C=0.0, a_M=-0.0)

In [0]:
plt.plot(spec.wave, spec.flux)
plt.plot(spec.wave, spec.cont)

In [0]:
plt.plot(spec.wave, spec.line)

In [0]:
g = np.meshgrid(np.arange(axes[m_h].values.shape[0]), np.arange(axes['T_eff'].values.shape[0]), indexing='ij')
g[0].shape, g[1].shape

In [0]:
for i in range(axes['log_g'].values.shape[0]):
    f, ax = plt.subplots(1, 1, figsize=(4, 4), squeeze=True)
    #idx = grid.grid.value_indexes['flux'][:, :, i, 0, 0]
    idx = grid.grid.value_indexes['flux'][:, :, i, 2]
    #ax.plot(idx[g[0]].flatten())
    ax.scatter(axes[m_h].values[g[0].flatten()], axes['T_eff'].values[g[1].flatten()], c=idx[g[0].flatten(), g[1].flatten()])
    ax.set_xlabel(r'[{}]'.format(m_h))
    ax.set_ylabel(r'$T_\mathrm{eff}$')
    f.suptitle('log g = {}'.format(axes['log_g'].values[i]))

# Plot some nice spectra

In [0]:
grid.grid.get_nearest_index(M_H=-1.5, T_eff=5000, log_g=1.5, a_M=0.0, C_M=0.0)

In [0]:
model = grid.get_nearest_model(M_H=-1.5, T_eff=5000, log_g=1.5, a_M=0.0, C_M=0.0)
#model = grid.get_nearest_model(Fe_H=-1.5, T_eff=5000, log_g=1.5, a_M=0.0)

In [0]:
mask = (8200 <= model.wave) & (model.wave <= 8500)
plt.plot(model.wave[mask], model.flux[mask], lw=0.3)
#plt.plot(model.wave, model.cont)

In [0]:
plt.plot(0.5 * (model.wave[1:] + model.wave[:-1]), model.wave[1:] - model.wave[:-1])

In [0]:
grid.wave_lim = [4000, 6000]

modelip = grid.interpolate_model_linear(M_H=-1.52, T_eff=5050, log_g=1.55, a_M=0.0, C_M=0.0)
modelip.wave.shape, modelip.flux.shape

In [0]:
plt.plot(model.wave, model.flux)
plt.plot(modelip.wave, modelip.flux)