In [None]:
# MAP Study 2023 for immunotherapy response prediction
# Yanan Wang @ BNR Mar. 27, 2024
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
from matplotlib import pyplot as plt
import numpy as np
import scipy.io as sio
from sklearn.metrics import confusion_matrix
import matplotlib
import argparse
from sklearn.metrics import f1_score, accuracy_score, auc, roc_curve
import torch
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from model_and_metrics import *
from torch.nn import Linear
import pandas as pd

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn.conv import GraphConv, GCNConv, GINConv, GATConv
from torch_geometric.nn.pool import TopKPooling, SAGPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.nn import global_mean_pool


import math
from pytorchtools import EarlyStopping
import optuna
from sklearn.utils import shuffle
import networkx as nx
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.explain.metric import fidelity, unfaithfulness

import joblib
from tripletnet import Tripletnet
# import gnn_models

In [None]:
print(torch.cuda.is_available())  # 查看cuda是否可用
 
print(torch.cuda.device_count())  # 返回GPU数目
 
print(torch.cuda.get_device_name(0))  # 返回GPU名称，设备索引默认从0开始
 
print(torch.cuda.current_device())  # 返回当前设备索引


In [None]:
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
ref_df = pd.read_excel('pa_val_split_0.xlsx')
ref_df_patient = pd.pivot_table(ref_df, index=['pid'], values=['label', 'preds'], aggfunc='mean') #pid	split	label	preds

ref_df_patient = ref_df_patient[((ref_df_patient.preds > 0.65) & (ref_df_patient.label == 1)) | ((ref_df_patient.preds < 0.35) & (ref_df_patient.label == 0))]

print(ref_df_patient.shape)
ref_df_patient.head()
ref_patients = ref_df_patient.index.values

In [None]:
def plot_graph(edge_index, edge_attr):

    G = nx.Graph()

    # G.add
    #G.add_nodes_from(coor[0])

    edge_index_tmp = np.array(edge_index[:,0:100],dtype=int)
    print(edge_index_tmp)

    # G.add_edges_from(edge_index_tmp)

    for i in range(len(edge_index_tmp)):
        G.add_edge(edge_index_tmp[i, 0], edge_index_tmp[i, 1], weight=edge_attr[i][0])


    elarge = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] > 2]
    esmall = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] <= 2]

    pos = nx.spring_layout(G, seed=7)
    # nodes
    nx.draw_networkx_nodes(G, pos, node_size=20)

    # edges
    nx.draw_networkx_edges(G, pos, edgelist=elarge, edge_color="r", width=2)
    nx.draw_networkx_edges(
       G, pos, edgelist=esmall, width=0.5, alpha=0.5, edge_color="gray", style="dashed"
    )

    # node labels
    #nx.draw_networkx_labels(G, pos, font_size=2, font_family="sans-serif")
    # edge weight labels
    #edge_labels = nx.get_edge_attributes(G, "weight")
    #nx.draw_networkx_edge_labels(G, pos, edge_labels)

    ax = plt.gca()
    ax.margins(0.01)

    plt.axis("off")
    plt.tight_layout()
    plt.show()
    # plt.savefig

