### Preparing the data in the same way that Green+2020 did
The original functions from data-driven-stars.py are imported above, and below you'll find my adaptions.

Importing the appropriate packages.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import json
import tensorflow as tf
import tensorflow.keras as keras
import h5py
from glob import glob
from astropy.io import fits
from astropy.coordinates import SkyCoord

In [2]:
def get_corr_matrix(cov):
    rho = cov.copy()
    sqrt_cov_diag = np.sqrt(cov[np.diag_indices(cov.shape[0])])
    rho /= sqrt_cov_diag[:,None]
    rho /= sqrt_cov_diag[None,:]
    rho[np.diag_indices(cov.shape[0])] = sqrt_cov_diag
    return rho

In [3]:
def load_data(fnames):
    # Load in all the data files
    d = []
    b19 = []
    b19_err = []

    for fn in fnames:
        print('Loading {:s} ...'.format(fn))
        with h5py.File(fn, 'r') as f:
            d.append(f['stellar_phot_spec_ast'][:])
            b19.append(f['reddening'][:])
            b19_err.append(f['reddening_err'][:])

    d = np.hstack(d)
    b19 = np.hstack(b19)
    b19_err = np.hstack(b19_err)

    return d, b19, b19_err

In [4]:
def extract_data(d, b19, b19_err):
    # Gather into one dataset
    dtype = [
        ('atm_param', '3f4'),
        ('atm_param_cov', 'f4', (3,3)),
        ('atm_param_p', '3f4'),           # normalized
        ('atm_param_cov_p', 'f4', (3,3)), # normalized
        ('r', 'f4'),
        ('r_err', 'f4'),
        ('mag', '13f4'),
        ('mag_err', '13f4'),
        ('parallax', 'f4'),
        ('parallax_err', 'f4'),
        ('atm_source', 'S6'),
        ('r_source', 'S7')
    ]
    io_data = np.empty(d.size, dtype=dtype)
    
    # Offsets to bring spectroscopic labels from different
    # surveys into alignment
    offsets = {
        'apogee': np.array([23.04715, 0.01189, 0.05019]),
        'lamost': np.array([0., 0., 0.]),
        'galah': np.array([-3.60096, -0.01396, 0.06770])
    }
    # Fix offsets to GALAH
    offsets['apogee'] -= offsets['galah']
    offsets['lamost'] -= offsets['galah']
    offsets['galah'][:] = 0.
    
    print('offsets:')
    for key in offsets:
        print(f'  * {key}: {offsets[key]}')

    # How to load data depends on survey
    if 'sdss_aspcap_param' in d.dtype.names: # APOGEE
        io_data['atm_source'] = 'apogee'
        
        param_idx = [0, 1, 3] # (T_eff, logg, [M/H])
        param_name = ['teff', 'logg', 'm_h']

        # Copy in parameters and corresponding covariance entries
        for k,i in enumerate(param_idx):
            io_data['atm_param'][:,k] = d['sdss_aspcap_param'][:,i]
            for l,j in enumerate(param_idx):
                io_data['atm_param_cov'][:,k,l] = d['sdss_aspcap_fparam_cov'][:,9*k+l]
        
        io_data['atm_param'] -= offsets['apogee'][None,:]

        # Copy calibrated errors into diagonals of covariance matrices.
        #   - Keep uncalibrated errors if larger.
        for k,n in enumerate(param_name):
            io_data['atm_param_cov'][:,k,k] = np.maximum(
                d[f'sdss_aspcap_{n}_err']**2,
                io_data['atm_param_cov'][:,k,k]
            )

    elif 'ddpayne_teff' in d.dtype.names: # LAMOST DDPAYNE
        io_data['atm_source'] = 'lamost'
        
        # Copy in parameters
        io_data['atm_param'][:,0] = d['ddpayne_teff'][:]
        io_data['atm_param'][:,1] = d['ddpayne_logg'][:]
        io_data['atm_param'][:,2] = d['ddpayne_feh'][:]
        io_data['atm_param'] -= offsets['lamost'][None,:]

        # Diagonal covariance matrix
        io_data['atm_param_cov'][:] = 0.
        io_data['atm_param_cov'][:,0,0] = d['ddpayne_teff_err']**2.
        io_data['atm_param_cov'][:,1,1] = d['ddpayne_logg_err']**2.
        io_data['atm_param_cov'][:,2,2] = d['ddpayne_feh_err']**2.
    elif 'snr_c1' in d.dtype.names: # GALAH
        io_data['atm_source'] = 'galah'
        
        # Copy in parameters
        io_data['atm_param'][:,0] = d['teff'][:]
        io_data['atm_param'][:,1] = d['logg'][:]
        io_data['atm_param'][:,2] = d['feh'][:]
        io_data['atm_param'] -= offsets['galah'][None,:]

        # Diagonal covariance matrix
        io_data['atm_param_cov'][:] = 0.
        io_data['atm_param_cov'][:,0,0] = d['teff_err']**2.
        io_data['atm_param_cov'][:,1,1] = d['logg_err']**2.
        io_data['atm_param_cov'][:,2,2] = d['feh_err']**2.
    
    # Add in error floor to atmospheric parameters
    sigma_atm_param_floor = [10., 0.05, 0.03] # (T_eff, logg, [M/H])
    for i,sig in enumerate(sigma_atm_param_floor):
        io_data['atm_param_cov'][:,i,i] += sig**2

    # Print correlation matrices, for fun
    for i in range(10):
        rho = get_corr_matrix(io_data['atm_param_cov'][i])
        print('Correlation matrices:')
        print(np.array2string(
            rho,
            formatter={'float_kind':lambda z:'{: >7.4f}'.format(z)}
        ))

    # Reddening sources, in order of priority:
    #   1. If |z| > 400 pc: Use SFD with 10% uncertainty
    #   2. If parallax/error > 5: Use Bayestar19
    #   3. Otherwise: Use 0 +- SFD
    
    z_0 = 0.4 # kpc
    sin_b_over_z = np.abs(np.sin(np.radians(d['gal_b']))) / z_0
    idx_z = (d['parallax'] + 5*d['parallax_err'] < sin_b_over_z)
    idx_plx_over_err = (d['parallax'] / d['parallax_err'] > 5.)
    idx_b19 = np.isfinite(b19)
    
    idx_sfd = idx_z
    idx_b19 = ~idx_sfd & idx_plx_over_err & idx_b19
    idx_default = ~idx_sfd & ~idx_b19
    
    print(r'Reddening sources:')
    print(r' * SFD: {:.4f}'.format(np.count_nonzero(idx_sfd)/idx_sfd.size))
    print(r' * B19: {:.4f}'.format(np.count_nonzero(idx_b19)/idx_b19.size))
    print(r' * ---: {:.4f}'.format(np.count_nonzero(idx_default)/idx_default.size))
    
    r_err_scale = 0.1
    
    io_data['r'][idx_default] = 0.
    io_data['r_err'][idx_default] = d['SFD'][idx_default]
    io_data['r_source'][idx_default] = 'default'
    
    #idx = idx_plx_over_err & idx_b19
    b19_val = b19[idx_b19]
    b19_err = b19_err[idx_b19]
    b19_err = np.sqrt(b19_err**2 + r_err_scale**2*b19_val**2)
    io_data['r'][idx_b19] = b19_val
    io_data['r_err'][idx_b19] = b19_err
    io_data['r_source'][idx_b19] = 'b19'
    
    io_data['r'][idx_sfd] = d['SFD'][idx_sfd]
    io_data['r_err'][idx_sfd] = 0.1 * d['SFD'][idx_sfd]
    io_data['r_source'][idx_sfd] = 'sfd'
    
    # Add in reddening error floor
    r_err_floor = 0.02
    io_data['r_err'] = np.sqrt(
          io_data['r_err']**2
        + r_err_floor**2
        #+ (r_err_scale*io_data['r'])**2
    )
    
    ## Use Bayestar19 reddening by default
    #io_data['r'] = b19[:]
    #io_data['r_err'] = b19_err[:]
    #
    ## Use SFD reddening as fallback
    #idx = ~np.isfinite(b19)
    #io_data['r'][idx] = d['SFD'][idx]
    #io_data['r_err'][idx] = d['SFD'][idx]
    
    # Stricter fracflux cut on WISE passbands
    idx = (d['unwise_fracflux'] < 0.5)
    d['unwise_mag'][idx] = np.nan
    d['unwise_mag_err'][idx] = np.nan

    # Copy in magnitudes
    io_data['mag'][:,0] = d['gaia_g_mag']
    io_data['mag_err'][:,0] = d['gaia_g_mag_err']
    io_data['mag'][:,1] = d['gaia_bp_mag']
    io_data['mag_err'][:,1] = d['gaia_bp_mag_err']
    io_data['mag'][:,2] = d['gaia_rp_mag']
    io_data['mag_err'][:,2] = d['gaia_rp_mag_err']
    io_data['mag'][:,3:8] = d['ps1_mag']
    io_data['mag_err'][:,3:8] = d['ps1_mag_err']
    for i,b in enumerate('JHK'):
        io_data['mag'][:,8+i] = d[f'tmass_{b}_mag']
        io_data['mag_err'][:,8+i] = d[f'tmass_{b}_mag_err']
    io_data['mag'][:,11:13] = d['unwise_mag']
    io_data['mag_err'][:,11:13] = d['unwise_mag_err']

    io_data['parallax'][:] = d['parallax']
    io_data['parallax_err'][:] = d['parallax_err']

    # Add in photometric error floors
    mag_err_floor = 0.02 * np.ones(13)
    io_data['mag_err'] = np.sqrt(
          io_data['mag_err']**2
        + mag_err_floor[None,:]**2
    )

    # Filter out magnitudes with err > 0.2
    idx = (io_data['mag_err'] > 0.2)
    io_data['mag'][idx] = np.nan
    io_data['mag_err'][idx] = np.nan

    return io_data

