In [1]:
import sys
sys.path.append("..")

import numpy as np
import jax.numpy as jnp
import jax
import scipy.linalg
import scipy.optimize
import matplotlib.pyplot as plt

import pyequion2
import tmcdiff.builder

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
eqsys = pyequion2.EquilibriumBackend(["Na", "Cl", "Ca", "C"], from_elements=True, backend="jax", logbase="e")



In [4]:
builder = tmcdiff.builder.TransportBuilder(eqsys,
                                           298.15,
                                           0.1,
                                           1e-6,
                                           jnp.array([65, 56, 75.0, 28.0, 0.0]), ["Calcite"])
builder.set_species(['CaCO3', 'CaHCO3+', 'CaOH+', 'Na2CO3', 'NaCO3-', 'NaHCO3', 'NaOH', 'CaOH+',
                     'H+', 'OH-'])
builder.make_grid(3, 5)
builder.species

['CO2', 'CO3--', 'Ca++', 'Cl-', 'HCO3-', 'Na+']

In [None]:
logc = np.ones([builder.nspecies, builder.ngrid])*0.0 - 1.0
mu = builder.reduced_standard_potentials[..., None] + logc
logcmu = jnp.vstack([logc, mu])

In [6]:
builder.simplify()
builder.full_residual(logcmu)

DeviceArray([[ 14.393301  ,   0.        , -64.63212   ],
             [  0.        ,   0.        , -55.632122  ],
             [  0.        ,   0.        , -74.63212   ],
             [  0.        ,   0.        , -26.896362  ],
             [  0.        ,   0.        ,  -0.36787945],
             [ -0.4658966 ,  -0.4658966 ,  -0.4658966 ],
             [  1.8973236 ,   1.8973236 ,   1.8973236 ],
             [  1.234909  ,   1.234909  ,   1.234909  ],
             [  0.56424713,   0.56424713,   0.56424713],
             [  0.47433472,   0.47433472,   0.47433472],
             [  0.23307037,   0.23307037,   0.23307037]], dtype=float32)

In [7]:
print(jax.jit(builder.flattened_equality_constraint)(logcmu.flatten()))
print(jax.jit(jax.jacfwd(builder.flattened_equality_constraint))(logcmu.flatten()))
print(jax.jit(builder.flattened_minimization_objective)(logcmu.flatten()))
print(jax.jit(jax.grad(builder.flattened_minimization_objective))(logcmu.flatten()))

[ 14.393301     0.         -64.63212      0.           0.
 -55.632122     0.           0.         -74.63212      0.
   0.         -26.896362     0.           0.          -0.36787945
  -0.4658966   -0.4658966   -0.4658966    1.8973236    1.8973236
   1.8973236    1.234909     1.234909     1.234909     0.56424713
   0.56424713   0.56424713   0.47433472   0.47433472   0.47433472
   0.23307037   0.23307037   0.23307037]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]
-365.32776
[-19.213984   -19.213984   -19.213984   -26.236238   -26.236238
 -26.236238   -27.47984    -27.47984    -27.47984     -6.615661
  -6.615661    -6.615661   -29.152672   -29.152672   -29.152672
 -13.077531   -13.077531   -13.077531     0.12262648   0.12262648
   0.12262648   0.12262648   0.12262648   0.12262648   0.12262648
   0.12262648   0.12262648   0.12262648   0.12262648   0.12262648
   0.12262648   0.12262

In [8]:
constraints = {'type':'eq',
               'fun':jax.jit(builder.flattened_equality_constraint),
               'jac':jax.jit(jax.jacfwd(builder.flattened_equality_constraint))}
equality_constraint = scipy.optimize.NonlinearConstraint(
    jax.jit(builder.flattened_equality_constraint),
    lb=0.0,
    ub=0.0,
    jac = jax.jit(jax.jacfwd(builder.flattened_equality_constraint)))


In [9]:
sol_simple = scipy.optimize.minimize(jax.jit(builder.flattened_minimization_objective),
                              logcmu.flatten(),
                              jac = jax.jit(jax.grad(builder.flattened_minimization_objective)),
                              constraints=equality_constraint,
                              method='trust-constr')

In [10]:
sol_simple.success

False

In [12]:
builder.solve()

AttributeError: 'TransportBuilder' object has no attribute 'solve'