In [None]:
# set working directory
import sys
import os
from pathlib import Path
sys.path.insert(0, Path(os.getcwd()).parent)
os.chdir(sys.path[0])

#  GSSA Public Data Lithology Classification Model

In [None]:
# Import Packages
import plotly.io as pio
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import seaborn as sns

## Import XRF Chemistry Data

In [None]:
pd.set_option("display.max_columns", 10)

In [None]:
# Import chemistry data and examine the first 5 rows
clean_df = pd.read_csv('data/cleaned_pivoted_chem.csv')
clean_df.iloc[:, [0, 7] + list(range(16, len(clean_df.columns)))].head()

In [None]:
clean_df.value_counts("LITHOLOGY_NAME")

# Clean data

In [None]:
# Remove records where the LITHOLOGY_NAME is absent
df = clean_df.dropna(subset=['LITHOLOGY_NAME'])
df = df.drop(['COLLECTORS_NUMBER', 'LITHO_CONF', 'LITHO_MODIFIER', 'STRAT_CONF'], axis= 1)
df.info()

# Explore data

In [None]:
# Create a dataframe of hole locations
map = df.drop_duplicates(subset=['DH_NAME'], keep='first')
map.to_csv('outputs/map.csv')

In [None]:
# Checkout the dominant class sets
Vis = df.groupby(["LITHOLOGY_NAME"]).size().sort_values(ascending=False).reset_index(name='count') 

plt.figure(figsize=(5,8))
sns.barplot(data=Vis, y="LITHOLOGY_NAME", x="count")

## Visualise an Example Striplog

In [None]:
# group by random drill hole and print out unique lithofacies
hole = df.loc[df['DH_NAME'] == 'MSDP01']
hole = hole.dropna(subset = ['LITHOLOGY_NAME'])
hole = hole.sort_values(by=['DH_DEPTH_FROM'])

print(hole['LITHOLOGY_NAME'].unique())

In [None]:
# Visualise Lithologies on a Strip Log
plt.rcParams["font.size"] = 14
# Compare the results of clustering and cuttings interp
d1 = {'Siltstone': 1, 'Dolomite Rock': 2, 'Rhyolite': 3, 'Breccia (Undiff. Origin)': 4, 'Basalt': 5, 'Sedimentary Siliciclastic Breccia': 6, 'Volcaniclastic Rock': 7,
      'Ignimbrite': 8}
colors1 = ['green', 'blue','red', 'pink', 'purple', 'blue', 'orange', 'grey', 'yellow']
cmap1 = ListedColormap(colors1)

#Create function for strip log
def striplog(hole, bottom_depth, top_depth, var1, var2):
    fig, ax = plt.subplots(figsize=(7,6))

    #Set up the plot axes
    ax1 = plt.subplot2grid((1,3), (0,0), rowspan=1, colspan = 1)
    ax2 = plt.subplot2grid((1,3), (0,1), rowspan=1, colspan = 1, sharey = ax1)
    ax3 = plt.subplot2grid((1,3), (0,2), rowspan=1, colspan = 1, sharey = ax1)

    # MSDP01 Ti track
    ax1.plot(hole[var1], hole['DH_DEPTH_FROM'], color = "black", linewidth = 1.0)
    ax1.set_xlabel(var1)
    ax1.set_xlim(0, 15000)
    ax1.xaxis.label.set_color("black")
    ax1.tick_params(axis='x', colors="black")
    ax1.spines["top"].set_edgecolor("black")
    ax1.set_xticks(np.arange(0, 15000, 10000))

    # MSDP01 P track
    ax2.plot(hole[var2], hole['DH_DEPTH_FROM'], color = "black", linewidth = 1.0)
    ax2.set_xlabel(var2)
    ax2.set_xlim(0, 5000)
    ax2.xaxis.label.set_color("black")
    ax2.tick_params(axis='x', colors="black")
    ax2.spines["top"].set_edgecolor("black")
    ax2.set_xticks(np.arange(0, 5000, 2000))

    # Lithofacies
    ax3.pcolormesh([-1, 3], hole['DH_DEPTH_FROM'], hole['LITHOLOGY_NAME'][:-1].map(d1).to_numpy().reshape(-1, 1),
              cmap=cmap1, vmin=1, vmax=len(colors1))
    ax3.set_xticks([])
    ax3.set_aspect(0.01)
    ax3.set_xlabel("Logged Lithofacies")

    for ax in [ax1, ax2, ax3]:
        ax.set_ylim(bottom_depth, top_depth)
        ax.grid(which='major', color='lightgrey', linestyle='-', axis="y")
        ax.xaxis.set_ticks_position("bottom")
        ax.xaxis.set_label_position("top")
        ax.spines["top"].set_position(("axes", 1.02))
        
    for ax in [ax2, ax3]:
        plt.setp(ax.get_yticklabels(), visible = False)
       
    plt.tight_layout()
    fig.subplots_adjust(wspace = 0.15)

