In [None]:
import pandas as pd
import os
import numpy as np
from tqdm import tqdm
import math
import random
from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, roc_curve
from statannot import add_stat_annotation
import seaborn as sns
import time

import networkx as nx
from tqdm import tqdm
import sklearn
import seaborn as sns
import matplotlib.pyplot as plt

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import dgl.function as fn
from dgl.nn.pytorch import GraphConv, SAGEConv, TAGConv, GINConv

from deepsurv_utils import c_index, adjust_learning_rate
# from loss import NegativeLogLikelihood

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# load SHPH data
all_patient_info = pd.read_csv("")
all_patient_info = all_patient_info[['folder_name', 'Sex_1_male_2_female', 'Age',
       'Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL','Histology_1_Adenocarcinoma_2_SquamousCellCarcinoma_3_Others',
        'pT_Stage', 'pN_Stage', 'pM_Stage', 'pTNM', 'RFS_Status', 'RFS_Month',
       'OS_Status', 'OS_Month']]
stage1 = list(np.load("/home/jielian/lung-graph-project/data/seg_image/labels/name_stage1.npy"))
stage2 = list(np.load("/home/jielian/lung-graph-project/data/seg_image/labels/name_stage2.npy"))
patint_list = [*stage1, *stage2]
patient_info = all_patient_info[all_patient_info['folder_name'].isin(patint_list)]

patient_info['pT_Stage']=patient_info['pT_Stage'].replace({"T1a":0, "T1b":0, "T1c":0, "T2a":1,"T2b":1,"T3":2})
patient_info['pM_Stage']=patient_info['pM_Stage'].replace({"M1a":1})

feature_files = os.listdir("trans_feature")

In [None]:
trans_feature = "/home/jielian/lung-graph-project/data/seg_image/medicalnet/tumor_medical/"

data = []
name = []
for feature_name in feature_files:
#     print(feature_name)
    path = trans_feature +feature_name[:-4] + "_d64.npy"
    name.append(int(feature_name[:-4]))
    feature = np.load(path, allow_pickle=True)
    af = feature.reshape((-1, 8, 8, 8))
    af_mean = af.mean(axis=(1, 2, 3))
    data.append(af_mean)
feature_data = pd.DataFrame(data)
feature_data = (feature_data-feature_data.min())/(feature_data.max()-feature_data.min())
feature_data['folder_name']=name
all_data = patient_info.merge(feature_data, how='left', on='folder_name')
print(len(all_data))

In [None]:
all_data['Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL'] = all_data['Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL'].replace({4:3})

In [None]:
# load external information
external_info = pd.read_csv("/home/jielian/lung-graph-project/Tumor_tranformer/data_ind/External_label.csv")
external_patint_list = external_info['Patient']
external_info=external_info.rename(columns={"Patient":"folder_name"})
external_info=external_info.rename(columns={"Histology":"Histology_1_Adenocarcinoma_2_SquamousCellCarcinoma_3_Others"})
external_info['pT_Stage']=external_info['pT_Stage'].replace({"Tis":0,"T1a":0, "T1b":0, "T1c":0, "T2a":1,"T2b":1,"T3":2,"T4":3 })


In [None]:
external_feature_files = os.listdir("trans_feature_val")
external_info['Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL'] = external_info['Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL'].replace({4:3})
external_data = []
for feature_name in external_patint_list:
    path = "/tumor/"+feature_name+"_d64.npy"
    feature = np.load(path, allow_pickle=True)
    af = feature.reshape((-1, 8, 8, 8))
    af_mean = af.mean(axis=(1, 2, 3))
    external_data.append(af_mean)
external_feature_data = pd.DataFrame(external_data)
external_feature_data = (external_feature_data-external_feature_data.min())/(external_feature_data.max()-external_feature_data.min())
external_feature_data['folder_name']=external_patint_list
external_all_data = external_info.merge(external_feature_data, how='left', on='folder_name')
print(len(external_all_data))