In [None]:
# # Build the model
class GNN(torch.nn.Module):
    def __init__(self, nhid=128, gnn_layer='GCN', drop_out=0.5, num_feature = joblib.load('data/pa_num_feature.joblib')):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.nhid=nhid
        self.num_feature = num_feature
        self.drop_out = drop_out
        self.adp = torch.nn.AlphaDropout(p=self.drop_out)
        self.bn = torch.nn.BatchNorm1d(self.nhid)
        self.elu = torch.nn.ELU()
        
        if(gnn_layer=='GCN'):
            self.conv1 = GraphConv(self.num_feature, self.nhid)
            self.conv2 = GraphConv(self.nhid, self.nhid//2)
            self.conv3 = GraphConv(self.nhid//2, self.nhid//4)
        elif(gnn_layer=='GAT'):
            self.conv1 = GATConv(int(self.num_feature), self.nhid)
            self.conv2 = GATConv(self.nhid, self.nhid//2)
            self.conv3 = GATConv(self.nhid//2, self.nhid//4)
        elif(gnn_layer=='GIN'):
            self.conv1 = GINConv(Seq(Lin(self.num_feature, self.nhid), ReLU(), Lin(self.nhid, self.nhid)))
            self.conv2 = GINConv(Seq(Lin(self.nhid, self.nhid//2), ReLU(), Lin(self.nhid//2, self.nhid//2)))
            self.conv3 = GINConv(Seq(Lin(self.nhid//2, self.nhid//4), ReLU(), Lin(self.nhid//4, self.nhid//4))) 
            
        self.linear_1 = Linear(self.nhid//2, self.nhid//4)
        self.linear_2 = Linear(self.nhid//4, self.nhid//8)
        # self.linear_3 = Linear(self.nhid//2, 2)
        # self.linear_1 = Linear(self.nhid//2, 2)

    def forward(self, x, edge_index, batch=torch.tensor([0])):
        # print(batch)
        x = self.conv1(x, edge_index)
        x = self.bn(x)
        x = self.elu(x)
        #x = self.adp(x)
        x = self.conv2(x, edge_index)
        #x = self.bn(x)
        x = self.elu(x)
        # x = self.adp(x)
        # x = self.conv3(x, edge_index)
        # x = self.bn(x)
        # x = self.elu(x)
        # x = self.adp(x)

        x = global_mean_pool(x, batch)

        #x = (x - torch.min(x))/(torch.max(x) - torch.min(x))

        x = torch.tensor(x, dtype=torch.float32)
        x = self.linear_1(x)
        x = self.linear_2(x)
        # x = self.linear_3(x)
        y = F.sigmoid(x)
        return y

In [None]:
def triplet_data_load_train(tmp_data_list):
    triplet_data = []
    print("Overall sample size: %d" % len(tmp_data_list))
    pos_x_idx = []
    pos_y_idx = []
    neg_z_idx = []
    for i in range(len(tmp_data_list)):
        # print(tmp_data_list[i].y.cpu().numpy())
        if(tmp_data_list[i].y.cpu().numpy()[0][0] == 1):
            neg_z_idx.append(i)
        else:
            pos_x_idx.append(i)
            pos_y_idx.append(i)
            
    print(len(neg_z_idx))
    print(len(pos_x_idx))
    max_sample_size = max(len(pos_x_idx), len(neg_z_idx))
    print(max_sample_size)
    pos_x_idx = np.random.choice(pos_x_idx, max_sample_size*5, replace=True)
    pos_y_idx = np.random.choice(pos_y_idx, max_sample_size*5, replace=True)
    neg_z_idx = np.random.choice(neg_z_idx, max_sample_size*5, replace=True)
    while(np.any((pos_x_idx-pos_y_idx)==0)):
        np.random.shuffle(pos_y_idx)
    
    for j in range(max_sample_size):
        triplet_data.append([tmp_data_list[pos_x_idx[j]], tmp_data_list[pos_y_idx[j]], tmp_data_list[neg_z_idx[j]]])
        
    return triplet_data

In [None]:
def triplet_data_load(tmp_data_list):
    triplet_data = []
    print("Overall sample size: %d" % len(tmp_data_list))
    pos_x_idx = []
    pos_y_idx = []
    neg_z_idx = []
    for i in range(len(tmp_data_list)):
        # print(tmp_data_list[i].y.cpu().numpy())
        if(tmp_data_list[i].y.cpu().numpy()[0][0] == 1):
            neg_z_idx.append(i)
        else:
            pos_x_idx.append(i)
            pos_y_idx.append(i)
            
    print(len(neg_z_idx))
    print(len(pos_x_idx))
    # max_sample_size = max(len(pos_x_idx), len(neg_z_idx))
    max_sample_size = len(pos_x_idx) + len(neg_z_idx)
    print(max_sample_size)
    pos_x_idx = np.random.choice(range(max_sample_size), max_sample_size, replace=True)
    pos_y_idx = np.random.choice(pos_y_idx, max_sample_size, replace=True)
    neg_z_idx = np.random.choice(neg_z_idx, max_sample_size, replace=True)
    while(np.any((pos_x_idx-pos_y_idx)==0)):
        np.random.shuffle(pos_y_idx)
    
    for j in range(max_sample_size):
        triplet_data.append([tmp_data_list[pos_x_idx[j]], tmp_data_list[pos_y_idx[j]], tmp_data_list[neg_z_idx[j]]])
        
    return triplet_data

In [None]:
def triplet_data_load_ex_test(train_list, tmp_data_list):
    triplet_data = []
    print("Overall sample size: %d" % len(tmp_data_list))
    pos_x_idx = []
    pos_y_idx = []
    neg_z_idx = []
    for i in range(len(train_list)):
        # print(tmp_data_list[i].y.cpu().numpy())
        if(train_list[i].y.cpu().numpy()[0][0] == 1):
            neg_z_idx.append(i)
        else:
            # pos_x_idx.append(i)
            pos_y_idx.append(i)
            
    print(len(neg_z_idx))
    print(len(pos_y_idx))
    max_sample_size = len(tmp_data_list) #max(len(pos_x_idx), len(neg_z_idx))
    print(max_sample_size)
    #pos_x_idx = np.random.choice(pos_x_idx, max_sample_size*5, replace=True)
    pos_y_idx = np.random.choice(pos_y_idx, max_sample_size, replace=True)
    neg_z_idx = np.random.choice(neg_z_idx, max_sample_size, replace=True)
    # while(np.any((pos_x_idx-pos_y_idx)==0)):
    #     np.random.shuffle(pos_y_idx)
    
    for j in range(max_sample_size):
        triplet_data.append([tmp_data_list[j], train_list[pos_y_idx[j]], train_list[neg_z_idx[j]]])
        
    return triplet_data

In [None]:
def get_cv_list(ori_list, split_info, ex_list, fold=0):
    tmp_col_id = split_info.columns.get_loc("split_{}".format(fold))
    
    train_list = []
    val_list = []
    test_list = []
    ex_test_list = []
    
    train_ids = list(split_info[split_info.iloc[:, tmp_col_id] == 'Train'].sample_id)
    val_ids = list(split_info[split_info.iloc[:, tmp_col_id] == 'Validate'].sample_id)
    test_ids = list(split_info[split_info.iloc[:, tmp_col_id] == 'Test'].sample_id)
    # print(val_ids)
    for i in range(len(ori_list)):
        tmp_data = ori_list[i]
        pid_temp = [chr(x) for x in list(tmp_data.pid.cpu().numpy()[0])]
        pid_temp = "".join(pid_temp)
        # print(pid_temp)
        if(pid_temp in train_ids):
            train_list.append(tmp_data)
        elif(pid_temp in val_ids):
            val_list.append(tmp_data)
        elif(pid_temp in test_ids):
            test_list.append(tmp_data)
            
    train_tri_data = triplet_data_load_train(train_list)
    val_tri_data_train = triplet_data_load_train(val_list)
    val_tri_data_test = triplet_data_load(val_list)
    test_tri_data = triplet_data_load(test_list)
    ex_test_tri_data = triplet_data_load_ex_test(train_list, ex_list)
    
    return(train_tri_data, val_tri_data_train, val_tri_data_test, test_tri_data, ex_test_tri_data)

In [None]:
# Test function:
# Input: model and test loader;
# Output: Index, Sample_ID, Label, Embeded, distance_a, distance_b, pred;

def test(model, loader, target_dir):
    model.eval()
    criterion = torch.nn.TripletMarginLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #device = torch.device('cuda:2')
    loss = 0       
    label = np.array([])
    p_id = []
    tile_index = []
    dist_a_all = []
    dist_b_all = []
    
    embedded_output = np.array([])
    
    for i, data in enumerate(loader):
        tile_index.append(i)
        dist_a, dist_b, embedded_x, embedded_y, embedded_z = model(data[0], data[1], data[2])
        
        dist_a_all.append(float(dist_a.cpu().detach()))
        dist_b_all.append(float(dist_b.cpu().detach()))
        
        pid_temp = [chr(x) for x in list(data[0].pid.cpu().numpy()[0])]
        pid_temp = "".join(pid_temp)
        p_id.append(pid_temp)
        
        loss += criterion(embedded_x, embedded_y, embedded_z).item()

        _tmp_label = data[0].y.cpu().detach().numpy()[:, 1]

        label = np.hstack([label,_tmp_label]) if label.size else _tmp_label
        embedded_output = np.hstack([embedded_output,embedded_x.cpu().detach().numpy()]) if embedded_output.size else embedded_x.cpu().detach().numpy()
        # print(label)
    # pred_1 = np.array(pred_1).reshape(pred_1)
    acc_trip = accuracy(dist_a_all, dist_b_all)
    trip_auc = triplet_auc(dist_a_all, dist_b_all, label, target_dir)
    # print(acc_trip)
    return tile_index, p_id, label, embedded_output, dist_a_all, dist_b_all, loss / len(loader.dataset), trip_auc, acc_trip

In [None]:
def accuracy(dista, distb):
    # print(dista)
    margin = 0
    pred = (np.array(dista) - np.array(distb) - margin)
    # print(pred)
    return (pred < 0).sum()*1.0/len(dista)

In [None]:
def triplet_auc(dista, distb, label, target_dir):
        # print(dista)
    margin = 0
    pred = (np.array(distb) - np.array(dista) - margin)
    pred_norm = (pred-min(pred)) / (max(pred) - min(pred))
    
    colors = ["#E69F00", "#56B4E9"]
    plt.close('all')
    plt.style.use("ggplot")
    matplotlib.rcParams['font.family'] = "Arial"
    plt.figure(figsize=(8, 8), dpi=400)


    _fpr, _tpr, _ = roc_curve(label, pred)
    _auc = auc(_fpr, _tpr)
    plt.plot(_fpr, _tpr, color=colors[0], label=r'%s ROC (AUC = %0.3f)' % ("MPR", _auc), lw=2, alpha=.9)
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([-0.01, 1.01])
    plt.ylim([-0.01, 1.01])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    #plt.title('ROC curve of')
    plt.legend(loc="lower right")
    plt.savefig('%s/roc_auc_fig_sample_%d'%(target_dir, len(label)), dpi=400)
    plt.close('all')
    return(_auc)
    
    

In [None]:
def model_test(test_loader, panel='pa', fold=0, split=0, batch_size=32, gnn_layer='GAT', \
               dropout_ratio=0.7, nhid=256, learning_rate = 1e-3, \
               weight_decay=1e-5, epochs=300, result_dir = 'pa_pan_results', runs=1):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #device = torch.device('cuda:2')
    ############################# Create folders ##################################

    target_dir = "Visualization_analysis_PA/%s_%s%d_%d_%.3f_%.3f_%.3f_%d_%d" % (result_dir, gnn_layer, nhid, batch_size, dropout_ratio, learning_rate, weight_decay, split, fold)
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
    print(target_dir)
    ###############################################################
    num_class = 2
    # patience = 30
    # train_loss = np.zeros((runs,epochs),dtype=float)
    val_acc = np.zeros((runs,epochs))
    val_loss = np.zeros((runs,epochs))
    val_loss_c = np.zeros(runs)
    test_acc_c = np.zeros(runs)
    test_loss_c = np.zeros(runs)
    test_pred_c = np.zeros(runs)
    test_out_c = np.zeros((runs,num_class)) 
    groud_truth_c = np.zeros((runs,num_class))
    test_acc_p = np.zeros(runs)
    min_loss = 1e10*np.ones(runs)
    
    # criterion = torch.nn.BCELoss()
    criterion = torch.nn.TripletMarginLoss()


    val_t_acc = 0

    base_model = GNN(nhid=nhid, gnn_layer=gnn_layer, drop_out=dropout_ratio).to(device)
    model = Tripletnet(base_model)
    model.load_state_dict(torch.load("optimal_model/split_{}/model_fold{}_run{}.pth".format(split, fold, 0)))
    # print(model)
#########
    #v_tile_index, v_p_id, v_label, v_embedded_output, v_dist_a_all, v_dist_b_all, v_loss_1, v_auc, v_acc = test(model, val_loader_test, target_dir)     
    t_tile_index, t_p_id, t_label, t_embedded_output, t_dist_a_all, t_dist_b_all, t_loss_1, t_auc, t_acc = test(model, test_loader, target_dir)

    sv = target_dir+'/model_test' + '_fold' + str(fold) + '_runs' + str(runs) + '_run' + str(0) + '_epochs' + str(epochs)+'.mat'

    sio.savemat(sv, mdict={'test_index': t_tile_index, 'test_sample_id': t_p_id, 'test_label': t_label, 'test_embedded': t_embedded_output, 
                           'test_dist_a': t_dist_a_all, 'test_dist_b': t_dist_b_all, 'test_loss': t_loss_1, 'test_auc':t_auc})



In [None]:
split = 0
panel = 'pa'
split_info = pd.read_excel('sample_info_split_%d.xlsx' % split)
print(split_info.shape)
split_info.Response.hist()
gnn_data_list = joblib.load('data/%s_data_list.joblib' % panel)
ex_gnn_data_list = joblib.load('data/%s_data_list.joblib' % panel)


train_list, val_list, val_list_test, test_list, ex_test_list = get_cv_list(gnn_data_list, split_info, ex_gnn_data_list)
test_loader = DataLoader(test_list, batch_size=1, shuffle = False)
ex_test_loader = DataLoader(ex_test_list, batch_size=1, shuffle = False)

print(train_list[0])
print(ex_test_list[0])

In [None]:
for i, data in enumerate(ex_test_loader):
    
    if(i == 100):
        print(f"Index: {i}")
        # print(data[0], data[1], data[2])
        
        pid_temp = [chr(x) for x in list(data[0].pid.cpu().numpy()[0])]
        pid_temp = "".join(pid_temp)
        print(pid_temp)
        print(data[0].y)
        print(data[1].y)
        print(data[2].y)
        fig = plt.figure()
        plt.subplot(131)
        plt.imshow(data[0].x.cpu().numpy(), interpolation="bicubic")
        plt.subplot(132)
        plt.imshow(data[1].x.cpu().numpy(), interpolation="bicubic")
        plt.subplot(133)
        plt.imshow(data[2].x.cpu().numpy(), interpolation="bicubic")
        break

In [None]:
# split = 0
panel = 'pa'

study = optuna.create_study(
        storage="sqlite:///db_PA_0606-Copy1.sqlite3",  # Specify the storage URL here.
        study_name="GNN_PA_test_0606", load_if_exists=True)
optimal_param = study.best_params
#optimal_param = {'batch_size': 64, 'dropout_ratio': 0.3264115110664789, 'gnn_layer': 'GAT', 'learning_rate': 1e-02, 'nhid': 256, 'weight_decay': 1e-05}

In [None]:
joblib.dump(ex_test_list, "Visualization_analysis_PA/ex_test_list_fold_0.pkl")

In [None]:
split_num = 0

for fold_num in range(5):
    print(f"Fold number: {fold_num}")
    model_test(ex_test_loader, panel=panel, fold=fold_num, split=split_num, epochs=300, result_dir = '%s_pan_results_split_%d_fold_%d' % (panel, split_num, fold_num), **optimal_param)


In [None]:
split = 1
split_num = split

for fold_num in range(5):
    model_test(ex_test_loader, panel=panel, fold=fold_num, split=split_num, epochs=300, result_dir = '%s_pan_results_split_%d_fold_%d' % (panel, split_num, fold_num), **optimal_param)


In [None]:
split = 2
split_num = split

for fold_num in range(5):
    model_test(ex_test_loader, panel=panel, fold=fold_num, split=split_num, epochs=300, result_dir = '%s_pan_results_split_%d_fold_%d' % (panel, split_num, fold_num), **optimal_param)


In [None]:
split = 3
split_num = split

for fold_num in range(5):
    model_test(ex_test_loader, panel=panel, fold=fold_num, split=split_num, epochs=300, result_dir = '%s_pan_results_split_%d_fold_%d' % (panel, split_num, fold_num), **optimal_param)


In [None]:
split = 4
split_num = split

for fold_num in range(5):
    model_test(ex_test_loader, panel=panel, fold=fold_num, split=split_num, epochs=300, result_dir = '%s_pan_results_split_%d_fold_%d' % (panel, split_num, fold_num), **optimal_param)
