# load essential and useful packages

# load data, inspect and clean

In [None]:
# the following indices grab our model parameters from `Walkers`
inds = [1,2,3,4,5,6,10,11,14]

# define the meaning for the features, i.e. your model parameters
parameters = np.array([r"$\log_{10} f_{*,10}$",
                       r"$\alpha_*$",
                       r"$\log_{10} f_{\rm esc, 10}$",
                       r"$\alpha_{\rm esc}$",
                       r"$\log_{10}[M_{\rm turn}/{\rm M}_{\odot}]$",
                       r"$t_*$",
                       r"$\log_{10}\frac{L_{\rm X<2keV}/{\rm SFR}}{{\rm erg\ s^{-1}\ M_{\odot}^{-1}\ yr}}$",
                       r"$E_0/{\rm keV}$",
                       r"$\alpha_{\rm X}$"])
                       
# and their limits
limits = np.array([[-3,0], [-0.5,1], [-3,0],[-1,0.5], [8,10], [0.01,1], [38,42], [0.1,1.5], [-1,3]])

## load the posterior

Note again, the entire ~0.5M database was built through a Bayesian inference run. Therefore we do have a posterior to compare with after having an emulator

MultiNest's output contains normalized parameters by mapping points in the range of `limits` to [0,1]. The last column gives lnL

### let's rescaled them back and visualize the distribution


Explore: We will be training using the ~450k points shown before that more or less follows the posterior. However, does it matter to train the network using points from a posterior or from a uniform distribution like what people normally do when they do not have a posterior?

## let's just visualize what the prediction looks like using the MAP model

points in AllFile3.h5 are in order, i.e. the nth row in `Walkers` corresponds to the nth row in other datasets

### global signals

In [None]:
with h5py.File('AllFile3.h5','r') as f:
    zs = f['AveDatas'][ML_index,:,0] # redshift for global signals; this is the same for all models by construction
    xH = f['AveDatas'][ML_index,:,1] # neutral fraction
    Tb = f['AveDatas'][ML_index,:,2] # 21cm brightness temperature
    tau_e = f['TauDatas'][ML_index]  # CMB optical depth
    
fig, (axxH, axTb) = plt.subplots(2,1, figsize=(15,8), sharex=True)

# model
...

# observations
# Dark Pixels
axxH.errorbar([5.6,6.07], [0.04,0.38], yerr=[[0,0],[0.05,0.20]], fmt='o',color='k', label='McGreer+15',mfc='white',capsize=5, elinewidth=2, markeredgewidth=1,alpha=1)
axxH.errorbar([5.6,6.07], [0.04,0.38], yerr=[[0.03,0.1],[0,0]], mfc='white',uplims=True, fmt=' ',color='k',alpha=1)
axxH.errorbar([5.9], [0.06], yerr=[[0],[0.05]],mfc='white', fmt='o',color='k',capsize=5, elinewidth=2, markeredgewidth=1,alpha=1)
axxH.errorbar([5.9], [0.06], yerr=[[0.03],[0]], uplims=True, fmt=' ',color='k',alpha=1)
axxH.errorbar([5.61,5.8,5.99,6.21,6.35], [0.42, 0.53,0.67,0.53,0.69], yerr=[[0,0,0,0,0],[0.05,0.07,0.07,0.11,0.15]], markersize=10,fmt='o',color='k', label='Campo+in prep.',capsize=5, elinewidth=2, markeredgewidth=2,alpha=1)
axxH.errorbar([5.61,5.8,5.99,6.21,6.35], [0.42, 0.53,0.67,0.53,0.69], yerr=[[0.1,0.1,0.1,0.1,0.1],[0,0,0,0,0]], markersize=20,uplims=True, fmt=' ',color='k',alpha=1)

