In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import pandas as pd
import seaborn as sns
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, FunctionTransformer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.neural_network import MLPClassifier
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap
from sklearn.model_selection import cross_val_score
import os
from typing import Counter
from sklearn.metrics import accuracy_score
import shap
from imblearn.over_sampling import SMOTE
from sklearn.metrics import roc_curve, auc



In [None]:
file_path = r"D:\cddvd\LA_knn_pyrite.xlsx"#Please enter the path to the Supplementary data 4
data = pd.read_excel(file_path)
df = data.loc[:, ["Deposit type", "Co", "Ni","Zn", "Cu", "Sb", "Pb", "Ag", "Se","As","Bi"]]
# Print class distribution of the original dataset
print("Class distribution of the original dataset:")
print(y.value_counts())
X = df.drop("Deposit type", axis=1)  # Features
y = df["Deposit type"]  # Target variable
# Check for missing values in the target variable and handle them
if y.isnull().values.any():
    y.dropna(inplace=True)
    X = X.loc[y.index]  
X.dropna(inplace=True)
y = y.loc[X.index]
print("Class distribution after removing missing values:")
print(y.value_counts())
y = pd.Categorical(y)
y = y.codes

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2)
print("Class distribution of the training set after splitting:")
print(pd.Series(y_train).value_counts())

# Cross-validation performance of the MLP classifier
models = (MLPClassifier(alpha=0.5),)

for clf in models:
    scores = cross_val_score(clf, X_train, y_train, cv=10, scoring='f1_macro', n_jobs=-1)
    print(f'{clf.__class__.__name__}: {scores.mean():2.2f}±{scores.std():2.2f}')

# Create a pipeline with a scaler and MLPClassifier
pipe_clf = Pipeline([
    ("scaler", StandardScaler()),
    ("mlpclassifier", MLPClassifier(random_state=1))
])

# Define parameter grid for hyperparameter optimization
alpha_range = np.logspace(-4, 1, 6, base=10)
param_grid = {
    "mlpclassifier__hidden_layer_sizes": [(50,), (50, 50)], 
    "mlpclassifier__solver": ['adam'], 
    "mlpclassifier__max_iter": [200],  
    "mlpclassifier__alpha": alpha_range,  
}

# Perform GridSearchCV for hyperparameter tuning
grid = GridSearchCV(
    pipe_clf, param_grid=param_grid, 
    cv=5, scoring="f1_macro", n_jobs=-1, refit=True, verbose=2
)
grid.fit(X_train, y_train)

print("The best parameters are %s with a score of %0.2f" % (grid.best_params_, grid.best_score_))
y_test_pred = grid.predict(X_test)
t_train_pred = grid.predict(X_train)
y_test_proba = grid.predict_proba(X_test)[:, 1]  

print(classification_report(y_train, t_train_pred))
print(classification_report(y_test, y_test_pred, output_dict=False))
print(confusion_matrix(y_test, y_test_pred))

report_filename = r"D:\cddvd\MLPpyrite_report.txt"  # Specify the file path and name
with open(report_filename, 'w', encoding='utf-8') as f:
    f.write("Training set report:\n")
    f.write(classification_report(y_train, t_train_pred))
    f.write("\nTest set report:\n")
    f.write(classification_report(y_test, y_test_pred, output_dict=False))
    f.write("\nConfusion Matrix:\n")
    cm = confusion_matrix(y_test, y_test_pred)
    f.write(str(cm))

print(f"Classification report and confusion matrix have been saved to: {report_filename}")

In [None]:
# Get the trained MLPClassifier instance from the grid search
mlp = grid.best_estimator_.named_steps["mlpclassifier"]

# Define a prediction function using the trained model
def predict_function(X):
    return mlp.predict_proba(X)

# Create a SHAP explainer using KernelExplainer
explainer = shap.KernelExplainer(predict_function, X_train)

# Compute SHAP values
shap_values = explainer.shap_values(X_train)

# Define custom colors for the plot
light_colors = [(0.6, 0.86, 0.88), (1, 0.9, 0.73)]
custom_cmap = ListedColormap(light_colors)

# Plot SHAP Bee Swarm plots for each class
for i, class_name in enumerate(['SEDEX', 'VMS']):
    # Convert shap_values to an Explanation object
    shap_exp = shap.Explanation(
        base_values=explainer.expected_value,
        values=shap_values[i],
        data=X_train,
        feature_names=X_train.columns
    )
    
    # Plot SHAP Bee Swarm plot for the current class with custom colors
    shap.plots.beeswarm(shap_exp, max_display=10, color=custom_cmap)  # Display top 10 features
    plt.title(f'SHAP Bee Swarm Plot for {class_name}')
    
    # Save the SHAP plot as a JPG image
    plt.savefig(f'D:/cddvd/SHAP_beeswarm_{class_name}.jpg', dpi=600, format='jpg')
    plt.close()  # Close the plot to avoid overlapping with the next one


