In [4]:
import sys
sys.path.append("..")

import numpy as np
import jax.numpy as jnp
import jax
import scipy.linalg
import scipy.optimize
import matplotlib.pyplot as plt

import pyequion2
import tmcdiff.builder
import matplotlib.pyplot as plt

In [5]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
eqsys = pyequion2.EquilibriumBackend(["Na", "Cl", "Ca", "C"], from_elements=True, backend="jax", logbase="e")



In [10]:
builder = tmcdiff.builder.TransportBuilder(eqsys,
                                           298.15,
                                           0.1,
                                           1e-6,
                                           jnp.array([65, 56, 75.0, 28.0, 0.0]), ["Calcite"])
builder.set_species(['CaCO3', 'CaHCO3+', 'CaOH+', 'Na2CO3', 'NaCO3-', 'NaHCO3', 'NaOH', 'CaOH+',
                     'H+', 'OH-'])
builder.make_grid(3, 5)
builder.species

['CO2', 'CO3--', 'Ca++', 'Cl-', 'HCO3-', 'Na+']

In [11]:
logc = np.ones([builder.nspecies, builder.ngrid])*0.0 - 1.0
mu = builder.reduced_standard_potentials[..., None] + logc
logcmu = jnp.vstack([logc, mu])

In [12]:
builder.simplify()
builder.full_residual(logcmu)

DeviceArray([[ 14.393301  ,   0.        , -64.63212   ],
             [  0.        ,   0.        , -54.896362  ],
             [  0.        ,   0.        , -74.63212   ],
             [  0.        ,   0.        , -27.63212   ],
             [  0.        ,   0.        ,  -0.36787945],
             [ -0.4658966 ,  -0.4658966 ,  -0.4658966 ],
             [  1.8973236 ,   1.8973236 ,   1.8973236 ],
             [  1.234909  ,   1.234909  ,   1.234909  ],
             [  0.56424713,   0.56424713,   0.56424713],
             [  0.47433472,   0.47433472,   0.47433472],
             [  0.23307037,   0.23307037,   0.23307037]], dtype=float32)

In [13]:
print(jax.jit(builder.flattened_equality_constraint)(logcmu.flatten()))
print(jax.jit(jax.jacfwd(builder.flattened_equality_constraint))(logcmu.flatten()))
print(jax.jit(builder.flattened_minimization_objective)(logcmu.flatten()))
print(jax.jit(jax.grad(builder.flattened_minimization_objective))(logcmu.flatten()))

[ 14.393301     0.         -64.63212      0.           0.
 -54.896362     0.           0.         -74.63212      0.
   0.         -27.63212      0.           0.          -0.36787945
  -0.4658966   -0.4658966   -0.4658966    1.8973236    1.8973236
   1.8973236    1.234909     1.234909     1.234909     0.56424713
   0.56424713   0.56424713   0.47433472   0.47433472   0.47433472
   0.23307037   0.23307037   0.23307037]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]
-365.32776
[-19.213984   -19.213984   -19.213984   -26.236238   -26.236238
 -26.236238   -27.47984    -27.47984    -27.47984     -6.615661
  -6.615661    -6.615661   -29.152672   -29.152672   -29.152672
 -13.077531   -13.077531   -13.077531     0.12262648   0.12262648
   0.12262648   0.12262648   0.12262648   0.12262648   0.12262648
   0.12262648   0.12262648   0.12262648   0.12262648   0.12262648
   0.12262648   0.12262

In [14]:
constraints = {'type':'eq',
               'fun':jax.jit(builder.flattened_equality_constraint),
               'jac':jax.jit(jax.jacfwd(builder.flattened_equality_constraint))}
equality_constraint = scipy.optimize.NonlinearConstraint(
    jax.jit(builder.flattened_equality_constraint),
    lb=0.0,
    ub=0.0,
    jac = jax.jit(jax.jacfwd(builder.flattened_equality_constraint)))


In [15]:
sol_simple = scipy.optimize.minimize(jax.jit(builder.flattened_minimization_objective),
                              logcmu.flatten(),
                              jac = jax.jit(jax.grad(builder.flattened_minimization_objective)),
                              constraints=equality_constraint,
                              method='trust-constr')

In [16]:
sol_simple

         cg_niter: 2816
     cg_stop_cond: 1
           constr: [array([ 6.8017292e+00, -2.8799686e-01, -4.8649632e+01, -2.0977886e-02,
       -9.2283434e-01,  1.0777283e-01, -4.2521711e-03, -1.1772488e+00,
       -1.5617100e+01, -2.3797576e-03, -5.5483043e-01, -3.6661720e-01,
        4.0473945e-02, -4.4173872e-01,  7.9992332e+00,  8.1643677e-01,
       -7.9487610e-01, -1.2063644e+01,  3.8017120e+00,  1.4663849e+00,
        2.5277557e+00,  3.8818512e+00,  1.7710876e-01, -2.6927612e+01,
        5.5050659e-01,  2.2394562e-01, -4.3907928e+00,  1.6401672e-01,
        3.4533691e-01, -2.9358978e+00,  1.1600113e+00, -2.2312927e-01,
       -1.5133003e+01], dtype=float32)]
      constr_nfev: [1000]
      constr_nhev: [0]
      constr_njev: [1000]
   constr_penalty: 1.944943600075987e+18
 constr_violation: 48.64963150024414
   execution_time: 9.06349802017212
              fun: DeviceArray(-9532.727, dtype=float32)
             grad: array([-1.3690460e-10, -2.1639787e-02, -2.4847784e+00, -3.8533

In [3]:
plt.figure()
for i in range(builder.nspecies):
    plt.plot(builder.ygrid, c[i, :], label=builder.species[i])
plt.legend()
plt.ylabel(r'$c$[molal]')
plt.xlabel(r'$y^+$')

NameError: name 'builder' is not defined