striplog(hole, 1120, 20, "Ti_ppm", "P_ppm")

plt.savefig('outputs/MSDP01.png', transparent=True)
plt.show()

## Examine Class Distribution in Dataset

In [None]:
# binarize the Dependant Variable
bin = df.copy()
bin['LITHOLOGY_NAME'] = (bin['LITHOLOGY_NAME'] == 'Basalt').astype(int)

# Define X and y
X = bin.loc[:,'Ag_ppm':'Zr_ppm']
y = bin['LITHOLOGY_NAME']

In [None]:
X.columns

In [None]:
plt.rcParams["font.size"] = 12
# visualise labels for each drill-hole
grpd = bin.groupby(['DH_NAME','LITHOLOGY_NAME']).size().reset_index().rename(columns={0: 'Sum'})
grpd["Lithology"] = grpd["LITHOLOGY_NAME"].apply(lambda x: "Basalt" if x == 1 else "Non-basalt")
# grouped barplot
# sns.barplot(x="DH_NAME", y="Sum", hue="LITHOLOGY_NAME", data=grpd, dodge = True)

# One liner to create a stacked bar chart.
fig, ax = plt.subplots(figsize = (5,4))
sns.histplot(grpd, x='DH_NAME', hue='Lithology', weights='Sum',
             multiple='stack', shrink=0.8, ax=ax, palette={"Basalt": "purple", "Non-basalt": "lightgrey"})
# Fix the legend so it's not on top of the bars.
legend = ax.get_legend()
legend.set_bbox_to_anchor((1, 1))
plt.xticks(rotation=45)
plt.xlabel("Drillhole name")
plt.ylabel("Number of samples")
plt.title("Class distribution per hole")
plt.tight_layout()
plt.savefig('outputs/Class_imbalance.png', transparent=True)


## Investigate chemistry of mafics

In [None]:
plt.rcParams["font.size"] = 12

df["Lithology_simplified"] = df["LITHOLOGY_NAME"].apply(lambda x: x if x in ["Basalt", "Meta-Mafic Igneous Rock", "Dolerite", "Gabbro"] else "Other")
df["Lithology_simplified"] = df["Lithology_simplified"].str.replace("Meta-Mafic Igneous Rock", "Meta-Mafic Igneous")
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

plt.sca(ax[0])
sns.scatterplot(data=df.sort_values("Lithology_simplified", ascending=False), x="Zr_ppm", y="Ti_ppm", hue="Lithology_simplified", 
                palette={"Other": "lightgrey", "Basalt": "purple", "Rhyolite": "red", "Dacite": "green", "Meta-Mafic Igneous": "orange", "Dolerite": "skyblue", "Gabbro": "skyblue"},
                legend=False)
plt.xscale("linear")
plt.yscale("linear")

plt.sca(ax[1])
sns.scatterplot(data=df.sort_values("Lithology_simplified", ascending=False), x="Y_ppm", y="Nb_ppm", hue="Lithology_simplified", 
                palette={"Other": "lightgrey", "Basalt": "purple", "Rhyolite": "red", "Dacite": "green", "Meta-Mafic Igneous": "orange", "Dolerite": "skyblue", "Gabbro": "skyblue"})
plt.xscale("linear")
plt.yscale("linear")
plt.xlim(0,100)
legend = ax[1].get_legend()
legend.set_title(None)

plt.savefig("outputs/chemistry_scatterplots.png", transparent = True)
plt.show()


In [None]:
plt.rcParams["font.size"] = 12

df["Lithology_simplified"] = df["LITHOLOGY_NAME"].apply(lambda x: x if x in ["Basalt", "Meta-Mafic Igneous Rock", "Dolerite", "Gabbro"] else "Other")
df["Lithology_simplified"] = df["Lithology_simplified"].str.replace("Meta-Mafic Igneous Rock", "Meta-Mafic Igneous")
df["Hole_of_interest"] = df["DH_NAME"].apply(lambda x: x if x == "MSDP02" else "Other holes")
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

plt.sca(ax[0])
sns.scatterplot(data=df.sort_values("Lithology_simplified", ascending=False), x="Zr_ppm", y="Ti_ppm", hue="Lithology_simplified", style="Hole_of_interest",
                palette={"Other": "lightgrey", "Basalt": "purple", "Rhyolite": "red", "Dacite": "green", "Meta-Mafic Igneous": "orange", "Dolerite": "skyblue", "Gabbro": "skyblue"},
                legend=False)
