# Classification model for DLBCL subtypes

In [None]:
# Python一般
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib
import scipy

from sklearn.metrics import accuracy_score, brier_score_loss, roc_auc_score, roc_curve, precision_recall_fscore_support
import torch
import shap

import module.utils as utils
import module.models as models

%precision 4

ROOT= "/".join(os.getcwd().split("/")[:-1])
DIR_LOG = os.path.join(ROOT, "LOG")

In [None]:
# Data File
gmt_path = os.path.join(ROOT, "data", "GSEA", "geneset", "c2.cp.kegg.v7.3.symbols.gmt")
gml_path = os.path.join(ROOT, "data", "Graphml", "kegg.graphml")
ensembl_path = os.path.join(ROOT, "data", "GSEA", "geneset", "Human_ENSEMBL_Gene_ID_MSigDB.v7.3.chip")
entrez_path = os.path.join(ROOT, "data", "GSEA", "geneset", "Human_NCBI_Entrez_Gene_ID_MSigDB.v7.2.chip")
expression_path = os.path.join(ROOT, "data", "GSE31312", "rma_expression.pickle")
clinical_df_path = os.path.join(ROOT, "data", "GSE31312", "clinical.csv")

cv = 5
n_seed = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(n_seed)
_ = torch.manual_seed(n_seed)

overall_scores = {}
Y_scores = []

In [None]:
# Load data for training
from module.models import Graph_from_GSEA, Load_Dataset, Graph_Exp_Data
# Graphの読み込み(Reactome)
path_graphs = Graph_from_GSEA(gmt_path = gmt_path,
                              gml_path = gml_path,
                              gene_convert_ensembl = ensembl_path,
                              gene_convert_entrez  = entrez_path)

# Load expression data
dataset = Load_Dataset(array_path = expression_path,
                   clinical_path = clinical_df_path)

# Set label data
dataset.Y = pd.DataFrame(dataset.clinicalDf["GEP"].map({"GCB":0, "ABC":1, np.nan:np.nan}).dropna())

# Select genes in KEGG pathways
dataset.x_expression = dataset.x_expression.loc[:, np.isin(dataset.x_expression.columns, list(path_graphs.pathways.all_genes))]
dataset.x_expression = pd.DataFrame(scipy.stats.zscore(dataset.x_expression, axis=1), 
                             index=dataset.x_expression.index, 
                             columns=dataset.x_expression.columns) # 正規化
train_idx = dataset.Y.index

dgl_dataset = Graph_Exp_Data(x=dataset.x_expression.loc[train_idx, :], 
                             y=dataset.Y.loc[train_idx], 
                             graphs = path_graphs)

In [None]:
# Load data for test dataset
expression_path = os.path.join(ROOT, "data", "GSE10846", "rma_expression.pickle")
clinical_df_path = os.path.join(ROOT, "data", "GSE10846", "clinical.csv")

test_dataset = Load_Dataset(array_path = expression_path,
                   clinical_path = clinical_df_path)

test_dataset.Y = pd.DataFrame(test_dataset.clinicalDf["GEP"].map({"GCB":0, "ABC":1, "UC": np.nan, np.nan:np.nan}).dropna())

test_dataset.x_expression = test_dataset.x_expression.loc[:, np.isin(test_dataset.x_expression.columns, list(path_graphs.pathways.all_genes))]
test_dataset.x_expression = pd.DataFrame(scipy.stats.zscore(test_dataset.x_expression, axis=1), 
                             index=test_dataset.x_expression.index, 
                             columns=test_dataset.x_expression.columns) # 正規化

test_idx = test_dataset.Y.index

test_dgl_dataset = Graph_Exp_Data(x=test_dataset.x_expression.loc[test_idx, :], 
                             y=test_dataset.Y.loc[test_idx], 
                             graphs = path_graphs)

# MLP: Multilayer Perceptron

In [None]:
# Train data
X_train = dataset.x_expression.loc[train_idx, :]
X_train = torch.from_numpy(X_train.values).float()

Y_train = dataset.Y.loc[train_idx, :]
Y_train = Y_train.values.reshape(-1)
Y_train = torch.from_numpy(Y_train).long()

# Test Data
X_test = test_dataset.x_expression.loc[test_idx, :]
X_test = torch.from_numpy(X_test.values).float()
Y_test = test_dataset.Y.loc[test_idx, :]
Y_test = Y_test.values.reshape(-1)
Y_test = torch.from_numpy(Y_test).long()

In [None]:
print("################################################")
print("#                      MLP                     #")
print("################################################")

importlib.reload(models)
from module.models import MLP_Trainer