In [None]:
# Confusion matrix
label_order = ["SEDEX", "VMS"]
cm = confusion_matrix(y_test, y_test_pred)
cm_df = pd.DataFrame(cm, columns=label_order, index=label_order)
cm_df_percentage = cm_df.div(cm_df.sum(axis=0), axis=1) * 100
plt.figure(figsize=(2.5, 2.5))
plt.rc('font', family='Times New Roman', size=8)
ax = sns.heatmap(cm_df, linewidths=.5, ax=plt.gca(), cmap="Blues")
norm = plt.Normalize(vmin=cm_df.values.min(), vmax=cm_df.values.max())
sm = plt.cm.ScalarMappable(cmap="Blues", norm=norm)
for i in range(len(cm_df)):
    for j in range(len(cm_df)):
        value = cm_df.iloc[i, j]
        percentage = cm_df_percentage.iloc[i, j]
        # Choose text color based on the background color of the cell
        color = 'white' if sm.to_rgba(value)[0:3] < (0.5, 0.5, 0.5) else 'black'
        plt.text(j + 0.5, i + 0.5, f"{value}\n{percentage:.1f}%",
                 ha='center', va='center', color=color, family='Times New Roman', size=8)
plt.title("Test set confusion matrix (MLP)", fontsize=8)
plt.xlabel("Predictions", fontsize=8)
plt.ylabel("True labels", fontsize=8)
ax.set_xticklabels(label_order, rotation=45, fontsize=8)
ax.set_yticklabels(label_order, fontsize=8)
plt.tight_layout()
plt.savefig(r'D:\cddvd\confusion_matrix_MLP_pyrite.svg', dpi=600, format='svg')
plt.savefig(r'D:\cddvd\confusion_matrix_MLP_pyrite.pdf', dpi=600, format='pdf')
plt.show()

In [None]:
print(
    """
RF classifier to predict the genetic classes of the chalcopyrite source with "Co", "Ni", "Cu","Zn", "Sb", "Pb", "Ag", "Se", "As","Bi"values,
Please enter the path of the .xlsx data file.(for example: /path/to/file/example_data.xlsx )
The data are supposed to contain all the 10 features above for prediction.
If any one of the features is missing in a sample, that sample will be discarded.
The columns' names of Co, Ni,Cu, Zn, As, Sb, Pb, Ag, Se,Bi should be exactly as listed above without any prefix and suffix
and MAKE SURE this column name row is the FIRST row.
"""
)
data_file_path = r"D:\我的论文\dongshengmiao_knn_pyrite.xlsx"#Please enter the path to the data file
df = pd.read_excel(data_file_path)
index = ['SEDEX', 'VMS']
print(df)
elements = [ "Co", "Ni","Zn", "Cu", "Sb", "Pb", "Ag", "Se","As","Bi"]

for element in elements:
    df[element] = pd.to_numeric(df[element], errors="coerce")

to_predict = df.loc[:, elements].dropna()
to_predict.reset_index(drop=True, inplace=True)
print(f"{to_predict.shape[0]} samples available")
print(to_predict.describe())
predict_res = grid.predict(to_predict)
predict_res = list(predict_res)
for i, ind in enumerate(predict_res):
    predict_res[i] = index[ind]

c: Counter[str] = Counter(predict_res)
if not c:
    input("no sample with the 10 features detected!")
    raise SystemExit()
    
proba = grid.predict_proba(to_predict)
predict_res = np.array(predict_res)
predict_res = predict_res.reshape((predict_res.shape[0], 1))
res = np.concatenate([predict_res, proba], axis=1)
res = pd.DataFrame(res, columns=['pred_pyrite_type', 'SEDEX_proba', 'VMS_proba'])
pd.set_option('display.max_columns', 10)
print('Detailed report preview:\n', res)

print("The samples are predicted respectively to be: ")
print(c.most_common(), "\n")
print(
    f"The most possible type of the group of samples is: {c.most_common(1)[0][0]}.\n"
)

if input('Save report? (y/n): ').lower() == 'y':
    base_filename = os.path.basename(data_file_path)
    prefix, _ = os.path.splitext(base_filename)
    save_name = prefix + '_resultMLPpyrite.xlsx'
    res2 = pd.concat([to_predict['Pb'], res], axis=1, )
    output = df.join(res2.set_index('Pb'), on='Pb')
    output.to_excel(save_name)
    print(f'{save_name} saved.')
input("Press any key to exit.")