plt.xscale("linear")
plt.yscale("linear")

plt.sca(ax[1])
sns.scatterplot(data=df.sort_values("Lithology_simplified", ascending=False), x="Y_ppm", y="Nb_ppm", hue="Lithology_simplified", style="Hole_of_interest",
                palette={"Other": "lightgrey", "Basalt": "purple", "Rhyolite": "red", "Dacite": "green", "Meta-Mafic Igneous": "orange", "Dolerite": "skyblue", "Gabbro": "skyblue"})
plt.xscale("linear")
plt.yscale("linear")
plt.xlim(0,100)
legend = ax[1].get_legend()
legend.set_title(None)

plt.savefig("outputs/chemistry_scatterplots_w_hole_marked.png", transparent = True)
plt.show()


# Run an initial model with a random train/test split

In [None]:
plt.rcParams["font.size"] = 16
# run a random forest model with no sampling process
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import matthews_corrcoef, balanced_accuracy_score, roc_auc_score, accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Invoke Random Train Test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=42)


RF = RandomForestClassifier(random_state=42) # Create classifier
RF.fit(X_train, y_train) # train the model
y_pred_train = RF.predict(X_train) # generate predictions for training set
y_pred_RF = RF.predict(X_test) # generate predictions for test set

# Calculate the Performance Metrics
print("Test Set Accuracy: {:.3f} %".format(accuracy_score(y_test, y_pred_RF)))
print("Test Set Balanced Accuracy: {:.3f} %".format(balanced_accuracy_score(y_test, y_pred_RF)))
print("Test Set Matthews Correlation Coefficient: {:.3f} %".format(matthews_corrcoef(y_test, y_pred_RF)))
print("Test Set ROC AUC: {:.3f} %".format(roc_auc_score(y_test, y_pred_RF)))

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred_RF, labels=RF.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=RF.classes_)
disp.plot()

plt.show()

## Test model performance by holding out a single hole

In [None]:
# Create function for holding out individual holes
def hold_out(hole_id):
    hole = bin.loc[df['DH_NAME'] == hole_id]
    train = bin.loc[df['DH_NAME'] != hole_id]

    return(hole, train)

# Create list of holes 
holes = df['DH_NAME'].unique()

In [None]:
# Define train and test example for a single hole
test, train = hold_out('MSDP01')

# Create X and y variables
X_test = test.loc[:,'Ag_ppm':'Zr_ppm']
X_train = train.loc[:,'Ag_ppm':'Zr_ppm']

y_test = test['LITHOLOGY_NAME']
y_train = train['LITHOLOGY_NAME']

# Show imbalnced class sets for MSDP01
train.groupby(['LITHOLOGY_NAME']).size().plot.bar()

In [None]:
# run a random forest model with no sampling process
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import matthews_corrcoef, balanced_accuracy_score, roc_auc_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

RF = RandomForestClassifier(random_state=42) # Create classifier
RF.fit(X_train, y_train) # train the model
y_pred_train = RF.predict(X_train) # generate predictions for training set
y_pred_RF = RF.predict(X_test) # generate predictions for test set

# Calculate the Performance Metrics
print("Test Set Balanced Accuracy: {:.3f} %".format(balanced_accuracy_score(y_test, y_pred_RF)))
print("Test Set Matthews Correlation Coefficient: {:.3f} %".format(matthews_corrcoef(y_test, y_pred_RF)))
print("Test Set ROC AUC: {:.3f} %".format(roc_auc_score(y_test, y_pred_RF)))

# Confusion Matrix
plt.rcParams["font.size"] = 18

disp = ConfusionMatrixDisplay.from_predictions(
    y_true=y_test, 
    y_pred=y_pred_RF,
    labels=RF.classes_, 
    display_labels=["Non-basalt", "Basalt"],
    colorbar=False,
    text_kw={"fontsize": 20})

plt.tight_layout()
plt.savefig("outputs/confusion_baseline.png", transparent=True)

In [None]:
# Add predictions back to the test set
test['RF _Predictions'] = y_pred_RF
hole = test.copy()

hole = hole.sort_values(by=['DH_DEPTH_FROM'])

In [None]:
# Visualise striplot of predicitions
# Compare the results of clustering and cuttings interp
plt.rcParams["font.size"] = 12

d1 = {0: 1, 1: 2}
colors1 = ['grey', 'purple']
cmap1 = ListedColormap(colors1)

d2 = {0: 1, 1: 2}
colors2 = ['grey', 'purple']
cmap2 = ListedColormap(colors2)

