In [2]:
import matplotlib.pyplot as plt 
import numpy as np 
from scipy import interpolate
import h5py
import sys
import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import pickle as pk
from multiprocessing import Pool
import multiprocessing

from sklearn import preprocessing
from sklearn.decomposition import PCA
#Please download the HIILines package from https://github.com/Sheng-Qi-Yang/HIILines
from HIILines import * 

print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
PoolNumber            = multiprocessing.cpu_count()

#Define the MDN architecture
class MDN(nn.Module):
    def __init__(self, n_hidden, n_gaussians):
        super(MDN, self).__init__()
        self.z_h      = nn.Sequential(nn.Linear(input_dim, n_hidden), nn.Tanh())
        self.z_pi     = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma1 = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma2 = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma3 = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma4 = nn.Linear(n_hidden, n_gaussians)
        self.z_mu1    = nn.Linear(n_hidden, n_gaussians)
        self.z_mu2    = nn.Linear(n_hidden, n_gaussians)
        self.z_mu3    = nn.Linear(n_hidden, n_gaussians)
        self.z_mu4    = nn.Linear(n_hidden, n_gaussians)

    def forward(self, x):
        z_h    = self.z_h(x)
        pi     = nn.functional.softmax(self.z_pi(z_h), -1)
        sigma1 = torch.exp(self.z_sigma1(z_h))
        sigma2 = torch.exp(self.z_sigma2(z_h))
        sigma3 = torch.exp(self.z_sigma3(z_h))
        sigma4 = torch.exp(self.z_sigma4(z_h))
        mu1    = self.z_mu1(z_h)
        mu2    = self.z_mu2(z_h)
        mu3    = self.z_mu3(z_h)
        mu4    = self.z_mu4(z_h)
        return pi, sigma1, sigma2, sigma3, sigma4, mu1, mu2, mu3, mu4

#Please download the SSP_Spectra_Conroy-et-al_v2.5_imfChabrier.hdf5 lookup table from https://zenodo.org/record/6338462#.ZBENinbMKbh
fSEDTable     = h5py.File('SSP_Spectra_Conroy-et-al_v2.5_imfChabrier.hdf5','r')
logage_Grid   = np.log10(fSEDTable['ages'][:])+3 #Myr
logZstar_Grid = fSEDTable['metallicities'][:]    #logZ/[Z_sun]
#Lookup table SSP_Spectra_Conroy-et-al_v2.5_imfChabrier_Q.hdf5 stores the logQ_HI for all the FSPS single star radiation spectra
fSEDTable     = h5py.File('SSP_Spectra_Conroy-et-al_v2.5_imfChabrier_Q.hdf5','r')
logQHITable   = fSEDTable['logQHI'][:]
logQHIf       = interpolate.RegularGridInterpolator((logZstar_Grid,logage_Grid),logQHITable)
MZR_FIRE      = lambda logMstar_tot: 0.38840077*logMstar_tot-3.97300904

def logLMAX(X):
    #This function computes the upper limit of LOII10, LOIII32, and LHbeta for a stellar particle with known {age, Zstar, M_{*,galaxy}}.
    logage       = X[:,0]
    logage[np.where(logage<logage_Grid[0])]  = logage_Grid[0]
    logage[np.where(logage>logage_Grid[-1])] = logage_Grid[-1]
    logZstar     = X[:,1]
    logZstar[np.where(logZstar<logZstar_Grid[0])]  = logZstar_Grid[0]
    logZstar[np.where(logZstar>logZstar_Grid[-1])] = logZstar_Grid[-1]
    logMstar_tot = X[:,2]
    pts          = np.array([[logZstar[i],logage[i]] for i in np.arange(len(logZstar))])
    logQHI       = logQHIf(pts)
    logZQ        = MZR_FIRE(logMstar_tot) 

    T4OIII       = np.zeros(len(logZQ))
    idx          = np.where(logZQ <= -0.06)[0]
    T4OIII[idx]  = 0.824*logZQ[idx]**2+0.101*logZQ[idx]+1.08
    idx          = np.where(logZQ > -0.06)[0]
    T4OIII[idx]  = 0.824*(-0.06)**2+0.101*(-0.06)+1.08
    aTe          = -12030.22*(12+logZQ-3.31)+113720.75
    T4OII        = aTe**2/2/T4OIII/10000/10000
    T4OII[np.where(T4OII<0.5)] = 0.5

    logLmax      = np.zeros((len(logage),3))
    
    logLmax[:,0] = -3.31+logZQ+np.log10(k01_OII(T4OII)*h*nu10_OII/alphaB_HI(T4OIII)*Jps2Lsun)+logQHI        #logLOII10
    logLmax[:,1] = -3.31+logZQ+np.log10(3/4*k03_OIII(T4OIII)*h*nu32_OIII/alphaB_HI(T4OIII)*Jps2Lsun)+logQHI #logLOIII32
    logLmax[:,2] = np.log10(h)+np.log10(nu_Hbeta*alphaB_Hbeta(T4OIII)/alphaB_HI(T4OIII)*Jps2Lsun)+logQHI    #logLHbeta
    return logLmax
    
