In [0]:
FIT_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/bosz/bosz_5000_GF/fit'
RBF_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/bosz/bosz_5000_GF/fit-rbf'

In [0]:
%matplotlib inline

In [0]:
import sys

# Allow load project as module
sys.path.insert(0, '../../../..')

In [0]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py

In [0]:
import matplotlib     
matplotlib.rcParams.update({'font.size': 14})

# Load stellar model grid

In [0]:
from pfsspec.core.grid import ArrayGrid, RbfGrid
from pfsspec.stellar.grid import ModelGrid
from pfsspec.stellar.grid.bosz import Bosz
from pfsspec.stellar.grid.phoenix import Phoenix

## Grid of fitted parameters

In [0]:
fn = os.path.join(FIT_GRID_PATH, 'spectra.h5')
fit_grid = ModelGrid(Bosz(normalized=True), ArrayGrid)
fit_grid.preload_arrays = True
fit_grid.load(fn, format='h5')

In [0]:
fit_grid.wave

In [0]:
for k in fit_grid.grid.axes.keys():
    print(k, fit_grid.grid.axes[k].values)

## Rbf grid

In [0]:
fn = os.path.join(RBF_GRID_PATH, 'spectra.h5')
rbf_grid = ModelGrid(Bosz(normalized=True), RbfGrid)
rbf_grid.preload_arrays = False
rbf_grid.load(fn, format='h5')

In [0]:
rbf_grid.wave

In [0]:
for k in rbf_grid.grid.axes.keys():
    print(k, rbf_grid.grid.axes[k].values)

In [0]:
rbf_grid.grid.values.keys()

# Plot

In [0]:
M_H = 0
T_eff = 4500
log_g = 1
a_M = 0
C_M = 0

rbf_idx = rbf_grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
rbf_idx

In [0]:
array_idx = fit_grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
array_idx

In [0]:
def plot_params(rbf_grid, array_grid, rbf, array, mask, idx=2, param_idx=0):
    # idx: iterate through this dimension in subplots
    # param_idx: which parameter to plot (last dimension of the arrays)
    
    axes = [ rbf_grid.axes[k] for k in rbf_grid.axes ]    
    pp = axes[idx].values.size
    rr = int(pp / 4 + 0.5)
    f, axs = plt.subplots(rr, 4, figsize=(16, 4 * rr))
    for p in range(pp):
        i = p // 4
        j = p % 4
        
        coords = [ np.arange(axes[i].values.size) for i in range(len(axes)) ]
        #del coords[idx]
        #print(coords)
        #coords.insert(idx, np.full_like(coords[0], p))
        #print(coords)
        coords[idx] = np.array([p])
        coords[-2] = np.array([array_idx[-2]])
        coords[-1] = np.array([array_idx[-1]])
        
        coords = np.meshgrid(*coords, indexing='ij')
        
        rbf_values = rbf(*[ x.flatten() for x in coords])
        rbf_values = rbf_values[..., param_idx].reshape(coords[0].shape)
        rbf_values = rbf_values.squeeze()
        
        s = (array.ndim - 1) * [slice(None)]
        s[idx] = p
        s = tuple(s)
        array_values = array[s][..., param_idx]
                
        #image = rbf_values
        #image = array_values - rbf_values
        # image[~mask[s]] = np.nan
        vmin, vmax = None, None
        
        image = array_values
        # image = rbf_values
        image[np.isnan(array_values)] = np.nan
        #vmin, vmax = array_values.min(), array_values.max()
        
        
        l = axs[i, j].imshow(image, aspect='auto', vmin=vmin, vmax=vmax)
        f.colorbar(l, ax=axs[i, j])
        axs[i, j].set_xlabel('param: {} | slice: {}'.format(param_idx, p))

In [0]:
fit_grid.array_grid.values['blended_0'].shape

In [0]:
{ k: a.values.size for k, a in rbf_grid.rbf_grid.axes.items() }

In [0]:
rbf_grid.rbf_grid.values

In [0]:
rbf = rbf_grid.rbf_grid.values['blended_1']
array = fit_grid.array_grid.values['blended_1'][:, :11, :, array_idx[-2], array_idx[-1]]
mask = fit_grid.array_grid.value_indexes['blended_1'][:, :11, :, array_idx[-2], array_idx[-1]]

array.shape, mask.shape

In [0]:
for pi in range(rbf.nodes.shape[-1]):
    plot_params(rbf_grid.rbf_grid, fit_grid.array_grid, rbf, array, mask, param_idx=pi)

In [0]:
def plot_params_hires(rbf_grid, rbf, idx=2, param_idx=0):
    # idx: iterate through this dimension in subplots
    # param_idx: which parameter to plot (last dimension of the arrays)
    
    axes = [ rbf_grid.axes[k] for k in rbf_grid.axes ]
    pp = axes[idx].values.size
    rr = int(pp / 4 + 0.5)
    f, axs = plt.subplots(rr, 4, figsize=(16, 4 * rr))
    for p in range(pp):
        i = p // 4
        j = p % 4
        
        coords = [ np.linspace(0, axes[i].values.size, 50) for i in range(len(axes)) ]

        coords[idx] = np.array([p])
        coords[-2] = np.array([array_idx[-2]])
        coords[-1] = np.array([array_idx[-1]])
        
        coords = np.meshgrid(*coords, indexing='ij')
        
        rbf_values = rbf(*[ x.flatten() for x in coords])
        rbf_values = rbf_values[..., param_idx].reshape(coords[0].shape)
        rbf_values = rbf_values.squeeze()
                
        image = rbf_values
        vmin, vmax = None, None
              
        l = axs[i, j].imshow(image, aspect='auto', vmin=vmin, vmax=vmax)
        f.colorbar(l, ax=axs[i, j])
        axs[i, j].set_xlabel('param: {} | slice: {}'.format(param_idx, p))

In [0]:
rbf = rbf_grid.rbf_grid.values['blended_1']

In [0]:
for pi in range(rbf.nodes.shape[-1]):
    plot_params_hires(rbf_grid.rbf_grid, rbf, param_idx=pi)