file_path = os.path.join(DIR_LOG, str(n_seed), "mlp")
os.makedirs(file_path, exist_ok=True)

param_grid = {
    "in_dim": [dataset.x_expression.shape[1]],
    "hidden_dim1": [1000, 3000],
    "hidden_dim2": [1000, 3000],
    "hidden_dim3": [1000],
    "n_class" : [2],
    "learning_rate" : [1e-4, 1e-6],
    "n_epoch": [50],
    "patience": [5],
    "n_batch" : [64, 128],
    "dropout" : [0.2, 0.4, 0.6],
         }
n_search = 30

trainer = MLP_Trainer(path=file_path, device=device)
trainer.X_test = X_test
trainer.Y_test = Y_test
trainer.gridsearch_cv(X_train, Y_train, param_grid, cv=cv, max_n_search=n_search, random_seed=n_seed, testscore=True, score_metrics="acc")
trainer.load_bestmodel(load_log=True)
pred = trainer.predict(X_test)
test_score = trainer.multi_acc(Y_test, pred)
print(f"Test Score: {test_score:.3g}")
overall_scores["mlp"] = np.mean(test_score)

In [None]:
# Evaluation
from module.models import MLP_Trainer
torch.manual_seed(n_seed)
variable_name = "mlp"
file_path = os.path.join(DIR_LOG, str(n_seed), "mlp")
trainer = MLP_Trainer(path=file_path, device=device)
trainer.load_bestmodel(load_log=True)

print("Training set")
Y_proba = trainer.predict(X_train)
fpr, tpr, thres = roc_curve(Y_train, Y_proba[:,1])
cutoff = thres[np.argmin(1-tpr+fpr)]
Y_pred = [1 if x > cutoff else 0 for x in Y_proba[:, 1]]

acc = accuracy_score(Y_train, Y_pred)
print(f"Accuracy: {acc:.3f}")
scores = precision_recall_fscore_support(Y_train, Y_pred, average="binary")

print(f"Precision: {scores[0]:.3f}")
print(f"Recall: {scores[1]:.3f}")
print(f"F1 score: {scores[2]:.3f}")

roc_auc = roc_auc_score(Y_train, Y_proba[:, 1])

print(f"ROC AUC: {roc_auc:.3f}")

print("Test set")
Y_proba = trainer.predict(X_test)
Y_pred = [1 if x > cutoff else 0 for x in Y_proba[:, 1]]
acc = accuracy_score(Y_test, Y_pred)
print(f"Accuracy: {acc:.3f}")
scores = precision_recall_fscore_support(Y_test, Y_pred, average="binary")
print(f"Precision: {scores[0]:.3f}")
print(f"Recall: {scores[1]:.3f}")
print(f"F1 score: {scores[2]:.3f}")
roc_auc = roc_auc_score(Y_test, Y_proba[:, 1])
print(f"ROC AUC: {roc_auc:.3f}")

n_parameters = sum([np.prod(x.shape) for x in trainer.net.parameters()])
print(f"Number of parameters {n_parameters}")
pred = trainer.predict(X_train)
Y_scores.append((Y_train, pred))

# Graph Convolutional Network

In [None]:
dgl_dataset.mapping_attr()
X_train = dgl_dataset.attr.float()
Y_train = dgl_dataset.y

test_dgl_dataset.mapping_attr()
X_test = test_dgl_dataset.attr.float()
Y_test = test_dgl_dataset.y

In [None]:
print("################################################")
print("#                      GCN                     #")
print("################################################")

importlib.reload(models)
from module.models import GCN_Trainer

file_path = os.path.join(DIR_LOG, str(n_seed), "gcn")
os.makedirs(file_path, exist_ok=True)

param_grid = {
    "in_dim": [1],
    "hidden_dim1": [0],
    "hidden_dim2": [0],
    "gcn_dim1" : [10, 20, 40],
    "gcn_dim2" : [5, 10, 20],
    "n_class" : [2],
    "learning_rate" : [1e-2,1e-4,1e-6],
    "n_epoch": [50],
    "patience": [5],
    "n_batch" : [64, 128],
    "dropout" : [0.2, 0.4, 0.6],
    "deep_fc" : [False]
         }

n_search = 30

trainer = GCN_Trainer(dgl_dataset.batched_graph, path=file_path, device=device)
trainer.X_test = X_test
trainer.Y_test = Y_test
trainer.gridsearch_cv(X_train, Y_train, param_grid, cv=5, max_n_search=n_search, random_seed=n_seed, testscore=True, score_metrics="acc")
trainer.load_bestmodel(load_log=True)
pred = trainer.predict(X_test)
test_score = trainer.multi_acc(Y_test, pred)
print(f"Test Score: {test_score:.3g}")
overall_scores["gcn"] = np.mean(test_score)