def striplog(hole, bottom_depth, top_depth):
    fig, ax = plt.subplots(figsize=(7,6))

    #Set up the plot axes
    ax1 = plt.subplot2grid((1,3), (0,0), rowspan=1, colspan = 1)
    ax2 = plt.subplot2grid((1,3), (0,1), rowspan=1, colspan = 1, sharey = ax1)
    ax3 = plt.subplot2grid((1,3), (0,2), rowspan=1, colspan = 1, sharey = ax1)

    # MSDP01 Ti track
    ax1.plot(hole["Ti_ppm"], hole['DH_DEPTH_FROM'], color = "black", linewidth = 1.0)
    ax1.set_xlabel("Ti_ppm")
    ax1.set_xlim(0, 15000)
    ax1.xaxis.label.set_color("black")
    ax1.tick_params(axis='x', colors="black")
    ax1.spines["top"].set_edgecolor("black")
    ax1.set_xticks(np.arange(0, 15000, 10000))

    # Lithofacies
    ax2.pcolormesh([-1, 2], hole['DH_DEPTH_FROM'], hole['LITHOLOGY_NAME'][:-1].map(d1).to_numpy().reshape(-1, 1),
              cmap=cmap1, vmin=1, vmax=len(colors1))
    ax2.set_xticks([])
    ax2.set_aspect(0.01)
    ax2.set_xlabel("MSDP01 Logged")

    # Predictions
    ax3.pcolormesh([-1, 2], hole['DH_DEPTH_FROM'], hole['RF _Predictions'][:-1].map(d2).to_numpy().reshape(-1, 1),
              cmap=cmap2, vmin=1, vmax=len(colors2))
    ax3.set_xticks([])
    ax3.set_aspect(0.01)
    ax3.set_xlabel("MSDP01 Predicted")

    for ax in [ax1, ax2, ax3]:
        ax.set_ylim(bottom_depth, top_depth)
        ax.grid(which='major', color='lightgrey', linestyle='-', axis="y")
        ax.xaxis.set_ticks_position("bottom")
        ax.xaxis.set_label_position("top")
        ax.spines["top"].set_position(("axes", 1.02))
        
    for ax in [ax2, ax3]:
        plt.setp(ax.get_yticklabels(), visible = False)
       
    plt.tight_layout()
    fig.subplots_adjust(wspace = 0.15)

striplog(hole, 1120, 20)

plt.savefig('outputs/MSDP01_Predictions.png', transparent = True)
plt.show()


# Attempt Workflow using Synthetic Oversampling

In [None]:
# Invoke ADYSN oversampling to balance classes (ref)
from imblearn.over_sampling import ADASYN
from collections import Counter

ada = ADASYN(random_state=42)

X_res, y_res = ada.fit_resample(X_train, y_train)
print('Resampled dataset shape %s' % Counter(y_res))

## Make plots to understand effect of ADYSN

In [None]:
plt.rcParams["font.size"] = 12

def adysn_histograms(X, y, X_res, y_res, elem, logx=False):

    data_orig = pd.concat([X, y], axis=1).assign(Lithology=y.apply(lambda x: "Basalt" if x == 1 else "Non-basalt")).sort_values("Lithology")
    data_adysn = pd.concat([X_res, y_res], axis=1).assign(Lithology=y_res.apply(lambda x: "Basalt" if x == 1 else "Non-basalt")).sort_values("Lithology")

    if logx:
        data_orig["log " + elem] = np.log10(data_orig[elem])
        data_adysn["log " + elem] = np.log10(data_adysn[elem])
        elem = "log " + elem

    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(5,5), sharex=True, sharey=True)
    sns.histplot(
        data = data_orig, 
        x=elem, 
        hue="Lithology", 
        bins=60, 
        ax=ax[0], 
        multiple='stack',
        palette={"Basalt": "purple", "Non-basalt": "lightgrey"}
    )
    sns.histplot(
        data = data_adysn, 
        x=elem, 
        hue="Lithology", 
        bins=60, 
        ax=ax[1], 
        multiple='stack', 
        palette={"Basalt": "purple", "Non-basalt": "lightgrey"},
        legend=False)
    # Fix the legend so it's not on top of the bars.
    # legend = ax[0].get_legend()
    # legend.set_bbox_to_anchor((1, 1))
    plt.xlabel(elem)
    ax[0].set_title("Original data")
    ax[1].set_title("After ADYSN")
    plt.tight_layout()

    return(fig)


In [None]:
plt.rcParams["font.size"] = 12

