In [0]:
#PCA_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/pca'
#PCA_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/bosz/bosz_50000_GK/pca'

#GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/norm'
#PCA_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes_norm_pca_6000/'
#NORMALIZED = True

#GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes'
#PCA_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes_pca_6000/'
#NORMALIZED = False

GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/grid/phoenix/phoenix_HiRes'
WEIGHT_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/norm/mask.h5'
NORM_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/norm'
PCA_GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/pca_none_weights_3'
NORMALIZED = True

In [0]:
%matplotlib inline

In [0]:
import sys

# Allow load project as module
sys.path.insert(0, '../../../..')

In [0]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py
from tqdm.notebook import tqdm

In [0]:
import matplotlib     
matplotlib.rcParams.update({'font.size': 7})

# Load stellar model grid

In [0]:
from pfs.ga.pfsspec.core.grid import ArrayGrid
from pfs.ga.pfsspec.stellar.grid import ModelGrid
from pfs.ga.pfsspec.stellar.grid.bosz import Bosz
from pfs.ga.pfsspec.stellar.grid.phoenix import Phoenix

In [0]:
fn = os.path.join(GRID_PATH, 'spectra.h5')
grid = ModelGrid(Phoenix(pca=False, normalized=False), ArrayGrid)
grid.preload_arrays = False
grid.load(fn, format='h5')

In [0]:
fn = WEIGHT_GRID_PATH
weight_grid = ModelGrid(Phoenix(pca=False, normalized=False), ArrayGrid)
weight_grid.preload_arrays = False
weight_grid.load(fn, format='h5')

In [0]:
fn = os.path.join(NORM_GRID_PATH, 'spectra.h5')
norm_grid = ModelGrid(Phoenix(pca=False, normalized=True), ArrayGrid)
norm_grid.preload_arrays = False
norm_grid.load(fn, format='h5')

In [0]:
fn = os.path.join(PCA_GRID_PATH, 'spectra.h5')
#pca_grid = ModelGrid(Bosz(pca=True, normalized=True), ArrayGrid)
pca_grid = ModelGrid(Phoenix(pca=True, normalized=NORMALIZED), ArrayGrid)
pca_grid.preload_arrays = False
pca_grid.load(fn, format='h5')

In [0]:
grid.wave, pca_grid.wave

In [0]:
grid.wave_lim = [pca_grid.wave.min(), pca_grid.wave.max()]
grid.get_wave_slice()

In [0]:
grid.get_wave(), pca_grid.get_wave()

In [0]:
norm_grid.wave_lim = [pca_grid.wave.min(), pca_grid.wave.max()]
norm_grid.get_wave_slice()

In [0]:
for _, k, axis in grid.enumerate_axes():
    print(k, axis.values)

In [0]:
for _, k, axis in pca_grid.enumerate_axes():
    print(k, axis.values)

In [0]:
pca_grid.grid.eigs.keys(), pca_grid.array_grid.value_shapes['flux']

## Count valid spectra

In [0]:
grid.array_grid.value_indexes['flux'].shape, np.sum(grid.array_grid.value_indexes['flux'])

In [0]:
pca_grid.array_grid.value_indexes['flux'].shape, np.sum(pca_grid.array_grid.value_indexes['flux'])

# Plot

In [0]:
M_H = -1.0
T_eff = 4000
log_g = 5
a_M = 0
C_M = 0

# M_H = -2.5
# T_eff = 4000
# log_g = 1
# a_M = 0
# C_M = 0

pca_idx = pca_grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
pca_idx

In [0]:
norm_idx = norm_grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
norm_idx

In [0]:
idx = grid.grid.get_nearest_index(M_H=M_H, T_eff=T_eff, log_g=log_g, a_M=a_M, C_M=C_M)
idx

In [0]:
norm_model = norm_grid.get_model_at(norm_idx)
plt.plot(norm_model.wave, norm_model.flux)

In [0]:
pca_grid.pca_grid.transform

In [0]:
pca_model = pca_grid.get_model_at(pca_idx)
plt.plot(pca_model.wave, pca_model.flux)

In [0]:
plt.plot(pca_model.wave, norm_model.flux - pca_model.flux)
plt.ylim(-0.005, 0.005)

In [0]:
model = grid.get_model_at(idx, denormalize=True)
plt.plot(model.wave, model.flux)

In [0]:
norm_model = norm_grid.get_model_at(norm_idx, denormalize=True)
plt.plot(norm_model.wave, norm_model.flux)

In [0]:
pca_model = pca_grid.get_model_at(pca_idx, denormalize=True)
plt.plot(pca_model.wave, pca_model.flux)

