In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn.conv import APPNP
import networkx as nx 

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# define single APPNP
class APPNP_model(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats=3, dropout=0, k=10, alpha=0.1,edge_drop=0, normalize =False):
        super(APPNP_model, self).__init__()
        self.lin1 = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(in_feats, hid_feats),
                nn.ReLU(),
                nn.Dropout(p=dropout),
                nn.Linear(hid_feats, out_feats),
            )
        self.drop_out = nn.Dropout(p=dropout)
        self.l1 = APPNP(K=k, alpha=alpha, dropout=edge_drop)

    def forward(self, x, edge):
        h = self.lin1(x)
        # h = self.drop_out(h)
        logits = self.l1(h, edge)
        return logits

In [None]:
# defien the task
path = ''
task =''
label_tar=['Label_NHY','Label_NP1TOT','Label_NP2PTOT','Label_NP3TOT', 'Label_NPTOT','Label_NP3_Axial','Label_NP3_Tremor',
'Label_NP3_Akinetic_Rigid','Label_MoCA_score','Label_ESS_TOT']
# load data
train_data = pd.read_csv(path+'ppmi_'+task+'_train''.csv')
test_data = pd.read_csv(path+'ppmi_'+task+'_test'+'.csv')
# load features
feature_cols = np.loadtxt(path+'ppmi_node.txt', dtype=str)
f = [train_data, test_data]
pd_patients_all = pd.concat(f)
pd_patients_all.reset_index(inplace=True, drop=True)

In [None]:
# model path
m_path= ''  # load model dictionary
rec_path = '' # load the best results error record json file
file = open(rec_path, 'r') 
record = file.read()
record = record.split("'edge':")
record = record[1:]
modelf = record[-1].split("['hid', 'dropout', 'k', 'alpha', 'weight decay', 'lr'],")[-1][2:-2].split()

In [None]:
def Sim_func(a1,a2,thresh): 
    c_score = 0
    if abs(a1-a2) <= thresh:
        c_score +=1
    return c_score

def age_adj_matrix(patient_info, edge_f, thresh):
    edge_feature = patient_info[edge_f].to_list()
    edge_list=[]
    edge_wight=[]
    n_sample = len(edge_feature)
    adj = np.zeros((n_sample, n_sample))
    for i in range(n_sample):
        for j in range(n_sample):
            adj[i,j] = Sim_func(edge_feature[i],edge_feature[j], thresh)
            if adj[i,j] != 0:
                edge_list.append([int(i),int(j)])
                edge_wight.append(adj[i,j])
    return adj, edge_list,edge_wight

def age_graph_bulider(all_data, label, feature_cols, edge_f, thresh):
    # save the labels
    norm_label_sh = all_data[label]
    labels_sh = torch.from_numpy(norm_label_sh.to_numpy()).long()
    node_feature_sh = torch.from_numpy(all_data[feature_cols].to_numpy()).float()
    adj_sh, edge_list_sh, edge_wight_sh = age_adj_matrix(all_data, edge_f, thresh)
    edge_list_sh = torch.tensor(edge_list_sh)
    # print(edge_list_sh)
    g_sh = Data(x = node_feature_sh, edge_index=edge_list_sh.t().contiguous(), y = labels_sh)
    return g_sh


def nx_graph_bulider(all_data, label, feature_cols, edge_f, thresh):
    # save the labels
    norm_label_sh = all_data[label]
    labels_sh = norm_label_sh.to_numpy()
    node_feature_sh = all_data[feature_cols].to_numpy()
    adj_sh, edge_list_sh, edge_wight_sh = age_adj_matrix(all_data, edge_f, thresh)
    rows, cols = np.where(adj_sh == 1)
    edges = zip(rows.tolist(), cols.tolist())
    g = nx.Graph()
    g.add_edges_from(edges)
    for i in g.nodes:
        g.nodes[i]["label"] = labels_sh[i]
    return g
    
def re_build_graph(record, i, c_label, re_cal_thred=False):
    rec = record[i].split(',')
    print(rec)
    dege_f = rec[0][2:-1]
    thred = float(rec[1].split(': ')[1])
    print(dege_f)
    if re_cal_thred:
        thred = pd_patients_all[dege_f].quantile(thred)
    print(thred)
    graph = age_graph_bulider(pd_patients_all, c_label, feature_cols, dege_f, thred)
    graph_nx = nx_graph_bulider(pd_patients_all, c_label, feature_cols, dege_f, thred)

    return dege_f, graph, graph_nx

In [None]:
i=0 # the first important graph
edge, graph, g_nx = re_build_graph(record, i, 'Label_NHY', re_cal_thred=True)
# model_path = m_path+'models/'+'search_'+f_method+c_label+'model_'+str(i-1)+'.pt'÷
model_path = m_path+str(i)+'.pt'
print(model_path)
# in_feats, hid_feats, out_feats=3, dropout=0, k=10, alpha=0.1,
model = APPNP_model(in_feats=len(feature_cols), hid_feats=int(modelf[0][:-1]), out_feats=3, dropout=float(modelf[1][:-1]), k=int(modelf[2][1:-2]), alpha=float(modelf[3][:-1]))
model.load_state_dict(torch.load(model_path))
model = model.to(device)

Single graph explain

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer

In [None]:
# explain the first important graph

# from torch_geometric.explain import Explainer, 
graph_device=graph.to(device)
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200), # you can define your own explainer
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',  # Model returns log probabilities.
    ),
)
explanation = explainer(graph_device.x, graph_device.edge_index)
print(f'Generated explanations in {explanation.available_explanations}')

path = ''
explanation.visualize_feature_importance(path, top_k=15)

In [None]:
d = dict(g_nx.nodes(data="label"))

ind = 72 # find the interested patient id
ns = [n for n in g_nx.neighbors(ind)]
print(len(ns))
print(g_nx.nodes[ind]['label'])

labelss = [g_nx.nodes[n]['label'] for n in ns]
df = pd.DataFrame()
df['neighbor'] = ns
df['label'] = labelss
df.label.value_counts() # get the neighbour labels distribution