def adysn_scatter(X, y, X_res, y_res, x_elem, y_elem, logx=False, logy=False):

    data_orig = pd.concat([X, y], axis=1).assign(Lithology=y.apply(lambda x: "Basalt" if x == 1 else "Non-basalt")).sort_values("Lithology", ascending=False)
    data_adysn = pd.concat([X_res, y_res], axis=1).assign(Lithology=y_res.apply(lambda x: "Basalt" if x == 1 else "Non-basalt")).sort_values("Lithology", ascending=False)

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10,5), sharex=True, sharey=True)
    sns.scatterplot(
        data = data_orig, 
        x=x_elem, 
        y=y_elem,
        hue="Lithology", 
        ax=ax[0], 
        palette={"Basalt": "purple", "Non-basalt": "lightgrey"}
    )
    sns.scatterplot(
        data = data_adysn, 
        x=x_elem, 
        y=y_elem,
        hue="Lithology", 
        ax=ax[1], 
        palette={"Basalt": "purple", "Non-basalt": "lightgrey"},
        legend=False)
    # Fix the legend so it's not on top of the bars.
    # legend = ax[0].get_legend()
    # legend.set_bbox_to_anchor((1, 1))
    ax[0].set_title("Original data")
    ax[1].set_title("After ADYSN")
    if logx:
        ax[0].set_xscale("log")
        ax[1].set_xscale("log")
    if logy:
        ax[0].set_yscale("log")
        ax[1].set_yscale("log")
    plt.tight_layout()

    return(fig)

In [None]:
fig = adysn_histograms(X_train, y_train, X_res, y_res, "P_ppm")
plt.savefig("outputs/histograms_adysn_P.png", transparent=True)
plt.show()

In [None]:
fig = adysn_histograms(X_train, y_train, X_res, y_res, "Ti_ppm")
plt.savefig("outputs/histograms_adysn_Ti.png", transparent=True)
plt.show()

In [None]:
fig = adysn_scatter(X_train, y_train, X_res, y_res, "Zr_ppm", "Ti_ppm")
plt.savefig("outputs/scatter_adysn_Ti.png", transparent=True)
plt.show()

In [None]:
fig = adysn_histograms(X_train, y_train, X_res, y_res, "Be_ppm", logx=True)
plt.savefig("outputs/histograms_adysn_Be.png", transparent=True)
plt.show()

In [None]:
fig = adysn_scatter(X_train, y_train, X_res, y_res, "Be_ppm", "Ti_ppm", logx=True)
plt.savefig("outputs/scatter_adysn_Ti-Be.png", transparent=True)
plt.show()

In [None]:
# Re run model with Synthetic Oversampling

RF_ada = RandomForestClassifier(random_state=42) # Create classifier
RF_ada.fit(X_res, y_res) # train the model
y_pred_train = RF_ada.predict(X_res) # generate predictions for training set
y_pred_RF = RF_ada.predict(X_test) # generate predictions for test set

# Calculate the Performance Metrics
print("Test Set Balanced Accuracy: {:.3f} %".format(balanced_accuracy_score(y_test, y_pred_RF)))
print("Test Set Matthews Correlation Coefficient: {:.3f} %".format(matthews_corrcoef(y_test, y_pred_RF)))
print("Test Set ROC AUC: {:.3f} %".format(roc_auc_score(y_test, y_pred_RF)))

# Confusion Matrix
disp = ConfusionMatrixDisplay.from_predictions(
    y_true=y_test, 
    y_pred=y_pred_RF,
    labels=RF_ada.classes_, 
    display_labels=["Non-basalt", "Basalt"],
    colorbar=False,
    text_kw={"fontsize": 20})

plt.tight_layout()
plt.savefig("outputs/confusion_adysn.png", transparent=True)

# Recursive Feature Elimination

(This section is commented out because it was removed from the worshop)

In [None]:
# # Removing Correlated Features
# corr_matrix = X_res.corr().abs()

# # Select upper triangle of correlation matrix
# upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))

# # Get features with correlations of greater than 0.8
# to_drop = [column for column in upper.columns if any(upper[column] > 0.8)]

# # Drop features 
# X_filt = X_res.copy()
# X_filt.drop(to_drop, axis=1)

# X_filt_test = X_test.copy()
# X_filt_test.drop(to_drop, axis=1)

In [None]:
# # implement RFE with stratified sampling
# from sklearn.feature_selection import RFECV
# from sklearn.model_selection import StratifiedKFold

# rfc = RandomForestClassifier(random_state=42)
# rfecv = RFECV(estimator= rfc, step=1, cv=StratifiedKFold(3), scoring='roc_auc')
# rfecv.fit(X_filt, y_res)

In [None]:
# # Visualise the affect on performance of RFE
# print('Optimal number of features: {}'.format(rfecv.n_features_))

