In [1]:
# This code is trying to reprodice the algorithm used in Stock et al. (FastChem)

In [2]:
################################################################

In [3]:
import numpy as np
import jax.numpy as jnp
import jaxopt
from jax import jit
from exojax.chem.mass_action import logK_FC
from exojax.chem import fastchemlike as fcl

In [4]:
cc='''
(2.1)
H2 <-> 2 H + 0 C + 0 O
CO <-> 0 H + 1 C + 1 O
CH4 <-> 4 H + 1 C + 0 O
H2O <-> 2 H + 0 C + 1 O
C2H2 <-> 2 H + 2 C + 0 O

(2.2)
n_H2   = K_H2* x_H^2* x_C^0* x_O^0
n_CO   = K_CO*  x_H^0* x_C^1* x_O^1
n_CH4  = K_CH4* x_H^4* x_C^1* x_O^0 
n_H2O  = K_H2O* x_H^2* x_C^0* x_O^1 
n_C2H2  = K_C2H2* x_H^2* x_C^2* x_O^0 

(2.3)
ep_H xbH = xH + 4 n_CH4 + 2 n_H2O
ep_C xbH = xC + n_CO + n_CH4
ep_O xbH = xO + n_CO + n_H2O
'''

In [5]:
#from FastChem
#logK = a1/T + a2 ln T + a3 + a4 T + a5 T^2 for FastChem:
#H2 Hydrogen : H 2 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpH2=[5.1909637142380554e+04,-1.8011701211306956e+00,8.7224583233705744e-02,2.5613890164973008e-04,-5.3540255367406060e-09]
#C1O1 Carbon_Monoxide : C 1 O 1 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpCO=[1.2899777785630804e+05,-1.7549835812545211e+00,-3.1625806804795502e+00,4.1336204683783961e-04 ,-2.3579962985989574e-08] 
#C1H4 Methane : C 1 H 4 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpCH4=[1.9784584536781305e+05,-8.8316803072239054e+00,5.2793066855988400e+00,2.7567674752936866e-03,-1.3966691995535711e-07]
#H2O1 Water : H 2 O 1 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpH2O=[1.1033645388793820e+05,-4.1783597409582285e+00,3.1744691010633233e+00,9.4064684023068001e-04,-4.0482461482866891e-08]
#C2H2 Ethyne : C 2 H 2 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpC2H2=[1.9646042660345597e+05,-4.5142919982618537e+00,-1.2427863494177018e+01,1.5177662181888874e-03,7.4542139989784635e-08]

In [6]:
ma_coeff_=jnp.array([kpH2,kpCO,kpCH4,kpH2O,kpC2H2]).T
nuf=np.array([[2,0,0],[0,1,1],[4,1,0],[2,0,1],[2,2,0]],dtype=np.float32) #(2.1) formula matrix

In [7]:
from exojax.atm.idealgas import number_density
T=1500.0
nh_=number_density(1.0,T)
epsilonj_=jnp.array([0.84,12*4012/739000,16*10400/739000.])
logK_=logK_FC(T,nuf,ma_coeff_)

In [8]:
b_=epsilonj_*nh_

In [9]:
from exojax.chem import fastchemlike as fcl

### check consistency
xj_=np.random.rand(3)
#law of mass action 
K_=np.exp(np.array(logK_))
ni_=K_*np.prod(xj_[np.newaxis,:]**nuf,axis=1) 
epsilonj_=(xj_+np.sum(nuf*ni_[:,np.newaxis],axis=0))/nh_

#znuf=zero_replaced_nuf(nuf)
nufmask=fcl.calc_nufmask(nuf)
epsiloni_=fcl.calc_epsiloni(nufmask,epsilonj_)
isamej,nufsamej=fcl.species_index_same_epsilonj(epsiloni_,epsilonj_,nuf)
Nj,Njmax=fcl.calc_Nj(nuf,epsiloni_,epsilonj_)
Apref=fcl.calc_Amatrix_np(nuf,xj_,K_,np.zeros_like(xj_),isamej,nufsamej,Njmax)

