Importing the required packages

In [1]:
#!/usr/bin/env python

from __future__ import print_function, division

import numpy as np
import scipy.stats
import h5py

import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.python.ops import math_ops
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from glob import glob
from time import time
import json

In [2]:
def get_corr_matrix(cov):
    rho = cov.copy()
    sqrt_cov_diag = np.sqrt(cov[np.diag_indices(cov.shape[0])])
    rho /= sqrt_cov_diag[:,None]
    rho /= sqrt_cov_diag[None,:]
    rho[np.diag_indices(cov.shape[0])] = sqrt_cov_diag
    return rho

In [3]:
def get_inputs_outputs(d, pretrained_model=None,
                          recalc_reddening=False,
                          rchisq_max=None,
                          return_cov_components=False):
    n_bands = 110 # Gaia (G, BP, RP), PS1 (grizy), 2MASS (JHK), unWISE (W1,W2)
    n_atm_params = 3 # (T_eff, logg, [M/H])
    
    large_err = 999.

    # Stellar spectroscopic parameters
    print('Fill in stellar atmospheric parameters ...')
    x = np.empty((d.size,3), dtype='f4')
    x[:] = d['atm_param'][:]

    x_p = np.empty((d.size,3), dtype='f4')
    x_p = d['atm_param_p'][:]
    x_p = x_p[0]
    print(x_p.shape)

    # Magnitudes
    print('Fill in stellar magnitudes ...')
    y = np.empty((d.size,n_bands), dtype='f4')
    y[:] = d['bprp'][:]

    # Covariance of y
    print('Empty covariance matrix ...')
    cov_y = np.zeros((d.size,n_bands,n_bands), dtype='f4')

    # \delta m
    print('Covariance: \delta m ...')
    for i in range(n_bands):
        cov_y[:,i,i] = d['bprperr'][:,:,i]**2

    # Replace NaN magnitudes with median (in each band).
    # Also set corresponding variances to large number.
    print('Replace NaN magnitudes ...')
    for b in range(n_bands):
        idx = (
              ~np.isfinite(y[:,b])
            | ~np.isfinite(cov_y[:,b,b])
        )
        n_bad = np.count_nonzero(idx)
        n_tot = idx.size
        y0 = np.median(y[~idx,b])
        if np.isnan(y0):
            y0 = 0.
        print(f'Band {b}: {n_bad} of {n_tot} bad. Replacing with {y0:.5f}.')
        y[idx,b] = y0
        cov_y[idx,b,b] = large_err**2.

    # Transform both y and its covariance
    B = np.identity(n_bands, dtype='f4')
    B[1:,0] = -1.
    
    print('Transform y -> B y ...')
    y = np.einsum('ij,nj->ni', B, y) # y' = B y
    print('Transform C -> B C B^T ...')
    #cov_y = np.einsum('ik,nkl,jl->nij', B, cov_y, B) # C' = B C B^T
    cov_y = np.einsum('nik,jk->nij', cov_y, B)
    cov_y = np.einsum('ik,nkj->nij', B, cov_y)
    
    if return_cov_components:
        cov_comp = {
            'delta_m': cov_y.copy()
        }
    #########################################################################3
    # Add in dM/dtheta and dR/dtheta terms
    if pretrained_model is not None:
        print('Calculate J = dM/dtheta ...')
        J_M = calc_dmag_color_dtheta(pretrained_model, x_p)
        # cov_x = d['atm_param_cov_p']
        cov_x = d['atm_param_cov_p'][0]
        print('Covariance: J C_theta J^T ...')
        print('J_M.shape =',J_M.shape)
        print('cov_x.shape =',cov_x.shape)
        print('cov_y.shape =',cov_y.shape)
        cov_y += np.einsum('nik,nkl,njl->nij', J_M, cov_x, J_M)
        
        if return_cov_components:
            cov_comp['dM/dtheta'] = np.einsum('nik,nkl,njl->nij', J_M, cov_x, J_M)
    ##########################################################################
    # print('{:d} NaN parallaxes'.format(
    #     np.count_nonzero(np.isnan(d['parallax']))
    # ))
    # err_over_plx = d['parallax_err'] / d['parallax']
    # print('Covariance: DM uncertainty term ...')
    # cov_y[:,0,0] += (5./np.log(10.) * err_over_plx)**2.
    
    if return_cov_components:
        cov_comp['dm'] = np.zeros_like(cov_y)
        # cov_comp['dm'][:,0,0] = (5./np.log(10.) * err_over_plx)**2.

    # print('Estimate DM ...')
    # dm = 10. - 5.*np.log10(d['parallax'])# + 5./np.log(10.)*dm_corr
    # y[:,0] -= dm

    # # Don't attempt to predict M_G for poor plx/err or when plx < 0
    # print('Filter out M_G for poor parallax measurements ...')
    # idx = (
    #       (err_over_plx > 0.2)
    #     | (d['parallax'] < 1.e-8)
    #     | ~np.isfinite(d['parallax'])
    #     | ~np.isfinite(d['parallax_err'])
    # )
    # n_use = idx.size - np.count_nonzero(idx)
    # print(r'Using {:d} of {:d} ({:.3f}%) of stellar parallaxes.'.format(
    #     n_use, idx.size, n_use/idx.size*100.
    # ))
    # cov_y[idx,0,0] = large_err**2
    # y[idx,0] = np.nanmedian(y[:,0])
    
    if return_cov_components:
        cov_comp['dm'][idx,0,0] = large_err**2