# QSO damping
axxH.errorbar([7.0], [0.70], yerr=[[0.23],[0.20]], fmt='p',color='b', label='Wang+20',markersize=8,capsize=5, elinewidth=2,mfc='white', markeredgewidth=2,alpha=1)
axxH.errorbar([7.5413], [0.56], yerr=[[0.18],[0.21]], fmt='h',color='b', label='Bañados+18',markersize=8,capsize=5, elinewidth=2,mfc='white', markeredgewidth=2,alpha=1)
axxH.errorbar([7.0851,7.5413], [0.40,0.21], yerr=[[0.19,0.19],[0.21,0.17]], fmt='*',color='b', label='Greig+17/19',markersize=10,capsize=5, elinewidth=2, mfc='white',markeredgewidth=2,alpha=1)
axxH.errorbar([7.0851,7.5413], [0.48,0.60], yerr=[[0.26,0.23],[0.26,0.20]], fmt='s',color='b', label='Davies+18',markersize=8,capsize=5, elinewidth=2,mfc='white', markeredgewidth=2,alpha=1)
axxH.errorbar([7.29], [0.49], yerr=[0.11], xerr=[0.20], fmt='8',color='b', label='Greig+in prep.',markersize=20,capsize=5, elinewidth=2, markeredgewidth=2,alpha=1)

# LAE fraction
## LF
axxH.errorbar([6.9], [0.33], yerr=[[0.1],[0.]], uplims=True, fmt='<',color='purple', label='Wold+21',markersize=8,capsize=5, elinewidth=2,  mfc='white',markeredgewidth=2,alpha=1)
axxH.errorbar([7.3], [0.5], yerr=[[0.3],[0.1]], fmt='>',color='purple', label='Inoue+18',markersize=8,capsize=5, elinewidth=2,  mfc='white',markeredgewidth=2,alpha=1)
#axxH.errorbar([5.7,6.6,7.0], [0.4,0.4,0.4], yerr=[[0.1,0.1,0.1],[0,0,0]], uplims=True, fmt='s',color='black',alpha=0.6)
axxH.errorbar([6.6,7.0,7.3], [0.08,0.28,0.83], yerr=[[0.05,0.05,0.07],[0.08,0.05,0.06]], fmt='^',color='purple', label='Morales+21',markersize=8,capsize=5, elinewidth=2,  mfc='white',markeredgewidth=2,alpha=1)
## clustering
axxH.errorbar([6.6], [0.15], yerr=[0.15], fmt='v',color='purple', label='Ouchi+18',markersize=10, capsize=5,mfc='white', elinewidth=2, markeredgewidth=2,alpha=1)

## EW
axxH.errorbar([7.0], [0.55], yerr=[[0.13],[0.11]], fmt='v',color='g', label='Whitler+20',markersize=10,capsize=5, elinewidth=2, mfc='white',markeredgewidth=2,alpha=1) #EW
axxH.errorbar([7.9], [0.76], xerr=[0.6], yerr=[[0.],[0.1]], lolims=True,fmt='<',label='Mason+19',color='g', markersize=10,capsize=5, elinewidth=2, mfc='white',markeredgewidth=2,alpha=1) #EW
axxH.errorbar([7.6], [0.88], yerr=[[0.1],[0.05]], xerr=[0.6],fmt='>',color='g', label='Hoag+19',markersize=10, capsize=5,mfc='white', elinewidth=2, markeredgewidth=2,alpha=1) #EW
axxH.errorbar([7.6], [0.36], yerr=[[0.14],[0.10]], fmt='^',color='g', label='Jung+21',markersize=10, capsize=5,mfc='white', elinewidth=2, markeredgewidth=2,alpha=1)

# CMB optical depth
axxH.text(0.98, 0.95, r'${\rm Planck18:} \tau_e = 0.0569^{+0.0081}_{-0.0066}$'+'\n'+
          'MAP: %.4f'%tau_e,ha='right',va='top',fontsize = 15,transform=axxH.transAxes) 


