In [1]:
# This code is trying to reprodice the algorithm used in Stock et al. (FastChem)

In [2]:
import numpy as np
import jax.numpy as jnp
import jaxopt
from jax import jit
from exojax.chem.mass_action import logK_FC

In [3]:
cc='''
(2.1)
H2 <-> 2 H + 0 C + 0 O
CO <-> 0 H + 1 C + 1 O
CH4 <-> 4 H + 1 C + 0 O
H2O <-> 2 H + 0 C + 1 O
C2H2 <-> 2 H + 2 C + 0 O

(2.2)
n_H2   = K_H2* x_H^2* x_C^0* x_O^0
n_CO   = K_CO*  x_H^0* x_C^1* x_O^1
n_CH4  = K_CH4* x_H^4* x_C^1* x_O^0 
n_H2O  = K_H2O* x_H^2* x_C^0* x_O^1 
n_C2H2  = K_C2H2* x_H^2* x_C^2* x_O^0 

(2.3)
ep_H xbH = xH + 4 n_CH4 + 2 n_H2O
ep_C xbH = xC + n_CO + n_CH4
ep_O xbH = xO + n_CO + n_H2O
'''

In [4]:
#from FastChem
#logK = a1/T + a2 ln T + a3 + a4 T + a5 T^2 for FastChem:
#H2 Hydrogen : H 2 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpH2=[5.1909637142380554e+04,-1.8011701211306956e+00,8.7224583233705744e-02,2.5613890164973008e-04,-5.3540255367406060e-09]
#C1O1 Carbon_Monoxide : C 1 O 1 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpCO=[1.2899777785630804e+05,-1.7549835812545211e+00,-3.1625806804795502e+00,4.1336204683783961e-04 ,-2.3579962985989574e-08] 
#C1H4 Methane : C 1 H 4 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpCH4=[1.9784584536781305e+05,-8.8316803072239054e+00,5.2793066855988400e+00,2.7567674752936866e-03,-1.3966691995535711e-07]
#H2O1 Water : H 2 O 1 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpH2O=[1.1033645388793820e+05,-4.1783597409582285e+00,3.1744691010633233e+00,9.4064684023068001e-04,-4.0482461482866891e-08]
#C2H2 Ethyne : C 2 H 2 # Chase, M. et al., JANAF thermochemical tables, 1998.
kpC2H2=[1.9646042660345597e+05,-4.5142919982618537e+00,-1.2427863494177018e+01,1.5177662181888874e-03,7.4542139989784635e-08]

In [5]:
ma_coeff_=jnp.array([kpH2,kpCO,kpCH4,kpH2O,kpC2H2]).T
nuf=np.array([[2,0,0],[0,1,1],[4,1,0],[2,0,1],[2,2,0]],dtype=np.float32) #(2.1) formula matrix


In [6]:
from exojax.atm.idealgas import number_density
T=1500.0
nh_=number_density(1.0,T)

epsilonj_=jnp.array([0.84,12*4012/739000,16*10400/739000.])
logK_=logK_FC(T,nuf,ma_coeff_)

In [7]:
logK_

DeviceArray([-21.127644,  27.547245, -95.67498 , -38.54748 , -41.08718 ],            dtype=float32)

In [8]:
b_=epsilonj_*nh_

In [9]:
def zero_replaced_nuf(nuf):
    """calc zero-replaced formula matrix
    Args:
        nuf: formula matrix
        
    Returns:
        zero-replaced formula matrix (float32) 
    """
    znuf=np.copy(nuf)
    znuf[znuf==0]=np.nan
    return znuf

def calc_nufmask(nuf):
    """calc zero-replaced to nan formula matrix mask
    Args:
        nuf: formula matrix
        
    Returns:
        nufmask (float32) 
    """
    nufmask=np.copy(nuf)
    msk=nufmask==0
    nufmask[~msk]=1.0
    nufmask[msk]=np.nan
    return nufmask