# # Plot the accuracy obtained per each number of feature used
# plt.figure(figsize=(16, 9))
# plt.title('Recursive Feature Elimination with Cross-Validation', fontsize=16, fontweight='bold', pad=25)
# plt.xlabel('No. Features', fontsize=16, labelpad=20)
# plt.ylabel('ROC AUC', fontsize=14, labelpad=20)
# plt.plot(range(1, len(rfecv.cv_results_['mean_test_score']) + 1), rfecv.cv_results_['mean_test_score'], color='black', linewidth=2)

# plt.savefig('outputs/RFE_Performance')
# plt.show()

In [None]:
# # Remove unwanted features
# print(np.where(rfecv.support_ == False)[0])

# # Filter variables
# X_filt = X_res.drop(X_res.columns[np.where(rfecv.support_ == False)[0]], axis=1)
# X_test_filt = X_test.drop(X_test.columns[np.where(rfecv.support_ == False)[0]], axis=1)

In [None]:
# # Visualise the Feature Importances

# dset = pd.DataFrame()
# dset['attr'] = X_filt.columns
# dset['importance'] = rfecv.estimator_.feature_importances_

# dset = dset.sort_values(by='importance', ascending=False)


# plt.figure(figsize=(16, 14))
# plt.barh(y=dset['attr'], width=dset['importance'], color='#1976D2')
# plt.title('RFECV - Feature Importance', fontsize=20, fontweight='bold', pad=20)
# plt.xlabel('Importance', fontsize=14, labelpad=20)

# plt.savefig('outputs/RFE_Performance.png')
# plt.show()

In [None]:
# Retrain and Test the model with RFE nd Cross Val

# RF_rfe = RandomForestClassifier(random_state=42) # Create classifier
# RF_rfe.fit(X_filt, y_res) # train the model
# y_pred_train = RF_rfe.predict(X_filt) # generate predictions for training set
# y_pred_RF = RF_rfe.predict(X_test_filt) # generate predictions for test set

# # Calculate the Performance Metrics
# print("Test Set Balanced Accuracy: {:.3f} %".format(balanced_accuracy_score(y_test, y_pred_RF)))
# print("Test Set Matthews Correlation Coefficient: {:.3f} %".format(matthews_corrcoef(y_test, y_pred_RF)))
# print("Test Set ROC AUC: {:.3f} %".format(roc_auc_score(y_test, y_pred_RF)))

# # Confusion Matrix
# cm = confusion_matrix(y_test, y_pred_RF, labels=RF.classes_)
# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=RF.classes_)
# disp.plot()

# plt.show()

# Model Explainability, Feature Importance

In [None]:
import shap
shap.initjs()

For the basic model

In [None]:
# Fits the explainer
explainer = shap.Explainer(RF.predict, X_test)
# Calculates the SHAP values - It takes some time
shap_values = explainer(X_test)

In [None]:
fig = shap.summary_plot(shap_values, plot_type='dot', max_display=10, show=False, plot_size=(5,5))
plt.title("Model without ADYSN")
plt.savefig("outputs/shap_summary_normal.png", transparent=True)
plt.show()

And for the ADYSN model

In [None]:
# Fits the explainer
explainer_ada = shap.Explainer(RF_ada.predict, X_test)
# Calculates the SHAP values - It takes some time
shap_values_ada = explainer(X_test)

In [None]:
fig = shap.summary_plot(shap_values_ada, plot_type='dot', max_display=10, show=False, plot_size=(5,5))
plt.title("Model with ADYSN")
plt.savefig("outputs/shap_summary_adysn.png", transparent=True)
plt.show()

In [None]:
# Waterfall plot for one record
shap.plots.waterfall(shap_values[30])

In [None]:
# Force plot for one record
shap.plots.force(shap_values[30])

# Design a Spatial Validation Approach

In [None]:
print(holes)

In [None]:
# Create a loop to generate predicts for every hole
from sklearn.metrics import matthews_corrcoef, balanced_accuracy_score, roc_auc_score, f1_score, recall_score 
import statistics

# Create lists to append to
ROC = []
f1 = []
acc = []
rec = []
probs = []
hole = []

# Create a list of Validation holes (those that contain basalt)
Validation = ["MSDP01", "MSDP02", "MSDP03", "MSDP04"]

for h in holes:
    test, train = hold_out(h)

    # split out variables
    X_test = test.loc[:,'Ag_ppm':'Zr_ppm']
    X_train = train.loc[:,'Ag_ppm':'Zr_ppm']

    y_test = test['LITHOLOGY_NAME']
    y_train = train['LITHOLOGY_NAME']

    # Resample using ADYSN
    X_res, y_res = ada.fit_resample(X_train, y_train)
    
    # train model and run inference
    RF = RandomForestClassifier(random_state=42) # Create classifier
    RF.fit(X_res, y_res) # train the model
    y_pred_train = RF.predict(X_res) # generate predictions for training set
    y_pred_RF = RF.predict(X_test) # generate predictions for test set
    
    probs.append(RF.predict_proba(X_test)[:,1]) # generate probabilities for each prediction

    # get name of hole for each iteration 
    hole.append(h)


    # append performance metrics
    acc.append(balanced_accuracy_score(y_test, y_pred_RF))