In [0]:
pca_model = pca_grid.get_model_at(pca_idx, denormalize=False)
plt.plot(pca_model.wave, pca_model.flux_err)
plt.title('Relative error after PCA compression()')

In [0]:
f, ax = plt.subplots(1, 1, dpi=120)

pca_model = pca_grid.get_model_at(pca_idx, denormalize=True)
ax.plot(pca_model.wave, pca_model.flux_err, lw=0.5)
ax.set_title('Relative error after PCA compression')

In [0]:
f, ax = plt.subplots(1, 1, dpi=120)

pca_model = pca_grid.get_model_at(pca_idx, denormalize=True)
ax.plot(pca_model.wave, pca_model.flux_err / pca_model.flux)
ax.set_title('Relative error after PCA compression and denormalization')

In [0]:
pca_grid.grid.k = 2000

In [0]:
pca_model = pca_grid.get_model_at(pca_idx, denormalize=True)
plt.plot(pca_model.wave, pca_model.flux * model.flux.sum())

In [0]:
#plt.plot(model.wave, (pca_model.flux - model.flux / model.flux.sum()) / (model.flux / model.flux.sum()))
plt.plot(model.wave, (pca_model.flux - model.flux) / model.flux)
#plt.plot(model.wave, (pca_model.flux - model.flux))
#plt.plot(model.wave, (pca_model.flux * model.flux.sum() - model.flux) / model.flux)

plt.ylim(-0.005, 0.005)
#plt.ylim(-1, 1)

plt.xlabel('wavelength')
plt.ylabel('relative error')

In [0]:
plt.plot(model.wave, (pca_model.flux - model.flux))

#plt.ylim(-0.005, 0.005)
#plt.ylim(-1, 1)

plt.xlabel('wavelength')
plt.ylabel('relative error')

## Reconstruction error assuming normalized models

In [0]:
norm_model = norm_grid.get_model_at(norm_idx)
plt.plot(norm_model.wave, norm_model.flux)

In [0]:
pca_model = pca_grid.get_model_at(pca_idx)
plt.plot(pca_model.wave, pca_model.flux)

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2), dpi=200)

ax.plot(pca_model.wave, pca_model.flux - norm_model.flux, lw=0.3)

ax.set_xlabel(r'$\lambda$ [A]')
ax.set_ylabel(r'$F_\mathrm{PCA} - F_\mathrm{model}$')
ax.set_ylim(-0.005, 0.005)
ax.grid()
ax.set_title('Reconstruction error\n'
r'$[M/H] = {}$, $T_\mathrm{{eff}} = {}$ K, $\log\,g = {}$'.format(pca_model.M_H, pca_model.T_eff, pca_model.log_g))

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2), dpi=200)

ax.plot(pca_model.wave, (pca_model.flux - norm_model.flux) / norm_model.flux, lw=0.3)

ax.set_xlabel(r'$ \lambda $ [A]')
ax.set_ylabel(r'$ \frac{ F_\mathrm{PCA} - F_\mathrm{model} }{ F_\mathrm{model} } $')
ax.set_ylim(-0.005, 0.005)
ax.grid()
ax.set_title('Relative reconstruction error\n'
r'$[M/H] = {}$, $T_\mathrm{{eff}} = {}$ K, $\log\,g = {}$'.format(pca_model.M_H, pca_model.T_eff, pca_model.log_g))

In [0]:
model.get_params()

# Error

In [0]:
pca_grid.grid.k

In [0]:
pca_idxs = np.stack(np.where(pca_grid.grid.grid.value_indexes['flux']), axis=-1)
pca_idxs.shape

In [0]:
from pfs.ga.pfsspec.core.util import SmartParallel

In [0]:
fn = os.path.join(PCA_GRID_PATH, 'error.h5')
if os.path.isfile(fn):
    with h5py.File(fn, 'r') as h:
        err = h['error'][()]
else:
    def process_item(v):
        i, idx = v

        pca_model = pca_grid.get_model_at(idx, denormalize=False)
        norm_model = norm_grid.get_model_at(idx, denormalize=False)
        #return  i, pca_model.flux / norm_model.flux - 1
        return  i, pca_model.flux - norm_model.flux

    err = {}

    N = pca_idxs.shape[0]
    t = tqdm(total=N)
    with SmartParallel(verbose=False, parallel=True, threads=24) as p:
        for zz in p.map(process_item, list(zip(range(N), pca_idxs[:N]))):
            i, e = zz
            err[i] = e
            t.update(1)
            
    err = np.stack([ err[i] for i in range(N) ], axis=0)

