In [None]:
import pandas as pd
import os
import numpy as np
from tqdm import tqdm
import math
import random
from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, roc_curve
from statannot import add_stat_annotation
import seaborn as sns
import time

import networkx as nx
from tqdm import tqdm
import sklearn
import seaborn as sns
import matplotlib.pyplot as plt

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import dgl.function as fn
from dgl.nn.pytorch import GraphConv, SAGEConv, TAGConv

from deepsurv_utils import c_index, adjust_learning_rate
# from loss import NegativeLogLikelihood

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# load SHPH data
all_patient_info = pd.read_csv("")
all_patient_info = all_patient_info[['folder_name', 'Sex_1_male_2_female', 'Age',
       'Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL','Histology_1_Adenocarcinoma_2_SquamousCellCarcinoma_3_Others',
        'pT_Stage', 'pN_Stage', 'pM_Stage', 'pTNM', 'RFS_Status', 'RFS_Month',
       'OS_Status', 'OS_Month']]
stage1 = list(np.load("labels/name_stage1.npy"))
stage2 = list(np.load("labels/name_stage2.npy"))
patint_list = [*stage1, *stage2]
patient_info = all_patient_info[all_patient_info['folder_name'].isin(patint_list)]

patient_info['pT_Stage']=patient_info['pT_Stage'].replace({"T1a":0, "T1b":0, "T1c":0, "T2a":1,"T2b":1,"T3":2})
patient_info['pM_Stage']=patient_info['pM_Stage'].replace({"M1a":1})

feature_files = os.listdir("trans_feature")

data = []
name = []
for feature_name in feature_files:
    path = "trans_feature/"+feature_name
    name.append(int(feature_name[:-4]))
    feature = list(np.load(path, allow_pickle=True))
    data.append(feature)
feature_data = pd.DataFrame(data)
feature_data['folder_name']=name
all_data = patient_info.merge(feature_data, how='left', on='folder_name')
print(len(all_data))

In [None]:
# load external information
external_info = pd.read_csv("/home/jielian/lung-graph-project/Tumor_tranformer/data_ind/External_label.csv")
external_patint_list = external_info['Patient']
external_info=external_info.rename(columns={"Patient":"folder_name"})
external_info=external_info.rename(columns={"Histology":"Histology_1_Adenocarcinoma_2_SquamousCellCarcinoma_3_Others"})
external_info['pT_Stage']=external_info['pT_Stage'].replace({"Tis":0,"T1a":0, "T1b":0, "T1c":0, "T2a":1,"T2b":1,"T3":2,"T4":3 })

external_feature_files = os.listdir("trans_feature_val")

external_data = []
for feature_name in external_patint_list:
    path = "trans_feature_val/"+feature_name+".npy"
    feature = list(np.load(path, allow_pickle=True))
    external_data.append(feature)
external_feature_data = pd.DataFrame(external_data)
external_feature_data['folder_name']=external_patint_list
external_all_data = external_info.merge(external_feature_data, how='left', on='folder_name')
print(len(external_all_data))

In [None]:
# # #merge the dataset
frames = [all_data, external_all_data]
final_data = pd.concat(frames)

In [None]:
train_id = np.load("data_ind/train_index.npy",allow_pickle=True)
val_id = np.load("data_ind/val_index.npy", allow_pickle=True)
test_id = np.load("data_ind/test_index.npy",allow_pickle=True)
external_id = np.array(range(len(all_data),len(final_data)))
idx_train = torch.LongTensor(train_id)
idx_val = torch.LongTensor(val_id)
idx_test = torch.LongTensor(test_id)
idx_external_val = torch.LongTensor(external_id)

print("training OS distribution:")
print(all_data.iloc[train_id,:]['OS_Status'].value_counts())
print("validation OS distribution:")
print(all_data.iloc[val_id,:]['OS_Status'].value_counts())
print("test OS distribution:")
print(all_data.iloc[test_id,:]['OS_Status'].value_counts())
print("External OS distribution:")
print(final_data.iloc[external_id,:]['OS_Status'].value_counts())


