In [None]:
import numpy as np
import pandas as pd
import shap
import seaborn as sns
import warnings
np.random.seed(10)
warnings.filterwarnings("ignore")

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import font_manager as fm, rcParams
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']

In [None]:
path = './ESCC/dataset/GEO_5y_gene.csv'

In [None]:
raw_data = pd.read_csv(path)
print(raw_data.shape)
raw_data.head()

In [None]:
columns_to_drop = ['Unnamed: 0']
raw_data = raw_data.drop(columns=columns_to_drop)
raw_data.head()

In [None]:
train_data = raw_data.copy()
data_X = train_data.drop(['OS'], axis=1)
data_bin_y = train_data['OS']
data_bin_y.value_counts()

In [None]:
data_X.head()

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

In [None]:
train_x, test_x, train_bin_y, test_bin_y = train_test_split(data_X,data_bin_y,test_size=0.28) 
print(train_x.shape)
print(test_x.shape)

scaler = StandardScaler()
train_x = scaler.fit_transform(train_x)
test_x = scaler.transform(test_x)

train_x = pd.DataFrame(train_x, columns=data_X.columns)
test_x = pd.DataFrame(test_x, columns=data_X.columns)

In [None]:
param_grid = {
    'n_neighbors': [3, 5, 7, 9, 11, 13, 15, 17, 19],
    'weights': ['uniform', 'distance'],
    'metric': ['euclidean', 'manhattan', 'minkowski']
}

knn_model = KNeighborsClassifier()
grid_search = GridSearchCV(estimator=knn_model, param_grid=param_grid, cv=5, n_jobs=-1, scoring='accuracy')
grid_search.fit(train_x, train_bin_y)

best_params = grid_search.best_params_
print("Best parameters found: ", best_params)

best_knn = grid_search.best_estimator_

knn_y_pred = best_knn.predict(test_x)
print(classification_report(test_bin_y, knn_y_pred))

In [None]:
# SHAP
explainer = shap.KernelExplainer(best_knn.predict_proba, train_x)

In [None]:
shap_values_train = explainer.shap_values(train_x)

In [None]:
plt.rcParams['figure.dpi'] = 500
shap.summary_plot(shap_values_train[1], train_x)

In [None]:
# SHAP Dependence Plot
plt.rcParams['figure.dpi'] = 500
shap.dependence_plot(12, shap_values_train[1], train_x)