#numpy Ap check

kind=np.arange(0,np.max(Nj+1)) #power index in (2.26)
xjk=(xj_[:,np.newaxis])**kind[np.newaxis,:]
np.sum(Apref*xjk,axis=1), epsilonj_*nh_

#jax.lax.scan run
nufmask=fcl.calc_nufmask(nuf)
epsiloni_=fcl.calc_epsiloni(nufmask,epsilonj_)
Nj,Njmax=fcl.calc_Nj(nuf,epsiloni_,epsilonj_)
isamej_formatted, nufsamej_formatted=fcl.set_samej_formatted(isamej,nufsamej,Nj,len(epsiloni_))
Ap=fcl.calc_Amatrix(nuf,xj_,K_,np.zeros_like(xj_),isamej_formatted, nufsamej_formatted, Njmax)
#Ap=calc_Amatrix(nuf,xj_,K_,np.zeros_like(xj_),isamej_formatted, nufsamej_formatted, Njmax)

# jax scan Ap check
kind=np.arange(0,np.max(Nj+1)) #power index in (2.26)
xjk=(xj_[:,np.newaxis])**kind[np.newaxis,:]
np.sum(Ap*xjk,axis=1), epsilonj_*nh_

# numpy - jax.lax.scan comparison
print((Ap-Apref))#/Apref)
Ap

[[0.e+00 0.e+00 0.e+00 0.e+00 1.e-45]
 [0.e+00 0.e+00 0.e+00 0.e+00 0.e+00]
 [0.e+00 0.e+00 0.e+00 0.e+00 0.e+00]]


DeviceArray([[0.00000000e+00, 1.00000000e+00, 1.33478384e-09,
              0.00000000e+00, 3.22859166e-42],
             [0.00000000e+00, 1.18345376e+11, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00],
             [0.00000000e+00, 2.63932232e+11, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00]], dtype=float32)

In [10]:
epsilonj_*nh_

array([8.19769290e-01, 3.39646536e+10, 3.39646536e+10])

In [12]:
#solve once

#initial
#xj_=np.zeros(3) #yielding nan
xjmin0=np.zeros(3)
xj_=np.random.rand(3)
#############

A0=-epsilonj_*nh_ + xjmin0
epsiloni_=fcl.calc_epsiloni(nufmask,epsilonj_)
Nj,Njmax=fcl.calc_Nj(nuf,epsiloni_,epsilonj_)
isamej_formatted, nufsamej_formatted=fcl.set_samej_formatted(isamej,nufsamej,Nj,len(epsiloni_))
Ap=fcl.calc_Amatrix(nuf,xj_,K_,A0,isamej_formatted, nufsamej_formatted, Njmax)
#Ap=calc_Amatrix(nuf,xj_,K_,A0,isamej_formatted, nufsamej_formatted, Njmax)



In [13]:
np.shape(Ap),Ap,A0
#A0??

((3, 5),
 DeviceArray([[-8.1976926e-01,  1.0000000e+00,  1.3347838e-09,
                0.0000000e+00,  8.6656297e-42],
              [-3.3964655e+10,  3.9944897e+11,  0.0000000e+00,
                0.0000000e+00,  0.0000000e+00],
              [-3.3964655e+10,  7.0878377e+11,  0.0000000e+00,
                0.0000000e+00,  0.0000000e+00]], dtype=float32),
 array([-8.19769290e-01, -3.39646536e+10, -3.39646536e+10]))

