# Stixrude-Lithgow-Bertelloni pseudo-omnicomponent phase generation
Required Python packages/modules

In [None]:
import numpy as np
from os import path
import pandas as pd
import scipy.optimize as opt
import scipy.linalg as lin 
import scipy as sp
import sys
import sympy as sym

import matplotlib.pyplot as plt

Required ENKI modules (ignore the error message from Rubicon running under Python 3.6+)

In [None]:
from thermoengine import coder, core, phases, model, equilibrate

In [None]:
def get_subsolidus_phases(database='Berman'):
    remove_phases = ['Liq','H2O']
    
    modelDB = model.Database(database)
    if database=='Stixrude':
        pure_soln_endmems = [
            'An', 'Ab', 'Spl', 'Hc', 'Fo', 'Fa', 'MgWds', 'FeWds', 'MgRwd', 
            'FeRwd', 'En', 'Fs', 'MgTs', 'oDi', 'Di', 'Hd', 'cEn', 
            'CaTs', 'Jd', 'hpcEn', 'hpcFs',  'MgAki', 'FeAki', 'AlAki', 'Prp', 
            'Alm', 'Grs', 'Maj', 'NaMaj', 'MgPrv', 'FePrv', 
            'AlPrv', 'MgPpv', 'FePpv', 'AlPpv', 'Per', 'Wus', 'MgCf', 'FeCf', 
            'NaCf']
        
        # soln_keys_Stixrude = ['Fsp', 'Ol', 'Wds', 'Rwd', 'PrvS', 'PpvS', 'Opx', 
        #                   'Cpx', 'hpCpx', 'AkiS', 'Grt', 'Fp', 'CfS', 'SplS']
        # pure_keys_Stixrude = ['CaPrv','Qz', 'Coe', 'Sti', 'Seif', 'Ky', 'Nph']
    else:
        assert False, [
            'Need to define list of pure solution endmembers to be removed '+
            'from the system, to avoid double counting.']
        

    phases = modelDB.phases
    [phases.pop(phs) for phs in remove_phases]
    [phases.pop(phs) for phs in pure_soln_endmems]
        
    return phases
        
def system_energy_landscape(T, P, phases, prune_polymorphs=True, TOL=1e-3):
    elem_comps = []
    phs_sym = []
    endmem_ids = []
    mu = []
    for phsnm in phases:
        phs = phases[phsnm]
        
        elem_comp = phs.props['element_comp']
        abbrev = phs.abbrev
        endmem_num = phs.endmember_num
        iendmem_ids = list(np.arange(endmem_num))
        
        if phs.phase_type=='pure':
            nelem = np.sum(elem_comp)
            mu += [phs.gibbs_energy(T, P)/nelem]
            # print(nelem)
        else:
            nelem = np.sum(elem_comp,axis=1)
            # print(nelem)
            for i in iendmem_ids:
                imol = np.eye(phs.endmember_num)[i]
                mu += [phs.gibbs_energy(T, P, mol=imol,deriv={"dmol":1})[0,i]/nelem[i]]
                # print(nelem[i])
                
        endmem_ids.extend(iendmem_ids)
        phs_sym.extend(list(np.tile(abbrev,endmem_num)))
        # print(elem_comp)
        
        elem_comps.extend(elem_comp)
        # print(elem_comp)
        # print(phs)
        
    elem_comps = np.vstack(elem_comps)
    
    natoms = np.sum(elem_comps,axis=1)
    elem_comps = elem_comps/natoms[:,np.newaxis]
    
    elem_mask = ~np.all(elem_comps<TOL, axis=0)
    
    elem_comps = elem_comps[:, elem_mask]
    mu = np.array(mu)
    endmem_ids = np.array(endmem_ids)
    
    sys_elems = core.chem.PERIODIC_ORDER[elem_mask]
    
    if prune_polymorphs:
        phs_sym, endmem_ids, mu, elem_comps = (
            remove_polymorphs(phs_sym, endmem_ids, mu, elem_comps))
    
    return phs_sym, endmem_ids, mu, elem_comps, sys_elems