In [None]:
# # #merge the dataset
frames = [all_data, external_all_data]
final_data = pd.concat(frames)

In [None]:
train_id = np.load("data_ind/train_index.npy",allow_pickle=True)
val_id = np.load("data_ind/val_index.npy", allow_pickle=True)
test_id = np.load("data_ind/test_index.npy",allow_pickle=True)
external_id = np.array(range(len(all_data),len(final_data)))
idx_train = torch.LongTensor(train_id)
idx_val = torch.LongTensor(val_id)
idx_test = torch.LongTensor(test_id)
idx_external_val = torch.LongTensor(external_id)

print("training OS distribution:")
print(all_data.iloc[train_id,:]['OS_Status'].value_counts())
print("validation OS distribution:")
print(all_data.iloc[val_id,:]['OS_Status'].value_counts())
print("test OS distribution:")
print(all_data.iloc[test_id,:]['OS_Status'].value_counts())
print("External OS distribution:")
print(final_data.iloc[external_id,:]['OS_Status'].value_counts())


print("training RFS_Status distribution:")
print(all_data.iloc[train_id,:]['RFS_Status'].value_counts())
print("validation RFS_Status distribution:")
print(all_data.iloc[val_id,:]['RFS_Status'].value_counts())
print("test RFS_Status distribution:")
print(all_data.iloc[test_id,:]['RFS_Status'].value_counts())
print("External RFS_Status distribution:")
print(final_data.iloc[external_id,:]['RFS_Status'].value_counts())

# Start Graph Building!!

In [None]:
# define similarity of two patient
def SimScore(a1,a2,s1,s2,l1,l2,h1,h2,t1,t2,n1,n2,m1,m2,tnm1,tnm2): 
    c_score = 0
    h_score = 0
    t_score = 0
    l_score = 0
    # sex and age
    if s1 == s2:
        c_score +=1
    if abs(a1-a2) <= 5:
        c_score +=1
    
    if l1 == l2:
        l_score +=1
        
    if h1 == h2:
        h_score +=1
    
#     if (t1 == t2) and (n1 == n2) and (m1 == m2):
#         t_score +=1
    
    if t1 == t2:
        t_score +=1
    if n1 == n2:
        t_score +=1
    if m1 == m2:
        t_score +=1



    return c_score*t_score*h_score*l_score


def adj_matrix(patient_info):
    age = patient_info['Age'].to_list()
    sex = patient_info['Sex_1_male_2_female'].to_list()
    loc = patient_info['Location_1_LUL_2_LLL_3_RUL_4_RML_5_RLL'].to_list()
    his = patient_info['Histology_1_Adenocarcinoma_2_SquamousCellCarcinoma_3_Others'].to_list()
    pts = patient_info['pT_Stage'].to_list()
    pns = patient_info['pN_Stage'].to_list()
    pms = patient_info['pM_Stage'].to_list()
    tnm = patient_info['pTNM'].to_list()

    edge_list=[]
    edge_wight=[]
    n_sample = len(age)
    adj = np.zeros((n_sample, n_sample))
    for i in range(n_sample):
        for j in range(n_sample):
            adj[i,j] = SimScore(age[i],age[j],sex[i],sex[j],loc[i],loc[j],his[i],his[j],
                                pts[i],pts[j],pns[i],pns[j], pms[i],pms[j],tnm[i],tnm[j])
            if adj[i,j] != 0:
                edge_list.append([i,j])
                edge_wight.append(adj[i,j])
    return adj, edge_list,edge_wight

In [None]:

