https://github.com/changhoonhahn/SEDflow/blob/1482d3a685aa73927f89eee965c52ec0a6ae64a9/nb/compile_nsa_sedflow.ipynb

# Compile SEDflow posteriors and MCMC posteriors for NSA galaxies

In [1]:
import os
import h5py
import numpy as np
from tqdm.notebook import tqdm, trange

import sys
sys.path.append('/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/SEDflow/src/')
from sedflow import obs as Obs
from sedflow import train as Train

os.environ["SPS_HOME"] = "/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/fsps"

from provabgs import infer as Infer
from provabgs import models as Models

In [2]:
sample = 'toy'
itrain = 2
nhidden = 500
nblocks = 15

In [3]:
dat_dir = '/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/data/'
# read in NSA data with clean photometry
nsa = {}
f = h5py.File(os.path.join(dat_dir, 'nsa.photometry.hdf5'), 'r')
for k in f.keys(): 
    nsa[k] = f[k][...]
    #print(k)
f.close()

mags_nsa = np.array( [nsa['mag_u'], nsa['mag_g'], nsa['mag_r'], nsa['mag_i'], nsa['mag_z'] ] )
sigs_nsa = np.array( [nsa['sigma_u'], nsa['sigma_g'], nsa['sigma_r'], nsa['sigma_i'], nsa['sigma_z'] ] )
zred_nsa = nsa['redshift']

nsa = np.concatenate([mags_nsa.T, sigs_nsa.T, zred_nsa.reshape( np.shape(zred_nsa)[0], 1 )], axis=1)
print( 'Number of galaxies: ', len( nsa ) )

Number of galaxies:  33884


In [4]:
prior = Train.prior_default()

# sps model
m_sps= Train.SPSmodel_default(emulator=True)

input parameters : logmstar, beta1_sfh, beta2_sfh, beta3_sfh, beta4_sfh, fburst, tburst, gamma1_zh, gamma2_zh, dust1, dust2, dust_index


In [5]:
finite = np.all(np.isfinite(np.array(nsa)[:,:5]), axis=1)

# compile
posts, fails = [], []
logsfrs, logzmws = [], []
for ichunk in range(8):
    finite_chunk =  finite[ichunk*1000:(ichunk+1)*1000]
    #print(finite_chunk)

    fpost = os.path.join(Train.data_dir(), 'anpe_thetaunt_magsigz.%s.%ix%i.%i.nsa%iof34.samples.npy' % (sample, nhidden, nblocks, itrain, ichunk))
    print(fpost)
    _post = np.load(fpost)[finite_chunk,:,:]

    fail = (np.sum(np.sum(_post, axis=2), axis=1) == 0)

    igals = np.arange(nsa.shape[0])[ichunk*1000:(ichunk+1)*1000][finite_chunk][fail]
    iigals = np.arange(_post.shape[0])[fail]

    post = np.zeros((_post.shape[0], _post.shape[1], 12))
    post[~fail,:] = prior.transform(_post[~fail])

    if np.sum(fail) > 0:
        for iigal, igal in zip(iigals, igals):
            fgal = os.path.join(Train.data_dir(), 'nsa_fail', 'mcmc.nsa.%i.hdf5' % igal)
            if os.path.isfile(fgal):
                gal = h5py.File(fgal, 'r')
                chain_i = Train.flatten_chain(gal['mcmc_chain'][...][2000:,:,:])[-10000:,:]
                post[iigal,:,:] = chain_i
    posts.append(post)
    fails.append(fail)
    
    # derived galaxy properties
    z_chunk = np.array(nsa)[:,-1][ichunk*1000:(ichunk+1)*1000][finite_chunk]

    logsfr = np.zeros((post.shape[0], post.shape[1]))
    logzmw = np.zeros((post.shape[0], post.shape[1]))
    for ii in trange(post.shape[0]):
        thetas_sps = post[ii,:]
        logsfr[ii,:] = np.log10(np.array(m_sps.avgSFR(thetas_sps, zred=float(z_chunk[ii]), dt=1.)))
        logzmw[ii,:] = np.log10(np.array(m_sps.Z_MW(thetas_sps, zred=float(z_chunk[ii])))).flatten()
    logsfrs.append(logsfr)
    logzmws.append(logzmw)

    #prop[ii,:,:] = Train.thetas2props(post[ii,:], np.repeat(z_chunk[ii], post[ii,:].shape[0]).astype(float))

posts = np.concatenate(posts)
fails = np.concatenate(fails)
logsfrs = np.concatenate(logsfrs)
logzmws = np.concatenate(logzmws)

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa0of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa1of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa2of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa3of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa4of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa5of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa6of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

/media/SSD/Doktori/Csillagtomegbecsles_cikk/SED/SEDflow/my_data/anpe_thetaunt_magsigz.toy.500x15.2.nsa7of34.samples.npy


  0%|          | 0/1000 [00:00<?, ?it/s]

In [23]:
# write out
fsedflow = h5py.File(os.path.join(Train.data_dir(), 'nsa.sedflow.v0.2.hdf5'), 'w')

nsaid = np.linspace( 0, 7999, 8000 )
fsedflow.create_dataset('NSAID', data=nsaid.astype(int)[finite[:8000]])
obs = ['mag_u', 'mag_g', 'mag_r', 'mag_i', 'mag_z', 'sigma_u', 'sigma_g', 'sigma_r', 'sigma_i', 'sigma_z', 'redshift']
for i, o in enumerate(obs):
    fsedflow.create_dataset(o, data=nsa[finite,i])

params = ['log_mstar', 'beta1', 'beta2', 'beta3', 'beta4', 'fburst', 'tburst', 'log_gamma1', 'log_gamma2', 'tau_bc', 'tau_ism', 'n_dust']
for i, param in enumerate(params):
    fsedflow.create_dataset(param, data=posts[:,:,i].astype(np.float32))

fsedflow.create_dataset('sedflow', data=~fails)    

# write out properties
fsedflow.create_dataset('log_sfr_1gyr', data=logsfrs.astype(np.float32))
fsedflow.create_dataset('log_z_mw', data=logzmws.astype(np.float32))

fsedflow.close()

In [25]:
nsa_comp = {}
f = h5py.File(os.path.join(Train.data_dir(), 'nsa.sedflow.v0.2.hdf5'), 'r')
for k in f.keys(): 
    nsa_comp[k] = f[k][...]
    #print(k)
f.close()

In [27]:
nsa_ori = {}
f = h5py.File(os.path.join(dat_dir, 'nsa.sedflow.v0.2.hdf5'), 'r')
for k in f.keys(): 
    nsa_ori[k] = f[k][...]
    #print(k)
f.close()

In [26]:
np.mean(nsa_comp['sigma_g'][2])

0.02342823

In [28]:
np.mean(nsa_ori['sigma_g'][2])

0.02342823