def remove_polymorphs(phs_sym, endmem_ids, mu, elem_comps, decimals=4):
    elem_round_comps = np.round(elem_comps, decimals=decimals)
        # Drop identical comps
    elem_comps_uniq = np.unique(elem_round_comps, axis=0)
    
    # uniq_num = elem_comps_uniq.shape[0]
    mu_uniq = []
    phs_sym_uniq = []
    endmem_ids_uniq = []
    for elem_comp in elem_comps_uniq:
        is_equiv_comp = np.all(elem_round_comps == elem_comp[np.newaxis,:], axis=1)
        equiv_ind = np.where(is_equiv_comp)[0]
        min_ind = equiv_ind[np.argsort(mu[equiv_ind])[0]]
        min_mu = mu[min_ind]
        assert np.all(min_mu <= mu[equiv_ind]), 'fail'
        
        mu_uniq.append(min_mu)
        phs_sym_uniq.append(phs_sym[min_ind])
        endmem_ids_uniq.append(endmem_ids[min_ind])
        
    mu_uniq = np.array(mu_uniq)
    phs_sym_uniq = np.array(phs_sym_uniq)
    elem_comps_uniq = np.array(elem_comps_uniq)
    
    return phs_sym_uniq, endmem_ids_uniq, mu_uniq, elem_comps_uniq



In [None]:
def min_energy_assemblage(bulk_comp, comp, mu, TOLmu=10, TOL=1e-5):
    xy = np.hstack((comp, mu[:,np.newaxis]))
    yavg = np.mean(mu)
    xy_bulk = np.hstack((bulk_comp, yavg))
    
    wt0, rnorm0 = opt.nnls(xy.T, xy_bulk)
    # print('rnorm',rnorm0)
    
    
    def fun(mu, shift=0):
        xy_bulk[-1] = mu
        wt, rnorm = opt.nnls(xy.T, xy_bulk)
        return rnorm-shift
    
    
    delmu = .1
    if rnorm0==0:
        shift_dir = -1
        soln_found = True
    else:
        output = opt.minimize_scalar(fun, bounds=[np.min(mu), np.max(mu)])
        xy_bulk[-1] = output['x']
        wt0, rnorm0 = opt.nnls(xy.T, xy_bulk)
        shift_dir = -1
        
    mu_prev=xy_bulk[-1]
    rnorm=rnorm0
    
    while True:
        mu_prev = xy_bulk[-1]
        rnorm_prev = rnorm
        
        xy_bulk[-1] += shift_dir*delmu
        wt, rnorm = opt.nnls(xy.T, xy_bulk)
        delmu *= 2
        
        # print(shift_dir, rnorm)
        if ((shift_dir==+1)&(rnorm>rnorm_prev)) or ((shift_dir==-1)&(rnorm>0)):
            break
            
        
    fun_fit = lambda mu, TOL=TOL: fun(mu, shift=TOL)
    if rnorm > TOL:
        mu_bulk = opt.brentq(fun_fit, mu_prev, xy_bulk[-1], xtol=TOLmu)
        xy_bulk[-1] = mu_bulk
        wt, rnorm = opt.nnls(xy.T, xy_bulk)
        
    mu_bulk = xy_bulk[-1]
    wt_bulk = wt
        
        
    ind_assem = np.where(wt_bulk>0)[0]
    return wt_bulk, mu_bulk, ind_assem 


In [None]:
def eval_curv(comps, method, cross_term_inds):
    single_pt = False
    if comps.ndim==1:
        single_pt = True
        comps = comps[np.newaxis,:]
        
    if method=='quad':
        XiXj = comps[:, cross_term_inds[0]]*comps[:, cross_term_inds[1]]
        X2_sum = np.sum(XiXj,axis=1)
        curv_term = X2_sum
    elif method=='quad-full':
        XiXj = comps[:, cross_term_inds[0]]*comps[:, cross_term_inds[1]]
        curv_term = XiXj
    elif method=='xlogx':
        logX = np.log(comps)
        logX[comps==0] = 0
        XlogX = comps*logX
        # XlogX[comps==0] = 0
        XlogX_sum = np.sum(XlogX,axis=1)
        curv_term = XlogX_sum
    elif method=='none':
        curv_term = np.zeros((comps.shape[0],0))
    else:
        assert False, method + ' is not a valid method for eval_curv.'
        
    if single_pt:
        curv_term = curv_term[0]
    
    return curv_term

