In [0]:
PROJECT_PATH = '/home/dobos/project/pfs_isochrones/python:/home/dobos/project/pfsspec-all:/home/dobos/project/pysynphot'

ISOCHRONES_PATH = '/datascope/subaru/data/isochrones/dartmouth/import/afep0_cfht_sdss_hsc'
GRID_PATH = '/datascope/subaru/data/pfsspec/models/stellar/rbf/phoenix/phoenix_HiRes_GK/norm'

In [0]:
PROJECT_PATH.split(':')

In [0]:
%matplotlib inline

In [0]:
import os, sys
import matplotlib.pyplot as plt
import numpy as np
import h5py as h5
from scipy.ndimage import binary_dilation

In [0]:
for p in reversed(PROJECT_PATH.split(':')):
    sys.path.insert(0, p)

In [0]:
sys.path

In [0]:
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [0]:
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

In [0]:
tf.config.list_physical_devices('GPU') 

# Load isochrones

In [0]:
from pfs.ga.isochrones.isogrid import IsoGrid

In [0]:
iso = IsoGrid()
iso.load(os.path.join(ISOCHRONES_PATH, 'isochrones.h5'))

In [0]:
iso.axes.keys()

In [0]:
for k in iso.values.keys():
    print(k, 
          tf.math.count_nonzero(tf.math.is_inf(iso.values[k])).numpy(),
          tf.math.count_nonzero(tf.math.is_nan(iso.values[k])).numpy())

In [0]:
f, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120)

X = iso.values['Log_T_eff'].numpy().flatten()
Y = iso.values['log_g'].numpy().flatten()
C = tf.broadcast_to(iso.Fe_H[:, tf.newaxis, tf.newaxis], iso.values['Log_T_eff'].shape)

#ax.plot(X, Y, 'sk', ms=0.1, alpha=0.1, rasterized=True)
ax.scatter(X, Y, c=C, s=0.1, rasterized=True, edgecolor='None', )

ax.invert_xaxis()
ax.set_xlabel('log T_eff')
ax.set_ylabel('log g')

# Load spectrum grid

In [0]:
from pfs.ga.pfsspec.core.grid import ArrayGrid
from pfs.ga.pfsspec.stellar.grid import ModelGrid
from pfs.ga.pfsspec.stellar.grid.bosz import Bosz
from pfs.ga.pfsspec.stellar.grid.phoenix import Phoenix

In [0]:
fn = os.path.join(GRID_PATH, 'spectra.h5')
grid = ModelGrid(Phoenix(normalized=True), ArrayGrid)
grid.preload_arrays = False
grid.load(fn, format='h5')

In [0]:
grid.array_grid.axes['log_g'].values

In [0]:
grid.array_grid.axes['T_eff'].values

In [0]:
grid.array_grid.axes['log_g'].values

In [0]:
f, ax = plt.subplots(1, 1, figsize=(14, 6), dpi=120)

X, Y = np.meshgrid(grid.array_grid.axes['T_eff'].values, grid.array_grid.axes['log_g'].values)
X, Y = X.flatten(), Y.flatten()

ax.plot(X, Y, 's', fillstyle='none')

###########

X = iso.values['Log_T_eff'].numpy().flatten()
Y = iso.values['log_g'].numpy().flatten()
C = tf.broadcast_to(iso.Fe_H[:, tf.newaxis, tf.newaxis], iso.values['Log_T_eff'].shape)

#ax.plot(X, Y, 'sk', ms=0.1, alpha=0.1, rasterized=True)
ax.scatter(10**X, Y, c=C, s=0.1, rasterized=True, edgecolor='None', )

ax.set_xscale('log')
ax.invert_xaxis()
ax.invert_yaxis()

ax.set_xlabel('T_eff')
ax.set_ylabel('log g')

# Generate the grid limits

In [0]:
def find_bins(x):
    y = np.empty(x.shape[0] + 1)
    y[1:-1] = 0.5 * (x[1:] + x[:-1])
    y[0] = 2 * y[1] - y[2]
    y[-1] = 2 * y[-2] - y[-3]
    
    return y

T_eff = find_bins(grid.array_grid.axes['T_eff'].values)
log_g = find_bins(grid.array_grid.axes['log_g'].values)

T_eff, log_g

In [0]:
X = iso.values['Log_T_eff'].numpy().flatten()
Y = iso.values['log_g'].numpy().flatten()

