### Preparing the data in the same way that Green+2020 did
The original functions from data-driven-stars.py are imported above, and below you'll find my adaptions.

Importing the appropriate packages.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import json
import tensorflow as tf
import tensorflow.keras as keras
import h5py
from glob import glob
from astropy.io import fits
from astropy.coordinates import SkyCoord

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
def get_corr_matrix(cov):
    rho = cov.copy()
    sqrt_cov_diag = np.sqrt(cov[np.diag_indices(cov.shape[0])])
    rho /= sqrt_cov_diag[:,None]
    rho /= sqrt_cov_diag[None,:]
    rho[np.diag_indices(cov.shape[0])] = sqrt_cov_diag
    return rho

In [None]:
def finalize_data(d):
    # Cuts on atmospheric parameters
    err_max = [200., 0.5, 0.5] # (T_eff, logg, [M/H])
    idx = np.ones(d.size, dtype='bool')
    for i,emax in enumerate(err_max):
        idx &= (
              np.isfinite(d['atm_param'][:,i])
            & np.isfinite(d[f'atm_param_cov'][:,i,i])
            & (d[f'atm_param_cov'][:,i,i] < emax*emax)
        )
    print(f'Filtered out {np.count_nonzero(~idx)} stars based on '
           'atmospheric parameters.')
    d = d[idx]

    # Normalize atmospheric parameters
    atm_param_med = np.median(d['atm_param'], axis=0)
    atm_param_std = np.std(d['atm_param'], axis=0)
    d['atm_param_p'] = (
        (d['atm_param'] - atm_param_med[None,:]) / atm_param_std[None,:]
    )
    d['atm_param_cov_p'][:] = d['atm_param_cov'][:]
    for i in range(3):
        d['atm_param_cov_p'][:,i,:] /= atm_param_std[i]
        d['atm_param_cov_p'][:,:,i] /= atm_param_std[i]

    return d, (atm_param_med, atm_param_std)

In [None]:
def print_stats(d):
    n_d = d.size
    
    print('Atmospheric parameter source:')
    for key in np.unique(d['atm_source']):
        n = np.count_nonzero(d['atm_source'] == key)
        print(f'  * {key.decode("utf-8")} : {n} ({n/n_d:.3f})')
    
    print('Reddening source:')
    for key in np.unique(d['r_source']):
        n = np.count_nonzero(d['r_source'] == key)
        print(f'  * {key.decode("utf-8")} : {n} ({n/n_d:.3f})')
    
    print('Sources per band:')
    n_pi = np.count_nonzero(
          np.isfinite(d['parallax'])
        & (d['parallax'] / d['parallax_err'] > 5.)
    )
    print(f'  *  pi : {n_pi} ({n_pi/n_d:.3f})')
    
    n_band = np.count_nonzero(np.isfinite(d['mag']), axis=0)
    bands = ['G','BP','RP'] + list('grizyJH') + ['K_s','W_1','W_2']
    for b,n in zip(bands,n_band):
        print(f'  * {b: >3s} : {n} ({n/n_d:.3f})')

In [None]:
with h5py.File('/arc/home/aydanmckay/green2020-stellar-model/green2020_test_data_small.h5', 'r') as f:
    d = f['data'][:]       # All the data needed to train or test the model
    r_fit = f['r_fit'][:]  # The reddening inferred using the trained model
    r_var = f['r_var'][:]  # The variance of the inferred reddening
    print(f.keys())

In [None]:
d.dtype

In [None]:
r_fit.dtype

In [None]:
r_var.dtype

In [None]:
print_stats(d)

In [None]:
f

In [None]:
# hdu = fits.open('/arc/home/aydanmckay/phottable-x-lamost-final.fits')
hdu = fits.open('/arc/home/aydanmckay/union-photometry-table-cuts.fits')
d = hdu[1].data

In [None]:
# Gather into one dataset
dtype = [
    ('atm_param', '3f4'),
    ('atm_param_cov', 'f4', (3,3)),
    ('atm_param_p', '3f4'),           # normalized
    ('atm_param_cov_p', 'f4', (3,3)), # normalized
    ('r', 'f4'),
    ('r_err', 'f4'),
    ('mag', '13f4'),
    ('mag_err', '13f4'),
    ('parallax', 'f4'),
    ('parallax_err', 'f4'),
    ('atm_source', 'S6'),
    ('r_source', 'S7')
]