In [None]:
def init_lstsq(comps, mu, curv_method, cross_term_inds, yscl=None):
    curv_term = eval_curv(comps, curv_method, cross_term_inds)
    if curv_term.ndim==1:
        curv_term = curv_term[:,np.newaxis]
    
    print(curv_term)
        
    xobs = np.hstack((comps, curv_term))
    if yscl is None:
        yexp_scl = np.floor(np.log10(np.max(mu)-np.min(mu)))
        yscl = 10**yexp_scl
        
    yobs = mu/yscl
    
    return xobs, yobs, yscl

### T,P, parameters and options for pseudo-phase generation

In [None]:
T = 1300.0                  # K
P = 300000.0                 # bars

# T = 2300.0                  # K
# P = 200000.0                 # bars

# P = 20000.0                 # bars
# T = 1500.0                  # K
# P = 40000.0                 # bars

In [None]:
# database='Berman'
database='Stixrude'

In [None]:
# ADD extra phases (e.g. carbonates as needed here)

In [None]:
phases = get_subsolidus_phases(database=database)
phs_sym, endmem_ids, mu, elem_comps, sys_elems = system_energy_landscape(
    T, P, phases, prune_polymorphs=True)
# display(phs_sym, endmem_ids, mu, elem_comps, sys_elems)
Nelems = len(sys_elems)
Npts = mu.size

In [None]:
sys_elems

In [None]:
phases

In [None]:
phs_sym

In [None]:
def get_quad_inds(Nelems):
    ind_rows, ind_cols = np.tril_indices(Nelems,-1)
    cross_term_inds = np.vstack((ind_rows,ind_cols))
    return cross_term_inds

cross_term_inds = get_quad_inds(Nelems)
# cross_term_inds[0]

# Define bulk composition

In [None]:
wt = np.random.rand(elem_comps.shape[0])
wt = wt/np.sum(wt)
bulk_comp = np.dot(wt, elem_comps)

In [None]:
bulk_comp = np.array([0.59760393, 0.01614512, 0.04809849, 0.09232406, 0.1571301 ,
       0.03304928, 0.05565109])

# Get minimum energy assemblage

In [None]:
wt_bulk, mu_bulk, ind_assem = min_energy_assemblage(
    bulk_comp, elem_comps, mu, TOLmu=10)
comp_assem, mu_assem = elem_comps[ind_assem], mu[ind_assem]
comp_assem_avg = np.mean(comp_assem,axis=0)
X2scl = np.min(np.sum((comp_assem-comp_assem_avg)**2,axis=1))

In [None]:
np.array(phs_sym)[ind_assem]

In [None]:
wt_bulk[ind_assem]

In [None]:
from scipy import optimize

In [None]:
optimize.mi

# Fit simple quadratic (diagonal terms only) as function of endmember fractions


In [None]:
dmu_bulk = 0.1e3
# dmu_endmem = 3e3
# dmu_endmem = 8e3
dmu_endmem = 2e3

In [None]:


def fit_quad_excess_endmem(comp_assem, wt_bulk_assem, dmu_endmem):
    N_assem = comp_assem.shape[0]
    
    wt_endmem = np.eye(N_assem)
    dwt_endmem = wt_endmem - wt_bulk_assem
    # X_endmem = np.vstack((np.eye(N_assem), wt_bulk_assem))
    # mu_vals = np.hstack((np.tile(dmu_endmem, N_assem), dmu_bulk))
    mu_vals = np.tile(dmu_endmem, N_assem)
    
    # np.round(X_endmem,decimals=3)
    # xobs = np.hstack((dX_endmem, dX2_endmem))
    
    dwt2_endmem = dwt_endmem**2
    output = np.linalg.lstsq(dwt2_endmem, mu_vals, rcond=None)
    # output = np.linalg.lstsq(xobs, mu_vals, rcond=None)
    curv_endmem = output[0]
    
    return curv_endmem