print("training RFS_Status distribution:")
print(all_data.iloc[train_id,:]['RFS_Status'].value_counts())
print("validation RFS_Status distribution:")
print(all_data.iloc[val_id,:]['RFS_Status'].value_counts())
print("test RFS_Status distribution:")
print(all_data.iloc[test_id,:]['RFS_Status'].value_counts())
print("External RFS_Status distribution:")
print(final_data.iloc[external_id,:]['RFS_Status'].value_counts())

# Start Graph Building!!

In [None]:
# define similarity of two patient
def SimScore(a1,a2,s1,s2,l1,l2,h1,h2,t1,t2,n1,n2,m1,m2,tnm1,tnm2): 
    c_score = 0
    h_score = 0
    t_score = 0
    # sex and age
    if s1 == s2:
        c_score +=1
    if abs(a1-a2) <= 5:
        c_score +=1
    
    if l1 == l2:
        h_score +=1
    if h1 == h2:
        h_score +=1
    
    if t1 == t2:
        t_score +=1
    if n1 == n2:
        t_score +=1
    if m1 == m2:
        t_score +=1
#     if tnm1 == tnm2:
#         t_score +=1

    return c_score*t_score*h_score

# def SimScore(a1,a2,s1,s2,l1,l2,h1,h2,t1,t2,n1,n2,m1,m2,tnm1,tnm2): 

#     return c_score*t_score*h_score


def adj_matrix(patient_info):
    age = patient_info['Age'].to_list()
    sex = patient_info['Sex_1_male_2_female'].to_list()
    loc = patient_info['Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL'].to_list()
    his = patient_info['Histology_1_Adenocarcinoma_2_SquamousCellCarcinoma_3_Others'].to_list()
    pts = patient_info['pT_Stage'].to_list()
    pns = patient_info['pN_Stage'].to_list()
    pms = patient_info['pM_Stage'].to_list()
    tnm = patient_info['pTNM'].to_list()

    edge_list=[]
    edge_wight=[]
    n_sample = len(age)
    adj = np.zeros((n_sample, n_sample))
    for i in range(n_sample):
        for j in range(n_sample):
            adj[i,j] = SimScore(age[i],age[j],sex[i],sex[j],loc[i],loc[j],his[i],his[j],
                                pts[i],pts[j],pns[i],pns[j], pms[i],pms[j],tnm[i],tnm[j])
            if adj[i,j] != 0:
                edge_list.append([i,j])
                edge_wight.append(adj[i,j])
    return adj, edge_list,edge_wight

In [None]:

def graph_bulider(all_data, start_cloumn = 13, event = "OS_Status", label = "OS_Month"):

    # save the labels
    norm_label_sh = all_data[label]
    # norm_label = (final_data['OS_Month']-np.min(final_data['OS_Month']))/(np.max(final_data['OS_Month'])-np.min(final_data['OS_Month']))
    labels_sh = torch.from_numpy(norm_label_sh.to_numpy())
    
    events_sh = torch.from_numpy(all_data[event].to_numpy())
    
    adj_sh, edge_list_sh, edge_wight_sh = adj_matrix(all_data)
    print("the number of nodes in this graph:",len(norm_label_sh))
    print("the number of edges in this graph:",len(edge_list_sh))
    print("Number of average degree: ",len(edge_list_sh)/len(norm_label_sh) )
    
    # build graph struture data
    g_sh = dgl.DGLGraph()
    g_sh.add_nodes(len(labels_sh))
    # add nodes
    # node_feature = (all_data.iloc[:, 15:]-all_data.iloc[:, 15:].min())/(all_data.iloc[:, 15:].max()- all_data.iloc[:, 15:].min())
    node_feature_sh = all_data.iloc[:, start_cloumn:]
    # print(node_feature)
    node_feature_norm_sh = node_feature_sh.to_numpy()
    g_sh.ndata['h'] = torch.from_numpy(node_feature_norm_sh).float()
    g_sh.ndata['event'] = events_sh
    g_sh.ndata['label'] = labels_sh
    g_sh.ndata
    # g.adj = adj
    # add edges
    src, dst = tuple(zip(*edge_list_sh))
    g_sh.add_edges(src, dst)
    # add edge weight
    edge_wight_sh = np.array(edge_wight_sh)
    g_sh.edata['w'] = torch.from_numpy(edge_wight_sh).float()
    return adj_sh, g_sh