def graph_bulider(all_data, start_cloumn = 13, event = "OS_Status", label = "OS_Month"):

    # save the labels
    norm_label_sh = all_data[label]
    # norm_label = (final_data['OS_Month']-np.min(final_data['OS_Month']))/(np.max(final_data['OS_Month'])-np.min(final_data['OS_Month']))
    labels_sh = torch.from_numpy(norm_label_sh.to_numpy())
    
    events_sh = torch.from_numpy(all_data[event].to_numpy())
    
    adj_sh, edge_list_sh, edge_wight_sh = adj_matrix(all_data)
    print("the number of nodes in this graph:",len(norm_label_sh))
    print("the number of edges in this graph:",len(edge_list_sh))
    print("Number of average degree: ",len(edge_list_sh)/len(norm_label_sh) )
    
    # build graph struture data
    g_sh = dgl.DGLGraph()
    g_sh.add_nodes(len(labels_sh))
    # add nodes
    # node_feature = (all_data.iloc[:, 15:]-all_data.iloc[:, 15:].min())/(all_data.iloc[:, 15:].max()- all_data.iloc[:, 15:].min())
    node_feature_sh = all_data.iloc[:, start_cloumn:]
    # print(node_feature)
    node_feature_norm_sh = node_feature_sh.to_numpy()
    g_sh.ndata['h'] = torch.from_numpy(node_feature_norm_sh).float()
    g_sh.ndata['event'] = events_sh
    g_sh.ndata['label'] = labels_sh
    g_sh.ndata
    # g.adj = adj
    # add edges
    src, dst = tuple(zip(*edge_list_sh))
    g_sh.add_edges(src, dst)
    # add edge weight
    edge_wight_sh = np.array(edge_wight_sh)
    g_sh.edata['w'] = torch.from_numpy(edge_wight_sh).float()
    return adj_sh, g_sh

# Network and Loss

In [None]:
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, dropout=0, activation = None,aggregator_type='mean'):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = SAGEConv(in_feats=hid_feats, out_feats=hid_feats, aggregator_type=aggregator_type, activation=activation, feat_drop=dropout)
        self.conv2 = SAGEConv(in_feats=hid_feats, out_feats= out_feats, aggregator_type=aggregator_type, activation=activation, feat_drop=dropout)
        self.fc2 = nn.Linear(out_feats, 1) 
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h = self.fc1(inputs)
        h = self.conv1(graph, h, w_input)
        h = self.conv2(graph,h)
#         print(h.size())
#         output=F.relu(self.fc2(h))
        output= self.fc2(h)

        return output
    
class TAG(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, activation = F.softmax):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = TAGConv(in_feats=hid_feats, out_feats= hid_feats, activation=activation)
        self.conv2 = TAGConv(in_feats=hid_feats, out_feats= out_feats,  activation=activation)
        self.fc2 = nn.Linear(out_feats, 1) 
        
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h= self.fc1(inputs)
        h = self.conv1(graph, h)
        h = self.conv2(graph,h)
#         output=F.relu(self.fc2(h))
        h=self.fc2(h)
        
        return h



    
class GCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, activation = F.softmax, norm ="both"):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = GraphConv(in_feats=hid_feats, out_feats= 32, activation=activation, norm=norm)
        self.conv2 = GraphConv(in_feats=32, out_feats= out_feats,  activation=activation, norm=norm)
        self.fc2 = nn.Linear(out_feats, 1) 
        
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h= self.fc1(inputs)
        h = self.conv1(graph, h)
        h = self.conv2(graph,h)
#         output=F.relu(self.fc2(h))
        h=self.fc2(h)
        
        return h
    
    
class SAGE1L(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, dropout=0, activation = None,aggregator_type='mean'):
        super().__init__()
        self.fc1 = nn.Linear(in_feats, hid_feats) 
        self.conv1 = SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type=aggregator_type, activation=activation, feat_drop=dropout)
        self.fc2 = nn.Linear(out_feats, 1) 
    def forward(self, graph, inputs, w_input):
        # inputs are features of nodes
        h = self.fc1(inputs)
        h = self.conv1(graph, h, w_input)

        output= self.fc2(h)

        return output
    