# Print Average Performance Metrics
print("Test Set Accuracy: {:.3f} %"  .format(statistics.mean(acc)*100))
print("Accuracy Standard Deviation: {:.3f} %"  .format(statistics.stdev(acc)*100))

In [None]:
probs_dict = dict(zip(hole, probs))

In [None]:
histdata = pd.DataFrame.from_records(dict(zip(holes, acc)), index=["Avg accuracy"]).T.reset_index(names="DH_NAME")

plt.rcParams["font.size"] = 12
fig, ax = plt.subplots(figsize = (5,4))
sns.histplot(histdata, x='DH_NAME', weights='Avg accuracy', shrink=0.8, ax=ax, color="skyblue")
plt.yticks(ticks=np.arange(0, 1.2, 0.2), labels = np.arange(0, 120, 20))
plt.xticks(rotation=60)
plt.xlabel("Drillhole name")
plt.ylabel("Avg prediction accuracy")
plt.title("Avg accuracy for each held-out hole")
plt.tight_layout()
plt.savefig('outputs/accuracy_per_hole.png', transparent=True)



## Investigate predictions

In [None]:
#Create function for strip log
def striplog_2(hole, bottom_depth, top_depth, var1, var2, d1, colors1, cmap1, lith_var="LITHOLOGY_NAME"):
    fig, ax = plt.subplots(figsize=(7,5))

    #Set up the plot axes
    ax1 = plt.subplot2grid((1,3), (0,0), rowspan=1, colspan = 1)
    ax2 = plt.subplot2grid((1,3), (0,1), rowspan=1, colspan = 1, sharey = ax1)
    ax3 = plt.subplot2grid((1,3), (0,2), rowspan=1, colspan = 1, sharey = ax1)

    # MSDP01 Ti track
    ax1.plot(hole[var1], hole['DH_DEPTH_FROM'], color = "black", linewidth = 1.0)
    ax1.set_xlabel(var1)
    # ax1.set_xlim(0, 5000)
    ax1.xaxis.label.set_color("black")
    ax1.tick_params(axis='x', colors="black")
    ax1.spines["top"].set_edgecolor("black")
    # ax1.set_xticks(np.arange(0, 5000, 2000))

    # MSDP01 P track
    ax2.plot(hole[var2], hole['DH_DEPTH_FROM'], color = "black", linewidth = 1.0)
    ax2.set_xlabel(var2)
    ax2.set_xlim(0, 1)
    ax2.xaxis.label.set_color("black")
    ax2.tick_params(axis='x', colors="black")
    ax2.spines["top"].set_edgecolor("black")
    ax2.set_xticks([0, 0.5, 1])

    # Lithofacies
    ax3.pcolormesh([-1, 1], hole['DH_DEPTH_FROM'], hole[lith_var][:-1].map(d1).to_numpy().reshape(-1, 1),
              cmap=cmap1, vmin=1, vmax=len(colors1))
    ax3.set_xticks([])
    ax3.set_aspect(0.01)
    ax3.set_xlabel("Logged Lithofacies")

    for ax in [ax1, ax2, ax3]:
        ax.set_ylim(bottom_depth, top_depth)
        ax.grid(which='major', color='lightgrey', linestyle='-', axis="y")
        ax.xaxis.set_ticks_position("bottom")
        ax.xaxis.set_label_position("top")
        ax.spines["top"].set_position(("axes", 1.02))
        
    for ax in [ax2, ax3]:
        plt.setp(ax.get_yticklabels(), visible = False)
       
    fig.subplots_adjust(wspace = 0.15)
    plt.tight_layout()
    return(fig)

In [None]:
# group by MSDP02
hole = df.loc[df['DH_NAME'] == 'MSDP02']
hole = hole.dropna(subset = ['LITHOLOGY_NAME'])
hole = hole.sort_values(by=['DH_DEPTH_FROM'])
hole['MSDP02_Preds'] = probs[1]

print(hole['LITHOLOGY_NAME'].unique())