def calc_epsiloni(nufmask,epsilonj):
    """calc species abundaunce=epsilon_i (2.24) in Stock et al.(2018)
    
    Args:
        nufmask: formula matrix mask
        epsilonj: element abundance (epsilon_j)
        
    Returns:
        species abundaunce= epsilon_i
    
    """
    emat=(np.full_like(nufmask,1)*epsilonj)
    return np.nanmin(emat*nufmask,axis=1)

In [10]:
znuf=zero_replaced_nuf(nuf)
nufmask=calc_nufmask(nuf)
epsiloni_=calc_epsiloni(nufmask,epsilonj_)

In [11]:

def calc_Nj(nuf,epsiloni,epsilonj):
    """calc Nj defined by (2.25) in Stock et al. (2018)
    
    Args:
        nuf: formula matrix
        epsiloni: elements abundance
        epsilonj: species abundance
        
    Returns:
        Nj (ndarray)
        Njmax 
    """
    mse=mask_diff_epsilon(epsiloni,epsilonj)
    masked_nuf=np.copy(nuf)
    masked_nuf[mse]=0.0
    Nj=np.array(np.max(masked_nuf,axis=0),dtype=int)
    return Nj, np.max(Nj)

In [12]:
def mask_diff_epsilon(epsiloni,epsilonj):
    """epsilon_i = epsilon_j
    
    Args:
        epsiloni: elements abundance
        epsilonj: species abundance
        
    Returns:
        mask for epsilon_i > epsilon_j
    """
    de=np.abs(np.array(epsiloni_[:,np.newaxis]-epsilonj_[np.newaxis,:]))
    mse=de>1.e-18 #should be refactored
    return np.array(mse)

In [13]:
def species_index_same_epsilonj(epsiloni,epsilonj,nuf):
    """species index of i for epsilon_i = epsilon_j for given element index j
    
    Args:
        epsiloni: elements abundance
        epsilonj: species abundance
        
    Returns:
        i(j) for epsilon_i = epsilon_j
    """
    mm=mask_diff_epsilon(epsiloni,epsilonj)
    si=np.arange(0,len(epsiloni))
    isamej=[]
    nufsamej=[]
    for j in range(0,len(epsilonj)):
        isamej.append(si[~mm[:,j]])
        nufsamej.append(np.array(nuf[:,j][~mm[:,j]],dtype=int))
    return isamej, nufsamej
    

In [14]:
isamej,nufsamej=species_index_same_epsilonj(epsiloni_,epsilonj_,nuf)

In [15]:
isamej

[array([0]), array([1, 2, 4]), array([3])]

In [16]:
nufsamej

[array([2]), array([1, 1, 2]), array([1])]

In [17]:
Nj=calc_Nj(nuf,epsiloni_,epsilonj_)

In [18]:
Nj

(array([2, 2, 1]), 2)

In [19]:
#emulate (2.23)
ni=np.ones(5)
de=np.array(epsiloni_[:,np.newaxis]-epsilonj_[np.newaxis,:])
np.nansum(-de*nufmask*ni[:,np.newaxis],axis=0)

array([2.16453576, 0.        , 0.16002166])

In [20]:
### check consistency
xj_=np.random.rand(3)
#law of mass action 
K_=np.exp(np.array(logK_))
ni_=K_*np.prod(xj_[np.newaxis,:]**nuf,axis=1) 
epsilonj_=(xj_+np.sum(nuf*ni_[:,np.newaxis],axis=0))/nh_

In [21]:

nufmask=calc_nufmask(nuf)
epsiloni_=calc_epsiloni(nufmask,epsilonj_)
Nj,Njmax=calc_Nj(nuf,epsiloni_,epsilonj_)
isamej,nufsamej=species_index_same_epsilonj(epsiloni_,epsilonj_,nuf)

In [22]:
Nj

array([4, 1, 1])