In [5]:
def extract_data_multiple(fname_lists):
    d_list = []

    for fnames in fname_lists:
        d,b19,b19_err = load_data(fnames)
        d = extract_data(d, b19, b19_err)
        print(f'Extracted {d.size} stars.')
        d_list.append(d)

    d = np.hstack(d_list)
    
    return d

In [6]:
def finalize_data(d):
    # Cuts on atmospheric parameters
    err_max = [200., 0.5, 0.5] # (T_eff, logg, [M/H])
    idx = np.ones(d.size, dtype='bool')
    for i,emax in enumerate(err_max):
        idx &= (
              np.isfinite(d['atm_param'][:,i])
            & np.isfinite(d[f'atm_param_cov'][:,i,i])
            & (d[f'atm_param_cov'][:,i,i] < emax*emax)
        )
    print(f'Filtered out {np.count_nonzero(~idx)} stars based on '
           'atmospheric parameters.')
    d = d[idx]

    # Normalize atmospheric parameters
    atm_param_med = np.median(d['atm_param'], axis=0)
    atm_param_std = np.std(d['atm_param'], axis=0)
    d['atm_param_p'] = (
        (d['atm_param'] - atm_param_med[None,:]) / atm_param_std[None,:]
    )
    d['atm_param_cov_p'][:] = d['atm_param_cov'][:]
    for i in range(3):
        d['atm_param_cov_p'][:,i,:] /= atm_param_std[i]
        d['atm_param_cov_p'][:,:,i] /= atm_param_std[i]

    return d, (atm_param_med, atm_param_std)

