In [1]:
#nu, a constant dependant on EOS, that affects a.
def nu_(EOS):
    if EOS == 'SRK':
        O1 = 0.42748
        O2 = 0.08664
        return [O1, O2]
    if EOS == 'PR':
        O1 = 0.45724
        O2 = 0.07780
        return [O1, O2] 

In [2]:
#D1, a parameter that defines the EOS.
def d_1(EOS):
    import jax.numpy as jnp
    if EOS == 'SRK':
        u0 = 1
        w0 = 0
    if EOS == 'PR':
        u0 = 2
        w0 = -1      
    D1 = (u0 + jnp.sqrt(u0**2-4*w0))/2
    return D1

#D2, a parameter that defines the EOS.
def d_2(EOS):
    import jax.numpy as jnp  
    if EOS == 'SRK':
        u0 = 1
        w0 = 0
    if EOS == 'PR':
        u0 = 2
        w0 = -1      
    D2 = (u0 - jnp.sqrt(u0**2-4*w0))/2
    return D2

def d_3(D1, D2):
    return D1+D2
def d_4(D1, D2):
    return(D1*D2)

In [3]:
#m, the accentricity polynomial.
def m_(w, C, EOS):
    import jax.numpy as jnp
    
    m = jnp.zeros([C])
    if EOS == 'SRK':
        m = m.at[:].set(0.48+1.574*w-0.176*w**2)
        return m
    if EOS == 'PR':
        m = m.at[:].set(0.37464+1.54226*w-0.26992*w**2)
        return m

In [4]:
#ai, the attraction paramater of each component i.
def a_i(T, M, NU, Tc, Pc, R, C):
    import jax.numpy as jnp
    ai = jnp.zeros([C])
    for i in range(C):
        ai = ai.at[i].set((R*Tc[i])**2*NU[0]/Pc[i]*(1 + M[i]*(1-(T/Tc[i])**0.5))**2)
    return ai

In [5]:
#aij, the binary attraction paramater of component system i-j.
def a_ij(ai, C, k):
    import jax.numpy as jnp
    aij = jnp.zeros([C, C])
    for i in range(C):
        for j in range(C):
            aij = aij.at[i,j].set((ai[i]*ai[j])**0.5*(1-k[i][j]))
    return aij

In [6]:
#bi, the repulsion paramater of each component i.
def b_i(NU, Tc, Pc, R, C):
    import jax.numpy as jnp
    bi = jnp.zeros([C])
    for i in range(C):
        bi = bi.at[i].set(NU[1]*R*Tc[i]/Pc[i])
    return bi

In [7]:
#a_tot, the weighted sum of binary attraction interactions.
def a_t(y, aij, C):
    import jax.numpy as jnp
    import jax
    a = jnp.array(0)
    
    yn = jnp.append(y[:-1], (1-jnp.sum(y[:-1])))
    
    r = C-1
    for i in range(C):
        for j in range(C):
            a += jnp.multiply(jnp.multiply(yn[i],aij[i,j]), yn[j])
    return a

In [8]:
#b_tot, the weighted sum of repulsion interactions.
def b_t(y, bi, C):
    import jax.numpy as jnp
    import jax
    bt = jnp.array(0)
    r = C-1
    for i in range(C):
        bt += y[i]*(bi[i]-bi[r])
    bt += bi[r]
    return bt

In [9]:
def A_dim(a, P, R, T):
    import jax.numpy as jnp
    return a*P/(R*T)**2

In [10]:
def ALPHA_(y, aij, a, C):
    import jax.numpy as jnp
    import jax
    
    yn = jnp.append(y[:-1], (1-jnp.sum(y[:-1])))
    ALPHA = jnp.zeros(C)
    for i in range(C):
        for k in range(C):
                ALPHA  = ALPHA.at[i].add(yn[k]*aij[k, i])
    
    ALPHA = ALPHA.at[:].set(ALPHA[:]/a)
    
    return ALPHA

In [11]:
def BETA_(bi, b, C):
    BETA = bi/b
    return BETA