###################################################################3
    if pretrained_model is not None:
        # Update reddenings, based on vector R and (y_obs - y_pred).
        # Use provided reddenings as a prior.

        # First, need to calculate inv_cov_y
        print('Invert C_y matrices ...')
        inv_cov_y = np.stack([np.linalg.inv(c) for c in cov_y])

        # Predict M & R for each star based on atm. params
        M_pred = predict_M(pretrained_model, x_p)
        
        # Calculate chi^2 for each star
        # chisq = calc_chisq(M_pred+r[:,None]*R-y, inv_cov_y)
        chisq = calc_chisq(M_pred-y, inv_cov_y)
        print('chisq =', chisq)

        # Calculate d.o.f. of each star
        print('Calculate d.o.f. of each star ...')
        n_dof = np.zeros(d.size, dtype='i4')
        for k in range(n_bands):
            n_dof += (cov_y[:,k,k] < (large_err-1.)**2).astype('i4')
        #print('n_dof =', n_dof)

        # Calculate reduced chi^2 for each star
        print('Calculate chi^2/d.o.f. for each star ...')
        rchisq = chisq / (n_dof - 1.00000001)
        pct = (0., 1., 10., 50., 90., 99., 100.)
        rchisq_pct = np.percentile(rchisq[np.isfinite(rchisq)], pct)
        print('chi^2/dof percentiles:')
        for p,rc in zip(pct,rchisq_pct):
            print(rf'  {p:.0f}% : {rc:.3g}')
        idx_rchisq = (rchisq < 10.)
        print(f'<chi^2/d.o.f.> = {np.mean(rchisq[idx_rchisq]):.3g}')
        
        # Filter on reduced chi^2
        if rchisq_max is not None:
            print('Filter on chi^2/d.o.f. ...')
            idx = np.isfinite(rchisq) & (rchisq > 0.) & (rchisq < rchisq_max)
            n_filt = np.count_nonzero(~idx)
            pct_filt = 100. * n_filt / idx.size
            print(
                rf'Filtering {n_filt:d} stars ({pct_filt:.3g}%) ' +
                rf'based on chi^2/dof > {rchisq_max:.1f}'
            )
            x = x[idx]
            x_p = x_p[idx]
            # r = r[idx]
            y = y[idx]
            cov_y = cov_y[idx]
            # r_var = r_var[idx]
            rchisq = rchisq[idx]
            
            if return_cov_components:
                for key in cov_comp:
                    cov_comp[key] = cov_comp[key][idx]
############################################################################3
#     # Cholesky transform of inverse covariance: L L^T = C^(-1).
#     print('Cholesky transform of each stellar covariance matrix ...')
#     LT = np.empty_like(cov_y)
    # inv_cov_y = np.empty_like(cov_y)
    # for k,c in enumerate(cov_y):
    #     try:
#             # Inflate diagonal of cov slightly, to ensure
#             # positive-definiteness
#             c_diag = c[np.diag_indices_from(c)]
#             c[np.diag_indices_from(c)] += 1.e-4 + 1.e-3 * c_diag
            
#             inv_cov_y[k] = np.linalg.inv(c)
#             LT[k] = np.linalg.cholesky(inv_cov_y[k]).T
#             #ic = np.linalg.inv(c)
#             #LT.append(np.linalg.cholesky(ic).T)
#             #inv_cov_y.append(ic)
#         except np.linalg.LinAlgError as e:
#             rho = get_corr_matrix(c)
#             print('Offending correlation matrix:')
#             print(np.array2string(
#                 rho[:6,:6],
#                 formatter={'float_kind':lambda z:'{: >7.4f}'.format(z)}
#             ))
#             print('Offending covariance matrix:')
#             print(np.array2string(
#                 c[:6,:6],
#                 formatter={'float_kind':lambda z:'{: >9.6f}'.format(z)}
#             ))
#             print('Covariance matrix of (normed) atmospheric parameters:')
#             print(d['atm_param_cov_p'][k])
#             # if pretrained_model is not None:
#             #     print(f'Variance of r: {r_var[k]:.8f}')
            
#             # Inflate errors along the diagonal and try again
#             c_diag = c[np.diag_indices_from(c)]
#             c[np.diag_indices_from(c)] += 0.02 + 0.02 * c_diag
#             rho = get_corr_matrix(c)
#             print('Inflated correlation matrix:')
#             print(np.array2string(
#                 rho[:6,:6],
#                 formatter={'float_kind':lambda z:'{: >7.4f}'.format(z)}
#             ))
            
#             inv_cov_y[k] = np.linalg.inv(c)
#             LT[k] = np.linalg.cholesky(inv_cov_y[k]).T
#             #raise e

#     #print('Stack L^T matrices ...')
#     #LT = np.stack(LT)
#     #print('Stack C^(-1) matrices ...')
#     #inv_cov_y = np.stack(inv_cov_y)