In [7]:
def print_stats(d):
    n_d = d.size
    
    print('Atmospheric parameter source:')
    for key in np.unique(d['atm_source']):
        n = np.count_nonzero(d['atm_source'] == key)
        print(f'  * {key.decode("utf-8")} : {n} ({n/n_d:.3f})')
    
    print('Reddening source:')
    for key in np.unique(d['r_source']):
        n = np.count_nonzero(d['r_source'] == key)
        print(f'  * {key.decode("utf-8")} : {n} ({n/n_d:.3f})')
    
    print('Sources per band:')
    n_pi = np.count_nonzero(
          np.isfinite(d['parallax'])
        & (d['parallax'] / d['parallax_err'] > 5.)
    )
    print(f'  *  pi : {n_pi} ({n_pi/n_d:.3f})')
    
    n_band = np.count_nonzero(np.isfinite(d['mag']), axis=0)
    bands = ['G','BP','RP'] + list('grizyJH') + ['K_s','W_1','W_2']
    for b,n in zip(bands,n_band):
        print(f'  * {b: >3s} : {n} ({n/n_d:.3f})')

In [8]:
# fnames = [
#     glob('data/dr16_data_*to*.h5'),
#     glob('data/ddpayne_data_*to*.h5'),
#     glob('data/galah_data_*to*.h5')
# ]
# d = extract_data_multiple(fnames)
# d,(atm_param_med,atm_param_std) = finalize_data(d)
# print_stats(d)