In [0]:
err.shape

In [0]:
if not os.path.isfile(fn):
    with h5py.File(fn, 'w') as h:
        h.create_dataset('error', data=err)

In [0]:
weight = []
for idx in tqdm(pca_idxs, disable=True):
    w = weight_grid.array_grid.get_value_at('weight', idx)
    weight.append(w)
weight = np.array(weight)
weight.shape

In [0]:
f, ax = plt.subplots(1, 1, dpi=120)
ax.set_title('weighted mean')
ax.plot(pca_model.wave, np.sum(weight[:, np.newaxis] * np.abs(err), axis=0) / np.sum(weight), lw=0.3)
#ax.set_ylim(-0.0001, 0.01)
ax.semilogy()

In [0]:
weight

In [0]:
mask = (weight > 0).squeeze()
mask.shape, mask.sum()

In [0]:
err.shape

In [0]:
stat = {}

for tt, ff in zip(['median', 'min', 'max'], [np.median, np.min, np.max]):
    stat[tt] = ff(np.abs(err[mask]), axis=0)

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 3), dpi=120)

for tt in stat:
    ax.plot(pca_model.wave, stat[tt], lw=0.1, label=tt)

ax.set_ylim(0.8e-15, 0.99)
ax.semilogy()
ax.legend()
ax.grid()

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 3), dpi=120)

ax.plot(pca_model.wave, stat['median'], lw=0.1, label='median')

#ax.set_ylim(0.8e-15, 0.99)
ax.semilogy()
ax.legend()
ax.grid()

# Error map

In [0]:
def find_bins(x):
    y = np.empty(x.shape[0] + 1)
    y[1:-1] = 0.5 * (x[1:] + x[:-1])
    y[0] = 2 * y[1] - y[2]
    y[-1] = 2 * y[-2] - y[-3]
    
    return y

T_eff = find_bins(pca_grid.array_grid.axes['T_eff'].values)
log_g = find_bins(pca_grid.array_grid.axes['log_g'].values)

T_eff, log_g

In [0]:
np.median(err, axis=-1).shape

In [0]:
pca_idxs = np.where(pca_grid.grid.grid.value_indexes['flux'])
pca_idxs

In [0]:
for i, k, ax in pca_grid.enumerate_axes():
    print(k, ax.values.size)

In [0]:
errmap = np.full(pca_grid.get_shape(), np.nan)
errmap[pca_idxs] = np.median(np.abs(err), axis=-1)

In [0]:
# shp = (pca_grid.array_grid.axes['M_H'].values.size, pca_grid.array_grid.axes['T_eff'].values.size)
# errmap = np.full(shp, np.nan)

# errmap[pca_idxs[..., 0], pca_idxs[..., 1]] = np.median(np.abs(err), axis=-1)

errmap.shape

In [0]:
f, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=120)

l = ax.imshow(np.nanmax(errmap, axis=(0, 3)).T,
              extent=[T_eff.min(), T_eff.max(), log_g.min(), log_g.max()], origin='lower', aspect=200,
              vmin=0, vmax=0.0003)
ax.invert_xaxis()
ax.invert_yaxis()

ax.set_title('Median error')

f.colorbar(l)

# Plot eigenvectors

In [0]:
mean = pca_grid.grid.mean['flux']
mean.shape

In [0]:
f, ax = plt.subplots(1, 1, figsize=(6, 2), dpi=96)

ax.plot(pca_grid.wave, mean, lw=0.1)
ax.grid()

In [0]:
eigv = pca_grid.grid.eigv['flux']
eigv.shape

In [0]:
N = min(5, eigv.shape[-1])

f, axs = plt.subplots(N, 1, figsize=(6, 2 * N), dpi=96)

for i, ax in enumerate(axs):
    ax.plot(pca_grid.wave, eigv[:, i], lw=0.1)
    ax.grid()
    #ax.set_ylim(-0.025, 0.025)
    
for i, ax in enumerate(axs[:-1]):
    ax.set_xticklabels([])
    
#axs[0].set_ylim(-0.025, 0.005)
    
f.tight_layout()

## Scree plot

In [0]:
S = pca_grid.grid.eigs['flux']
S

In [0]:
Y = np.log10(1 - np.cumsum(S**2) / np.sum(S**2))
X = np.log10(np.arange(Y.shape[0]) + 1)

f, ax = plt.subplots(1, 1, figsize=(3.5, 2.5), dpi=120)

ax.plot(X, Y)

ax.set_xlabel(r'$\log_{10} \, k$')
ax.set_ylabel(r'$\log_{10} \left( 1 - \frac{\sum_{i=1}^k S_i^2}{\sum_{i=1}^N S_i^2} \right)$')