## EDGES
axTb.axvspan(15,20, alpha=0.3, color='k')
axTb.text(15, 15, 'EDGES (Bowman+18)',ha='left',va='top',fontsize = 15) 

axxH.legend(loc='lower right',fontsize=12,ncol=3,frameon=False)
axxH.set_ylabel(r"$\bar{x}_{\rm HI}$",fontsize=15)
axTb.set_ylabel(r"$\bar{T}_{21}/{\rm mK}$",fontsize=15)
axTb.set_xlabel('redshift', fontsize=15)
axTb.set_xlim(5,30)
axTb.set_ylim(-110,20)

### 21cm power spectra

In [None]:
with h5py.File('AllFile3.h5','r') as f:
    ks = f['TotalPSDatas'][ML_index,:,0] # wavenumber for 21-cm power spectra, this is the same for all models by construction 
    ps = f['TotalPSDatas'][ML_index,:,1:]
    
# define the redshifts to plot
snapshots = [61,56,51,47,39,32,25,19,13,8, 4,0]

# define the axes limits
...

fig, axsPS = plt.subplots(3,4, figsize= (15,9), sharex=True, sharey=True)
axsPS = axsPS.flatten()

# the MAP model
...

# observations
PS_limit_ks_z = np.fromfile('HERA_Phase1_Limits/PS_limit_ks_z8.bin') 
PS_limit_vals_z = np.fromfile('HERA_Phase1_Limits/PS_limit_vals_z8.bin')
PS_limit_vars_z = np.fromfile('HERA_Phase1_Limits/PS_limit_vars_z8.bin')
axsPS[8].errorbar(PS_limit_ks_z, PS_limit_vals_z, c='black', ls='', marker='s', ms=10, yerr=[np.zeros_like(PS_limit_vars_z), PS_limit_vars_z**0.5], alpha=0.7)
axsPS[8].errorbar(PS_limit_ks_z, PS_limit_vals_z, c='black', ls='', marker='s', ms=10, yerr=[PS_limit_vals_z*0.5, np.zeros_like(PS_limit_vars_z)], alpha=0.7, uplims=True)
PS_limit_ks_z = np.fromfile('HERA_Phase1_Limits/PS_limit_ks_z10.bin')
PS_limit_vals_z = np.fromfile('HERA_Phase1_Limits/PS_limit_vals_z10.bin')
PS_limit_vars_z = np.fromfile('HERA_Phase1_Limits/PS_limit_vars_z10.bin')
axsPS[6].errorbar(PS_limit_ks_z, PS_limit_vals_z, c='black', ls='', marker='s', ms=10, yerr=[np.zeros_like(PS_limit_vars_z), PS_limit_vars_z**0.5], alpha=0.7)
axsPS[6].errorbar(PS_limit_ks_z, PS_limit_vals_z, c='black', ls='', marker='s', ms=10, yerr=[PS_limit_vals_z*0.5, np.zeros_like(PS_limit_vars_z)], alpha=0.7, uplims=True)

# cosmetics
for ii in range((len(snapshots))):
    axsPS[ii].text(0.01, 0.99, r'$z=%.1f$'%zs[::-1][snapshots[ii]],horizontalalignment='left',verticalalignment='top',
                    transform=axsPS[ii].transAxes,fontsize = 15) 
    axsPS[ii].axvspan(xlim[0], 0.1, color='#e5c494',hatch='x', alpha=1)
    axsPS[ii].axvspan(1, xlim[1], color='#e5c494',hatch='x', alpha=1)

    axsPS[ii].grid(False)
    axsPS[ii].set_xscale('log')
    axsPS[ii].set_yscale('log')
    axsPS[ii].set_xlim(xlim)
    axsPS[ii].set_ylim(ylim)

plt.tight_layout()
fig.subplots_adjust(hspace=0.05,wspace=0.05)

fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
plt.xticks([])
plt.yticks([])
plt.grid(False)