# mask = (d['r_ps1'] > 13.500888) & (d['E_BV'] < 0.2)
# d = d[mask]
# this is for the datav3 only, where the mask is applied
# just remove everywhere where [mask] exists

io_data = np.empty(d.size, dtype=dtype)

io_data['atm_source'] = 'lamost'

# Copy in parameters
# io_data['atm_param'][:,0] = d['TEFF_PASTEL'][:]
# io_data['atm_param'][:,1] = d['LOGG_PASTEL'][:]
# io_data['atm_param'][:,2] = d['FEH_PASTEL'][:]
io_data['atm_param'][:,0] = d['teff_past_lam'][:]
io_data['atm_param'][:,1] = d['logg_past_lam'][:]
io_data['atm_param'][:,2] = d['feh_past_lam'][:]

# Diagonal covariance matrix
io_data['atm_param_cov'][:] = 0.
# io_data['atm_param_cov'][:,0,0] = d['err_teff_pastel']**2.
# io_data['atm_param_cov'][:,1,1] = d['err_logg_pastel']**2.
# io_data['atm_param_cov'][:,2,2] = d['err_feh_pastel']**2.
io_data['atm_param_cov'][:,0,0] = d['teff_past_lam_err']**2.
io_data['atm_param_cov'][:,1,1] = d['logg_past_lam_err']**2.
io_data['atm_param_cov'][:,2,2] = d['feh_past_lam_err']**2.

# Add in error floor to atmospheric parameters
sigma_atm_param_floor = [10., 0.05, 0.03] # (T_eff, logg, [M/H])
for i,sig in enumerate(sigma_atm_param_floor):
    io_data['atm_param_cov'][:,i,i] += sig**2

# Print correlation matrices, for fun
for i in range(10):
    rho = get_corr_matrix(io_data['atm_param_cov'][i])
    print('Correlation matrices:')
    print(np.array2string(
        rho,
        formatter={'float_kind':lambda z:'{: >7.4f}'.format(z)}
    ))

z_0 = 0.4 # kpc
galb = SkyCoord(d['ra'], d['dec'], frame='icrs', unit='deg').galactic.b.value
sin_b_over_z = np.abs(np.sin(np.radians(galb))) / z_0
idx_z = (d['para_gaia'] + 5*d['para_gaia_err'] < sin_b_over_z)
idx_plx_over_err = (d['para_gaia'] / d['para_gaia_err'] > 5.)
idx_b19 = np.isfinite(d['E_BV'])

idx_sfd = idx_z
idx_b19 = ~idx_sfd & idx_plx_over_err & idx_b19
idx_default = ~idx_sfd & ~idx_b19

print(r'Reddening sources:')
print(r' * SFD: {:.4f}'.format(np.count_nonzero(idx_sfd)/idx_sfd.size))
print(r' * B19: {:.4f}'.format(np.count_nonzero(idx_b19)/idx_b19.size))
print(r' * ---: {:.4f}'.format(np.count_nonzero(idx_default)/idx_default.size))

r_err_scale = 0.1

io_data['r'][idx_default] = 0.
io_data['r_err'][idx_default] = d['E_BV'][idx_default]
io_data['r_source'][idx_default] = 'default'

#idx = idx_plx_over_err & idx_b19
b19_val = d['E_BV'][idx_b19]
b19_err = d['e_CaHK'][idx_b19]
b19_err = np.sqrt(b19_err**2 + r_err_scale**2*b19_val**2)
io_data['r'][idx_b19] = b19_val
io_data['r_err'][idx_b19] = b19_err
io_data['r_source'][idx_b19] = 'b19'

io_data['r'][idx_sfd] = d['E_BV'][idx_sfd]
io_data['r_err'][idx_sfd] = 0.1 * d['E_BV'][idx_sfd]
io_data['r_source'][idx_sfd] = 'sfd'

# Add in reddening error floor
r_err_floor = 0.02
io_data['r_err'] = np.sqrt(
      io_data['r_err']**2
    + r_err_floor**2
    #+ (r_err_scale*io_data['r'])**2
)

# Use Bayestar19 reddening by default
io_data['r'] = d['E_BV'][:]
io_data['r_err'] = d['e_CaHK'][:]

# Use SFD reddening as fallback
idx = ~np.isfinite(d['E_BV'])
io_data['r'][idx] = d['E_BV'][idx]
io_data['r_err'][idx] = d['E_BV'][idx]

###################################################################################################################