In [12]:
def B_dim(b, P, R, T):
    import jax.numpy as jnp
    return b*P/(R*T)

In [13]:
def CompFact(A, B, EOS):
    import jax.numpy as jnp
    import jax
    #EOS cubic compressbility coefficients
    D1 = d_1(EOS)
    D2 = d_2(EOS)
    D3 = d_3(D1, D2)
    D4 = d_4(D1, D2)
    
    c1 = jnp.array(1)
    c2 = jnp.array((D3-1)*B-1)
    c3 = jnp.array(A-D3*B-(D3-D4)*B**2)
    c4 = jnp.array(-D4*(B**3+B**2)-A*B)
    
    
    #Cubic discriminant
    disc = 18*c1*c2*c3*c4 - 4*c2**3*c4 + c2**2*c3**2 - 4*c1*c3**3 - 27*c1**2*c4**2
    def threeroot(c1, c2, c3, c4, disc):
        #Three Real Roots
        def three_distinct(c1, c2, c3, c4):
            p1 = (3*c1*c3 - c2**2)/(3*c1**2)
            p2 = (2*c2**3 - 9*c1*c2*c3 + 27*c1**2*c4)/(27*c1**3)
            arg = 3*p2/(2*p1)*jnp.sqrt(-3/p1)
            arg = jax.lax.complex(arg, jnp.array(0.0))

            Z1 = 2*(-p1/3)**(1/2)*jnp.cos(jnp.arccos(arg)/3)
            Z2 = 2*(-p1/3)**(1/2)*jnp.cos(jnp.arccos(arg)/3-2*jnp.pi/3)
            Z3 = 2*(-p1/3)**(1/2)*jnp.cos(jnp.arccos(arg)/3-4*jnp.pi/3)

            Z1 = jnp.real(Z1 - c2/(3*c1))
            Z2 = jnp.real(Z2 - c2/(3*c1))
            Z3 = jnp.real(Z3 - c2/(3*c1))
            return jnp.maximum(jnp.maximum(Z1, Z2), Z3)
        
        def repeated_root(c1, c2, c3, c4):
            def one_repeated(c1, c2, c3, c4):
                #Double multiplicity root
                Z1 = (4*c1*c2*c3-9*c1**2*c4-c2**3)/(c1*(c2**2-3*c1*c3))
                Z2 = (9*c1*c4-c2*c3)/(2*(c2**2-3*c1*c3))
                return jnp.maximum(Z1, Z2)
            def two_repeated(c1, c2, c3, c4):
                #Triple multiplicity root
                Z3 = -c2/3*c1
                return Z3
            return jax.lax.cond(c2**2 == 3*c1*c3, two_repeated, one_repeated, c1, c2, c3, c4)
    
        return jax.lax.cond(disc != 0, three_distinct, repeated_root, c1, c2, c3, c4)
    
    def oneroot(c1, c2, c3, c4, disc):
        #One Real Root, Two Complex Conjugates
        
        d0 = c2**2 - 3*c1*c3
        d1 = 2*c2**3 - 9*c1*c2*c3 + 27*c1**2*c4
        
        C0 = jnp.cbrt((d1 + jnp.sqrt(d1**2-4*d0**3))/2)
        #Select other root if needed
        def proot(d1, d0):
            C = jnp.cbrt((d1 + jnp.sqrt(d1**2-4*d0**3))/2)
            return C
        def nroot(d1, d0):
            C = jnp.cbrt((d1 - jnp.sqrt(d1**2-4*d0**3))/2)
            return C
        
        C = jax.lax.cond(C0 == 0, nroot, proot, d1, d0)

        Z1 = -1/(3*c1)*(c2+C+d0/C)
        
        return Z1
    
    return jax.lax.cond(disc >= 0, threeroot, oneroot, c1, c2, c3, c4, disc)