plt.xlabel('\n\n'+r'$k[{\rm Mpc}^{-1}]$', fontsize=15)
plt.ylabel(r'$\Delta_{21}^2[{\rm mK}^{2}]$'+'\n\n', fontsize=15)

### galaxy luminosity functions

In [None]:
# predefined LF redshifts and UV magnitudes
LF_redshifts = [6,7,8,10,12,15]
Muvs = np.linspace(-30,-5,100)
fig, axsLF = plt.subplots(2,3, figsize=(15,8), sharex=True, sharey=True)
axsLF = axsLF.flatten()

with h5py.File('AllFile3.h5','r') as f:
    for qq, redshift in enumerate(LF_redshifts):
        
        ...
        # varing models output the number density at different UV magnitudes.
        # to eliminate having UV magnitudes also as an output, 
        # we use interpolation to force all models to output number density at the same magnitudes, i.e., `Muvs`.
        ...
        
        axsLF[qq].plot(...,..., color='r',alpha=0.7, lw=5)
    
        fLF = 'LFs/LF_obs_Bouwens_%.6f.txt'%redshift
        if os.path.exists(fLF):
            datainput = np.loadtxt(fLF)
            axsLF[qq].errorbar(datainput[:,0], (datainput[:,1]),yerr=datainput[:,2], fmt='s',color='black',zorder=2)
        axsLF[qq].text(0.95,0.98, "z=%d"%redshift,horizontalalignment='right',\
                      verticalalignment='top',transform=axsLF[qq].transAxes,fontsize=15)

        axsLF[qq].set_xlim(-8,-22)
        axsLF[qq].set_ylim(2e-5,10)
        axsLF[qq].set_yscale('log')
        axsLF[qq].axvspan(-20, -22, color='#e5c494',hatch='x', alpha=1)
        
        
plt.tight_layout()
fig.subplots_adjust(hspace=0.05,wspace=0.05)

fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
plt.xticks([])
plt.yticks([])
plt.grid(False)

plt.xlabel('\n\n'+r'$M_{\rm 1500}$', fontsize=15)
plt.ylabel(r'$\phi$'+'\n\n', fontsize=15)

# Having the database inspected and understood, we now begin to train the network

## a few more data preparation

we have defined the features of our network. Let's normalize from having a range of `limits` to [0,1] to elimate having different dynamic ranges.

In [None]:
features = ( features - limits[:,0] ) / (limits[:,1] - limits[:,0])

pick the outputs that we would like to emulate with the network. 

For the purpose of using emulator for quick Bayesian inference, we emulate points that will go into the likelihoods. 

These are xHI at z=5.9 (vs McGreer+15, 1D); 

CMB optical depth (vs Planck+18, 1D); 

21-cm PS at z=8 (19D) and z=10 (18D); and 

UV LFs at z=6 (15D); let's ignore z=7(5D), 8(4D) and 10(3D) for speed.

In total, the outputs are 54D.

In [None]:
if os.path.exists('database.npy'):
    outputs = np.load('database.npy')