In [40]:
def calc_Amatrix_np(nuf,xj,Aj0):
    """calc A matrix in Stock et al. (2018) (2.28, 2.29) numpy version
    
    Args:
        nuf: formula matrix
        xj: elements activity
        Aj0: Aj0 component defined by (2.27)
        
    Returns:
        A matrix
    """
    numi,numj=np.shape(nuf)
    Ap=np.zeros((numj,Njmax+1))
    Ap[:,0]=Aj0
    Ap[:,1]=1.0
    xnuf=xj**nuf # 
    for j in range(0,numj):
        i=isamej[j]
        klist=nufsamej[j]
        Ki=K_[isamej[j]]
        #print("i_same_j",i,"K_i",Ki,"k=nu_ij",klist)
        lprod_i=np.prod(np.delete(xnuf,j,axis=1),axis=1) # Prod n_l^nu_{ij}(2.29) for all i
        kprodi=Ki*lprod_i[i]
        for ik,k in enumerate(klist):
            Ap[j,k]=Ap[j,k]+k*kprodi[ik]
    return Ap

In [41]:
kind=np.arange(0,np.max(Nj+1)) #power index in (2.26)
xjk=(xj_[:,np.newaxis])**kind[np.newaxis,:]
Ap=calc_Amatrix_np(nuf,xj_,np.zeros_like(xj_))
np.sum(Ap*xjk,axis=1), epsilonj_*nh_

(array([4.30067317e-01, 2.90333424e+11, 2.90333424e+11]),
 array([4.30067317e-01, 2.90333424e+11, 2.90333424e+11]))

In [24]:
from jax.lax import scan

In [25]:
isamej

[array([0, 2, 3, 4]), array([1]), array([1])]

In [26]:
nufsamej

[array([2, 4, 2, 2]), array([1]), array([1])]

In [27]:
#Njmax=np.max(Nj)##

#xs=np.zeros((len(isamej),4,Njmax))
#for j in range(0,len(isamej)):
#    xs[j,0,0:Nj[j]]=isamej[j]
#    xs[j,1,0:Nj[j]]=nufsamej[j]
#    xs[j,2,0:Nj[j]]=K_[isamej[j]] #should be modified later when you wanna derive the derivative by T

In [28]:
xnuf=xj_**nuf

In [29]:
lprod_ij=jnp.prod(xnuf,axis=1)
lprod_ij=lprod_ij[:,np.newaxis]/xnuf

#for j in range(0,3):
#    lprod_i=jnp.prod(jnp.delete(xnuf,j,axis=1),axis=1)
#    print(lprod_i-lprod_ij[:,j])
    
Klprod_ij=K_[:,np.newaxis]*lprod_ij

#Njmax=np.max(Nj)
#Kisame=np.zeros((len(isamej),Njmax))
#for j in range(0,len(isamej)):
#    print(isamej[j])
#    print(K_[isamej[j]]*lprod_ij[isamej[j],j])
#    print(Klprod_ij[:,j]) 
#    print("---")

In [30]:
# OK only for # of j + 1 < # of i need to be refined
isamej_formatted=np.zeros((len(isamej),len(epsiloni_)))
nufsamej_formatted=np.zeros((len(isamej),len(epsiloni_)))
for j in range(0,len(isamej)):
    isamej_formatted[j,0:Nj[j]]=isamej[j]
    nufsamej_formatted[j,0:Nj[j]]=nufsamej[j]

In [31]:
np.shape(isamej_formatted),np.shape(Klprod_ij.T),np.shape(nufsamej_formatted)

((3, 5), (3, 5), (3, 5))

In [32]:
        klist=nufsamej[j]
        Ki=K_[isamej[j]]
        #print("i_same_j",i,"K_i",Ki,"k=nu_ij",klist)
        lprod_i=np.prod(np.delete(xnuf,j,axis=1),axis=1) # Prod n_l^nu_{ij}(2.29) for all i
        kprodi=Ki*lprod_i[i]
        for ik,k in enumerate(klist):
            k=klist[ik]
            Ap[j,k]=Ap[j,k]+k*kprodi[ik]

NameError: name 'i' is not defined

In [110]:
numi,numj=np.shape(nuf)

xs=jnp.hstack([isamej_formatted,nufsamej_formatted,Klprod_ij.T])