In [None]:
# Evaluation
from module.models import GCN_Trainer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, brier_score_loss, roc_auc_score, roc_curve, precision_recall_fscore_support
torch.manual_seed(n_seed)
file_path = os.path.join(DIR_LOG, str(n_seed), "gcn")
trainer = GCN_Trainer(dgl_dataset.batched_graph, path=file_path, device=device)
trainer.load_bestmodel(load_log=True)

print("Training set")
Y_proba = trainer.predict(X_train)
fpr, tpr, thres = roc_curve(Y_train, Y_proba[:,1])
cutoff = thres[np.argmin(1-tpr+fpr)]
Y_pred = [1 if x > cutoff else 0 for x in Y_proba[:, 1]]

acc = accuracy_score(Y_train, Y_pred)
print(f"Accuracy: {acc:.3f}")
scores = precision_recall_fscore_support(Y_train, Y_pred, average="binary")

print(f"Precision: {scores[0]:.3f}")
print(f"Recall: {scores[1]:.3f}")
print(f"F1 score: {scores[2]:.3f}")

roc_auc = roc_auc_score(Y_train, Y_proba[:, 1])
#score_dict["rocauc"].append(roc_auc)
print(f"ROC AUC: {roc_auc:.3f}")

print("Test set")
Y_proba = trainer.predict(X_test)
Y_pred = [1 if x > cutoff else 0 for x in Y_proba[:, 1]]

acc = accuracy_score(Y_test, Y_pred)
print(f"Accuracy: {acc:.3f}")
scores = precision_recall_fscore_support(Y_test, Y_pred, average="binary")
print(f"Precision: {scores[0]:.3f}")
print(f"Recall: {scores[1]:.3f}")
print(f"F1 score: {scores[2]:.3f}")
roc_auc = roc_auc_score(Y_test, Y_proba[:, 1])
print(f"ROC AUC: {roc_auc:.3f}")

n_parameters = sum([np.prod(x.shape) for x in trainer.net.parameters()])
print(f"Number of parameters {n_parameters}")
pred = trainer.predict(X_train)
Y_scores.append((Y_train, pred))

In [None]:
print("################################################")
print("#                    GCN-MLP                   #")
print("################################################")

importlib.reload(models)
from module.models import GCN_Trainer

file_path = os.path.join(DIR_LOG, str(n_seed), "gcn_mlp")
os.makedirs(file_path, exist_ok=True)

param_grid = {
    "in_dim": [1],
    "hidden_dim1": [1000],
    "hidden_dim2": [1000],
    "gcn_dim1" : [10, 20, 40],
    "gcn_dim2" : [5, 10, 20],
    "n_class" : [2],
    "learning_rate" : [1e-2,1e-4],
    "n_epoch": [50],
    "patience": [5],
    "n_batch" : [64, 128],
    "dropout" : [0.2, 0.4, 0.6],
    "deep_fc" : [True]
         }

n_search = 30

trainer = GCN_Trainer(dgl_dataset.batched_graph, path=file_path, device=device)
trainer.X_test = X_test
trainer.Y_test = Y_test
trainer.gridsearch_cv(X_train, Y_train, param_grid, cv=cv, max_n_search=n_search, random_seed=n_seed, testscore=True, score_metrics="acc")
trainer.load_bestmodel(load_log=True)
pred = trainer.predict(X_test)
test_score = trainer.multi_acc(Y_test, pred)
print(f"Test Score: {test_score:.3g}")
overall_scores["gcn_mlp"] = np.mean(test_score)

In [None]:
from module.models import GCN_Trainer
torch.manual_seed(n_seed)
file_path = os.path.join(DIR_LOG, str(n_seed), "gcn_mlp")
trainer = GCN_Trainer(dgl_dataset.batched_graph, path=file_path, device=device)
trainer.load_bestmodel(load_log=True)

print("Training set")
Y_proba = trainer.predict(X_train)
fpr, tpr, thres = roc_curve(Y_train, Y_proba[:,1])
cutoff = thres[np.argmin(1-tpr+fpr)]
Y_pred = [1 if x > cutoff else 0 for x in Y_proba[:, 1]]

# 評価
acc = accuracy_score(Y_train, Y_pred)
print(f"Accuracy: {acc:.3f}")
scores = precision_recall_fscore_support(Y_train, Y_pred, average="binary")

print(f"Precision: {scores[0]:.3f}")
print(f"Recall: {scores[1]:.3f}")
print(f"F1 score: {scores[2]:.3f}")