In [None]:
wt_bulk_assem = wt_bulk[ind_assem]
curv_endmem = fit_quad_excess_endmem(comp_assem, wt_bulk_assem, dmu_endmem)

wt_bulk_assem

# Define quadratic surface offset
- NOTE: **dmu_endmem > dmu_bulk** MUST hold true
- dmu_bulk is offset at bulk composition
- dmu_endmem is offset at each endmember composition

In [None]:
def random_sample(curv_endmem, wt_bulk_assem, Nsamp=9000):
    N_assem = len(wt_bulk_assem)
    wt_rand = np.random.rand(Nsamp, N_assem)
    wt_rand = wt_rand/np.sum(wt_rand, axis=1)[:,np.newaxis]
    
    dwt_rand = wt_rand-wt_bulk_assem
    dwt2_rand = dwt_rand**2
    
    mu_rand = np.dot(dwt2_rand, curv_endmem)
    
    
    return wt_rand, mu_rand


In [None]:
wt_rand, mu_rand = random_sample(curv_endmem, wt_bulk_assem)
comp_rand = np.dot(wt_rand, comp_assem)

# Fit diagonal quadratic in endmember (eigen) space
* remap quadratic to elemental space
* visualize quad model in elemental space

In [None]:
def fit_quad_shift(bulk_comp, comp_rand, mu_rand, comp_assem, dmu_endmem, dmu_bulk):
    
    N_assem = comp_assem.shape[0]
    N_elems = len(bulk_comp)
    
    yobs = np.hstack((0, np.tile(dmu_endmem, N_assem), mu_rand)) + dmu_bulk
    comps = np.vstack((bulk_comp, comp_assem, comp_rand))
    
    cross_terms_ind = get_quad_inds(N_elems)
    X2 = eval_curv(comps, 'quad-full', cross_terms_ind)
    
    xobs = np.hstack((comps,X2))
    
    fit_output = np.linalg.lstsq(xobs, yobs, rcond=None)
    coef = fit_output[0]
    chem_pot = coef[:N_elems]
    quad_coef = coef[N_elems:]
    
    
    return chem_pot, quad_coef, xobs, yobs, fit_output

chem_pot, quad_coef, xobs, yobs, fit_output = fit_quad_shift(bulk_comp, comp_rand,mu_rand, comp_assem, dmu_endmem, dmu_bulk)



In [None]:
N_assem = comp_assem.shape[0]
for ind in range(Nelems):
    plt.figure()
    plt.plot(xobs[N_assem+1:,ind],yobs[N_assem+1:], 'o',color=[.5,.5,.5] )
    plt.plot(xobs[:N_assem+1, ind],yobs[:N_assem+1], 'rx', ms=8)
    plt.plot(comp_assem[:,ind],np.zeros(N_assem), 'k--x')
    plt.plot(bulk_comp[ind],0, 'kx', mew=4,ms=10)



# Calc elemental chem potentials 
* elem chempot = quadratic surface + equil assembalge

In [None]:
def fit_assemblage_chem_pot(comp_assem, mu_assem):
    fit_output_assem = np.linalg.lstsq(comp_assem, mu_assem, rcond=None)
    chem_pot_assem = fit_output_assem[0]
    
    return chem_pot_assem, fit_output_assem

chem_pot_assem, fit_output_assem = fit_assemblage_chem_pot(comp_assem, mu_assem)

In [None]:
mu_linear = chem_pot_assem+chem_pot

## Build endmembers of pseudo-phase using the coder module

In [None]:
modelCD = coder.StdStateModel()
modelCD.set_module_name('pseudo_end')

GTP = sym.symbols('GTP')
params = [('GTP','J',GTP)]
modelCD.add_expression_to_model(GTP, params)


model_working_dir = "working"
!mkdir -p {model_working_dir}
%cd {model_working_dir}

In [None]:
def standardize_formula(form):
    cmp = form.split('O')
    str = ''
    if cmp[0][-1].isdigit():
        str += cmp[0][:-1] + '(' + cmp[0][-1] + ')'
    else:
        str += cmp[0] + '(1)'
    if cmp[1] == '':
        str += 'O'
    else:
        str += 'O(' + cmp[1] + ')'
    return str