# Visualise Lithologies and Prdictions on Low Proability Hole
plt.rcParams["font.size"] = 12
# Compare the results of clustering and cuttings interp
d1 = {'Siltstone': 1, 'Sandstone': 2, 'Conglomerate': 3, 'Dolomite Rock': 4, 'Autobreccia': 5, 'Basalt': 6}
colors1 = ['green', 'orange','yellow', 'blue', 'red', 'purple']
cmap1 = ListedColormap(colors1)

striplog_2(hole, 560, 60, "V_ppm", "MSDP02_Preds", d1, colors1, cmap1)

plt.savefig('outputs/MSDP02.png', transparent=True)
plt.show()

In [None]:
# Add average predicted basalt probabilities from model to a dataframe 
av_prob = []
for h in probs:
    av_prob.append(np.mean(h))

map['Basalt_prob'] = av_prob

map.to_csv('outputs/map.csv')

In [None]:
# Show differences in the most important feature across holes
df.groupby(["DH_NAME"])["Sc_ppm", "Be_ppm", "Co_ppm"].mean().plot.bar()

In [None]:
# Show differences in the most important feature across holes
df.groupby(["DH_NAME"])["Ti_ppm", "V_ppm"].mean().plot.bar()

In [None]:
plt.rcParams["font.size"] = 14

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 5))
data_for_plot = df.sort_values("Lithology_simplified", ascending=False)
data_for_plot["Hole_of_interest"] = data_for_plot["DH_NAME"].apply(lambda x: x if x == "MSDP02" else "Other holes")
plt.sca(ax)
sns.scatterplot(data=data_for_plot, 
                x="Zr_ppm", y="V_ppm", hue="Lithology_simplified", style="Hole_of_interest",
                palette={"Other": "lightgrey", "Basalt": "purple", "Rhyolite": "red", "Dacite": "green", "Meta-Mafic Igneous": "orange", "Dolerite": "skyblue", "Gabbro": "skyblue"},
                legend=True)
plt.xscale("linear")
plt.yscale("linear")
legend = ax.get_legend()
legend.set_title(None)
# legend.set_bbox_to_anchor((0.5, 0.5))
plt.legend(fontsize=12)
plt.tight_layout()
plt.savefig("outputs/chemistry_scatterplots_w_hole_marked.png", transparent = True)
plt.show()

In [None]:
# group by MSDP12
hole = df.loc[df['DH_NAME'] == 'MSDP12']
hole = hole.dropna(subset = ['LITHOLOGY_NAME'])
hole = hole.sort_values(by=['DH_DEPTH_FROM'])
hole['MSDP12_Preds'] = probs[11]

print(hole['LITHOLOGY_NAME'].unique())

# Visualise Lithologies and Prdictions on Low Proability Hole

# Compare the results of clustering and cuttings interp
d1 = {'Amphibolite': 1, 'Augen Gneiss': 2, 'Quartz Rock (Metasomatic)': 3, 'Orthogneiss': 4, 'Meta-Mafic Igneous Rock': 5, 'Pegmatite (No Composition)': 6,
      'Metagranite': 7, 'Fault Breccia': 8, 'Dacite': 9}
colors1 = ['lightgrey', 'lightgrey','lightgrey', 'lightgrey', 'orange', 'lightgrey', 'lightgrey', 'lightgrey', 'lightgrey', 'lightgrey']
cmap1 = ListedColormap(colors1)

striplog_2(hole, 500, 0, "Ti_ppm", "MSDP12_Preds", d1, colors1, cmap1)

plt.savefig('outputs/MSDP12.png', transparent=True)
plt.show()

In [None]:
# group by MSDP11
hole = df.loc[df['DH_NAME'] == 'MSDP11']
hole = hole.dropna(subset = ['LITHOLOGY_NAME'])
hole = hole.sort_values(by=['DH_DEPTH_FROM'])
hole['MSDP11_Preds'] = probs_dict["MSDP11"]

print(hole['LITHOLOGY_NAME'].unique())

# Visualise Lithologies and Prdictions on Low Proability Hole

# Compare the results of clustering and cuttings interp
d1 = {'Saprolite':1, 'Metagranite':2, 'Pegmatite (No Composition)':3,
       'Meta-Mafic Igneous Rock':4, 'Dolerite':5, 'Rhyolite':6, 'Porphyry':7,
       'Magnetite-Rich Rock (Metasomatic)':8, 'Skarn':9, 'Gneiss':10}
colors1 = ['lightgrey', 'lightgrey', 'lightgrey', 'orange', 'skyblue', 'lightgrey', 'lightgrey', 'lightgrey','lightgrey','lightgrey']
cmap1 = ListedColormap(colors1)

striplog_2(hole, 500, 0, "Ti_ppm", "MSDP11_Preds", d1, colors1, cmap1)

plt.savefig('outputs/MSDP11.png', transparent=True)
plt.show()