# with h5py.File('data/apogee_lamost_galah_data.h5', 'w') as f:
#     dset = f.create_dataset(
#         'io_data',
#         data=d,
#         chunks=True,
#         compression='gzip',
#         compression_opts=3
#     )
#     dset.attrs['atm_param_med'] = atm_param_med
#     dset.attrs['atm_param_std'] = atm_param_std

In [9]:
with h5py.File('/arc/home/aydanmckay/green2020-stellar-model/green2020_test_data_small.h5', 'r') as f:
    d = f['data'][:]       # All the data needed to train or test the model
    r_fit = f['r_fit'][:]  # The reddening inferred using the trained model
    r_var = f['r_var'][:]  # The variance of the inferred reddening
    print(f.keys())

<KeysViewHDF5 ['data', 'r_fit', 'r_var']>


In [10]:
d.dtype

dtype([('atm_param', '<f4', (3,)), ('atm_param_cov', '<f4', (3, 3)), ('atm_param_p', '<f4', (3,)), ('atm_param_cov_p', '<f4', (3, 3)), ('r', '<f4'), ('r_err', '<f4'), ('mag', '<f4', (13,)), ('mag_err', '<f4', (13,)), ('parallax', '<f4'), ('parallax_err', '<f4'), ('atm_source', 'S6'), ('r_source', 'S7')])

In [11]:
r_fit.dtype

dtype('float32')

In [12]:
r_var.dtype

dtype('float32')

In [13]:
print_stats(d)

Atmospheric parameter source:
  * apogee : 570 (0.057)
  * galah : 466 (0.047)
  * lamost : 8964 (0.896)
Reddening source:
  * b19 : 5086 (0.509)
  * default : 949 (0.095)
  * sfd : 3965 (0.397)
Sources per band:
  *  pi : 8462 (0.846)
  *   G : 10000 (1.000)
  *  BP : 9989 (0.999)
  *  RP : 9990 (0.999)
  *   g : 8181 (0.818)
  *   r : 7046 (0.705)
  *   i : 6438 (0.644)
  *   z : 7381 (0.738)
  *   y : 8419 (0.842)
  *   J : 9482 (0.948)
  *   H : 9299 (0.930)
  * K_s : 8881 (0.888)
  * W_1 : 2909 (0.291)
  * W_2 : 6498 (0.650)


In [14]:
hdu = fits.open('/arc/home/aydanmckay/union-photometry-table-cuts.fits')
d = hdu[1].data

In [18]:
# Gather into one dataset
dtype = [
    ('atm_param', '3f4'),
    ('atm_param_cov', 'f4', (3,3)),
    ('atm_param_p', '3f4'),           # normalized
    ('atm_param_cov_p', 'f4', (3,3)), # normalized
    ('r', 'f4'),
    ('r_err', 'f4'),
    ('mag', '13f4'),
    ('mag_err', '13f4'),
    ('parallax', 'f4'),
    ('parallax_err', 'f4'),
    ('atm_source', 'S6'),
    ('r_source', 'S7')
]
io_data = np.empty(d.size, dtype=dtype)

io_data['atm_source'] = 'lamost'

# Copy in parameters
io_data['atm_param'][:,0] = d['teff_past_lam'][:]
io_data['atm_param'][:,1] = d['logg_past_lam'][:]
io_data['atm_param'][:,2] = d['feh_past_lam'][:]

# Diagonal covariance matrix
io_data['atm_param_cov'][:] = 0.
io_data['atm_param_cov'][:,0,0] = d['teff_past_lam_err']**2.
io_data['atm_param_cov'][:,1,1] = d['logg_past_lam_err']**2.
io_data['atm_param_cov'][:,2,2] = d['feh_past_lam_err']**2.