roc_auc = roc_auc_score(Y_train, Y_proba[:, 1])

print(f"ROC AUC: {roc_auc:.3f}")

print("Test set")
Y_proba = trainer.predict(X_test)
Y_pred = [1 if x > cutoff else 0 for x in Y_proba[:, 1]]
acc = accuracy_score(Y_test, Y_pred)
print(f"Accuracy: {acc:.3f}")
scores = precision_recall_fscore_support(Y_test, Y_pred, average="binary")
print(f"Precision: {scores[0]:.3f}")
print(f"Recall: {scores[1]:.3f}")
print(f"F1 score: {scores[2]:.3f}")
roc_auc = roc_auc_score(Y_test, Y_proba[:, 1])
print(f"ROC AUC: {roc_auc:.3f}")

n_parameters = sum([np.prod(x.shape) for x in trainer.net.parameters()])
print(f"Number of parameters {n_parameters}")
pred = trainer.predict(X_train)
Y_scores.append((Y_train, pred))

In [None]:
# Record scores in dataframe
file_path = os.path.join(DIR_LOG, "overall_scores.csv")
try:
    result = pd.read_csv(file_path, index_col=0).to_dict()
    if str(n_seed) not in result:
        result[str(n_seed)] = {}
    result[str(n_seed)].update(overall_scores)
except:
    result = {str(n_seed): overall_scores}
result_df = pd.DataFrame(result)
result_df.to_csv(file_path)

In [None]:
# Calibration plot
from sklearn.calibration import calibration_curve
fig = plt.figure(1, figsize=(10, 10))
ax1 = fig.add_subplot(1,1,1)
ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
label_name = ["mlp", "gcn", "gcn_mlp"]

i=0
labels = Y_test
for labels, pred in Y_scores:
    fraction_of_positives, mean_predicted_value = calibration_curve(labels, pred[:,1], n_bins=5)
    prob_pos = pred[:, 1]
    clf_score = brier_score_loss(labels, prob_pos, pos_label=1)
    ax1.plot(mean_predicted_value, fraction_of_positives, "s-",
             label=f"{label_name[i]} ({clf_score:.3f})")
    i+=1
ax1.set_ylabel("Fraction of positives")
ax1.set_ylim([-0.05, 1.05])
ax1.legend(loc="lower right")
ax1.set_title('Calibration plots  (reliability curve)')
fig.savefig(os.path.join(DIR_LOG, str(n_seed), "calibration_plots.png"))

# Output dataset for GSEA

In [None]:
gcb_idx = dataset.Y.query(f'GEP == 0').index
abc_idx = dataset.Y.query(f'GEP == 1').index
label_idx = {"GCB": gcb_idx, "ABC": abc_idx}

_ = utils.create_gsea_dataset(dataset.x_expression.T,
                          label_idx,
                          gene_id_type="symbol",
                          filename="GSE31312_gep")

# SHAP on GCN

In [None]:
torch.manual_seed(n_seed)
file_path = os.path.join(DIR_LOG, str(n_seed), "gcn")
model_path = os.path.join(file_path, "model.pth")

trainer = GCN_Trainer(dgl_dataset.batched_graph, path=file_path, device=device)
trainer.load_bestmodel(load_log=True)
params = trainer.p
gcn_out_dim = params["gcn_dim2"]
graph_names = dgl_dataset.graph_names
n_graph = len(graph_names)

In [None]:
#%%time
# SHAP
shap.initjs()
shap_seed=0

gcn_param = {"g": trainer.g, "n_graphs": n_graph}
gcn_net = GCN_classifier(outcome=2, **gcn_param ,**params)
gcn_net.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
gcn_net = gcn_net.to(device)

explainer_x = dgl_dataset.attr.float().to(device)
shap_value_x = dgl_dataset.attr.float().to(device)

explainer_whole = shap.GradientExplainer(gcn_net, explainer_x)
shap_values_whole = explainer_whole.shap_values(explainer_x[0:100])

# Save SHAP data
df = pd.DataFrame(shap_values_whole[0], columns = dgl_dataset.columns)
df.to_csv(os.path.join(file_path, "gene_level_shap.csv"))

In [None]:
plt.style.use('seaborn-darkgrid')

In [None]:
# SHAP summary plot
# Load SHAP data
df = pd.read_csv(os.path.join(file_path, "gene_level_shap.csv"), index_col=0)
df.columns = [x.split(".")[0] for x in df.columns]

df = df.abs().T.groupby(df.columns).sum().T
fig = shap.summary_plot(df, feature_names=df.columns, plot_type="bar",
                 title = "Gene-level feature importance",
                 show=False)