# Network and Loss

In [None]:
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, dropout=0, activation = None,aggregator_type='mean'):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = SAGEConv(in_feats=hid_feats, out_feats=hid_feats, aggregator_type=aggregator_type, activation=activation, feat_drop=dropout)
        self.conv2 = SAGEConv(in_feats=hid_feats, out_feats= out_feats, aggregator_type=aggregator_type, activation=activation, feat_drop=dropout)
        self.fc2 = nn.Linear(out_feats, 1) 
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h = self.fc1(inputs)
        h = self.conv1(graph, h, w_input)
        h = self.conv2(graph,h)
#         print(h.size())
#         output=F.relu(self.fc2(h))
        output= self.fc2(h)

        return output

class TAG(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, activation = F.softmax):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = TAGConv(in_feats=hid_feats, out_feats= hid_feats, activation=activation)
        self.conv2 = TAGConv(in_feats=hid_feats, out_feats= out_feats,  activation=activation)
        self.fc2 = nn.Linear(out_feats, 1) 
        
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h= self.fc1(inputs)
        h = self.conv1(graph, h)
        h = self.conv2(graph,h)
#         output=F.relu(self.fc2(h))
        h=self.fc2(h)
        
        return h

    
class GCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, activation = F.softmax, norm ="both"):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = GraphConv(in_feats=hid_feats, out_feats= 32, activation=activation, norm=norm)
        self.conv2 = GraphConv(in_feats=32, out_feats= out_feats,  activation=activation, norm=norm)
        self.fc2 = nn.Linear(out_feats, 1) 
        
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h= self.fc1(inputs)
        h = self.conv1(graph, h)
        h = self.conv2(graph,h)
#         output=F.relu(self.fc2(h))
        h=self.fc2(h)
        
        return h
    
    
class SAGE1L(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, dropout=0, activation = None,aggregator_type='mean'):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type=aggregator_type, activation=activation, feat_drop=dropout)
        self.fc2 = nn.Linear(out_feats, 1) 
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h = self.fc1(inputs)
        h = self.conv1(graph, h, w_input)

        output= self.fc2(h)

        return output

    


In [None]:
class Regularization(object):
    def __init__(self, order, weight_decay):
        ''' The initialization of Regularization class
        :param order: (int) norm order number
        :param weight_decay: (float) weight decay rate
        '''
        super(Regularization, self).__init__()
        self.order = order
        self.weight_decay = weight_decay

    def __call__(self, model):
        ''' Performs calculates regularization(self.order) loss for model.
        :param model: (torch.nn.Module object)
        :return reg_loss: (torch.Tensor) the regularization(self.order) loss
        '''
        reg_loss = 0
        for name, w in model.named_parameters():
            if 'weight' in name:
                reg_loss = reg_loss + torch.norm(w, p=self.order)
        reg_loss = self.weight_decay * reg_loss
        return reg_loss

    
