In [883]:
import numpy as np
import jax.numpy as jnp
import jax
import scipy.linalg

import pyequion2
import matplotlib.pyplot as plt

In [884]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [885]:
class TransportBuilder(object):
    def __init__(self, eqsys, TK, shear_velocity, kinematic_viscosity, cbulk, phases):
        self.eqsys = eqsys
        self.TK = TK
        self.shear_velocity = shear_velocity
        self.kinematic_viscosity = kinematic_viscosity
        self.cbulk = cbulk
        self.phases = phases
        
    def set_species(self, species_to_remove=[]):
        self.basedict = {k: v
                         for k, v in pyequion2.datamods.chemical_potentials.items()
                         if k in self.eqsys.solutes
                         and k not in species_to_remove}
        self.species = list(self.basedict.keys())
        self.species_ind = jnp.array([eqsys.species.index(spec) for spec in self.species])
        self.solid_ind = jnp.array([eqsys.solid_phase_names.index(phase) for phase in self.phases])
        self.reduced_formula_matrix = self.eqsys.formula_matrix[2:, self.species_ind]
        self.reduced_reaction_vector = \
            eqsys.solid_stoich_matrix[self.solid_ind, :][:, self.species_ind]
        print(self.reduced_formula_matrix.shape)
        print(self.reduced_reaction_vector.shape)
        self.closure_matrix = jnp.array(
            scipy.linalg.null_space(
                (self.reduced_formula_matrix@(self.reduced_reaction_vector.T)).T
            ).T
        )
        self.nspecies = len(self.species)
        self.logk_solid = eqsys.get_solid_log_equilibrium_constants(self.TK)[self.solid_ind]
        self.reduced_standard_potentials = jnp.array([self.basedict[spec]['mu0'] for spec in self.species])
        self.reduced_standard_potentials /= (pyequion2.constants.GAS_CONSTANT*self.TK)
        self.reduce_activity_function()
        self.reduced_diffusion_coefficients = pyequion2.equilibrium_backend.diffusion_coefficients.get_diffusion_coefficients(
                                        self.species, self.TK)
        self.reduced_diffusion_coefficients = jnp.array(self.reduced_diffusion_coefficients)/self.kinematic_viscosity
        
    def reduce_activity_function(self):
        self._actfunc = pyequion2.equilibrium_backend.ACTIVITY_MODEL_MAP[eqsys.activity_model](
                            self.species, backend="jax")
        
    def activity_model_func(self, molals, TK):
        return jnp.zeros_like(molals)
        return self._actfunc(molals, TK)[..., 1:]/pyequion2.constants.LOG10E
        
    def make_grid(self, ngrid, ymax):
        self.ngrid = ngrid
        self.npoints = ngrid + 1
        self.ymax = ymax
        self.ygrid, self.ystep = jnp.linspace(0, self.ymax, self.npoints, retstep=True)
        
    def wall_diffusion_plus(self, yplus):
        return 9.5*1e-4*yplus**3
    
    def wall_diffusion_plus_deriv(self, yplus):
        return 3*9.5*1e-4*yplus**2
    
    def bulk_boundary_condition(self, c, mu):
        cbulk = self.cbulk[..., None]
        return self.reduced_formula_matrix@(c[:, -1][..., None]) - cbulk
        
    def equilibrium_wall_boundary_condition(self, c, mu):
        dc = (c[:, 1] - c[:, 0])/(self.ystep)
        loga = jnp.log(c[:,0]) + self.activity_model_func(c[:,0], self.TK)
        res1 = self.reduced_reaction_vector@loga - self.logk_solid[..., None]
        res2 = self.closure_matrix@self.reduced_formula_matrix@dc[..., None]
        res = jnp.vstack([res1, res2])
        return res
    
    def modeled_wall_boundary_condition(self, x, mu):
        dc = (c[:, 1] - c[:, 0])/(self.ystep)
        loga = jnp.log(c[:,0]) + self.activity_model_func(c[:,0], self.TK)
        logsatur = self.reduced_reaction_vector@loga - self.logk_solid[..., None]
        fsatur = jnp.clip(jnp.exp(logsatur)-1, 0.0, jnp.inf)*1e-3
        res = self.reduced_formula_matrix@dc[..., None] \
              - self.reduced_formula_matrix@(self.reduced_reaction_vector.T)@fsatur
        return res

    def transport_residual(self, c, mu):
        #logc : (nsolutes, ngrid)
        #loga : (nsolutes, ngrid)
        
        ymiddle = self.ygrid[1:-1]
        cm = c[:, 1:-1]
        d1c = (c[:, 2:] - c[:, :-2])/(2*self.ystep) #(nsolutes, ngrid-2)
        d1mu = (mu[:, 2:] - mu[:, :-2])/(2*self.ystep) #(nsolutes, ngrid-2)
        d2c = (c[:, 2:] - 2*c[:, 1:-1] + c[:, :-2])/(self.ystep**2) #(nsolutes, ngrid-2
        d2mu = (mu[:, 2:] - 2*mu[:, 1:-1] + mu[:, :-2])/(self.ystep**2) #(nsolutes, ngrid-2)        
        molecular_diffusions = self.reduced_diffusion_coefficients[..., None] #(nsolutes, 1)
        turbulent_diffusions = self.wall_diffusion_plus(ymiddle)
        turbulent_diffusions_deriv = self.wall_diffusion_plus_deriv(ymiddle)
        term1 = molecular_diffusions*(d1c*d1mu + cm*d2mu)
        term2 = turbulent_diffusions*d2c + turbulent_diffusions_deriv*d1c
        term = term1 + term2
        res = self.reduced_formula_matrix@term
        return res
                
    def potential_residual(self, c, mu):
        logg = self.activity_model_func(c.T, self.TK).T
        mu0 = self.reduced_standard_potentials[..., None]
        return mu - (mu0 + logg + jnp.log(c))
    
    def full_residual(self, cmu):
        n = cmu.shape[0]
        c = cmu[:n//2, :]
        mu = cmu[n//2:, :]
        res1a = self.equilibrium_wall_boundary_condition(c, mu)
        res1b = self.transport_residual(c, mu)
        res1c = self.bulk_boundary_condition(c, mu)
        res1 = jnp.hstack([res1a, res1b, res1c])
        res3 = self.potential_residual(c, mu)
        res = jnp.vstack([res1, res3])
        return res
    
    def bulk_residual(self, cmu):
        n = cmu.shape[0]
        c = cmu[:n//2, :]
        mu = cmu[n//2:, :]
        res1 = self.bulk_boundary_condition(c, mu)
        res3 = self.potential_residual(c, mu)
        res = jnp.vstack([res1, res3])
        return res

    def gibbs_free_energy(self, cmu):
        n = cmu.shape[0]
        c = cmu[:n//2, :]
        mu = cmu[n//2:, :]
        return jnp.mean(jnp.sum(c*mu, axis=0))    
    
    def flattened_equality_constraint(self, x):
        cmu = x.reshape(2*self.nspecies, self.npoints)
        return self.full_residual(cmu).flatten()
    
    def flattened_minimization_objective(self, x):
        cmu = x.reshape(2*self.nspecies, self.npoints)
        return self.gibbs_free_energy(cmu).flatten()[0]
        
    def wall_length(self):
        return self.kinematic_viscosity/self.shear_velocity
    
    def wall_time(self):
        return self.kinematic_viscosity/(self.shear_velocity**2)
    
    def get_log_equilibrium_constants(self, TK):
        return self.eqsys.get_log_equilibrium_constants(TK)


In [886]:
eqsys = pyequion2.EquilibriumBackend(["Na", "Cl"], from_elements=True, backend="jax", logbase="e")

In [889]:
builder = TransportBuilder(eqsys, 298.15, 0.1, 1e-6, jnp.array([1e-3, 1e-3, 0.0]), ["Halite"])
builder.set_species()
builder.make_grid(10, 5)

(3, 4)
(1, 4)


In [890]:
builder.species

['Cl-', 'H+', 'Na+', 'OH-']

In [891]:
c = np.ones([builder.nspecies, builder.npoints])*1e-3
mu = np.random.randn(*c.shape)*0.0 + builder.reduced_standard_potentials[..., None] + np.log(c)
cmu = jnp.vstack([c, mu])

In [892]:
builder.bulk_boundary_condition(c, mu)

DeviceArray([[0.],
             [0.],
             [0.]], dtype=float32)

In [893]:
builder.closure_matrix.shape

(2, 3)

In [894]:
builder.potential_residual(c, mu)

DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [895]:
builder.equilibrium_wall_boundary_condition(c, mu)

DeviceArray([[-17.430569],
             [  0.      ],
             [  0.      ]], dtype=float32)

In [896]:
builder.modeled_wall_boundary_condition(c, mu)

DeviceArray([[0.],
             [0.],
             [0.]], dtype=float32)

In [897]:
builder.transport_residual(c, mu)

DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [898]:
cmu

DeviceArray([[ 1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03],
             [ 1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03],
             [ 1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03],
             [ 1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03,  1.0000000e-03,
               1.0000000e-03,  1.0000000e-03],
             [-6.9607053e+00, -6.9607053e+00, -6.9607053e+00,
      

In [899]:
builder.full_residual(cmu)

DeviceArray([[-17.430569,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ,   0.      ,
                0.      ,   0.      ,   0.      ,  

In [900]:
builder.gibbs_free_energy(cmu)

DeviceArray(-0.02785307, dtype=float32)

In [901]:
equality_constraint = scipy.optimize.NonlinearConstraint(builder.flattened_equality_constraint,
                                                         lb=0.0,
                                                         ub=0.0,
                                                         jac = jax.jacfwd(builder.flattened_equality_constraint))
lb = np.hstack([np.zeros(builder.nspecies*builder.npoints), -np.inf*np.ones(builder.nspecies*builder.npoints)])
bounds = scipy.optimize.Bounds(lb = lb, ub = np.inf)

In [902]:
sol = scipy.optimize.minimize(builder.flattened_minimization_objective,
                        cmu.flatten(),
                        jac = jax.grad(builder.flattened_minimization_objective),
                        bounds=bounds,
                        constraints=equality_constraint,
                        method='trust-constr')

KeyboardInterrupt: 

In [None]:
sol.success

In [None]:
cmu = sol.x.reshape(builder.nspecies*2, builder.npoints)
c, mu = cmu[:builder.nspecies, :], cmu[builder.nspecies:, :]

In [None]:
builder.ygrid

In [None]:
plt.plot(builder.ygrid, np.exp(c[2, :]))