In [14]:
def EOS_Params(y, T, P, Tc, Pc, w, k, R, C, EOS):
    nu = nu_(EOS)
    m = m_(w, C, EOS)
    d1 = d_1(EOS)
    d2 = d_2(EOS)
    
    ai = a_i(T, m, nu, Tc, Pc, R, C)
    aij = a_ij(ai, C, k)
    a = a_t(y, aij, C)
    bi = b_i(nu, Tc, Pc, R, C)
    b = b_t(y, bi, C)
    
    
    A = A_dim(a, P, R, T)
    B = B_dim(b, P, R, T)
    
    
    
    Z = CompFact(A, B, EOS)
    
    ALPHA = ALPHA_(y, aij, a, C)
    BETA = BETA_(bi, b, C)
    return ai, aij, a, bi, b, A, B, d1, d2, Z, ALPHA, BETA

In [15]:
#ln(f/xP) the chemical potential departure for component i.
def lnfc(ai, aij, a, bi, b, A, B, d1, d2, Z, ALPHA, BETA, C):
    import jax
    import jax.numpy as jnp
    lnfc = jnp.zeros(C)
    E0 = jnp.log((Z+d1*B)/(Z+d2*B))
    lnfci = BETA*(Z-1) -jnp.log(Z-B) - 2*(A/B)*ALPHA/(d1-d2)*E0 + (A/B)*BETA/(d1-d2)*E0
    lnfc = lnfc.at[:].set(lnfci)
        
    return lnfc

In [16]:
def ChemicalPotential(y, T, P, Tc, Pc, w, k, R, C, EOS):
    import jax
    import jax.numpy as jnp
    
    def lnf_(y, T, P, Tc, Pc, w, k, R, C, EOS):
        ai, aij, a, bi, b, A, B, d1, d2, Z, ALPHA, BETA = EOS_Params(y, T, P, Tc, Pc, w, k, R, C, EOS)
        lnf = lnfc(ai, aij, a, bi, b, A, B, d1, d2, Z, ALPHA, BETA, C)
        return lnf
    
    lnf = lnf_(y, T, P, Tc, Pc, w, k, R, C, EOS)
    dlnf_ = jax.jacfwd(lnf_, argnums = 0)
    dlnf = dlnf_(y, T, P, Tc, Pc, w, k, R, C, EOS)

    r = C-1
    dmu = jnp.zeros([C, C])
    for i in range(C):
        for j in range(C):
            if i != j:
                if i != r:
                    dmuij = dlnf[i, j]*R*T
                elif i == r:
                    dmuij = (dlnf[r, j] - 1/y[r])*R*T
            else:
                dmuij = (dlnf[j, j] + 1/y[j])*R*T
            
            dmu = dmu.at[i, j].set(dmuij)
    return dmu

In [17]:
def GeneralizedCubicHessian(y, yr, T, P, Tc, Pc, w, k, R, C, EOS):
    import jax.numpy as jnp
    dmu = ChemicalPotential(jnp.append(y, yr), T, P, Tc, Pc, w, k, R, C, EOS)
    r = C-1
    H = jnp.zeros([r, r])
    
    for i in range(r):
        for j in range(r):
            Hij = dmu[i, j]-dmu[r, j]
            H = H.at[i, j].set(Hij)
    return H

In [18]:
def l_min(H):
    import jax.numpy as jnp
    import jax
    
    #Inverse power iteration
    H_inv = jnp.linalg.inv(H)
    umin = jnp.ones(len(H))
    umin = umin/jnp.linalg.norm(umin)
    for t in range(50):
        umin = jnp.dot(H_inv, umin)
        umin = umin/jnp.linalg.norm(umin)
    
    l = jnp.dot(jnp.dot(H, umin), umin)
    return l, umin

In [19]:
def CostFunction(TP, y, Tc, Pc, w, k, R, C, EOS):
    import jax
    import jax.numpy as jnp
    T = TP[0]
    P = TP[1]
    yr = y[-1]
    y = y[0:-1]
    Hf = jax.jit(GeneralizedCubicHessian, static_argnames=['C', 'R', 'EOS'])
    H = Hf(y, yr, T, P, Tc, Pc, w, k, R, C, EOS)
    lmin = jax.jit(l_min)
    l, u = lmin(H)
    
    Q = l**2
    
    dHf = jax.jacfwd(Hf, argnums = 0)
    dH = dHf(y, yr, T, P, Tc, Pc, w, k, R, C, EOS)
    C = jnp.dot(jnp.dot(jnp.dot(dH, u), u), u)**2
    
    return C+Q