In [None]:
use_oxides_as_basis = False

In [None]:
model_type = "calib"
for ind,elm in enumerate(sys_elems):
    imu = mu_linear[ind]
    print(imu)
    if use_oxides_as_basis:
        formula = standardize_formula(elm)
    else:
        formula = elm+'(1)'
    param_dict = {'Phase':elm,'Formula':formula,'T_r':298.15,'P_r':1.0,'GTP':imu}
    print (param_dict)
    result = modelCD.create_code_module(
        phase=param_dict.pop('Phase', None),
        formula=param_dict.pop('Formula', None),
        params=param_dict, module_type=model_type, silent=True)
    print ('Component', elm, 'done!')

Build the code (ignore error messages generated by Cython regarding 'language_level')

In [None]:
import pseudo_end
%cd ..

# Create Simple Solution Coder Module

In [None]:
elm_sys=sys_elems

In [None]:
c = len(elm_sys)
c

In [None]:
modelCD = coder.SimpleSolnModel(nc=c)

In [None]:
n = modelCD.n
nT = modelCD.nT
X = n/nT

In [None]:
mu = modelCD.mu
mu

In [None]:
# Tsym = modelCD.get_symbol_for_t()

In [None]:
G_ss = (n.transpose()*mu)[0]
G_ss

In [None]:
# if curv_method=='quad-full':

curv_string = ''
quad_strs = []
for i,j in cross_term_inds.T:
    # print(i, j)
    istr = 'k_' + str(i+1) + '_' + str(j+1)
    curv_string +=  istr + ' '
    quad_strs.append(istr)
    

quad_consts = sym.Matrix(list(sym.symbols(curv_string)))

# k_curv = sym.symbols('k_curv')
# k_curv

print(quad_consts)
print(quad_coef)
print(quad_strs)

In [None]:
XiXj = np.dot(n,n.T)[cross_term_inds[0], cross_term_inds[1]]/nT**2
G_quad = nT*np.dot(XiXj, quad_consts)[0]
G_quad

# Create mu_shft for convenient modification after compiling

In [None]:
mu_shft = sym.symbols('mu_shft')
mu_shft

In [None]:
Gshft = mu_shft*nT

In [None]:
G = G_ss + G_quad + Gshft
G

In [None]:
mu_expr = G.diff(n)
print(mu_expr[0])

In [None]:
soln_params = []
soln_params.append(('mu_shft', 'J/m', mu_shft))
for iquad, iquad_str in zip(quad_consts, quad_strs):
    soln_params.append((iquad_str, 'J/m', iquad))
    
soln_params

In [None]:

# Need mu_shft here as an expression
modelCD.add_expression_to_model(G, soln_params)



In [None]:
modelCD.module = "pseudo_soln"

In [None]:
formula = ''
convert = []
test = []
if use_oxides_as_basis:
    for ind,elm in enumerate(elm_sys):
        ox_index = list(core.chem.oxide_props['oxides']).index(elm)
        ox_cat = core.chem.oxide_props['cations'][ox_index]
        formula += ox_cat + '[' + ox_cat + ']'
        ox_cat_num = core.chem.oxide_props['cat_num'][ox_index]
        if ox_cat_num > 1:
            convert.append('['+str(ind)+']=['+ox_cat+']/'+str(ox_cat_num)+'.0')
        else:
            convert.append('['+str(ind)+']=['+ox_cat+']')
        test.append('['+str(ind)+'] >= 0.0')
    formula += 'O[O]'
else:
    for ind,elm in enumerate(elm_sys):
        formula += elm + '[' + elm + ']'
        convert.append('['+str(ind)+']=['+elm+']')
        test.append('['+str(ind)+'] >= 0.0')
        # test.append('['+str(ind)+'] >= -10.0')
        
# Loosen constraint on Al as a test
# test[3] = '[3] >= -100.0'
# test[6] = '[6] >= -100.0'
# test[0] = '[0] >= 0.0'
        
formula, convert, test

In [None]:
modelCD.formula_string = formula
modelCD.conversion_string = convert
modelCD.test_string = test

In [None]:
test