In [None]:
class Regularization(object):
    def __init__(self, order, weight_decay):
        ''' The initialization of Regularization class
        :param order: (int) norm order number
        :param weight_decay: (float) weight decay rate
        '''
        super(Regularization, self).__init__()
        self.order = order
        self.weight_decay = weight_decay

    def __call__(self, model):
        ''' Performs calculates regularization(self.order) loss for model.
        :param model: (torch.nn.Module object)
        :return reg_loss: (torch.Tensor) the regularization(self.order) loss
        '''
        reg_loss = 0
        for name, w in model.named_parameters():
            if 'weight' in name:
                reg_loss = reg_loss + torch.norm(w, p=self.order)
        reg_loss = self.weight_decay * reg_loss
        return reg_loss

    
class NegativeLogLikelihood(nn.Module):
    def __init__(self, l2_reg, device):
        super(NegativeLogLikelihood, self).__init__()
        self.L2_reg = l2_reg
        self.device = device
        self.reg = Regularization(order=2, weight_decay=self.L2_reg)

    def forward(self, risk_pred, y, e, model):
        mask = torch.ones(y.shape[0], y.shape[0]).to(self.device)
        mask[(y.T - y) > 0] = 0
        log_loss = torch.exp(risk_pred) * mask
        log_loss = torch.sum(log_loss, dim=0) / torch.sum(mask, dim=0)
        log_loss = torch.log(log_loss).reshape(-1, 1)
        neg_log_loss = -torch.sum((risk_pred-log_loss) * e) / torch.sum(e)
        l2_loss = self.reg(model)
        return neg_log_loss + l2_loss

In [None]:
adj_all, g_all = graph_bulider(final_data)
adj_sh, g_sh = graph_bulider(all_data)
# g_sh = dgl.node_subgraph(g_all, list(range(len(all_data))))
g_external = dgl.node_subgraph(g_all, list(range(len(all_data),len(final_data))))
# g_val = graph_bulider(all_data)

# Training

In [None]:
def train(g, g_all, model,device, save_dic, idx_train,idx_val, idx_test, total_epoch=100, patience=5, lr=0.001, reg_l2=0, weight_decay=0.0001):
    model_name = save_dic['model']+str(save_dic['hid_feats'])+str(save_dic['out_feats'])+str(save_dic['reg_l2'])+save_dic["aggregator_type"]
    optimizer = torch.optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
    best_cindex = 0
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.5, patience=patience, min_lr = 0.0001, verbose=True)
    criterion = NegativeLogLikelihood(reg_l2, device).to(device) 
    model = model.to(device) 
    features = g.ndata['h'].to(device) 
    e_feature = g.edata['w'].to(device) 
    labels = g.ndata['label'].to(device) 
    events = g.ndata['event'].to(device) 
    g_all= g_all.to(device)  
    g = g.to(device) 
    t_total = time.time()
    with tqdm(range(total_epoch)) as t:
        for epoch in t:
            t.set_description('Epoch %d' % epoch)
            start = time.time()
            model.train()
            optimizer.zero_grad()
            output = model(g, features,e_feature)
            # Compute loss
            # Note that you should only compute the losses of the nodes in the training set.
            loss_train = criterion(output[idx_train], labels[idx_train],events[idx_train], model).clone()
            auc_train = c_index(-output[idx_train], labels[idx_train],events[idx_train])
            
            loss_train.backward(retain_graph=True)
            optimizer.step()
            
            model.eval()
            val_output = model(g, features,e_feature)
            loss_val = criterion(val_output[idx_val], labels[idx_val],events[idx_val], model).clone()
            scheduler.step(loss_val)
            
            auc_val = c_index(-val_output[idx_val], labels[idx_val],events[idx_val])
            auc_test = c_index(-val_output[idx_test], labels[idx_test], events[idx_test])
            exter_val_output = model(g_all, g_all.ndata['h'],g_all.edata['w'])

            t.set_postfix(
                  {"train_loss":loss_train.item(), "val_loss":loss_val.item(),
                  "train_cindex":auc_train.item(), "val_auc":auc_val.item(),
                "lr":optimizer.param_groups[0]['lr']}) 