class NegativeLogLikelihood(nn.Module):
    def __init__(self, l2_reg, device):
        super(NegativeLogLikelihood, self).__init__()
        self.L2_reg = l2_reg
        self.device = device
        self.reg = Regularization(order=2, weight_decay=self.L2_reg)

    def forward(self, risk_pred, y, e, model):
        mask = torch.ones(y.shape[0], y.shape[0]).to(self.device)
        mask[(y.T - y) > 0] = 0
        log_loss = torch.exp(risk_pred) * mask
        log_loss = torch.sum(log_loss, dim=0) / torch.sum(mask, dim=0)
        log_loss = torch.log(log_loss).reshape(-1, 1)
        neg_log_loss = -torch.sum((risk_pred-log_loss) * e) / torch.sum(e)
        l2_loss = self.reg(model)
        return neg_log_loss + l2_loss

In [None]:
adj_all, g_all = graph_bulider(final_data)
adj_sh, g_sh = graph_bulider(all_data)
# g_sh = dgl.node_subgraph(g_all, list(range(len(all_data))))
g_external = dgl.node_subgraph(g_all, list(range(len(all_data),len(final_data))))
# g_val = graph_bulider(all_data)

# Training

In [None]:
def train(g, g_all, model,device, save_dic, idx_train,idx_val, idx_test, total_epoch=100, patience=5, lr=0.001, reg_l2=0, weight_decay=0.0001):
    model_name = save_dic['model']+str(save_dic['hid_feats'])+str(save_dic['out_feats'])+str(save_dic['reg_l2'])+save_dic["aggregator_type"]
    optimizer = torch.optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
    best_cindex = 0
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.5, patience=patience, min_lr = 0.0001, verbose=True)
    criterion = NegativeLogLikelihood(reg_l2, device).to(device) 
    model = model.to(device) 
    features = g.ndata['h'].to(device) 
    e_feature = g.edata['w'].to(device) 
    labels = g.ndata['label'].to(device) 
    events = g.ndata['event'].to(device) 
    g_all= g_all.to(device)  
    g = g.to(device) 
    t_total = time.time()
    with tqdm(range(total_epoch)) as t:
        for epoch in t:
            t.set_description('Epoch %d' % epoch)
            start = time.time()
            model.train()
            optimizer.zero_grad()
            output = model(g, features,e_feature)
            # Compute loss
            # Note that you should only compute the losses of the nodes in the training set.
            loss_train = criterion(output[idx_train], labels[idx_train],events[idx_train], model).clone()
            auc_train = c_index(-output[idx_train], labels[idx_train],events[idx_train])
            
            loss_train.backward(retain_graph=True)
            optimizer.step()
            
            model.eval()
            val_output = model(g, features,e_feature)
            loss_val = criterion(val_output[idx_val], labels[idx_val],events[idx_val], model).clone()
            scheduler.step(loss_val)
            
            auc_val = c_index(-val_output[idx_val], labels[idx_val],events[idx_val])
            auc_test = c_index(-val_output[idx_test], labels[idx_test], events[idx_test])
            exter_val_output = model(g_all, g_all.ndata['h'],g_all.edata['w'])
            
            t.set_postfix(
                  {"train_loss":loss_train.item(), "val_loss":loss_val.item(),
                  "train_cindex":auc_train.item(), "val_auc":auc_val.item(),
                "lr":optimizer.param_groups[0]['lr']}) 

# Subanalysis: High and Low Risk

In [None]:
from lifelines import KaplanMeierFitter, CoxPHFitter, calibration
from lifelines.statistics import logrank_test
import matplotlib as mpl