#     # L^T y
#     print('Calculate L^T y ...')
#     LTy = np.einsum('nij,nj->ni', LT, y)

    print('Gather inputs and outputs and return ...')
    # inputs_outputs = {
    #     'x':x, 'x_p':x_p, 'r':r, 'y':y,
    #     'LT':LT, 'LTy':LTy,
    #     'cov_y':cov_y, 'inv_cov_y':inv_cov_y,
    # }
    inputs_outputs = {
        'x':x, 'x_p':x_p, 'y':y,
        'cov_y':cov_y,# 'inv_cov_y':inv_cov_y,
    }
    
    if return_cov_components:
        inputs_outputs['cov_comp'] = cov_comp
    
    # if pretrained_model is not None:
        # inputs_outputs['r_var'] = r_var
        inputs_outputs['rchisq'] = rchisq

    # Check that there are no NaNs or Infs in results
    for key in inputs_outputs:
        if isinstance(inputs_outputs[key], dict):
            continue
        if key == 'rchisq': # Infs appear when d.o.f. = 1
            continue
        if np.any(~np.isfinite(inputs_outputs[key])):
            raise ValueError(f'NaNs or Infs detected in {key}.')

    return inputs_outputs

In [4]:
def predict_M(nn_model, x_p):
    """
    Predicts (absmag0,color1,color2,...) for input
    normalized stellar parameters.

    Inputs:
        nn_model (keras.Model): Neural network model.
        x_p (np.ndarray): Normalized stellar parameters.
            Shape = (n_stars, 3).
    
    Outputs:
        M (np.ndarray): Shape = (n_stars, n_bands).
    """
    inputs = nn_model.get_layer(name='theta').input
    outputs = nn_model.get_layer(name='B_BPRP').output
    mag_color_model = keras.Model(inputs, outputs)
    M = mag_color_model.predict(x_p)
    return M

In [5]:
def save_predictions(fname, nn_model, d_test, io_test):
    M_pred = predict_M(nn_model, io_test['x_p'])
    
    with h5py.File(fname, 'w') as f:
        f.create_dataset('/data', data=d_test, chunks=True,
                         compression='gzip', compression_opts=3)
        f.create_dataset('/y_obs', data=io_test['y'], chunks=True,
                         compression='gzip', compression_opts=3)
        f.create_dataset('/cov_y', data=io_test['cov_y'], chunks=True,
                         compression='gzip', compression_opts=3)
        # f.create_dataset('/r_fit', data=io_test['r'], chunks=True,
        #                  compression='gzip', compression_opts=3)
        f.create_dataset('/M_pred', data=M_pred, chunks=True,
                         compression='gzip', compression_opts=3)
        # f.create_dataset('/R_pred', data=R_pred, chunks=True,
        #                 compression='gzip', compression_opts=3)
        f.attrs['R0'] = R0
        
        if 'cov_comp' in io_test:
            for key in io_test['cov_comp']:
                f.create_dataset(
                    f'/cov_comp/{key.replace(r"/","_")}',
                    data=io_test['cov_comp'][key],
                    chunks=True,
                    compression='gzip',
                    compression_opts=3
                )

In [6]:
def calc_chisq(dy, inv_cov_y):
    """
    Returns the chi^2 for each observation, given
    an array of residuals and inverse covariance matrices.
    
        chi^2 = dy^T C^{-1} dy.
    
    Inputs:
        dy (np.ndarray): Residual values. Shape = (n_obs, n_dim),
            where n_obs is the number of observations, and n_dim is
            the dimensionality of the vector space.
        inv_cov_y (np.ndarray): Inverse covariance matrices.
            Shape = (n_obs, n_dim, n_dim).
    
    Returns:
        chisq (np.ndarray): Chi^2 for each observation. Shape=(n_obs,).
    """
    C_inv_dy = np.einsum('nij,nj->ni', inv_cov_y, dy)
    chisq = np.einsum('ni,ni->n', dy, C_inv_dy)
    return chisq

##### Creates NN model

In [7]:
def get_nn_model():
    # Stellar model: B M(theta)
    atm = keras.Input(shape=(3,), name='theta')
    x = atm
    x = keras.layers.Dense(
        32,
        activation='gelu',
        kernel_regularizer=keras.regularizers.l2(l=1.e-4),
        name=f'stellar_bprp_model_hidden_0'
        )(x)
    x = keras.layers.Dense(
        64,
        activation='gelu',
        kernel_regularizer=keras.regularizers.l2(l=1.e-4),
        name=f'stellar_bprp_model_hidden_1'
        )(x)
    x = keras.layers.Dense(
        128,
        activation='gelu',
        kernel_regularizer=keras.regularizers.l2(l=1.e-4),
        name=f'stellar_bprp_model_hidden_2',
        )(x)
    spec_xp = keras.layers.Dense(110, name='B_BPRP')(x)

    # Compile model
    model = keras.Model(
        inputs=atm,
        outputs=spec_xp,
        name='stellar_BPRP_model'
    )
    model.compile(
        loss='mse',
        # loss='mae',
        optimizer='Adam',
        metrics=['mse']
        # metrics=['mae']
    )

    return model