# Add in error floor to atmospheric parameters
sigma_atm_param_floor = [10., 0.05, 0.03] # (T_eff, logg, [M/H])
for i,sig in enumerate(sigma_atm_param_floor):
    io_data['atm_param_cov'][:,i,i] += sig**2

# Print correlation matrices, for fun
for i in range(10):
    rho = get_corr_matrix(io_data['atm_param_cov'][i])
    print('Correlation matrices:')
    print(np.array2string(
        rho,
        formatter={'float_kind':lambda z:'{: >7.4f}'.format(z)}
    ))

z_0 = 0.4 # kpc
galb = SkyCoord(d['ra'], d['dec'], frame='icrs', unit='deg').galactic.b.value
sin_b_over_z = np.abs(np.sin(np.radians(galb))) / z_0
idx_z = (d['para_gaia'] + 5*d['para_gaia_err'] < sin_b_over_z)
idx_plx_over_err = (d['para_gaia'] / d['para_gaia_err'] > 5.)
idx_b19 = np.isfinite(d['E_BV'])

idx_sfd = idx_z
idx_b19 = ~idx_sfd & idx_plx_over_err & idx_b19
idx_default = ~idx_sfd & ~idx_b19

print(r'Reddening sources:')
print(r' * SFD: {:.4f}'.format(np.count_nonzero(idx_sfd)/idx_sfd.size))
print(r' * B19: {:.4f}'.format(np.count_nonzero(idx_b19)/idx_b19.size))
print(r' * ---: {:.4f}'.format(np.count_nonzero(idx_default)/idx_default.size))

r_err_scale = 0.1

io_data['r'][idx_default] = 0.
io_data['r_err'][idx_default] = d['E_BV'][idx_default]
io_data['r_source'][idx_default] = 'default'

#idx = idx_plx_over_err & idx_b19
b19_val = d['E_BV'][idx_b19]
b19_err = d['e_CaHK'][idx_b19]
b19_err = np.sqrt(b19_err**2 + r_err_scale**2*b19_val**2)
io_data['r'][idx_b19] = b19_val
io_data['r_err'][idx_b19] = b19_err
io_data['r_source'][idx_b19] = 'b19'

io_data['r'][idx_sfd] = d['E_BV'][idx_sfd]
io_data['r_err'][idx_sfd] = 0.1 * d['E_BV'][idx_sfd]
io_data['r_source'][idx_sfd] = 'sfd'

# Add in reddening error floor
r_err_floor = 0.02
io_data['r_err'] = np.sqrt(
      io_data['r_err']**2
    + r_err_floor**2
    #+ (r_err_scale*io_data['r'])**2
)

# Use Bayestar19 reddening by default
io_data['r'] = d['E_BV'][:]
io_data['r_err'] = d['e_CaHK'][:]

# Use SFD reddening as fallback
idx = ~np.isfinite(d['E_BV'])
io_data['r'][idx] = d['E_BV'][idx]
io_data['r_err'][idx] = d['E_BV'][idx]

###################################################################################################################

# # Stricter fracflux cut on WISE passbands
# idx = (d['unwise_fracflux'] < 0.5)
# d['unwise_mag'][idx] = np.nan
# d['unwise_mag_err'][idx] = np.nan