In [None]:
def plot_roc_curve(fpr, tpr):
    plt.plot(fpr, tpr, color='orange', label='ROC')
    plt.plot([0, 1], [0, 1], color='darkblue', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend()
    plt.show()

y_true = test['OS_Status']
y_scores = test['risk']

fpr, tpr, thresholds = roc_curve(y_true, y_scores)
print("AUC_value:",roc_auc_score(y_true, y_scores))
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
print("Threshold value is:", optimal_threshold)
plot_roc_curve(fpr, tpr)

In [None]:
thred = optimal_threshold
risk_label = []
for risk_score in test['risk']:
    if risk_score <= thred:
        risk_label.append(0)
    else:
        risk_label.append(1)
test['risk_label'] = risk_label



In [None]:
# Stage as reference
Ts1 = test[test['Stage']==0]['OS_Month']
Es1 = test[test['Stage']==0]['OS_Status']
kmfs1 = KaplanMeierFitter(label="TNM Low Risk")
kmfs1.fit(Ts1, Es1)
kmfs1.survival_function_
kmfs1.cumulative_density_
kmfs1.plot_survival_function(ci_show =True)
# kmf2.plot_cumulative_density()

Ts2 = test[test['Stage']==1]['OS_Month']
Es2 = test[test['Stage']==1]['OS_Status']
kmfs2 = KaplanMeierFitter(label="TNM High Risk")
kmfs2.fit(Ts2, Es2)
kmfs2.survival_function_
kmfs2.cumulative_density_
kmfs2.plot_survival_function(ci_show =True)
# kmf2.plot_cumulative_density()
stage_results=logrank_test(Ts1,Ts2,event_observed_A=Es1, event_observed_B=Es2)
stage_results.print_summary()


TGCN_low = test[test['risk_label']==0]['OS_Month']
EGCN_low = test[test['risk_label']==0]['OS_Status']
kmGCN_low = KaplanMeierFitter(label="GCN Low Risk")
kmGCN_low.fit(TGCN_low, EGCN_low)
kmGCN_low.survival_function_
kmGCN_low.cumulative_density_
kmGCN_low.plot_survival_function(ci_show =True)
# kmf2.plot_cumulative_density()

TGCN_high = test[test['risk_label']==1]['OS_Month']
EGCN_high = test[test['risk_label']==1]['OS_Status']
kmGCN_high = KaplanMeierFitter(label="GCN High Risk")
kmGCN_high.fit(TGCN_high, EGCN_high)
kmGCN_high.survival_function_
kmGCN_high.cumulative_density_
kmGCN_high.plot_survival_function(ci_show =True)
# kmf2.plot_cumulative_density()
gcn_results=logrank_test(TGCN_low,TGCN_high,event_observed_A=EGCN_low, event_observed_B=EGCN_high)
gcn_results.print_summary()

In [None]:
from matplotlib.offsetbox import AnchoredText
fig = plt.figure(figsize=(20,7))

ax1 = fig.add_subplot(131)
kmGCN_low.plot_survival_function(ci_show =False)
kmGCN_high.plot_survival_function(ci_show =False, color='r')
ax1.title.set_text('Transformer Graph Model Kaplan-Meier Curve ')
ax1.title.set_fontsize(15)
ax1.set_ylim([0.4, 1.0])
ax1.set_xlim([0.0, 65.0])
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
p1 = gcn_results.p_value
ax1.add_artist(AnchoredText("p = %.3e" % round(p1, 10), loc=4, frameon=False))
# ax1.title("GCN Kaplan-Meier Curve")
ax1.legend(loc='lower left',prop={'size': 10})


ax2 =  fig.add_subplot(132)
kmfs1.plot_survival_function(ci_show =False)
kmfs2.plot_survival_function(ci_show =False, color='r')
ax2.title.set_text('TNM Model Kaplan-Meier Curve ')
ax2.title.set_fontsize(15)
ax2.set_ylim([0.4, 1.0])
ax2.set_xlim([0.0, 65.0])
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
p2 = stage_results.p_value
ax2.add_artist(AnchoredText("p = %.3e" % round(p2,10 ), loc=4, frameon=False))
ax2.legend(loc='lower left',prop={'size': 10})

# plt.tight_layout()
plt.show()

In [None]:
GCN_cox = test[['OS_Month','OS_Status', 'risk']]
cph_GCN = CoxPHFitter()
cph_GCN.fit(GCN_cox, 'OS_Month', event_col='OS_Status')
cph_GCN.print_summary()
axGCN, ICI, E50 = calibration.survival_probability_calibration(cph_GCN, GCN_cox, t0=60)

In [None]:
test_stage = test[['OS_Month','OS_Status', 'Stage']]
cph_stage = CoxPHFitter()
cph_stage.fit(test_stage, 'OS_Month', event_col='OS_Status')
cph_stage.print_summary()
ax_stage, ICI, E50 = calibration.survival_probability_calibration(cph_stage, test_stage, t0=60)

# Subanalysis: Inside the Graph
### Turn to Networkx Plotting and Analysis

In [None]:
def thred(risk, best_thresh = optimal_threshold):
    if risk < best_thresh:
        label = -1
    else:
        label = 1
    return label

In [None]:
# Get the trained graph
g_test = dgl.node_subgraph(g_sh, idx_test)
g_test.ndata['h'] = torch.from_numpy(np.array(test['risk']))
print(g_test.number_of_nodes())
print(g_test.number_of_edges())
print(g_test.number_of_edges()/g_test.number_of_nodes())
nx_G = g_test.to_networkx()

# add node and edge data to the graph
for i in range(len(nx_G.nodes())):
    nx_G.nodes[i]["Risk"] = g_test.ndata['h'][i].item()
    nx_G.nodes[i]["Risk_Label"] = thred(g_test.ndata['h'][i].item())
    nx_G.nodes[i]["E"] = g_test.ndata['event'][i].item()
    nx_G.nodes[i]["T"] = g_test.ndata['label'][i].item()

w = g_test.edata['w'].numpy()
for i in range(len(nx_G.edges())):
    edge = list(nx_G.edges())[i]
    nx_G.edges[edge[0],edge[1],0]["w"] = g_test.edata['w'][i].item()

In [None]:
from networkx.readwrite import json_graph
import json
# import json
data = json_graph.node_link_data(nx_G)
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
jsodata = json.dumps(data, cls=NpEncoder)
with open('dense_data.json', 'w') as f:
    json.dump(jsodata, f)

In [None]:
d = dict(nx_G.nodes(data="Risk_Label"))
low, *_, high = sorted(d.values())
norm = mpl.colors.Normalize(vmin=low-2, vmax=high+2, clip=True)
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.coolwarm)
plt.figure(figsize=[20,10])
nx.draw(nx_G, nodelist=d, node_size= 400,
        node_color=[mapper.to_rgba(i) for i in d.values()], 
        edge_color = 'grey', with_labels=False, width= 0.2, font_color='white')