In [None]:
paramValues = {'T_r':298.15,'P_r':1.0}

paramValues['mu_shft'] = 0
for iquad_val, iquad_str in zip(quad_coef, quad_strs):
    paramValues[iquad_str] = iquad_val


endmembers = []
for elm in elm_sys:
    # endmembers.append(str(elm)+'_pseudo_end_calib')
    endmembers.append(str(elm)+'_pseudo_end')

In [None]:
paramValues

# Compile and import solution phase code

In [None]:
import os
def compile_soln_phase(paramValues, endmembers, working_dir='working'):
    try:
        os.mkdir(working_dir)
    except FileExistsError:
        os.chdir(working_dir)
    except:
        assert False, 'Problem making working dir'
    
    # !mkdir -p {working_dir}
    # %cd {working_dir}
    
    modelCD.create_code_module(
        phase="PseudoPhase", params=paramValues, endmembers=endmembers, 
        prefix="cy", module_type='calib', silent=False)
    
    import pseudo_soln
    # %cd ..
    os.chdir('..')
    
    return pseudo_soln


In [None]:
pseudo_soln = compile_soln_phase(paramValues, endmembers)

In [None]:

# model_working_dir = "working"
# !mkdir -p {model_working_dir}
# %cd {model_working_dir}
# 
# modelCD.create_code_module(
#     phase="PseudoPhase", params=paramValues, endmembers=endmembers, 
#     prefix="cy", module_type='calib', silent=False)
# 
# import pseudo_soln
# %cd ..


# Use custom module

In [None]:
def get_pseudo_phase():
    modelPseudo = model.Database(database="CoderModule", calib="calib", 
                                 phase_tuple=('pseudo_soln', {'Psu':['PseudoPhase','solution']}))
    Pseudo = modelPseudo.get_phase('Psu')

    for phase_name, abbrv in zip(modelPseudo.phase_info.phase_name,modelPseudo.phase_info.abbrev):
        print ('Abbreviation: {0:<10s} Name: {1:<30s}'.format(abbrv, phase_name))

    return Pseudo

Pseudo = get_pseudo_phase()

In [None]:
vals = Pseudo.get_param_values(all_params=True)
names = np.array(Pseudo.param_names)

print(names)
print(vals)
# Pseudo.set_param_values(param_names=[2],param_values=[1.0])

In [None]:
mudiff = Pseudo.gibbs_energy(T, P, mol=bulk_comp)-mu_bulk
mudiff

In [None]:
# shift pseudo if needed

# mu_shft_val= +1e4 
# mu_shft_val= 0 
# Pseudo.set_param_values(param_names=[len(vals)-1], param_values=[mu_shft_val])
# Pseudo.get_param_values(all_params=True)

# mudiff = Pseudo.gibbs_energy(T, P, mol=bulk_comp)-mu_bulk
# mudiff

Check pseudo-phase import by printning some phase characteristics

In [None]:
print (Pseudo.props['phase_name'])
print (Pseudo.props['formula'])
print (Pseudo.props['molwt'])
print (Pseudo.props['abbrev'])
print (Pseudo.props['endmember_num'])
print (Pseudo.props['endmember_name'])

## Try the equiibrium calculations with the omnicomponent pseudo-phase
#### Choose a phase assemblage

In [None]:

phs_sys = [Pseudo]
phs_sys.extend(list(phases.values()))

# phs_sys += [phases['Fsp'], phases['Ol'], phases['Cpx'], phases['Grt']] # solutiopns,
# phs_sys += [phases['Qz'], phases['Ky'], phases['Nph']]
#
#phs_sys  = [Pseudo, stix_phases['Opx']]



In [None]:
equil = equilibrate.Equilibrate(sys_elems, phs_sys)

In [None]:
np.array(phs_sym)[ind_assem]

In [None]:
bulk_comp

In [None]:
state = equil.execute(T, P, bulk_comp=bulk_comp, debug=0)

In [None]:


state.print_state()

In [None]:
np.array(phs_sym)[ind_assem]

In [None]:
phases

In [None]:
state.tot_grams_phase('PseudoPhase')

In [None]:
state.properties()