# Copy in magnitudes
io_data['mag'][:,0] = d['g_gaia']
io_data['mag_err'][:,0] = d['g_gaia_err']
io_data['mag'][:,1] = d['bp_gaia']
io_data['mag_err'][:,1] = d['bp_gaia_err']
io_data['mag'][:,2] = d['rp_gaia']
io_data['mag_err'][:,2] = d['rp_gaia_err']
io_data['mag'][:,3] = d['g_pan1']
io_data['mag_err'][:,3] = d['g_pan1_err']
io_data['mag'][:,4] = d['r_pan1']
io_data['mag_err'][:,4] = d['r_pan1_err']
io_data['mag'][:,5] = d['i_pan1']
io_data['mag_err'][:,5] = d['i_pan1_err']
io_data['mag'][:,6] = d['z_pan1']
io_data['mag_err'][:,6] = d['z_pan1_err']
io_data['mag'][:,7] = d['y_pan1']
io_data['mag_err'][:,7] = d['y_pan1_err']
io_data['mag'][:,8] = d['j_mass']
io_data['mag_err'][:,8] = d['j_mass_err']
io_data['mag'][:,9] = d['h_mass']
io_data['mag_err'][:,9] = d['h_mass_err']
io_data['mag'][:,10] = d['k_mass']
io_data['mag_err'][:,10] = d['k_mass_err']
# 
# idx = (np.isinf(d['w1_desi']))
# d['w1_desi'][idx] = np.nan
# plt.hist(d['w1_desi'],bins=50)
io_data['mag'][:,11] = d['w1_desi']
io_data['mag_err'][:,11] = d['w1_desi_err']
# 
# idx = (np.isinf(d['w2_desi']))
# d['w2_desi'][idx] = np.nan
io_data['mag'][:,12] = d['w2_desi']
io_data['mag_err'][:,12] = d['w2_desi_err']

io_data['parallax'][:] = d['para_gaia']
io_data['parallax_err'][:] = d['para_gaia_err']

# Add in photometric error floors
mag_err_floor = 0.02 * np.ones(13)
io_data['mag_err'] = np.sqrt(
      io_data['mag_err']**2
    + mag_err_floor[None,:]**2
)

# Filter out magnitudes with err > 0.2
idx = (io_data['mag_err'] > 0.2)
io_data['mag'][idx] = np.nan
io_data['mag_err'][idx] = np.nan

Correlation matrices:
[[168.9941  0.0000  0.0000]
 [ 0.0000  0.3836  0.0000]
 [ 0.0000  0.0000  0.2525]]
Correlation matrices:
[[153.8716  0.0000  0.0000]
 [ 0.0000  0.3462  0.0000]
 [ 0.0000  0.0000  0.2263]]
Correlation matrices:
[[222.7133  0.0000  0.0000]
 [ 0.0000  0.5099  0.0000]
 [ 0.0000  0.0000  0.3415]]
Correlation matrices:
[[129.7830  0.0000  0.0000]
 [ 0.0000  0.2836  0.0000]
 [ 0.0000  0.0000  0.1825]]
Correlation matrices:
[[215.3058  0.0000  0.0000]
 [ 0.0000  0.4929  0.0000]
 [ 0.0000  0.0000  0.3295]]
Correlation matrices:
[[95.2142  0.0000  0.0000]
 [ 0.0000  0.1806  0.0000]
 [ 0.0000  0.0000  0.1108]]
Correlation matrices:
[[253.4684  0.0000  0.0000]
 [ 0.0000  0.5792  0.0000]
 [ 0.0000  0.0000  0.3905]]
Correlation matrices:
[[228.5512  0.0000  0.0000]
 [ 0.0000  0.5232  0.0000]
 [ 0.0000  0.0000  0.3509]]
Correlation matrices:
[[226.5648  0.0000  0.0000]
 [ 0.0000  0.5187  0.0000]
 [ 0.0000  0.0000  0.3477]]
Correlation matrices:
[[249.8151  0.0000  0.0000]
 [ 0.0

KeyError: "Key 'w1_desi' does not exist."

In [None]:
io_data,(atm_param_med,atm_param_std) = finalize_data(io_data)
print_stats(io_data)

with h5py.File('/arc/home/aydanmckay/ml/network/datalargev1.h5', 'w') as f:
    dset = f.create_dataset(
        'io_data',
        data=io_data,
        chunks=True,
        compression='gzip',
        compression_opts=3
    )
    dset.attrs['atm_param_med'] = atm_param_med
    dset.attrs['atm_param_std'] = atm_param_std

In [None]:
# d['w1flux_desi']

In [None]:
# print(np.unique(d['w1flux_desi']))
# plt.hist(d['w1flux_desi'],bins=50)
# plt.show()

In [None]:
# idx = (np.isinf(d['w1_desi']))
# d['w1_desi'][idx] = np.nan
# print(np.unique(d['w1_desi']))
# plt.hist(d['w1_desi'],bins=50)
# plt.show()