plt.show()

In [None]:
d = dict(nx_G.nodes(data="E"))
low, *_, high = sorted(d.values())
norm = mpl.colors.Normalize(vmin=low-1, vmax=high+1, clip=True)
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.coolwarm)
plt.figure(figsize=[20,10])
nx.draw_spring(nx_G, nodelist=d, node_size= 400,
        node_color=[mapper.to_rgba(i) for i in d.values()], 
        edge_color = 'grey', with_labels=False, width= 0.2, font_color='white')
plt.show()

# Make analysis

In [None]:
# g_sh
risk_label = [thred(i) for i in patient_info['risk']]
g_sh.ndata['risk'] = torch.from_numpy(patient_info['risk'].to_numpy())
g_sh.ndata['risk_label'] =torch.from_numpy(np.array(risk_label))

In [None]:
def sum_np(values):
    neg = 0
    pos = 0
    for i in values:
        if i>0:
            pos += i
        else:
            neg +=i
    return neg, pos

In [None]:
# find all neighboor of a target node: here is: dx_test[0]
def neighboor_analysis(idx): 
    node_survival = g_sh.ndata['event'][idx]
    node_survival_month = g_sh.ndata['label'][idx]
    node_risk = g_sh.ndata['risk'][idx]
    node_risk_label = g_sh.ndata['risk_label'][idx]
    print("node risk:", node_risk)
    print("node risk label:", node_risk_label)
    print("node survival:", node_survival)
    neighboor = g_sh.in_edges(idx)
    # find the neiboors's risk:
    neighboor_risk = g_sh.ndata['risk'][neighboor[0]]
    neighboor_risk_label = g_sh.ndata['risk_label'][neighboor[0]]
    uni, count = torch.unique(neighboor_risk_label, return_counts=True)
    print("node neriboor risk label:", uni, count)
    neighboor_survival_label = g_sh.ndata['event'][neighboor[0]]
    print("node neriboor sum survival",sum(neighboor_survival_label))
    # find the neighboor with weight
    neighboor_weight = g_sh.edges[neighboor][0]['w']
    neighboor_weight_label = torch.mul(neighboor_risk_label, neighboor_weight)
    neg, pos = sum_np(neighboor_weight_label)