plt.savefig(os.path.join(file_path, "shap_gene_level_gcn.png"))
plt.show()
df.mean().sort_values(ascending=False)[0:20]

In [None]:
from module.models import GCN_classifier, GCN_classifier_layer1, GCN_classifier_layer2

# Extract intermedaite layer
shap.initjs()
shap_seed=0
feature_name = [graph_names[i] for i in range(len(graph_names)) for _ in range(gcn_out_dim)]

gcn_param = {"g": trainer.g, "n_graphs": n_graph}
gcn_net_layer1 = GCN_classifier_layer1(outcome=2, **gcn_param ,**params)
gcn_net_layer1.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
gcn_net_layer1 = gcn_net_layer1.to(device)
gcn_net_layer2 = GCN_classifier_layer2(outcome=2, **gcn_param ,**params)
gcn_net_layer2.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
gcn_net_layer2 = gcn_net_layer2.to(device)

# Shap
sample = np.random.choice(len(dgl_dataset), size=len(dgl_dataset), replace=False)
explainer_x = dgl_dataset.attr[sample].float().to(device)
shap_value_x = dgl_dataset.attr[sample[0:100]].float().to(device)
shap_value_x = dgl_dataset.attr[sample].float().to(device)

gcn_net_layer1.eval()
with torch.no_grad():
    explainer_mid_x = gcn_net_layer1(explainer_x)
    shap_mid_x = gcn_net_layer1(shap_value_x)

explainer = shap.DeepExplainer(gcn_net_layer2, explainer_mid_x)
shap_values = explainer.shap_values(shap_mid_x)

# Compute mean absolute shapley values
shap_mean = np.abs(shap_values[0]).mean(axis=0)
shap_mean = [shap_mean[i*params["gcn_dim2"]:(i+1)*params["gcn_dim2"]].sum() for i in range(n_graph)]

# Save Data
df = pd.DataFrame(shap_values[0], columns = feature_name)
df.to_csv(os.path.join(file_path, "kegg_shap.csv"))

In [None]:
# SHAP summary plot
# Load data
df = pd.read_csv(os.path.join(file_path, "kegg_shap.csv"), index_col=0)
df.columns = [x.split(".")[0] for x in df.columns]

# graph names
def edit_graph_names(x):
    x = x.replace("KEGG_", "")
    x = x.replace("_", " ")
    x = x.capitalize()
    return x

graphs = [edit_graph_names(x) for x in graph_names]

df = df.abs().T.groupby(df.columns).sum().T
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax = shap.summary_plot(df, feature_names=graphs, plot_type="bar", show=False)
plt.savefig(os.path.join(file_path, "shap_kegg_summary.png"))
plt.show()

# Local Interpretation
1. Compute SHAP values for a sample in the GCN model 
2. Visualize SHAP values to show pathway contribution

In [None]:
pred = trainer.predict(shap_value_x[0:100])
expected_values = pred.mean(0).to("cpu").numpy()
pred_label = np.argmax(pred, axis=1)
y_label = dgl_dataset.y[sample][0:100]
acc_pred = (y_label == pred_label)

display_names = [x.replace('KEGG_', "").replace('_', " ") for x in graph_names]

In [None]:
# force plot
l = 0
shap.force_plot(expected_values[l], shap_values[l],
               feature_names = display_names)

In [None]:
plt.style.use('default')
def plot_interpretation(i):
    label_dict = {0: "GCB", 1:"ABC"}
    print(f"Predicted Label {i}: {label_dict[int(pred_label[i])]}")
    print(f"Is the prediction correct?: {bool(acc_pred[i])}")
    l = int(pred_label[i]) # predicted label

    aggregated_sv = np.array([shap_values[l][i][j*params["gcn_dim2"]:(j+1)*params["gcn_dim2"]].sum() for j in range(n_graph)]) # グラフ毎にまとめたShap values

    ax = shap.force_plot(explainer.expected_value[l], aggregated_sv,
                   feature_names = display_names,
                   matplotlib=True, show=False,
                    figsize=(16,5),text_rotation=0)
    plt.subplots_adjust(top=0.50, bottom=0)
    plt.savefig("tmp.png")
    plt.savefig("tmp.svg")
    plt.show()
    
    sv_dict = {}
    for p_name, sv in zip(display_names, aggregated_sv):
        sv_dict[p_name] = [sv]
    df = pd.DataFrame(sv_dict).T.sort_values(by=0, ascending=False)
    df = df.reset_index()
    df.columns = ["Pathway", "Shapley value"]
    #print(df)
    return df

df = plot_interpretation(0)
df