In [54]:
from jax.lax import scan
def set_samej_formatted(isamej,nufsamej,Nj,numi):
    """make formatted isamej and nufsamej
    
    Note:
       species_index_same_epsilonj generates the inputs of this function.
    
    Args:
       isamej: isamej (the species index i(j) for epsilon_i = epsilon_j)
       nufsamej: nufsamej (the formula matrix component nuf(j) for epsilon_i = epsilon_j) 
       Nj: Nj computed 
       numi: number of the species
        
    Returns:
       isamej_formatted, nufsamej_formatted
       
    
    """
    isamej_formatted=np.zeros((len(isamej),numi))
    nufsamej_formatted=np.zeros((len(isamej),numi))
    for j in range(0,len(isamej)):
        isamej_formatted[j,0:Nj[j]]=isamej[j]
        nufsamej_formatted[j,0:Nj[j]]=nufsamej[j]
    return isamej_formatted, nufsamej_formatted
        
def calc_Amatrix(nuf,xj,Keq,Aj0,isamej_formatted, nufsamej_formatted,Njmax):
    """calc A matrix in Stock et al. (2018) (2.28, 2.29) jax version
    
    Note:
        isamej_formatted and nufsamej_formatted can be computed using set_samej_formatted
    
    Args:
        nuf: formula matrix
        xj: elements activity
        Keq: equilibrium constant 
        Aj0: Aj0 component defined by (2.27)
        isamej_formatted: formatted isamej (the species index i(j) for epsilon_i = epsilon_j)
        nufsamej_formatted: formatted nufsamej (the formula matrix component nuf(j) for epsilon_i = epsilon_j) 
        Njmax: Njmax
        
    Returns:
        A matrix
    """
    xnuf=xj**nuf
    lprod_ij=jnp.prod(xnuf,axis=1)
    lprod_ij=lprod_ij[:,jnp.newaxis]/xnuf
    Klprod_ij=Keq[:,jnp.newaxis]*lprod_ij
    
    numi,numj=jnp.shape(nuf)
    xs=jnp.hstack([nufsamej_formatted,Klprod_ij.T,isamej_formatted])
    def f(Apj,x):
        j=Apj[0]
        #j=j+1
        Ap=Apj[1]
        klist=x[:numi]
        Klprod_each=x[numi:2*numi]
        isamej_each=x[2*numi:]
    
        def g(Ap,x):
            k=x[0]#[:numi]
            isamej=x[1]#[numi:]
            Ap=Ap.at[j.astype(int),k.astype(int)].add(k*Klprod_each[isamej.astype(int)])
            return Ap, 0
    
        xt=jnp.vstack([klist,isamej_each]).T
        Ap,_=scan(g,Ap,xt)
        Apj=[j+1,Ap]
        return Apj, 0 

    #Apj initialization
    Ap=jnp.zeros((numj,Njmax+1))
    Ap=Ap.at[:,1].set(1.0)
    Ap=Ap.at[:,0].set(Aj0)
    Apj=[0,Ap]

    Apj,_=scan(f,Apj,xs)
    j,Ap=Apj
    return Ap



In [44]:
jnp.polyval(Ap[0,::-1],jnp.array([0.0,1.0,3.0]))

DeviceArray([0., 1., 3.], dtype=float32)

In [45]:
Ap[0,0]

DeviceArray(0., dtype=float32)

In [34]:
jnp.sum(Ap[0,:])

DeviceArray(1., dtype=float32)

In [None]:
def f(xj,Ap):
    Ap

In [1]:
################################################

In [2]:
#def zero_replaced_nuf(nuf):
#    """calc zero-replaced formula matrix
#    Args:
#        nuf: formula matrix
#        
#    Returns:
#        zero-replaced formula matrix (float32) 
#    """
#    znuf=np.copy(nuf)
#    znuf[znuf==0]=np.nan
#    return znuf

def calc_nufmask(nuf):
    """calc zero-replaced to nan formula matrix mask
    Args:
        nuf: formula matrix
        
    Returns:
        nufmask (float32) 
    """
    nufmask=np.copy(nuf)
    msk=nufmask==0
    nufmask[~msk]=1.0
    nufmask[msk]=np.nan
    return nufmask