#     print("low risk imformation:", neg)
#     print("High risk imformation:", pos)  
    sum_weights = neg + pos
    if sum_weights >0:
        print("high risk provide more wights!")
    else:
        print("==========low risk provide more wights!=============")
    return node_survival, node_survival_month, node_risk, node_risk_label, count[0], count[1], neighboor_risk, neighboor_risk_label, neighboor_weight, neighboor_weight_label, -neg, pos, sum_weights

In [None]:
# analysis from edge information

idxs = [] 
survivals = [] 
survival_months = []
risks = []
risk_labels = []
neighboor_risks = []
low_counts  = []
high_counts = []
num_neighboors = []
neighboor_risk_labels = []
neighboor_weights = []
neighboor_weight_labels = []
negs = []
poss = []
sum_weightss = []
for idx in idx_test:    
    node_survival, node_survival_m, node_risk, node_risk_label, low_count, high_count, neighboor_risk, _, neighboor_weight, neighboor_weight_label, neg, pos,sum_weights = neighboor_analysis(idx)
    idxs.append(idx.item())
    survivals.append(node_survival.item())
    survival_months.append(node_survival_m.item())
    risks.append(node_risk.item())
    risk_labels.append(node_risk_label.item())
    neighboor_risks.append(neighboor_risk)
    low_counts.append(low_count.item())
    high_counts.append(high_count.item())
    num_neighboors.append((low_count+high_count).item())
    neighboor_weights.append(neighboor_weight)
    neighboor_weight_labels.append(neighboor_weight_label)
    negs.append(neg.item())
    poss.append(pos.item())
    sum_weightss.append(sum_weights.item())

In [None]:
clinical_df = pd.DataFrame()
clinical_df['node'] = idxs
clinical_df['survival'] = survivals
clinical_df['survival_month'] =
clinical_df['num_neighboors'] = num_neighboors
clinical_df['risks'] = risks
clinical_df['risk_labels'] = risk_labels
clinical_df['neighboor_risks'] = neighboor_risks
clinical_df['neighboor_low_counts'] = [100*low_counts[i]/num_neighboors[i] for i in range(len(num_neighboors))]
clinical_df['neighboor_high_counts'] = [100*high_counts[i]/num_neighboors[i] for i in range(len(num_neighboors))]


clinical_df['low_weights'] = negs
clinical_df['high_weights'] = poss
clinical_df['sum_weights'] = sum_weightss
clinical_df['sum_weights_abs'] = clinical_df['low_weights'] + clinical_df['low_weights']
clinical_df['survival'] = clinical_df['survival'].replace({0:-1})
clinical_df['mean_sum_weights_abs'] = clinical_df['sum_weights_abs']/clinical_df['num_neighboors']


In [None]:
fig = plt.figure(figsize=(20,7), dpi = 350)
sns.set(style="whitegrid")
# fig.subplots_adjust(top=0.15)
ax1 = fig.add_subplot(121)
x1 = "risk_labels"
y = "num_neighboors"
order1 = [-1, 1]
# ax1 = sns.boxplot(data=clinical_df, x=x1, y=y, order=order1, palette=sns.color_palette(['#FF5720', '#18C288']))
ax1 = sns.boxplot(data=clinical_df, x=x1, y=y, order=order1, palette=sns.color_palette(['#FF5720', '#18C288']))