In [20]:
def minimizer(MxN, DataSet, EOS, y_given = None, TP_bound = None):
    import jax.numpy as jnp
    import KeyFunctions as me
    import jax
    import numpy as np
    import scipy as sp 

    jax.config.update("jax_enable_x64", True)
    [y, Tc, Pc, w, C, R, Vc, k, mxNames] = me.LookUpMix(MxN, DataSet, EOS)
    
    if y_given is not None:
        y = y_given
     
    TP_guess = [jnp.dot(y, Tc), jnp.dot(y, Pc)]
    if TP_bound is None:
        TP_bound = [(TP_guess[0]*0.75, TP_guess[0]*1.25), (TP_guess[1], TP_guess[1]*2)]
    
    minima = sp.optimize.differential_evolution(CostFunction, bounds = TP_bound, args = (y, Tc, Pc, w, k, R, C, EOS), mutation = 0.95, atol = 1e-14, init = 'sobol')
    
    if minima.fun > 1:
        #TP_bound = [(TP_guess[0]-50, TP_guess[0]+200), (TP_guess[1]*0.4, TP_guess[1]*4)]
        minima = sp.optimize.differential_evolution(CostFunction, bounds = TP_bound, args = (y, Tc, Pc, w, k, R, C, EOS),\
                                                mutation = 0.95, popsize = 40, tol = 1e-14, recombination = 0.4)
        if minima.fun > 1:
            display('Erroneous Minima')
    
    return minima, TP_bound

In [36]:
import KeyFunctions as me
import scipy as sp
import jax.numpy as jnp
MxN = 5
DataSet = "DimJiaLi Unique"
EOS = "SRK"
TP_bound = [(300, 400),(6e6, 8e6)]
#minima, TP_bound = minimizer(MxN, DataSet, EOS, y_given = None, TP_bound = TP_bound)

R = sp.constants.R
y= jnp.array([0.32013065, 0.4021396 , 0.03397377, 0.243756  ])
Tc= jnp.array([1444.7587, 1571.4685,  895.5158, 1412.933 ])
Pc= jnp.array([ 891032.75,  471466.  , 7106475.5 , 1263401.6 ])
w= jnp.array([2.0821772 , 2.159455  , 0.24360026, 1.509761  ])
k= jnp.array([[0.        , 0.13556027, 0.12379241, 0.13632333],\
       [0.13556027, 0.        , 0.11582702, 0.13421768],\
       [0.12379241, 0.11582702, 0.        , 0.12719214],\
       [0.13632333, 0.13421768, 0.12719214, 0.        ]])
C = len(y)


In [None]:
TP_bound = [(1000, 2000),(10e6, 15e6)]
minima = sp.optimize.differential_evolution(CostFunction, bounds = TP_bound, args = (y, Tc, Pc, w, k, R, C, EOS), mutation = 0.95, atol = 1e-14, init = 'sobol', popsize = 40, tol = 1e-14)

In [33]:
minima

             message: Optimization terminated successfully.
             success: True
                 fun: 5.236524200021197e+16
                   x: [ 1.000e+03  5.016e+06]
                 nit: 88
                nfev: 11395
          population: [[ 1.000e+03  5.016e+06]
                       [ 1.000e+03  5.015e+06]
                       ...
                       [ 1.000e+03  5.016e+06]
                       [ 1.000e+03  5.016e+06]]
 population_energies: [ 5.237e+16  5.237e+16 ...  5.237e+16  5.237e+16]

In [None]:
[y, Tc, Pc, w, C, R, Vc, k, mxNames] = me.LookUpMix(MxN, DataSet, EOS)
x, f, _, _ = sp.optimize.brute(CostFunction, ranges = TP_bound, args = (y, Tc, Pc, w, k, R, C, EOS), full_output =True, finish = None)
display(x)
display(f)