else:
    outputs = np.zeros([len(features), 54])

    current_index = 0
    with h5py.File('AllFile3.h5','r') as f:

        # neutral fraction
        outputs[:,current_index] = f['AveDatas'][:,-1,1] # by design, the last entry is for z=5.9
        current_index+=1

        # CMB optical depth
        outputs[:,current_index] = ...
        current_index...

        #z=8 21cm PS, still need to interpolate to get the power at the observed wavenumbers
        PS_limit_ks_z = np.fromfile('HERA_Phase1_Limits/PS_limit_ks_z8.bin') 
        ps = np.log10(f['TotalPSDatas'][:,:,1+snapshots[8]]) # zs[::-1][snapshots[8]] = 8
        # network normally cannot deal with NaN/inf, the simplest way is to replace them with zeros. 
        # You can also mask those points out or resample those invalid points from some distribution.
        ps = np.nan_to_num(ps) 
        for ii in range(len(features)):
            outputs[ii, current_index:current_index+len(PS_limit_ks_z)] = interp1d(ks, ps[ii], fill_value="extrapolate")(PS_limit_ks_z)
        current_index+=len(PS_limit_ks_z)

        #z=10 21cm PS, still need to interpolate to get the power at the observed wavenumbers
        PS_limit_ks_z = ...
        ps = ...
        ps = ...
        for ii in range(len(features)):
            outputs[ii, current_index:current_index+len(PS_limit_ks_z)] = ...
        current_index+=...

        # UV LF, still need to interpolate to get the number density at the observed UV magnitudes
        redshift=6
        fLF = 'LFs/LF_obs_Bouwens_%.6f.txt'%redshift
        observed_muv = np.loadtxt(fLF, usecols=0)
        observed_muv = observed_muv[observed_muv>-20]

        results = f['LFDatas_%d'%redshift]
        for ii in range(len(features)):
            outputs[ii, current_index:current_index+len(observed_muv)] = ...

        current_index+=...

## Split into training set, validation set, test set


In [None]:
f_train = ...
f_valid = ...
f_test = 1 - f_valid - f_train

...

N_train = 
N_valid =  
N_test = ...

print('Training set size: ', N_train)
print('Validation set size: ', N_valid)
print('Test set size: ', N_test)

X_train = features[:N_train]
X_valid = features[N_train:N_train+N_valid]
X_test = features[-N_test:]

Y_train = outputs[:N_train]
Y_valid = outputs[N_train:N_train+N_valid]
Y_test = outputs[-N_test:]


# This is the size of a single training batch (training is not done on the entire test set at once, but on batches of data)
# You can play with this number but it shouldn't make a very big difference as long as it is ~ 100 
# If it is too big / small, the network will not learn as well
batch_size = 64 

# convert the database into tensorflow format
x_train = tf.data.Dataset.from_tensor_slices(X_train)
x_val = tf.data.Dataset.from_tensor_slices(X_valid)
y_train = tf.data.Dataset.from_tensor_slices(Y_train)
y_val = tf.data.Dataset.from_tensor_slices(Y_valid)
training_data = tf.data.Dataset.zip((x_train, y_train)).shuffle(X_train.shape[0]).batch(batch_size)
validation_data = tf.data.Dataset.zip((x_val, y_val)).shuffle(X_valid.shape[0]).batch(batch_size)

In [None]:
print(training_data)

## some callback functions to improve the training

In [None]:
callbacks = [
    tf.keras.callbacks..., # If loss does not improve for 20 epochs, reduce learning rate
    tf.keras.callbacks... # If loss does not improve for 50 epochs, stop the learning
]

## build the network architecture

In [None]:
# Input size = number of params
input_layer = ...

# Hidden fully-connected layers
# Number of nodes and hidden layers is arbitrary
output = ...
# Batch normalization = normalize the training weights at every batch.
# You can try removing it and you will find that the results are a bit worse
# This is because normalization helps stabilising the network. It can also help it run faster.
# In general, batch normalization always helps improve accuracy a little bit.
output = ...
...
# Last layer output shape = number of redshift bins in globale signal = 84
output = ...
model = ...
model.summary()

# define optimizer
opt = ...
# Use mean squared error for the loss function
...

## starting training

In [None]:
history = model.fit(...)

## check how the traning goes

In [None]:
fig, axs = plt.subplots(2,1, figsize=(10,8), sharex=True)

axs[0].plot(...)
axs[0].plot(...)
axs[0].set_ylabel('MSE Loss', fontsize=15)
axs[0].legend(loc='upper right')
axs[1].plot(...)
axs[1].set_ylabel('Learning Rate', fontsize=15)
axs[1].set_xlabel('Epoch', fontsize=15)

plt.tight_layout()
fig.subplots_adjust(hspace=0.05,wspace=0.05)