oneDivSqrtTwoPI = 1.0 / np.sqrt(2.0*np.pi) # normalization factor for Gaussians
def gaussian_distribution(y, mu, sigma):
    result = (y.expand_as(mu) - mu) * torch.reciprocal(sigma)
    result = -0.5 * (result * result)
    return torch.exp(result) * torch.reciprocal(sigma) * oneDivSqrtTwoPI

def gumbel_sample(x, axis=1):
    gs       = torch.distributions.gumbel.Gumbel(torch.tensor([0.0]), torch.tensor([1.0]))
    z        = gs.sample(x.shape)[:,:,0].cuda()
    w        = torch.log(x)+z
    return w.argmax(axis=axis)

input_dim             = 3
output_dim            = 13
scaler_X              = preprocessing.StandardScaler()
scaler_X.mean_        = np.loadtxt('scaler_X_mean.txt')
scaler_X.scale_       = np.loadtxt('scaler_X_scale.txt')
scaler_Y              = preprocessing.StandardScaler()
scaler_Y.mean_        = np.loadtxt('scaler_Y_mean.txt')
scaler_Y.scale_       = np.loadtxt('scaler_Y_scale.txt')
pca                   = pk.load(open('pca.pkl','rb'))
print('finish data pre-processing')
pca_mean_cuda         = Variable(torch.from_numpy(np.float32(pca.mean_))).cuda()
pca_components_cuda   = Variable(torch.from_numpy(np.float32(pca.components_))).cuda()
scalerY_mean_cuda     = Variable(torch.from_numpy(np.float32(scaler_Y.mean_))).cuda()
scalerY_scale_cuda    = Variable(torch.from_numpy(np.float32(scaler_Y.scale_))).cuda()

network               = torch.load('MDN')

f            = np.loadtxt('example_stars.txt') #10 example star particles from FIRE high-z galaxy z5m12b
logages      = f[:,0]                          #Myr
idx          = np.where(logages<=2)[0]         #MDN is only trained to process stellar particles younger than 100 Myr
logages      = logages[idx]
Zstar        = 10**f[:,1][idx]                 #Z_sun
mass         = 10**f[:,2][idx]                 #M_sun, mass of every star particles
Mstar_tot    = 10**10.05                       #M_sun, stellar mass of the entire galaxy 
Mstar        = np.ones(len(logages))*Mstar_tot
hist,edges   = np.histogram(logages,bins=20)
logage_PDF   = interpolate.interp1d((edges[1:]+edges[:-1])/2,hist,bounds_error=False, fill_value=(hist[0],hist[-1]))
weight_temp  = 1/logage_PDF(logages)
weight       = weight_temp*len(logages)/np.sum(weight_temp)

logL_MDN              = np.zeros((len(idx),13)) #Stores log(Luminosity/L_sun) for 13 lines sampled by MDN