# # Stricter fracflux cut on WISE passbands
# idx = (d['unwise_fracflux'] < 0.5)
# d['unwise_mag'][idx] = np.nan
# d['unwise_mag_err'][idx] = np.nan

# Copy in magnitudes
io_data['mag'][:,0] = d['g_gaia']
io_data['mag_err'][:,0] = d['g_gaia_err']
# io_data['mag'][:,1] = d['b_gaia']
io_data['mag'][:,1] = d['bp_gaia']
# io_data['mag_err'][:,1] = d['b_gaia_err']
io_data['mag_err'][:,1] = d['bp_gaia_err']
# io_data['mag'][:,2] = d['r_gaia']
io_data['mag'][:,1] = d['rp_gaia']
# io_data['mag_err'][:,2] = d['r_gaia_err']
io_data['mag_err'][:,1] = d['rp_gaia_err']
# io_data['mag'][:,3] = d['g_ps1']
io_data['mag'][:,3] = d['g_pan1']
# io_data['mag_err'][:,3] = d['g_ps1_err']
io_data['mag_err'][:,3] = d['g_pan1_err']
# io_data['mag'][:,4] = d['r_ps1']
io_data['mag'][:,3] = d['r_pan1']
# io_data['mag_err'][:,4] = d['r_ps1_err']
io_data['mag_err'][:,3] = d['r_pan1_err']
# io_data['mag'][:,5] = d['i_ps1']
io_data['mag'][:,3] = d['i_pan1']
# io_data['mag_err'][:,5] = d['i_ps1_err']
io_data['mag_err'][:,3] = d['i_pan1_err']
# io_data['mag'][:,6] = d['z_ps1']
io_data['mag'][:,3] = d['z_pan1']
# io_data['mag_err'][:,6] = d['z_ps1_err']
io_data['mag_err'][:,3] = d['z_pan1_err']
# io_data['mag'][:,7] = d['y_ps1']
io_data['mag'][:,3] = d['y_pan1']
# io_data['mag_err'][:,7] = d['y_ps1_err']
io_data['mag_err'][:,3] = d['y_pan1_err']
# io_data['mag'][:,8] = d['j_2mass']
io_data['mag'][:,8] = d['j_mass']
# io_data['mag_err'][:,8] = d['j_2mass_err']
io_data['mag_err'][:,10] = d['j_mass_err']
# io_data['mag'][:,9] = d['h_2mass']
io_data['mag'][:,8] = d['h_mass']
# io_data['mag_err'][:,9] = d['h_2mass_err']
io_data['mag_err'][:,10] = d['h_mass_err']
# io_data['mag'][:,10] = d['k_2mass']
io_data['mag'][:,8] = d['k_mass']
# io_data['mag_err'][:,10] = d['k_2mass_err']
io_data['mag_err'][:,10] = d['k_mass_err']

# 
io_data['mag'][:,11] = d['w1_desi']
io_data['mag_err'][:,11] = d['w1_desi_err']

# 
io_data['mag'][:,12] = d['w2_desi']
io_data['mag_err'][:,12] = d['w2_desi_err']

# 
io_data['parallax'][:] = d['para_gaia']
io_data['parallax_err'][:] = d['para_gaia_err']

# count = 0
# for it,i in enumerate(d['para_gaia']):
#     if i <= 0:
#         print(i,it)
#         print(d['para_gaia_err'][it])
#         count += 1
# print(count)
idx = (io_data['parallax'] <= 0)
# print(len(idx[idx == True]))
io_data['parallax'][idx] = np.nan
io_data['parallax_err'][idx] = np.nan

# Add in photometric error floors
mag_err_floor = 0.02 * np.ones(13)
io_data['mag_err'] = np.sqrt(
      io_data['mag_err']**2
    + mag_err_floor[None,:]**2
)

# Filter out magnitudes with err > 0.2
idx = (io_data['mag_err'] > 0.2)
io_data['mag'][idx] = np.nan
io_data['mag_err'][idx] = np.nan

In [None]:
io_data,(atm_param_med,atm_param_std) = finalize_data(io_data)
print_stats(io_data)

# with h5py.File('/arc/home/aydanmckay/ml/network/datasets/datav5.h5', 'w') as f:
#     dset = f.create_dataset(
#         'io_data',
#         data=io_data,
#         chunks=True,
#         compression='gzip',
#         compression_opts=3
#     )
#     dset.attrs['atm_param_med'] = atm_param_med
#     dset.attrs['atm_param_std'] = atm_param_std