hist, _, _ = np.histogram2d(10**X, Y, (T_eff, log_g))

In [0]:
f, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=120)

l = ax.imshow(hist.T, extent=[T_eff.min(), T_eff.max(), log_g.min(), log_g.max()], origin='lower', aspect=200)
ax.invert_xaxis()
ax.invert_yaxis()

f.colorbar(l)

In [0]:
f, axs = plt.subplots(2, 1, figsize=(4, 5), dpi=120)

axs[0].imshow(hist.T > 0, extent=[T_eff.min(), T_eff.max(), log_g.min(), log_g.max()], origin='lower', aspect=200)
axs[1].imshow(binary_dilation(hist.T > 0, structure=np.array([[1, 1, 1]])), extent=[T_eff.min(), T_eff.max(), log_g.min(), log_g.max()], origin='lower', aspect=200)

for ax in axs:
    ax.invert_xaxis()
    ax.invert_yaxis()

In [0]:
np.isnan(hist).sum()

In [0]:
w = hist.copy()
w[w < 1] = 1
w = np.log(w)
w /= w.max()
w[w < 0.1] = 0.1

h, b = np.histogram(w.flatten(), bins=20)
plt.step(0.5 * (b[1:] + b[:-1]), h, where='mid')
plt.grid()

In [0]:
w.shape

In [0]:
f, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=120)

l = ax.imshow(w.T, extent=[T_eff.min(), T_eff.max(), log_g.min(), log_g.max()], origin='lower', aspect=200)
ax.invert_xaxis()
ax.invert_yaxis()

f.colorbar(l)

In [0]:
for i, k, ax in grid.enumerate_axes():
    print(k, ax.values.shape)

In [0]:
ww = np.broadcast_to(
    w[np.newaxis, :, :, np.newaxis], 
    grid.grid.axes['M_H'].values.shape + w.shape + grid.grid.axes['a_M'].values.shape)
ww.shape

# Save weights into the grid

In [0]:
fn = os.path.join(GRID_PATH, 'weights.h5')
weights_grid = ModelGrid(Phoenix(normalized=False), ArrayGrid)
weights_grid.preload_arrays = False
#weights_grid.save(fn, format='h5')

weights_axes = { p: axis for i, p, axis in grid.enumerate_axes() }

In [0]:
for k in weights_axes:
    print(k, weights_axes[k].values)

In [0]:
weights_grid.set_axes(weights_axes)
weights_grid.build_axis_indexes()
weights_grid.save(fn, format='h5')

In [0]:
for i, p, axis in weights_grid.enumerate_axes():
    print(p, axis.values)

In [0]:
weights_grid.array_grid.init_value("weight", shape=(1,), )
weights_grid.array_grid.set_value("weight", ww[..., np.newaxis])

weights_grid.save(fn, format='h5')

In [0]:
!h5ls -r "$fn"

In [0]:
weights_grid.array_grid.values

# Save mask into a grid

In [0]:
mask = binary_dilation(hist.T > 0, structure=np.array([[1, 1, 1]])).T

In [0]:
f, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=120)

l = ax.imshow(mask.T, extent=[T_eff.min(), T_eff.max(), log_g.min(), log_g.max()], origin='lower', aspect=200)
ax.invert_xaxis()
ax.invert_yaxis()

f.colorbar(l)

In [0]:
mm = np.broadcast_to(
    mask[np.newaxis, :, :, np.newaxis], 
    grid.grid.axes['M_H'].values.shape + mask.shape + grid.grid.axes['a_M'].values.shape)
mm.shape

In [0]:
fn = os.path.join(GRID_PATH, 'mask.h5')
weights_grid = ModelGrid(Phoenix(normalized=False), ArrayGrid)
weights_grid.preload_arrays = False
#weights_grid.save(fn, format='h5')

weights_axes = { p: axis for i, p, axis in grid.enumerate_axes() }

In [0]:
weights_grid.set_axes(weights_axes)
weights_grid.build_axis_indexes()
weights_grid.save(fn, format='h5')

In [0]:
weights_grid.array_grid.init_value("weight", shape=(1,), )
weights_grid.array_grid.set_value("weight", np.where(mm[..., np.newaxis], 1.0, 0.0), valid=np.full_like(mm, True, dtype=bool))

weights_grid.save(fn, format='h5')

In [0]:
!ls -lah "$fn"

In [0]:
!h5ls -r "$fn"