#ax.set_ylim(-6, 0)
ax.grid()

## Basis error

We need the full basis first, and then we can sum up everything beyond truncation $k$ to calculate the error as a function of lambda.

In [0]:
S = pca_grid.grid.eigs['flux']
U = pca_grid.grid.eigv['flux']
S.shape, U.shape

In [0]:
sigma_wk = np.cumsum((S**2 * U**2)[:, ::-1], axis=1)[:, ::-1]
sigma_wk.shape

In [0]:
# Propagation of error

# SQRT -> ()^2 -> 
# def err_tr(sigma_2):
    

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2), dpi=200)

ax.plot(pca_grid.wave, np.sqrt(sigma_wk[:, 1000]), lw=0.3, label='k=1000')
ax.plot(pca_grid.wave, np.sqrt(sigma_wk[:, 2000]), lw=0.3, label='k=2000')
ax.plot(pca_grid.wave, np.sqrt(sigma_wk[:, 3000]), lw=0.3, label='k=3000')

ax.set_xlabel(r'$\lambda$ [A]')
ax.set_ylabel(r'$\sigma^2_k(\lambda)$')
ax.set_ylim(None, 0.001)
ax.grid()
ax.set_title('Residual error')
ax.legend()

## Spectral information

In [0]:
f, ax = plt.subplots(figsize=(7, 4), dpi=120)

k = -100
ax.plot(pca_grid.wave, U[:, k]**2 * L[k], lw=0.3, label=str(k))

ax.semilogy()

## Leftover noise

In [0]:
U = pca_grid.grid.eigv['flux']
L = pca_grid.grid.eigs['flux']
U.shape, L.shape

In [0]:
(U**2 * L).shape

In [0]:
UU = np.cumsum(U[:, ::-1]**2 * L[::-1], axis=1, )[:, ::-1]
UU.shape

In [0]:
f, ax = plt.subplots(figsize=(7, 4), dpi=120)

for k in [500, 1000, 2000, 5000]:
    l = ax.plot(pca_grid.wave, UU[:, k], lw=0.3, label="k = {}".format(k))
    ax.axhline(UU[:, k].min(), c=l[0].get_color())
    ax.axhline(UU[:, k].max(), c=l[0].get_color())

    ax.semilogy()
    
ax.set_title('"Leftover noise"')
ax.legend()

# Leverage score

In [0]:
U = pca_grid.grid.eigv['flux']
U.shape

In [0]:
UU = np.cumsum(U**2, axis=1)
UU.shape

In [0]:
f, ax = plt.subplots(figsize=(7, 4), dpi=120)

for k in [100, 200, 300, 500, 1000, 2000]:
    ax.plot(pca_grid.wave, UU[:, k], lw=0.3, label=str(k))
ax.semilogy()

ax.set_ylim(1e-7, 0)
ax.legend()

In [0]:
f, ax = plt.subplots(figsize=(7, 4), dpi=120)

for k, l in zip([100, 200, 300, 500, 1000], [0, 100, 200, 300, 500]):
    ax.plot(pca_grid.wave, UU[:, k] - UU[:, l], lw=0.3, label='{} - {}'.format(k, l))
ax.semilogy()

ax.set_ylim(1e-7, 0)
ax.legend()

# Plot principal components

In [0]:
#s = np.s_[:, :, :, idx[3], idx[4]]
s = np.s_[:, :, :, idx[-1]]

In [0]:
def load_params(name):
    params = pca_grid.array_grid.get_value(name)
    masks = pca_grid.array_grid.value_indexes[name]
        
    return params, masks

In [0]:
def plot_params(params, idx=2, param_idx=0):
    pp = params.shape[idx]
    rr = int(np.ceil(pp / 4 + 0.5))
    f, axs = plt.subplots(rr, 4, figsize=(16, 4 * rr))
    for p in range(pp):
        i = p // 4
        j = p % 4
        
        s = (params.ndim - 1) * [slice(None)]
        s[idx] = p
        s = tuple(s)
        vmin, vmax = params[s][..., param_idx].min(), params[s][..., param_idx].max()
        l = axs[i, j].imshow(params[s][..., param_idx], aspect='auto') #, vmin=vmin, vmax=vmax)
        f.colorbar(l, ax=axs[i, j])
        axs[i, j].set_xlabel('param: {} | slice: {}'.format(param_idx, p))

In [0]:
params, masks = load_params('flux')
params.shape[-1]

In [0]:
for pi in range(min(15, params.shape[-1])):
    plot_params(params[s], param_idx=pi)