In [1]:
import warnings
import numpy as np

import os

os.environ['KMP_DUPLICATE_LIB_OK']='True'

from importlib import reload

import wandb

from context import omphalos

# Import Omphalos modules.
from omphalos import generate_inputs as gi
from omphalos import file_methods as fm
from omphalos import my_metrics as mm
from omphalos import omphalos_plotter as op
from omphalos import attributes as attr
from omphalos import labels as lbls
from omphalos import spatial_constructor

from matplotlib import pyplot as plt

import xgboost as xgb

%matplotlib osx
%precision 4

'%.4f'

In [2]:
bst = xgb.Booster()
bst.load_model(fname='rifle_xgb.json')

In [3]:
%%time

species_list = ['NH4+', 'SO4--','Ca++', 'Acetate', 'CO2(aq)']
axis_labels = ['[NH$_4^+$] (mM)', '[SO$_4^{2-}$] (mM)','[Ca$^2+$] (mM)', '[Acetate] (mM)', '[CO$_{2(aq)}$] (mM)']

def plot_2d(ax, plot_vars):
    num = 20
    upper = 30
    lower = 0
    x1 = np.linspace(lower, upper, num) # NH4+
    x2 = np.linspace(lower, upper, num) # SO4--
    x1_mesh,x2_mesh = np.meshgrid(x1, x2)
    X = np.dstack([x1_mesh, x2_mesh]).reshape(-1, 2)

    defaults = np.zeros((num**2, 5))
    defaults[:,0] = 1.5 # NH4+
    defaults[:,1] = 8.8 # SO4--
    defaults[:,2] = 4.8 # Ca++
    defaults[:,3] = 9.7 # Acetate
    defaults[:,4] = 0.0325 # CO2(aq)

    defaults[:, plot_vars[0]] = X[:, 0]
    defaults[:, plot_vars[1]] = X[:, 1]

    vals = bst.predict(xgb.DMatrix(defaults))

    ax.plot_trisurf(X[:,0], X[:,1], vals*1e-4, cmap=plt.cm.magma, linewidth=0, antialiased=False)
    
    ax.set_xlabel(axis_labels[plot_vars[0]], fontsize=18, labelpad=20)
    ax.set_ylabel(axis_labels[plot_vars[1]], fontsize=18, labelpad=20)
    ax.set_zlabel('Net pyrite precipitation (vol. frac.)', fontsize=18, labelpad=20)
    ax.tick_params('both', labelsize=15, pad=10)
    #ax.ticklabel_format(axis='z', style='sci', scilimits=(0,0))
    ax.view_init(elev=15, azim=-128)
    return

CPU times: user 18 µs, sys: 9 µs, total: 27 µs
Wall time: 73 µs


In [6]:
#%matplotlib widget
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import string

plt.style.use('seaborn-paper')


axis_labels = ['[NH$_4^+$] (mM)', '[SO$_4^{2-}$] (mM)','[Ca$^2+$] (mM)', '[Acetate] (mM)', '[CO$_{2(g)}$] (bar)']


fig = plt.figure(figsize=(20,20))
ax1 = fig.add_subplot(251, projection='3d')
ax2 = fig.add_subplot(252, projection='3d')
ax3 = fig.add_subplot(253, projection='3d')
ax4 = fig.add_subplot(254, projection='3d')
ax5 = fig.add_subplot(255, projection='3d')
ax6 = fig.add_subplot(256, projection='3d')
ax7 = fig.add_subplot(257, projection='3d')
ax8 = fig.add_subplot(258, projection='3d')
ax9 = fig.add_subplot(259, projection='3d')
ax10 = fig.add_subplot(2,5,(10), projection='3d')

ax_list = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10]

plot_list = [(0,1), (0,2), (0,3), (0,4), (1,2), (1,3), (1,4), (2,3), (2,4), (3,4)]

for ax, i, n in zip(ax_list, plot_list, np.arange(len(plot_list))):
        plot_2d(ax, i)
        ax.text2D(-0.1, 1.1, '({})'.format(string.ascii_uppercase[n]), transform=ax.transAxes, size=20)

In [7]:
fig.savefig('/Users/angus/Dropbox/dissertation/rifle_2d.png', dpi=300)

In [74]:
%%time

#Finding best parameter set!
num = 15
upper = 30
lower = 0
x1 = np.linspace(lower, upper, num)
x2 = np.linspace(lower, upper, num)
x3 = np.linspace(lower, upper, num)
x4 = np.linspace(lower, upper, num)
x5 = np.linspace(lower, 10, num)

XX = np.stack(np.meshgrid(x1, x2, x3, x4, x5), -1).reshape(-1, 5)

vals = bst.predict(xgb.DMatrix(XX))

max_index = np.argmax(vals)

print(max_index)
print(XX[max_index])
print(vals[max_index]*1e-4)

CPU times: user 1h 15min 52s, sys: 2min 40s, total: 1h 18min 33s
Wall time: 20min 41s