def f(Apj,x):
    j=Apj[0]
    Ap=Apj[1]
    j=j+1
    arr=x 
    isamej_each=arr[:numi]
    klist=arr[numi:2*numi]
    Klprod_each=arr[2*numi:]
    y=klist
    
    xt=jnp.vstack([klist,Klprod_each]).T
    
    def g(Ap,x):
        k=x[0]#[:numi]
        Klx=x[1]#[numi:]
        Ap=Ap.at[j.astype(int),k.astype(int)].add(k*Klx)
        #Ap=Ap.at[0,0].add(2.0)
        #Ap[j,k]=Ap[j,k]+k*kprodi[ik]
        return Ap, _
    
    Ap,w=scan(g,Ap,xt)
    Apj=[j,Ap]
    return Apj, _ 

Ap=jnp.zeros((numj,Njmax+1))
Apj=[j,Ap]
Ap,_=scan(f,Apj,xs)
Ap

[DeviceArray(3, dtype=int32, weak_type=True),
 DeviceArray([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
               0.0000000e+00],
              [0.0000000e+00, 0.0000000e+00, 1.3347838e-09, 0.0000000e+00,
               1.1613337e+12],
              [0.0000000e+00, 1.2343940e-10, 0.0000000e+00, 0.0000000e+00,
               0.0000000e+00]], dtype=float32)]

In [51]:
nufsamej_formatted,Klprod_ij.T

(array([[2., 4., 2., 2., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]]),
 DeviceArray([[6.6739192e-10, 2.9033343e+11, 1.0509738e-42, 1.5339135e-17,
               2.0003512e-19],
              [1.2343940e-10, 7.7691932e+11, 9.6689594e-44, 2.8370941e-18,
               2.6493346e-19],
              [1.2343940e-10, 3.4366688e+11, 3.6433760e-44, 3.3582602e-18,
               3.6998073e-20]], dtype=float32))

In [42]:
Ap=jnp.zeros((2,3))
Ap.at[0,1].set(1.0)

DeviceArray([[0., 1., 0.],
             [0., 0., 0.]], dtype=float32)

In [34]:
def calc_Amatrix(nuf,xj,Aj0):
    """calc A matrix in Stock et al. (2018) (2.28, 2.29)
    
    Args:
        nuf: formula matrix
        xj: elements activity
        Aj0: Aj0 component defined by (2.27)
        
    Returns:
        A matrix
    """
    numi,numj=jnp.shape(nuf)
    Apt=jnp.zeros((Njmax+1,numj))
    Apt.at[0].set(Aj0)
    Apt.at[1].set(jnp.ones(numj))
    Ap=Apt.T
    xnuf=xj**nuf # 
    #scan(f, Ap0,xs)
    def f(Ap,):
        
    for j in range(0,numj):
        i=isamej[j]
        klist=nufsamej[j]
        Ki=K_[isamej[j]]
        #print("i_same_j",i,"K_i",Ki,"k=nu_ij",klist)
        lprod_i=np.prod(np.delete(xnuf,j,axis=1),axis=1) # Prod n_l^nu_{ij}(2.29) for all i
        kprodi=Ki*lprod_i[i]
        for ik,k in enumerate(klist):
            k=klist[ik]
            
            #Ap[j,k]=Ap[j,k]+k*kprodi[ik]
    return Ap

IndentationError: expected an indented block (295735272.py, line 21)

In [40]:
Ap=calc_Amatrix(nuf,xj_,np.zeros_like(xj_))
print(Ap)

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

In [29]:
kind=np.arange(0,np.max(Nj+1)) #power index in (2.26)
xjk=(xj_[:,np.newaxis])**kind[np.newaxis,:]
Ap=calc_Amatrix_np(nuf,xj_,np.zeros_like(xj_))
print(Ap)


np.sum(Ap*xjk,axis=1)

[[0.00000000e+00 1.00000000e+00 1.33478896e-09 0.00000000e+00
  6.11955780e-42]
 [0.00000000e+00 3.56111117e+11 0.00000000e+00 0.00000000e+00
  0.00000000e+00]
 [0.00000000e+00 5.00512587e+11 0.00000000e+00 0.00000000e+00
  0.00000000e+00]]


array([2.07648441e-01, 1.93813492e+11, 1.93813492e+11])

In [25]:
epsilonj_*nh_

array([2.07648441e-01, 1.93813492e+11, 1.93813492e+11])