## Test emulator with test set = data that the network never saw


In [None]:
prediction = ...

#### let's first take a fewer random test set and see the 21cm PS and galaxy UV LFs

In [None]:
Ntest_show = 5
current_index = 2 # skip the first two 1D datasets
colors         = ['#984ea3','#ff7f00','#fec44f','#a65628','#f781bf']

fig, axs = plt.subplots(2,1, figsize=(10,8), sharex=True)
for ii in range(Ntest_show):
     
    #z=8 21cm PS
    PS_limit_ks_z = np.fromfile('HERA_Phase1_Limits/PS_limit_ks_z8.bin')
    axs[0].plot(PS_limit_ks_z, ..., color=colors[ii])
    axs[0].plot(PS_limit_ks_z, ..., color=colors[ii], ls='-.')
    current_index+=len(PS_limit_ks_z)
        
    #z=10 21cm PS
    PS_limit_ks_z = np.fromfile('HERA_Phase1_Limits/PS_limit_ks_z10.bin') 
    axs[1].plot(PS_limit_ks_z, ..., color=colors[ii])
    axs[1].plot(PS_limit_ks_z, ..., color=colors[ii], ls='-.')
    if ii < Ntest_show-1: 
        current_index=2
    else:
        current_index+=len(PS_limit_ks_z)
 
    axs[0].text(0.01, 0.99, r'$z=8$',ha='left',va='top',transform=axs[0].transAxes,fontsize = 15) 
    axs[1].text(0.01, 0.99, r'$z=10$',ha='left',va='top',transform=axs[1].transAxes,fontsize = 15) 
   
plt.tight_layout()
fig.subplots_adjust(hspace=0.05,wspace=0.05)

fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
plt.xticks([])
plt.yticks([])
plt.grid(False)

plt.xlabel('\n\n'+r'$k[{\rm Mpc}^{-1}]$', fontsize=15)
plt.ylabel(r'$\Delta_{21}^2[{\rm mK}^{2}]$'+'\n\n', fontsize=15)

fig, axs = plt.subplots(5,1, figsize=(5,15),sharex=True, sharey=True)
# UV LF @ z=6
redshift=6
fLF = 'LFs/LF_obs_Bouwens_%.6f.txt'%redshift
for ii in range(Ntest_show):
    observed_muv = np.loadtxt(fLF, usecols=0)
    observed_muv = observed_muv[observed_muv>-20]
    axs[ii].plot(observed_muv, ..., color=colors[ii])
    axs[ii].plot(observed_muv, ..., color=colors[ii], ls='-.')           
    axs[ii].text(0.01, 0.99, r'$z=6$',ha='left',va='top',transform=axs[ii].transAxes,fontsize = 15) 

plt.tight_layout()
fig.subplots_adjust(hspace=0.05,wspace=0.05)

fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
plt.xticks([])
plt.yticks([])
plt.grid(False)

plt.xlabel('\n\n'+r'$M_{\rm 1500}$', fontsize=15)
plt.ylabel(r'$\phi$'+'\n\n', fontsize=15)

### let's check all 54D outputs

In [None]:
fig, axs = plt.subplots(9,6, figsize=(15,10), sharex=True)
axs = axs.flatten()

frac_err = ...
for ii in range(Y_test.shape[1]):
    axs[ii].hist(frac_err[ii], bins=np.linspace(-1,1,100), color='k')
    axs[ii].yaxis.set_tick_params(labelsize=0)

plt.tight_layout()
fig.subplots_adjust(hspace=0.0,wspace=0.0)

fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
plt.xticks([])
plt.yticks([])
plt.grid(False)

plt.xlabel('\n\n'+r'$({Y_{pred} - Y_{true}})/{Y_{true}}$', fontsize = 30)
plt.ylabel('Histogram', fontsize = 20)

# Caution / thoughts

how do we improve the network for parameter space where its performance is poor?

# save the model for inference

In [None]:
...