def calc_epsiloni(nufmask,epsilonj):
    """calc species abundaunce=epsilon_i (2.24) in Stock et al.(2018)
    
    Args:
        nufmask: formula matrix mask
        epsilonj: element abundance (epsilon_j)
        
    Returns:
        species abundaunce= epsilon_i
    
    """
    emat=(np.full_like(nufmask,1)*epsilonj)
    return np.nanmin(emat*nufmask,axis=1)

In [21]:

def calc_Nj(nuf,epsiloni,epsilonj):
    """calc Nj defined by (2.25) in Stock et al. (2018)
    
    Args:
        nuf: formula matrix
        epsiloni: elements abundance
        epsilonj: species abundance
        
    Returns:
        Nj (ndarray)
        Njmax 
    """
    mse=mask_diff_epsilon(epsiloni,epsilonj)
    masked_nuf=np.copy(nuf)
    masked_nuf[mse]=0.0
    Nj=np.array(np.max(masked_nuf,axis=0),dtype=int)
    return Nj, np.max(Nj)

In [17]:
def mask_diff_epsilon(epsiloni,epsilonj):
    """epsilon_i = epsilon_j
    
    Args:
        epsiloni: elements abundance
        epsilonj: species abundance
        
    Returns:
        mask for epsilon_i > epsilon_j
    """
    de=np.abs(np.array(epsiloni[:,np.newaxis]-epsilonj[np.newaxis,:]))
    mse=de>1.e-18 #should be refactored
    return np.array(mse)

In [18]:
def species_index_same_epsilonj(epsiloni,epsilonj,nuf):
    """species index of i for epsilon_i = epsilon_j for given element index j
    
    Note:
        isamej is the species index i(j) for epsilon_i = epsilon_j. nufsamej is the formula matrix component nuf(j) for epsilon_i = epsilon_j.
        
    Args:
        epsiloni: elements abundance
        epsilonj: species abundance
        
    Returns:
        isamej (the species index i(j) for epsilon_i = epsilon_j), nufsamej (the formula matrix component nuf(j) for epsilon_i = epsilon_j) 
    """
    mm=mask_diff_epsilon(epsiloni,epsilonj)
    si=np.arange(0,len(epsiloni))
    isamej=[]
    nufsamej=[]
    for j in range(0,len(epsilonj)):
        isamej.append(si[~mm[:,j]])
        nufsamej.append(np.array(nuf[:,j][~mm[:,j]],dtype=int))
    return isamej, nufsamej
    

In [23]:
def calc_Amatrix_np(nuf,xj,Keq,Aj0,isamej,nufsamej,Njmax):
    """calc A matrix in Stock et al. (2018) (2.28, 2.29) numpy version
    
    Args:
        nuf: formula matrix
        xj: elements activity
        Keq: equilibrium constant 
        Aj0: Aj0 component defined by (2.27)
        isamej: isamej (the species index i(j) for epsilon_i = epsilon_j)
        nufsamej: nufsamej (the formula matrix component nuf(j) for epsilon_i = epsilon_j) 
        Njmax: Njmax
        
    Returns:
        A matrix
    """
    numi,numj=np.shape(nuf)
    Ap=np.zeros((numj,Njmax+1))
    Ap[:,0]=Aj0
    Ap[:,1]=1.0
    xnuf=xj**nuf # 
    for j in range(0,numj):
        i=isamej[j]
        klist=nufsamej[j]
        Ki=Keq[isamej[j]]
        #print("i_same_j",i,"K_i",Ki,"k=nu_ij",klist)
        lprod_i=np.prod(np.delete(xnuf,j,axis=1),axis=1) # Prod n_l^nu_{ij}(2.29) for all i
        kprodi=Ki*lprod_i[i]
        for ik,k in enumerate(klist):
            Ap[j,k]=Ap[j,k]+k*kprodi[ik]
    return Ap

In [27]:
from jax.lax import scan