N       = len(logages)
#X is the 3D input 
X       = np.zeros((N,input_dim))
X[:,0]  = logages
X[:,1]  = np.log10(Zstar)
X[:,2]  = np.log10(Mstar)
#Z is used for computing the LOII10 and LOIII32 upper bounds 
Z       = np.zeros((N,3))
Z[:,0]  = logages
Z[:,1]  = np.log10(Zstar)
Z[:,2]  = np.log10(Mstar)
W       = weight

X_scaled         = scaler_X.transform(X)
x_tensor         = torch.from_numpy(np.float32(X_scaled))
x_variable       = Variable(x_tensor).cuda()

idx_invalid      = np.arange(N)
logLmax          = logLMAX(Z)
sample2          = np.ones((N,13))*100

counter          = 0
eps              = 0
while len(idx_invalid)>0:
    pi, sigma1, sigma2, sigma3, sigma4, mu1, mu2, mu3, mu4 = network(x_variable[idx_invalid,:])
    N_test             = x_variable[idx_invalid,:].shape[0]
    k                  = gumbel_sample(pi)
    indices            = (torch.arange(N_test), k)
    rn                 = torch.randn(N_test).cuda()
    sampled1           = rn * sigma1[indices] + mu1[indices]
    sampled2           = rn * sigma2[indices] + mu2[indices]
    sampled3           = rn * sigma3[indices] + mu3[indices]
    sampled4           = rn * sigma4[indices] + mu4[indices]
    sampled            = torch.zeros((N_test,4)).cuda()
    sampled[:,0]       = sampled1
    sampled[:,1]       = sampled2
    sampled[:,2]       = sampled3
    sampled[:,3]       = sampled4
    sample_mdn1        = torch.matmul(sampled,pca_components_cuda)+pca_mean_cuda
    sample_mdn         = sample_mdn1*scalerY_scale_cuda+scalerY_mean_cuda
    sample2[idx_invalid,:] = sample_mdn.cpu().detach().numpy()
    #If some random sampled results are greater than the upper bounds, abandon those results and redo the sampling.
    idx_invalid            = np.where((sample2[:,0]-np.max(logLmax[:,0])>eps)|(sample2[:,9]-np.max(logLmax[:,1])>eps))[0]
    counter                = counter + 1
    if counter == 10:
        counter = 0
        eps     = eps + 0.1

logL_MDN[:,0]  = sample2[:,0] +np.log10(mass)  #log(LOII10/L_sun)  [OII]3730A
logL_MDN[:,1]  = sample2[:,1] +np.log10(mass)  #log(LOII20/L_sun)  [OII]3727A
logL_MDN[:,2]  = sample2[:,2] +np.log10(mass)  #log(LOII30/L_sun)  [OII]2471.1A
logL_MDN[:,3]  = sample2[:,3] +np.log10(mass)  #log(LOII31/L_sun)  [OII]7322A
logL_MDN[:,4]  = sample2[:,4] +np.log10(mass)  #log(LOII32/L_sun)  [OII]7332A
logL_MDN[:,5]  = sample2[:,5] +np.log10(mass)  #log(LOII40/L_sun)  [OII]2471.0A
logL_MDN[:,6]  = sample2[:,6] +np.log10(mass)  #log(LOIII10/L_sun) [OIII]88 micron
logL_MDN[:,7]  = sample2[:,7] +np.log10(mass)  #log(LOIII21/L_sun) [OIII]52 micron
logL_MDN[:,8]  = sample2[:,8] +np.log10(mass)  #log(LOIII31/L_sun) [OIII]4960A
logL_MDN[:,9]  = sample2[:,9] +np.log10(mass)  #log(LOIII32/L_sun) [OIII]5007A
logL_MDN[:,10] = sample2[:,10]+np.log10(mass)  #log(LOIII43/L_sun) [OIII]4364A
logL_MDN[:,11] = sample2[:,11]+np.log10(mass)  #log(LHalpha/L_sun) Halpha line
logL_MDN[:,12] = sample2[:,12]+np.log10(mass)  #log(LHbeta/L_sun)  Hbeta line

print('finish')

True
Tesla P100-PCIE-16GB
finish data pre-processing
finish