def save_theta_norm(d_attrs, fname):
    d = {
        'theta_med': d_attrs['atm_param_med'].tolist(),
        'theta_std': d_attrs['atm_param_std'].tolist()
    }
    with open(fname, 'w') as f:
        json.dump(d, f)

In [8]:
def train_model(nn_model, io_train, k, n_iterations, epochs=100,
                checkpoint_fn='checkpoint', batch_size=32,
                suff='_'):
    checkpoint_fn = (
          'checkpoints/'
        + checkpoint_fn
        + '.e{epoch:03d}_vl{val_loss:.3f}.h5'
    )
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_fn,
            save_best_only=True,
            monitor='val_loss',
            verbose=1
        )#, PlotLearning()
    ]
    inputs = io_train['x_p']
    outputs = io_train['y']
    # print(inputs.shape)
    # print(outputs.shape)
    nn_model.fit(
        inputs, outputs,
        epochs=epochs,
        validation_split=0.25/0.9,
        callbacks=callbacks,
        batch_size=batch_size,
        verbose=True
    )
    plt.title('Loss: Iteration {} of {}.'.format(k+1, n_iterations))
    plt.plot(range(1,epochs+1),nn_model.history.history['loss'],label='loss')
    plt.plot(range(1,epochs+1),nn_model.history.history['val_loss'],label='val_loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid()
    plt.savefig('/arc/home/aydanmckay/networkplots/train_val_loss'+suff+'.svg', dpi=150)
    plt.close()

In [9]:
def evaluate_model(nn_model, io_eval, batch_size=32, rchisq_max=None):
    """
    Runs the model on the given inputs and outputs, and returns the
    MSE and loss.
    
    Inputs:
        nn_model (keras.Model): The neural network model.
        io_eval (dict): A dictionary containing, among other things,
            x_p, r, LT and LTy. If rchisq_max is provided, then the
            dictionary must also contain rchisq.
        batch_size (int): Defaults to 32.
        rchisq_max (float): Stars with greater than this reduced chi^2
            will not be included in the calculation. Defaults to None.
    
    Returns:
        A list containing the MSE and loss.
    """
    inputs = io_eval['x_p']
    outputs = io_eval['y']
    
    if rchisq_max is not None:
        idx = (io_eval['rchisq'] < rchisq_max)
        inputs = inputs[idx]
        outputs = outputs[idx]
    
    loss = nn_model.evaluate(
        inputs,
        outputs,
        batch_size=batch_size,
        verbose=0
    )
    
    loss = [float(x) for x in loss] # Make JSON serializable
    return loss

In [10]:
def diagnostic_plots(nn_model, io_test, d_test, suffix=None):
    if suffix is None:
        suff = ''
    else:
        suff = '_' + suffix
    
    inputs = [
        nn_model.get_layer(name='theta').input,
        nn_model.get_layer(name='E').input
    ]
    outputs = nn_model.get_layer(name='B_M_plus_A').output
    absmag_model = keras.Model(inputs, outputs)

    # Predict y for the test dataset
    test_pred = {
        'y': absmag_model.predict([
            io_test['x_p']
        ]),
    }
    test_pred['y_resid'] = io_test['y'] - test_pred['y']

    # Read out colors, magnitudes
    g = io_test['y'][:,3] + io_test['y'][:,0]
    ri = io_test['y'][:,4] - io_test['y'][:,5]
    gr = io_test['y'][:,3] - io_test['y'][:,4]
    g_pred = test_pred['y'][:,3] + test_pred['y'][:,0]
    ri_pred = test_pred['y'][:,4] - test_pred['y'][:,5]
    gr_pred = test_pred['y'][:,3] - test_pred['y'][:,4]

    gaia_g = io_test['y'][:,0]
    bp_rp = io_test['y'][:,1] - io_test['y'][:,2]
    gaia_g_pred = test_pred['y'][:,0]
    bp_rp_pred = test_pred['y'][:,1] - test_pred['y'][:,2]

    print('g =', g)
    print('ri =', ri)
    print('gr =', gr)
    print('gaia_g =', gaia_g)
    print('bp_rp =', bp_rp)

    # Plot HRD
    params = {
        'density': (None, r'$N$', (None, None)),
        'teff': (d_test['atm_param'][:,0], r'$T_{\mathrm{eff}}$', (4000., 8000.)),
        'logg': (d_test['atm_param'][:,1], r'$\log \left( g \right)$', (0., 5.)),
        'mh': (d_test['atm_param'][:,2], r'$\left[ \mathrm{M} / \mathrm{H} \right]$', (-2.5, 0.5))
    }

    plot_spec = [
        {
            'colors': [(1,2), (4,5)],
            'mag': 0
        },
        {
            'colors': [(3,4), (4,5)],
            'mag': 0
        }
    ]

    idx_goodobs = np.isfinite(d_test['mag_err'])
    idx_goodobs &= (np.abs(io_test['cov_y'][:,0,0]) < 90.)[:,None]
    idx_goodobs = idx_goodobs.T

    def scatter_or_hexbin(ax, x, y, c, vmin, vmax, extent):
        if p == 'density':
            im = ax.hexbin(
                x, y,
                extent=extent,
                bins='log',
                rasterized=True
            )
        else:
            im = ax.scatter(
                x,
                y,
                c=c,
                edgecolors='none',
                alpha=0.1,
                vmin=vmin,
                vmax=vmax,
                rasterized=True
            )
        return im

    def get_lim(*args, **kwargs):
        expand = kwargs.get('expand', 0.4)
        expand_low = kwargs.get('expand_low', expand)
        expand_high = kwargs.get('expand_high', expand)
        pct = kwargs.get('pct', 1.)
        lim = [np.inf, -np.inf]
        for a in args:
            a0,a1 = np.nanpercentile(a, [pct, 100.-pct])
            lim[0] = min(a0, lim[0])
            lim[1] = max(a1, lim[1])
        w = lim[1] - lim[0]
        lim[0] -= expand_low * w
        lim[1] += expand_high * w
        return lim

    labels = ['G', 'BP', 'RP', 'g', 'r', 'i', 'z', 'y']

    for ps in plot_spec:
        mag_label = r'$M_{{ {} }}$'.format(labels[ps['mag']])
        mag_obs = io_test['y'][:,ps['mag']]
        mag_pred = test_pred['y'][:,ps['mag']]
        print('mag_pred:',mag_pred)

        if ps['mag'] != 0:
            mag_obs += io_test['y'][:,0]
            mag_pred += io_test['y'][:,0]

        color_labels = []
        colors_obs = []
        colors_pred = []
        idx_colors_obs = []
        E_vec = []
        for i1,i2 in ps['colors']:
            color_labels.append(r'${} - {}$'.format(labels[i1], labels[i2]))
            colors_obs.append(io_test['y'][:,i1] - io_test['y'][:,i2])
            colors_pred.append(test_pred['y'][:,i1] - test_pred['y'][:,i2])
            idx_colors_obs.append(idx_goodobs[i1] & idx_goodobs[i2])

        mag_lim = get_lim(
            mag_obs[idx_goodobs[ps['mag']]],
            pct=2.
        )[::-1]
        color_lim = [
            get_lim(c[idx_colors_obs[k]], expand_low=0.5, expand_high=0.4)
            for k,c in enumerate(colors_obs)
        ]
        
        for p in params.keys():
            c, label, (vmin,vmax) = params[p]
            
            fig = plt.figure(figsize=(14,4.5), dpi=150)
            fig.patch.set_alpha(0.)
            gs = GridSpec(
                1,4,
                width_ratios=[1,1,1,0.1],
                left=0.07, right=0.93,
                bottom=0.10, top=0.92
            )
            ax_obs = fig.add_subplot(gs[0,0], facecolor='gray')
            ax_pred = fig.add_subplot(gs[0,1], facecolor='gray')
            cax = fig.add_subplot(gs[0,3], facecolor='w')

            extent = color_lim[0] + mag_lim

            idx = (
                  idx_goodobs[ps['mag']]
                & idx_goodobs[ps['colors'][0][0]]
                & idx_goodobs[ps['colors'][0][1]]
            )
            im = scatter_or_hexbin(
                ax_obs,
                colors_obs[0][idx],
                mag_obs[idx],
                c if c is None else c[idx],
                vmin, vmax,
                extent
                #(-0.3,1.0,11.5,-2.0)
            )

            ax_obs.set_xlim(color_lim[0])
            ax_obs.set_ylim(mag_lim)
            ax_obs.set_xlabel(color_labels[0])
            ax_obs.set_ylabel(mag_label)
            ax_obs.grid('on', alpha=0.3)
            ax_obs.set_title(r'$\mathrm{Observed}$')

            im = scatter_or_hexbin(
                ax_pred,
                colors_pred[0],
                mag_pred,
                c,
                vmin, vmax,
                extent
            )

            ax_pred.set_xlim(color_lim[0])
            ax_pred.set_ylim(mag_lim)
            ax_pred.set_xlabel(color_labels[0])
            ax_pred.grid('on', alpha=0.3)
            ax_pred.set_title(r'$\mathrm{Predicted}$')

            cb = fig.colorbar(im, cax=cax)
            cb.set_label(label)
            cb.set_alpha(1.)
            cb.draw_all()

            cm_desc = '{}_vs_{}{}'.format(
                labels[ps['mag']],
                labels[ps['colors'][0][0]],
                labels[ps['colors'][0][1]]
            )
            fig.savefig(
                '/arc/home/aydanmckay/networkplots/nn_predictions_'+cm_desc+'_'+p+suff+'.svg',
                dpi=150,
                facecolor=fig.get_facecolor(),
                edgecolor='none'
            )
            plt.close(fig)

            # Color-color diagrams
            fig = plt.figure(figsize=(14,4.5), dpi=150)
            fig.patch.set_alpha(0.)
            gs = GridSpec(
                1,4,
                width_ratios=[1,1,1,0.1],
                left=0.07, right=0.93,
                bottom=0.10, top=0.92
            )
            ax_obs = fig.add_subplot(gs[0,0], facecolor='gray')
            ax_pred = fig.add_subplot(gs[0,1], facecolor='gray')
            cax = fig.add_subplot(gs[0,3], facecolor='w')

            extent = color_lim[0] + color_lim[1]

            idx = (
                  idx_goodobs[ps['colors'][0][0]]
                & idx_goodobs[ps['colors'][0][1]]
                & idx_goodobs[ps['colors'][1][0]]
                & idx_goodobs[ps['colors'][1][1]]
            )

            im = scatter_or_hexbin(
                ax_obs,
                colors_obs[0][idx],
                colors_obs[1][idx],
                c if c is None else c[idx],
                vmin, vmax,
                extent
                #(-0.2,1.5,-0.15,0.8)
            )
            ax_obs.set_xlim(color_lim[0])
            ax_obs.set_ylim(color_lim[1])
            ax_obs.set_xlabel(color_labels[0], fontsize=14)
            ax_obs.set_ylabel(color_labels[1], fontsize=14)
            ax_obs.grid('on', alpha=0.3)
            ax_obs.set_title(r'$\mathrm{Observed}$')

            im = scatter_or_hexbin(
                ax_pred,
                colors_pred[0][idx],
                colors_pred[1][idx],
                c if c is None else c[idx],
                #c,
                vmin, vmax,
                extent
            )
            ax_pred.set_xlim(color_lim[0])
            ax_pred.set_ylim(color_lim[1])
            ax_pred.set_xlabel(color_labels[0], fontsize=14)
            ax_pred.grid('on', alpha=0.3)
            ax_pred.set_title(r'$\mathrm{Predicted}$')

            cb = fig.colorbar(im, cax=cax)
            cb.set_label(label, fontsize=14)
            cb.set_alpha(1.)
            cb.draw_all()

            cc_desc = '{}{}_vs_{}{}'.format(
                labels[ps['colors'][0][0]],
                labels[ps['colors'][0][1]],
                labels[ps['colors'][1][0]],
                labels[ps['colors'][1][1]]
            )
            fig.savefig(
                '/arc/home/aydanmckay/networkplots/test_'+cc_desc+'_'+p+suff+'.svg',
                dpi=150,
                facecolor=fig.get_facecolor(),
                edgecolor='none'
            )
            plt.close(fig)
    
    # # Plot histograms of residuals
    # dr = (io_test['r'] - d_test['r'])/np.hypot(np.nanstd(d_test['r']),.01)
    # # dmag = (io_test['LTy'] - d_test['mag'])
    # # dm1,dm2,dm3,dm4,dm5,dm6,dm7,dm8,dm9,dm10,dm11,dm12,dm13 = dmag.T
    # names = ['G','(BP-G)','(RP-G)','(g-G)','(r-G)','(i-G)','(z-G)','(y-G)','(J-G)','(H-G)','(K_s-G)','(W_1-G)','(W_2-G)']
    # # ds = [dm1,dm2,dm3,dm4,dm5,dm6,dm7,dm8,dm9,dm10,dm11,dm12,dm13]
    # fig = plt.figure(figsize=(12,18))
    # ax = fig.add_subplot(5,3,1)
    # dr_mean = np.nanmean(dr)
    # dr_std = np.nanstd(dr)
    # ax.hist(dr, bins=50)
    # dr_skew = scipy.stats.moment(dr, moment=3, nan_policy='omit')
    # dr_txt = r'$\Delta E = {:+.3f} \pm {:.3f}$'.format(dr_mean, dr_std)
    # dr_skew /= (dr_std**1.5 + 1.e-5)
    # dr_txt += '\n' + r'$\tilde{{\mu}}_3 = {:+.3f}$'.format(dr_skew)
    # ax.text(0.05, 0.95, dr_txt, ha='left', va='top', transform=ax.transAxes)
    # ax.set_xlabel(r'$\Delta E \ \left( \mathrm{estimated} - \mathrm{Bayestar19} \right)$',fontsize=10)
    # for it,(io,dm,name) in enumerate(zip(io_test['LTy'].T,d_test['mag'].T,names)):
    #     dd = (io - dm)/np.hypot(np.nanstd(dm),.01)
    #     ax = fig.add_subplot(5,3,it+2)
    #     dd_mean = np.nanmean(dd)
    #     dd_std = np.nanstd(dd)
    #     ax.hist(dd, bins=50)
    #     dd_skew = scipy.stats.moment(dd, moment=3, nan_policy='omit')
    #     dd_txt = r'$\Delta '+name+r' = {:+.3f} \pm {:.3f}$'.format(dd_mean, dd_std)
    #     dd_skew /= (dd_std**1.5 + 1.e-5)
    #     dd_txt += '\n' + r'$\tilde{{\mu}}_3 = {:+.3f}$'.format(dd_skew)
    #     ax.text(0.05, 0.95, dd_txt, ha='left', va='top', transform=ax.transAxes)
    #     ax.set_xlabel(r'$\Delta '+name+r'\ \left( \mathrm{estimated} - \mathrm{observed} \right)$',fontsize=10)
    # fig.savefig('/arc/home/aydanmckay/networkplots/test_z-score_dE'+suff+'.svg', dpi=150)
    # plt.close(fig)

In [11]:
def calc_dmag_color_dtheta(nn_model, x_p):
    m = keras.Model(
        inputs=nn_model.get_layer(name='theta').input,
        outputs=nn_model.get_layer(name='B_BPRP').output
    )
    with tf.GradientTape() as g:
        x_p = tf.constant(x_p)
        g.watch(x_p)
        mag_color = m(x_p)
    J = g.batch_jacobian(mag_color, x_p).numpy()
    return J

##### start of main()

In [12]:
# Load/create neural network
nn_name = 'xp_l2_mse_lr02'
nn_model = get_nn_model()
nn_model.summary()

2022-07-05 16:36:17.597121: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Model: "stellar_BPRP_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 theta (InputLayer)          [(None, 3)]               0         
                                                                 
 stellar_bprp_model_hidden_0  (None, 32)               128       
  (Dense)                                                        
                                                                 
 stellar_bprp_model_hidden_1  (None, 64)               2112      
  (Dense)                                                        
                                                                 
 stellar_bprp_model_hidden_2  (None, 128)              8320      
  (Dense)                                                        
                                                                 
 B_BPRP (Dense)              (None, 110)               14190     
                                                

In [13]:
# Load stellar data
print('Loading data ...')
fname = '/arc/home/aydanmckay/ml/network/datasets/dr3datasmall.h5'

print(f'Loading {fname} ...')
attrs = {}
with h5py.File(fname, 'r') as f:
    dset = f['io_data']
    for key in dset.attrs.keys():
        attrs[key] = dset.attrs[key]
    d = dset[:]
save_theta_norm(attrs, 'thetanorms/xp_theta_normalization.json')

Loading data ...
Loading /arc/home/aydanmckay/ml/network/datasets/dr3datasmall.h5 ...


In [14]:
# (training+validation) / test split
# Fix random seed (same split every run)
np.random.seed(42)
split = train_test_split(d, test_size=.1)
d_train = np.array(split[::2])
d_test = np.array(split[1::2])
np.random.shuffle(d_train)
print(f'{d_train.size: >10d} training/validation stars.')
print(f'{d_test.size: >10d} test stars.')

     90000 training/validation stars.
     10000 test stars.


In [15]:
# Iteratively update dM/dtheta contribution to uncertainties,
# reddening estimates and reduced chi^2 cut, and retrain.
n_iterations = 20

# On GPU, use large batch sizes for memory transfer efficiency
# batch_size = 1024
batch_size = 256

rchisq_max_init = 100.
rchisq_max_final = 5.
rchisq_max = np.exp(np.linspace(
    np.log(rchisq_max_init),
    np.log(rchisq_max_final),
    n_iterations-6
))
rchisq_max = [None] + rchisq_max.tolist() + 5*[rchisq_max_final]
print('chi^2/dof = {}'.format(rchisq_max))

chi^2/dof = [None, 100.00000000000004, 79.41833348134496, 63.07271692954115, 50.09130066684769, 39.78167620874025, 31.593944275926187, 25.091384024965357, 19.927159040031896, 15.825817619770502, 12.56860061341878, 9.981773149103292, 7.927357886906197, 6.295775522882865, 4.999999999999999, 5.0, 5.0, 5.0, 5.0, 5.0]


In [16]:
for k in range(0, n_iterations):
    
    '''Transform data to inputs and outputs
    On subsequent iterations, inflate errors using
    gradients dM/dtheta from trained model, and derive new
    estimates of the reddenings of the stars.'''
    t0 = time()
    io_train = get_inputs_outputs(
        d_train,
        pretrained_model=None if k == 0 else nn_model,
        recalc_reddening=True,
        rchisq_max=rchisq_max[k]
    )                                
    io_test = get_inputs_outputs(
        d_test,
        pretrained_model=None if k == 0 else nn_model,
        recalc_reddening=True,
    )                                                                        
    t1 = time()
    print(f'Time elapsed to prepare data: {t1-t0:.2f} s')
    # Set learning rate based on the iteration
    lr = 0.01 * np.exp(-0.2*k)
    # lr = 0.01*np.exp(-0.02*k)
    print('learning rate = {}'.format(K.get_value(nn_model.optimizer.lr)))
    print('setting learning rate to {}'.format(lr))
    K.set_value(nn_model.optimizer.lr, lr)

    n_hidden = 4
    # Train the model
    print('Iteration {} of {}.'.format(k+1, n_iterations))
    t0 = time()
    train_model(
        nn_model,
        io_train,
        k,
        n_iterations,
        epochs=25,
        checkpoint_fn='{:s}_{:d}hidden_it{:d}'.format(
            nn_name, n_hidden, k
        ),
        batch_size=batch_size,
        suff='{:s}_{:d}hidden_it{:d}'.format(nn_name, n_hidden, k)
    )
    t1 = time()
    print(f'Time elapsed to train: {t1-t0:.2f} s')
    nn_model.save(
        'models/{:s}_{:d}hidden_it{:d}.h5'.format(
            nn_name, n_hidden, k
        )
    )
    nn_model = keras.models.load_model(
       'models/{:s}_{:d}hidden_it{:d}.h5'.format(nn_name, n_hidden, k)
    )

    # # Plot results on test set
    # print('Diagnostic plots ...')
    # t0 = time()
    # diagnostic_plots(
    #    nn_model,
    #    io_test,
    #    d_test,
    #    #io_train,
    #    #d_train,
    #    suffix='{:s}_{:d}hidden_it{:d}'.format(nn_name, n_hidden, k)
    # )
    # t1 = time()
    # print(f'Time elapsed to make plots: {t1-t0:.2f} s')

Fill in stellar atmospheric parameters ...
(90000, 3)
Fill in stellar magnitudes ...
Empty covariance matrix ...
Covariance: \delta m ...
Replace NaN magnitudes ...
Band 0: 0 of 90000 bad. Replacing with 2612.85376.
Band 1: 0 of 90000 bad. Replacing with -188.65933.
Band 2: 0 of 90000 bad. Replacing with -18.31512.
Band 3: 0 of 90000 bad. Replacing with 6.18737.
Band 4: 0 of 90000 bad. Replacing with -10.64803.
Band 5: 0 of 90000 bad. Replacing with -2.49388.
Band 6: 0 of 90000 bad. Replacing with -4.17958.
Band 7: 0 of 90000 bad. Replacing with 1.43300.
Band 8: 0 of 90000 bad. Replacing with -2.10850.
Band 9: 0 of 90000 bad. Replacing with 0.57402.
Band 10: 0 of 90000 bad. Replacing with 0.63645.
Band 11: 0 of 90000 bad. Replacing with 0.40025.
Band 12: 0 of 90000 bad. Replacing with -0.54446.
Band 13: 0 of 90000 bad. Replacing with -0.95632.
Band 14: 0 of 90000 bad. Replacing with 0.17720.
Band 15: 0 of 90000 bad. Replacing with -0.45189.
Band 16: 0 of 90000 bad. Replacing with 0.293

In [17]:
# print('Updating covariances and reddening estimates of test dataset ...')
# t0 = time()
# io_test = get_inputs_outputs(
#     d_test,
#     pretrained_model=nn_model,
#     recalc_reddening=True
# )
# t1 = time()
# print(f'Time elapsed to update covariances and reddenings: {t1-t0:.2f} s')

In [20]:
# Evaluate performance on (train, validation and test sets)
loss = {}
for n,io_eval in (('test',io_test), ('train',io_train)):
    loss[n] = evaluate_model(
        nn_model,
        io_eval,
        batch_size=batch_size,
        rchisq_max=rchisq_max[-1]
    )
    print(f'{n} loss: {loss[n]}')
fname = 'loss/loss_{:s}_{:d}hidden_it{:d}.json'.format(
    nn_name, n_hidden, n_iterations-1
)
with open(fname, 'w') as f:
    json.dump(loss, f, indent=2, sort_keys=True)

# n_hidden = 2
fname = 'predictions/predictions_{:s}_{:d}hidden_it{:d}.h5'.format(
    nn_name, n_hidden, n_iterations-1
)
save_predictions(fname, nn_model, d_test, io_test)

KeyError: 'rchisq'

In [21]:
print('Saving covariance components for small subset of test dataset ...')
# Fix random seed (same subset every run)
np.random.seed(3)
idx = np.arange(d_test.size)
np.random.shuffle(idx)
idx = idx[:1000]
d_comp = d_test[idx]
io_comp = get_inputs_outputs(
    d_comp,
    pretrained_model=nn_model,
    recalc_reddening=True,
    return_cov_components=True
)
fname = 'predictions/predictions_{:s}_{:d}hidden_it{:d}_comp.h5'.format(
    nn_name, n_hidden, n_iterations-1
)
save_predictions(fname, nn_model, d_comp, io_comp)

Saving covariance components for small subset of test dataset ...


IndexError: index 5876 is out of bounds for axis 0 with size 1

In [22]:
print('Saving data and reddening estimates of subset of test dataset ...')
np.random.seed(5)
idx = np.arange(d_test.size)
np.random.shuffle(idx)
idx = idx[:10000]
d_small = d_test[idx]
r_fit_small = io_test['r'][idx]
r_var_small = io_test['r_var'][idx]
fname = 'test_data_small/test_data_small_{:s}_{:d}hidden.h5'.format(
    nn_name, n_hidden
)

Saving data and reddening estimates of subset of test dataset ...


IndexError: index 7054 is out of bounds for axis 0 with size 1

In [23]:
print(f'Saving subset to {fname} ...')
with h5py.File(fname, 'w') as f:
    dset = f.create_dataset(
        'data',
        data=d_small,
        chunks=True,
        compression='gzip',
        compression_opts=3
    )
    for key in d_attrs:
        dset.attrs[key] = d_attrs[key]

    # # Store updated reddening estimates
    # dset = f.create_dataset(
    #     'r_fit',
    #     data=r_fit_small,
    #     chunks=True,
    #     compression='gzip',
    #     compression_opts=3
    # )
    # dset = f.create_dataset(
    #     'r_var',
    #     data=r_var_small,
    #     chunks=True,
    #     compression='gzip',
    #     compression_opts=3
    # )

Saving subset to /arc/home/aydanmckay/ml/network/datasets/dr3datasmall.h5 ...


NameError: name 'd_small' is not defined