add_stat_annotation(ax1, data=clinical_df, x=x1, y=y,  box_pairs=[(-1,1)],
                    test='Mann-Whitney', text_format='star', loc='outside', verbose=2)

ax1.set(xlabel='risk', ylabel='number of neighboors')
ax1.set(xticklabels=["low risk", "high risk"])
# fig.subplots_adjust(bottom = 0.5)
# fig.savefig("box_plot.jpg")
                  
ax2 = fig.add_subplot(122)
x2 = "survival"
y = "num_neighboors"
order2= [-1, 1]
ax2 = sns.boxplot(data=clinical_df, x=x2, y=y, order=order2, palette=sns.color_palette(['#FF5720', '#18C288']))
add_stat_annotation(ax2, data=clinical_df, x=x2, y=y, box_pairs=[(-1,1)],
                    test='Mann-Whitney', text_format='star', loc='outside', verbose=2)
ax2.set(xlabel='Survival', ylabel='number of neighboors')
ax2.set(xticklabels=["Survival", "Death"])

In [None]:
fig = plt.figure(figsize=(20,7), dpi = 350)
sns.set(style="whitegrid")
# fig.subplots_adjust(top=0.15)
ax1 = fig.add_subplot(121)
x1 = "survival"
y = "sum_weights_abs"
order1 = [-1, 1]
# ax1 = sns.boxplot(data=clinical_df, x=x1, y=y, order=order1, palette=sns.color_palette(['#FF5720', '#18C288']))
ax1 = sns.boxplot(data=clinical_df, x=x1, y=y, order=order1, palette=sns.color_palette(['#FF5720', '#18C288']))

add_stat_annotation(ax1, data=clinical_df, x=x1, y=y,  box_pairs=[(-1,1)],
                    test='Mann-Whitney', text_format='star', loc='outside', verbose=2)

ax1.set(xlabel='Survival', ylabel='sum of neighboors weights')
ax1.set(xticklabels=["Survival", "Death"])
# fig.subplots_adjust(bottom = 0.5)
# fig.savefig("box_plot.jpg")
ax2 = fig.add_subplot(122)
x2 = "survival"
y = "mean_sum_weights_abs"
order2= [-1, 1]
ax2 = sns.boxplot(data=clinical_df, x=x2, y=y, order=order2, palette=sns.color_palette(['#FF5720', '#18C288']))
add_stat_annotation(ax2, data=clinical_df, x=x2, y=y, box_pairs=[(-1,1)],
                    test='Mann-Whitney', text_format='star', loc='outside', verbose=2)
ax2.set(xlabel='Survival', ylabel='mean of neighboors weights')
ax2.set(xticklabels=["Survival", "Death"])

In [None]:
clinical_d = clinical_df[['survival', 'low_weights', 'high_weights']]
d2 = clinical_d.melt(id_vars="survival", var_name="neighboor risk")

In [None]:
# Draw a nested barplot to show survival for class and sex
fig = plt.figure(figsize=(20,7), dpi = 350)
sns.set(style="whitegrid")
# fig.subplots_adjust(top=0.15)
ax1 = fig.add_subplot(121)
x = "survival"
y1 = "value"
hue1 = 'neighboor risk'
order1 = [-1, 1]
# ax1 = sns.boxplot(data=clinical_df, x=x1, y=y, order=order1, palette=sns.color_palette(['#FF5720', '#18C288']))
ax1 = sns.boxplot(data=d2, x=x, y=y1, hue = hue1,  order=order1)
ax1.set(xlabel='survival', ylabel='sum of neighboor weights')
ax1.set(xticklabels=["survival", "death"])



In [None]:
clinical_df.to_csv('data_ind/test_graph_feature.csv', index=False)