In [2]:
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
from tqdm import tqdm
import heapq
import numpy as np

with open('./full_graph_split1_eval.pkl', 'rb') as f:
    res = pickle.load(f)
result = pd.DataFrame(res['result'])

with open('./graphmask_output_indication.pkl', 'rb') as f:
    d_gm = pickle.load(f)
    
with open('./attention_output_indication.pkl', 'rb') as f:
    d_att = pickle.load(f)
    
with open('./gnnexplainer_output_indication.pkl', 'rb') as f:
    d_ge = pickle.load(f)

In [3]:
def preprocess(d):
    d = d[~d.y_name.str.contains('CYP')]
    d = d[~d.x_name.str.contains('CYP')]
    d = d.rename(columns = {'indication_layer1_att': 'layer1_att', 'indication_layer2_att': 'layer2_att'})
    return d

In [4]:
d_gm = preprocess(d_gm)
d_ge = preprocess(d_ge)
d_att = preprocess(d_att)

In [176]:
def build_graph(df):
    G = nx.MultiDiGraph()
    for _, row in tqdm(df.iterrows()):
        # Add nodes with their types
        G.add_node(row['x_name'], node_type=row['x_type'])
        G.add_node(row['y_name'], node_type=row['y_type'])

        # Add edges with relation type and weights
        G.add_edge(row['x_name'], row['y_name'], relation=row['relation'],
                   layer1_att=row['layer1_att'], layer2_att=row['layer2_att'])
    return G

def get_two_hop_neighborhood(G, node_id):
    # Get neighbors within two hops
    neighbors_one_hop = set(nx.all_neighbors(G, node_id))
    neighbors_two_hop = set()
    for n in neighbors_one_hop:
        neighbors_two_hop.update(nx.all_neighbors(G, n))
    return neighbors_one_hop.union(neighbors_two_hop)

def find_relation_specific_paths(G, start_id, end_id, max_depth=4):
    paths = []
    for path in nx.all_simple_paths(G, source=start_id, target=end_id, cutoff=max_depth):
        # Construct relation-specific path
        relation_path = [start_id]
        for i in range(len(path) - 1):
            relation = G[path[i]][path[i + 1]][0]['relation']  # Assuming one relation per edge for simplicity
            relation_path.extend([relation, path[i + 1]])
        paths.append(relation_path)
    return paths

def get_two_hop_neighborhood_enrichment_per_relation(G, node_id, K, K2, relation_averages, enrichment = True):
    
    # First hop: Get the top K neighbors for each relation type
    neighbors_by_relation = {}
    for neighbor in G[node_id]:
        for edge_key in G[node_id][neighbor]:
            edge_data = G[node_id][neighbor][edge_key]
            relation = edge_data['relation']
            weight = edge_data['layer1_att'] + edge_data['layer2_att']
            if enrichment:
                avg_weight = relation_averages[relation]
                relative_increase = ((weight - avg_weight) / avg_weight) * 100
            else:
                relative_increase = weight
            if relation not in neighbors_by_relation:
                neighbors_by_relation[relation] = []
            heapq.heappush(neighbors_by_relation[relation], (-relative_increase, neighbor))

    # Select the top K neighbors for each relation type
    first_hop_neighbors = set()
    for relation, neighbors in neighbors_by_relation.items():
        top_neighbors = [neighbor for _, neighbor in heapq.nlargest(K, neighbors)]
        first_hop_neighbors.update(top_neighbors)

    # Second hop: Repeat the process for each neighbor in first_hop_neighbors
    second_hop_neighbors = set()
    for first_hop_neighbor in first_hop_neighbors:
        neighbors_by_relation = {}
        for neighbor in G[first_hop_neighbor]:
            for edge_key in G[first_hop_neighbor][neighbor]:
                edge_data = G[first_hop_neighbor][neighbor][edge_key]
                relation = edge_data['relation']
                weight = edge_data['layer1_att'] + edge_data['layer2_att']
                
                if enrichment:
                    avg_weight = relation_averages[relation]
                    relative_increase = ((weight - avg_weight) / avg_weight) * 100
                else:
                    relative_increase = weight

                if relation not in neighbors_by_relation:
                    neighbors_by_relation[relation] = []
                heapq.heappush(neighbors_by_relation[relation], (-relative_increase, neighbor))

        for relation, neighbors in neighbors_by_relation.items():
            top_neighbors = [neighbor for _, neighbor in heapq.nlargest(K, neighbors)]
            second_hop_neighbors.update(top_neighbors)

    return first_hop_neighbors.union(second_hop_neighbors)

def score_path_enrichment(G, path, relation_averages, enrichment = True):
    score = 0
    path_length = len(path) // 2  # Number of edges in the path

    for i in range(0, path_length * 2, 2):
        node1 = path[i]
        relation = path[i + 1]
        node2 = path[i + 2]

        if (node1, node2) in G.edges():
            for edge_key in G[node1][node2]:
                edge_data = G[node1][node2][edge_key]
                if edge_data['relation'] == relation:
                    weight = edge_data['layer1_att'] + edge_data['layer2_att']
                    if enrichment:
                        avg_weight = relation_averages[relation]
                        # Calculate percentage relative increase
                        relative_increase = ((weight - avg_weight) / avg_weight) * 100
                    else:
                        relative_increase = weight
                    score += relative_increase

    score /= path_length

    return score

def print_beautiful_path(path):
    path = [i if 'rev' not in i else i[4:] for i in path]
    return ' -> '.join(path)

def print_beautiful_paths(paths):
    return [print_beautiful_path(i) for i in paths]

def group_paths_by_meta_paths_with_node_types(paths, G):
    meta_paths_dict = {}

    for path in paths:
        # Extract the meta path with node types and relation types
        meta_path = []
        for i in range(len(path)):
            if i % 2 == 0:  # Node
                node_id = path[i]
                node_type = G.nodes[node_id]['node_type']
                meta_path.append(node_type)
            else:  # Relation
                meta_path.append(path[i])

        meta_path = tuple(meta_path)

        # Add the path to the list of paths for the corresponding meta path
        if meta_path not in meta_paths_dict:
            meta_paths_dict[meta_path] = []
        meta_paths_dict[meta_path].append(path)

    return meta_paths_dict

def calculate_relation_averages(G, layer_only = False):
    relation_sums = {}
    relation_counts = {}

    # Iterate over all edges in the graph
    for u, v, data in tqdm(G.edges(data=True)):
        relation = data['relation']
        if layer_only:
            weight = data['layer' + str(layer_only) + '_att']
        else:
            weight = data['layer1_att'] + data['layer2_att']
        
        # Summing weights and counting occurrences for each relation
        if relation in relation_sums:
            relation_sums[relation] += weight
            relation_counts[relation] += 1
        else:
            relation_sums[relation] = weight
            relation_counts[relation] = 1

    # Calculating averages
    relation_averages = {rel: relation_sums[rel] / relation_counts[rel] for rel in relation_sums}
    return relation_averages

In [175]:
def find_meta_paths(X_id, Y_id, G, not_cool_rel, relation_averages, enrichment = True):
    neighborhood_X = get_two_hop_neighborhood_enrichment_per_relation(G, X_id, K, K2, relation_averages, enrichment)
    neighborhood_Y = get_two_hop_neighborhood_enrichment_per_relation(G, Y_id, K, K2, relation_averages, enrichment)
    common_neighborhood = neighborhood_X.union(neighborhood_Y)
    subG = G.subgraph(common_neighborhood)
    #print('Number of neighbors: ', len(subG))
    paths = find_relation_specific_paths(subG, X_id, Y_id, max_depth=4)
    path_scores_enrichment = [score_path_enrichment(G, path, relation_averages, enrichment) for path in paths]
    meta_paths_dict = group_paths_by_meta_paths_with_node_types(paths, G)
    meta_paths = list(meta_paths_dict.keys())
    valid_meta_paths = [i for i in meta_paths if len(np.intersect1d(not_cool_rel, i)) == 0]
    return paths, path_scores_enrichment, meta_paths_dict, meta_paths, valid_meta_paths

In [26]:
res.keys()

dict_keys(['prediction', 'label', 'result'])

In [34]:
import sys
sys.path.append('../')
from txgnn import TxData
txdata = TxData(data_folder_path = '../data')
txdata.prepare_split(split = 'random', seed = 1)

Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....


  6%|█████▌                                                                                             | 50.3M/888M [00:16<01:47, 7.82MiB/s]

Creating DGL graph....
Done!


In [35]:
mapping = txdata.retrieve_id_mapping()
idx2id_disease = mapping['idx2id_disease'] 
idx2id_drug = mapping['idx2id_drug'] 
id2name_disease = mapping['id2name_disease'] 
id2name_drug = mapping['id2name_drug']
id2idx_disease = {j:i for i,j in idx2id_disease.items()}

In [240]:
import pickle
with open('name_mapping.pkl', 'wb') as f: 
    pickle.dump(mapping, f)

In [7]:
## for each relation, get the top K nodes.
K = 10 ## top K first hop neighbors
K2 = 10 ## top K2 second hop neighbors

In [8]:
## build a graph using either d_att, d_gm, d_ge
G_att = build_graph(d_att)
G_gm = build_graph(d_gm)
G_ge = build_graph(d_ge)

## to calculate average weight per relation for enrichment calculation
relation_averages_att = calculate_relation_averages(G_att)
relation_averages_gm = calculate_relation_averages(G_gm)
relation_averages_ge = calculate_relation_averages(G_ge)

7676670it [07:53, 16214.33it/s]
7676670it [07:55, 16133.04it/s]
7676670it [07:59, 15993.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 7676670/7676670 [00:07<00:00, 960340.24it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 7676670/7676670 [00:08<00:00, 948166.28it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 7676670/7676670 [00:07<00:00, 977170.29it/s]


In [9]:
G_dict = {
    'att': G_att,
    'gm': G_gm,
    'ge': G_ge
}

relation_avg_dict = {
    'att': relation_averages_att,
    'gm': relation_averages_gm,
    'ge': relation_averages_ge
}

In [20]:
# query disease name
X_id = 'macular degeneration'

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [None]:
# top 20 diseases with many connections to the graph but zero indications
# top 20 diseases with few indications (ranked by accuracy)
# 10 diseases with recently approved therapies (which means lots of interest in these diseases)

In [19]:
result[result.Name == 'macular degeneration']['Hits@50'].values

array([list(['Lutein', 'Verteporfin', 'Ranibizumab', 'Pegaptanib', 'Anecortave acetate', 'Aflibercept', 'Brolucizumab'])],
      dtype=object)

In [21]:
truth = result[result.Name == X_id]['Hits@100'].values[0] + result[result.Name == X_id]['Missed@100'].values[0]

In [22]:
## given a disease, find the most promising drugs that are not approved. or pick your fav drugs

ranked_list = result[result.Name == X_id]['Ranked List'][0]
difference = np.setdiff1d(ranked_list, truth)
ordered_difference = np.array([item for item in ranked_list if item in difference])
Y_id = ordered_difference[0]
print(Y_id)

Polidocanol


In [38]:
Y_id = id2name_drug['DB15303']
Y_id

'Faricimab'

In [39]:
## filtering not-that-interesting meta paths?
not_cool_rel = ['rev_contraindication', 'contraindication', 'drug_drug', 'rev_off-label use', 'off-label use', 'anatomy_protein_absent', 'rev_anatomy_protein_absent']

In [86]:
idx2id_disease

{2502.0: '13924_12592_14672_13460_12591_12536_30861_8146_8148_32846_13459_44329_14544_9805_49223_9804_14086_8147_13515_14029_12581_19019',
 1038.0: '11160_13119_13978_12060_12327_12670_13210_11067_12903_12293_12376_12375_11767_10965_12460_10967_11602_12002_11762_13386_14363_10933_12452_13365_13250_13826_12445_12326_11360_11392_13985_14739_11351_13489_12421_9076_13738_11279_14675_11286_13249_12485_10986_12420_14428_12170_12091_12442_11364_13984_12418_14237_13010_12355_912_14469_12273_13269_12602_11774_10807_12977_12003_12370_11192_10987_11991_12333_10860_13929_13471_11912_13537_13963_11799_13215_11553_14182_19588_14849',
 15420.0: '8099_12497_12498',
 2962.0: '14854_14293_14470_12380_11832_14603_14853_11761_11032_14594_12975_10973_12090_14740_12902_10915_11058_14283_11519_12083_7424_11673_11389_13632_11103_11226_11102_12974_12086_11159_11074_11031_10963_13823_11660_11893_13305_11708_11994_12030_11625_11350_13114_12023_11568_11920_12976_14738_14291_10817_11480_13593_11657_19587',
 10457.

In [93]:
name2id_disease = {j:i for i,j in id2name_disease.items()}
name2id_drug = {j:i for i,j in id2name_drug.items()}

In [95]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

In [96]:
sigmoid(res['prediction']['1.0']['DB00002'])

0.009866393590457347

In [173]:
all_path = pd.DataFrame()
G = G_dict['gm']
enrichment = False
def get_path(X_id, Y_id, G, not_cool_rel, enrichment, label):
    try:
        path = find_meta_paths(X_id, Y_id, G, not_cool_rel, relation_avg_dict['gm'], enrichment = enrichment)
        to_save = []
        for meta_path, paths in path[2].items():
            if meta_path in path[-1]:
                path_scores = [score_path_enrichment(G, path, relation_avg_dict['gm'], enrichment) for path in paths]
                to_save += tuple(zip([X_id]* len(path_scores), 
                                     [Y_id]* len(path_scores), 
                                     [print_beautiful_path(meta_path)] * len(path_scores), 
                                     print_beautiful_paths(paths), 
                                     path_scores, 
                                     [sigmoid(res['prediction'][name2id_disease[X_id]][name2id_drug[Y_id]])] * len(path_scores)))
        out =  pd.DataFrame(to_save).rename(columns = {0: 'Disease', 1: 'Drug', 2: 'Meta-Path', 3: 'Path', 4: 'Path Score', 5: 'Prediction Score'})
        out['Category'] = label
        return out
    except:
        print('Error: ', X_id, ' ', Y_id)

In [None]:
# 10 disease-drug pairs with recently approved therapies

In [98]:
approve_diseases = ['von Hippel-Lindau disease', 'atopic dermatitis', 'familial hypercholesterolemia', \
                     'asthma', 'cytomegalovirus infection', 'acquired polycythemia vera', \
                     'psoriasis', 'type 2 diabetes mellitus', 'CDKL5 disorder', \
                     'myelofibrosis', 'macular degeneration']
s = """DB15463
DB12169
DB14901
DB15090
DB06234
DB15119
DB06083
DB15171
DB05087
DB11697
DB15303"""
fda_drugs = s.split('\n')
approve_drugs = [id2name_drug[i] for i in fda_drugs]

In [167]:
disease_drug_pairs = tuple(zip(approve_diseases, approve_drugs, ['Approved Pairs'] * len(approve_diseases)))

In [174]:
def get_path_wrapper(X):
    return get_path(X[0], X[1], G, not_cool_rel, enrichment, label = X[2])

In [108]:
all_path = pd.DataFrame()
for i in r:
    all_path = all_path.append(i)

In [109]:
all_path

Unnamed: 0,Disease,Drug,Meta-Path,Path,Path Score,Prediction Score,Category
0,von Hippel-Lindau disease,PT2977,disease -> disease_phenotype_positive -> effec...,von Hippel-Lindau disease -> disease_phenotype...,0.174054,0.711052,Approved Pairs
1,von Hippel-Lindau disease,PT2977,disease -> disease_phenotype_positive -> effec...,von Hippel-Lindau disease -> disease_phenotype...,0.188529,0.711052,Approved Pairs
2,von Hippel-Lindau disease,PT2977,disease -> disease_phenotype_positive -> effec...,von Hippel-Lindau disease -> disease_phenotype...,0.162783,0.711052,Approved Pairs
3,von Hippel-Lindau disease,PT2977,disease -> disease_phenotype_positive -> effec...,von Hippel-Lindau disease -> disease_phenotype...,0.156913,0.711052,Approved Pairs
4,von Hippel-Lindau disease,PT2977,disease -> disease_protein -> gene/protein -> ...,von Hippel-Lindau disease -> disease_protein -...,0.139921,0.711052,Approved Pairs
...,...,...,...,...,...,...,...
707,macular degeneration,Faricimab,disease -> indication -> drug -> drug_protein ...,macular degeneration -> indication -> Pegaptan...,0.503949,0.432673,Approved Pairs
708,macular degeneration,Faricimab,disease -> indication -> drug -> drug_protein ...,macular degeneration -> indication -> Broluciz...,0.591097,0.432673,Approved Pairs
709,macular degeneration,Faricimab,disease -> indication -> drug -> drug_protein ...,macular degeneration -> indication -> Pegaptan...,0.678410,0.432673,Approved Pairs
710,macular degeneration,Faricimab,disease -> indication -> drug -> drug_protein ...,macular degeneration -> indication -> Broluciz...,0.794607,0.432673,Approved Pairs


In [None]:
# 10 diseases with recent approvals and get their top 20 drugs for pairs

In [168]:
disease_drug_pairs = list(disease_drug_pairs)

In [169]:
label = 'Recent Approved Diseases'
for disease in approve_diseases:
    ranked_list = result[result.Name == disease]['Ranked List'][0]
    difference = np.setdiff1d(ranked_list, truth)
    ordered_difference = np.array([item for item in ranked_list if item in difference])
    Y_ids = ordered_difference[:5]
    disease_drug_pairs+=list(zip([disease] * 5, Y_ids, [label] * 5))
    

In [163]:
# top 20 diseases with many connections to the graph but zero indications
# top 20 diseases with few indications (ranked by accuracy)

In [104]:
result['number_of_indications'] = result.Labels.apply(lambda x: sum(x.values()))

In [117]:
disease_id_to_num_of_neighbors = dict(txdata.df[txdata.df.x_type == 'disease'].groupby('x_id').y_id.agg(len))

In [120]:
result['number_of_neighbors'] = result.ID.apply(lambda x: disease_id_to_num_of_neighbors[x] if x in disease_id_to_num_of_neighbors else 0)

In [124]:
result.number_of_neighbors.value_counts()

0      5966
1      1300
2      1263
3       914
4       747
       ... 
169       1
162       1
363       1
347       1
150       1
Name: number_of_neighbors, Length: 190, dtype: int64

In [None]:
## pick 10 diseases in each disease area, 5 with 0 indication, 5 with 1 indication, 

In [170]:
for area in ['cell_proliferation', 'mental_health', 'cardiovascular', 'anemia', 'adrenal_gland','autoimmune', 'metabolic_disorder', 'diabetes', 'neurodigenerative']:
    ind_diseases_test = pd.read_csv('../data/disease_files/' + area + '.csv').node_id.values
    disease_in_the_area = [id2name_disease[idx2id_disease[i]] for i in ind_diseases_test if i in idx2id_disease]
    ## diseases with a few indications
    disease_few_indications = result[result.Name.isin(disease_in_the_area) & (result.number_of_indications >= 1) & (result.number_of_indications < 5) & (result.number_of_neighbors != 0)].Name.values[:3]
    label = area + '-disease-few-indications-top3'
    for disease in disease_few_indications:
        ranked_list = result[result.Name == disease]['Ranked List'][0]
        difference = np.setdiff1d(ranked_list, truth)
        ordered_difference = np.array([item for item in ranked_list if item in difference])
        Y_ids = ordered_difference[:5]
        disease_drug_pairs+=list(zip([disease] * 5, Y_ids, [label] * 5))

    ## diseases with zero indications but many neighbors
    disease_zero_indications_many_neighbors = result[result.Name.isin(disease_in_the_area) & (result.number_of_indications == 0) & (result.number_of_neighbors != 0)].sort_values('number_of_neighbors')[::-1].Name.values[:3]
    
    label = area + '-disease-zero-indications-top3-num-neighbors'
    for disease in disease_zero_indications_many_neighbors:
        ranked_list = result[result.Name == disease]['Ranked List'][0]
        difference = np.setdiff1d(ranked_list, truth)
        ordered_difference = np.array([item for item in ranked_list if item in difference])
        Y_ids = ordered_difference[:5]
        disease_drug_pairs+=list(zip([disease] * 5, Y_ids, [label] * 5))

In [171]:
len(disease_drug_pairs)

276

In [177]:
import multiprocessing
with multiprocessing.Pool(30) as p:
    r = list(tqdm(p.imap(get_path_wrapper, disease_drug_pairs), total=len(disease_drug_pairs)))


  0%|                                                                                                                | 0/276 [00:00<?, ?it/s][A
  0%|▍                                                                                                       | 1/276 [00:00<04:05,  1.12it/s][A
  1%|█▏                                                                                                      | 3/276 [00:19<32:20,  7.11s/it][A
  1%|█▍                                                                                                    | 4/276 [04:01<6:09:24, 81.49s/it][A

Error:  hereditary continuous muscle fiber activity   Furazidin
Error:  hereditary continuous muscle fiber activity   Monoxerutin
Error:  trichothiodystrophy   Monoxerutin



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 276/276 [33:15<00:00,  7.23s/it][A


In [179]:
all_path = pd.DataFrame()
for i in r:
    all_path = all_path.append(i)

In [181]:
all_path.reset_index(drop = True).to_csv('path_viz.csv', index = False)

In [197]:
def parse_genes(x):
    idx_gene = []
    for idx, i in enumerate(x['Meta-Path'].split('->')):
        if i.strip() == 'gene/protein':
            idx_gene.append(idx)
    return [x['Path'].split('->')[gene].strip() for gene in idx_gene]

In [198]:
tqdm.pandas()
all_path['genes'] = all_path.progress_apply(lambda x: parse_genes(x), axis = 1)


  0%|                                                                                                             | 0/335752 [00:00<?, ?it/s][A
  1%|▉                                                                                              | 3466/335752 [00:00<00:09, 34659.21it/s][A
  4%|███▍                                                                                          | 12194/335752 [00:00<00:04, 65607.85it/s][A
  6%|█████▉                                                                                        | 21163/335752 [00:00<00:04, 76602.84it/s][A
  9%|████████▎                                                                                     | 29499/335752 [00:00<00:03, 79269.09it/s][A
 11%|██████████▌                                                                                   | 37796/335752 [00:00<00:03, 80602.96it/s][A
 14%|█████████████▍                                                                                | 47801/335752 [00:00<00:03, 8

In [200]:
gene_count_per_path = {}
for gene in tqdm(all_path.genes):
    for i in gene:
        if i in gene_count_per_path:
            gene_count_per_path[i] += 1
        else:
            gene_count_per_path[i] = 1


  0%|                                                                                                             | 0/335752 [00:00<?, ?it/s][A
100%|███████████████████████████████████████████████████████████████████████████████████████████| 335752/335752 [00:00<00:00, 1785299.41it/s][A


In [201]:
gene_count_per_path

{'EPAS1': 57,
 'CCND1': 440,
 'VHL': 498,
 'STAT5A': 25,
 'UBC': 1086,
 'EGLN3': 11,
 'NEDD8': 6,
 'PIM1': 10,
 'PSMC3': 1,
 'PCSK9': 772,
 'ABCA1': 1317,
 'APOA2': 393,
 'APOB': 536,
 'APOC3': 392,
 'APOE': 1609,
 'CETP': 651,
 'HMGCR': 2783,
 'LIPC': 539,
 'LPL': 1184,
 'PON1': 444,
 'PON2': 1112,
 'SREBF2': 1282,
 'LDLRAP1': 1040,
 'ABCG5': 270,
 'ABCG8': 307,
 'GHR': 1473,
 'PPP1R17': 181,
 'SREBF1': 17,
 'ARMC6': 22,
 'HBA1': 121,
 'HBB': 104,
 'JAK2': 1781,
 'IFNA2': 425,
 'JAK1': 686,
 'MPL': 484,
 'SH2B3': 112,
 'TET2': 762,
 'GH1': 62,
 'EGLN1': 130,
 'SLC30A10': 53,
 'HBA2': 9,
 'IL6': 1443,
 'IL12B': 777,
 'IL23A': 721,
 'IL6R': 44,
 'IL6ST': 19,
 'IL2': 136,
 'CXCL8': 447,
 'RUNX3': 962,
 'HLA-C': 1377,
 'MKI67': 948,
 'NOS2': 2490,
 'REN': 518,
 'TYK2': 1527,
 'VNN2': 860,
 'TRAF3IP2': 1150,
 'IL36RN': 265,
 'VNN3': 279,
 'CARD14': 792,
 'LCE3D': 225,
 'TAGAP': 781,
 'ZNF816': 581,
 'NFKB1': 37,
 'TCF7': 23,
 'HDAC1': 455,
 'IRF1': 8,
 'POLR2A': 2,
 'TBP': 858,
 'CREBBP': 

In [203]:
all_path['gene occurences(out of 335752 paths)'] = all_path.genes.apply(lambda x: {i: gene_count_per_path[i] for i in x})

In [208]:
gene_count_per_path_per_disease = {}
path_count_per_disease = {}
for disease, gene in tqdm(all_path[['Disease','genes']].values):
    if disease not in gene_count_per_path_per_disease:
        gene_count_per_path_per_disease[disease] = {}
        path_count_per_disease[disease] = 0
    for i in gene:
        if i in gene_count_per_path_per_disease[disease]:
            gene_count_per_path_per_disease[disease][i] += 1
        else:
            gene_count_per_path_per_disease[disease][i] = 1
    path_count_per_disease[disease] += 1


  0%|                                                                                                             | 0/335752 [00:00<?, ?it/s][A
 14%|█████████████▏                                                                               | 47413/335752 [00:00<00:00, 474111.24it/s][A
 31%|████████████████████████████▎                                                               | 103506/335752 [00:00<00:00, 525173.93it/s][A
 49%|████████████████████████████████████████████▋                                               | 163252/335752 [00:00<00:00, 558174.32it/s][A
 66%|████████████████████████████████████████████████████████████▊                               | 221849/335752 [00:00<00:00, 569140.28it/s][A
100%|████████████████████████████████████████████████████████████████████████████████████████████| 335752/335752 [00:00<00:00, 578787.19it/s][A


In [209]:
path_count_per_disease

{'von Hippel-Lindau disease': 2671,
 'familial hypercholesterolemia': 18448,
 'acquired polycythemia vera': 19932,
 'psoriasis': 23359,
 'CDKL5 disorder': 513,
 'macular degeneration': 1261,
 'atopic dermatitis': 574,
 'asthma': 20978,
 'cytomegalovirus infection': 6,
 'type 2 diabetes mellitus': 7374,
 'myelofibrosis': 2528,
 'severe pre-eclampsia': 36,
 'mycetoma': 2069,
 'disorder of phenylalanine metabolism': 398,
 'multiple congenital anomalies/dysmorphic syndrome-intellectual disability': 2,
 'congenital myasthenic syndrome': 5797,
 'NGLY1-deficiency': 5926,
 'cryptosporidiosis': 18387,
 'hypersensitivity pneumonitis': 6978,
 'malignant hyperthermia of anesthesia': 4731,
 'X-linked intellectual disability': 12958,
 'limb-girdle muscular dystrophy': 5825,
 'chromosome 1p36 deletion syndrome': 5272,
 'epithelioid sarcoma': 4987,
 'neuroendocrine neoplasm': 11,
 'hepatic veno-occlusive disease': 502,
 'infectious disease': 4453,
 'Cowden disease': 8596,
 'Ehlers-Danlos syndrome, cla

In [207]:
gene_count_per_path_per_disease

{'von Hippel-Lindau disease': {'EPAS1': 57,
  'CCND1': 436,
  'VHL': 348,
  'STAT5A': 2,
  'UBC': 20,
  'EGLN3': 1,
  'NEDD8': 1,
  'PIM1': 3,
  'PSMC3': 1,
  'GLUL': 46,
  'BCHE': 41,
  'GHR': 34,
  'AR': 2,
  'THRA': 72,
  'UGT1A1': 80,
  'ABCB1': 249,
  'SLCO1B3': 59,
  'SLCO1C1': 66,
  'SLC7A5': 94,
  'ALB': 350,
  'SERPINA7': 24,
  'SLCO4A1': 67,
  'SLC10A1': 24,
  'SLCO4C1': 43,
  'THRB': 63,
  'IFNA2': 2,
  'ESR2': 1,
  'PTGS2': 2,
  'AVP': 3,
  'KNG1': 3,
  'TRH': 1,
  'GSK3B': 2,
  'ESR1': 1,
  'TCF4': 2,
  'STAT3': 3,
  'ABCG2': 58,
  'CRYAB': 2,
  'HIF1A': 2},
 'familial hypercholesterolemia': {'PCSK9': 772,
  'ABCA1': 1317,
  'APOA2': 393,
  'APOB': 536,
  'APOC3': 392,
  'APOE': 1261,
  'CETP': 651,
  'HMGCR': 2647,
  'LIPC': 539,
  'LPL': 1184,
  'PON1': 442,
  'PON2': 1111,
  'SREBF2': 1282,
  'LDLRAP1': 1040,
  'ABCG5': 270,
  'ABCG8': 307,
  'GHR': 1027,
  'PPP1R17': 181,
  'UBC': 109,
  'SREBF1': 17,
  'ARMC6': 5,
  'GCG': 2,
  'UGT2B7': 611,
  'ITGAL': 564,
  'ABCB11

In [210]:
all_path['gene occurences for this disease'] = all_path.apply(lambda x: {i: gene_count_per_path_per_disease[x.Disease][i] for i in x.genes}, axis = 1)

In [212]:
all_path['number of paths for this disease'] = all_path.Disease.apply(lambda x: path_count_per_disease[x])

In [223]:
all_path = all_path[~all_path.Category.isin(['Recent Approved Diseases', 'Approved Pairs'])]
all_path = all_path.rename(columns = {'gene occurences(out of 335752 paths)': 'gene occurences across all paths'})

In [224]:
all_path.to_pickle('paths.pkl')

In [226]:
all_path.to_csv('paths.csv', index = False)

In [228]:
all_path.sort_values('Path Score')[::-1]

Unnamed: 0,Disease,Drug,Meta-Path,Path,Path Score,Prediction Score,Category,genes,gene occurences across all paths,gene occurences for this disease,number of paths for this disease
125,mevalonic aciduria,Canakinumab,disease -> indication -> drug,mevalonic aciduria -> indication -> Canakinumab,1.995601,0.982810,metabolic_disorder-disease-few-indications-top3,[],{},{},349
63,disorder of phenylalanine metabolism,Sapropterin,disease -> indication -> drug,disorder of phenylalanine metabolism -> indica...,1.978879,0.999842,cell_proliferation-disease-few-indications-top3,[],{},{},398
148,pancreatic insulinoma,Diazoxide,disease -> indication -> drug,pancreatic insulinoma -> indication -> Diazoxide,1.971971,0.999731,metabolic_disorder-disease-few-indications-top3,[],{},{},503
856,nephrogenic syndrome of inappropriate antidiur...,Tolvaptan,disease -> indication -> drug,nephrogenic syndrome of inappropriate antidiur...,1.967655,0.999577,neurodigenerative-disease-few-indications-top3,[],{},{},5273
141,hereditary angioedema with C1Inh deficiency,Ecallantide,disease -> indication -> drug,hereditary angioedema with C1Inh deficiency ->...,1.952124,0.999559,metabolic_disorder-disease-few-indications-top3,[],{},{},1827
...,...,...,...,...,...,...,...,...,...,...,...
3689,deafness dystonia syndrome,Testosterone,disease -> disease_protein -> gene/protein -> ...,deafness dystonia syndrome -> disease_protein ...,0.026248,0.995375,autoimmune-disease-zero-indications-top3-num-n...,"[TBP, AHR]","{'TBP': 858, 'AHR': 4}","{'TBP': 841, 'AHR': 3}",7565
3659,deafness dystonia syndrome,Testosterone,disease -> disease_protein -> gene/protein -> ...,deafness dystonia syndrome -> disease_protein ...,0.025970,0.995375,autoimmune-disease-zero-indications-top3-num-n...,[TBP],{'TBP': 858},{'TBP': 841},7565
3658,deafness dystonia syndrome,Testosterone,disease -> disease_protein -> gene/protein -> ...,deafness dystonia syndrome -> disease_protein ...,0.025734,0.995375,autoimmune-disease-zero-indications-top3-num-n...,[TBP],{'TBP': 858},{'TBP': 841},7565
3660,deafness dystonia syndrome,Testosterone,disease -> disease_protein -> gene/protein -> ...,deafness dystonia syndrome -> disease_protein ...,0.025653,0.995375,autoimmune-disease-zero-indications-top3-num-n...,[TBP],{'TBP': 858},{'TBP': 841},7565


In [233]:
all_path_exclude = all_path[~all_path['Meta-Path'].str.contains('indication')]
all_path_exclude = all_path_exclude[~all_path_exclude['Meta-Path'].str.contains('exposure')]
all_path_exclude = all_path_exclude[~all_path_exclude['Meta-Path'].str.contains('drug_effect')]

In [236]:
all_path_exclude = all_path_exclude[~all_path_exclude['Category'].str.contains('cell_proliferation')]

In [238]:
all_path_exclude.reset_index(drop = True).to_csv('paths_filter.csv', index = False)

In [229]:
len(all_path.Disease.unique())

42

In [187]:
idx_gene

[6]

In [196]:
[all_path['Path'].values[0].split('->')[gene].strip() for gene in idx_gene]

['EPAS1']

In [141]:
## diseases with a few indications
disease_few_indications = result[result.Name.isin(disease_in_the_area) & (result.number_of_indications >= 1) & (result.number_of_indications < 5) & (result.number_of_neighbors != 0)].Name.values
## diseases with zero indications but many neighbors
disease_zero_indications_many_neighbors = result[result.Name.isin(disease_in_the_area) & (result.number_of_indications == 0) & (result.number_of_neighbors != 0)].sort_values('number_of_neighbors')[::-1].Name.values[:5]

In [133]:
disease_in_the_area

['Faye-Petersen-Ward-Carey syndrome',
 'carcinoma of urethra',
 'cornea cancer',
 'glycine metabolism disease',
 'Trichomonas cervicitis',
 'cervical hypertrichosis with underlying kyphoscoliosis',
 'middle ear squamous cell carcinoma',
 'carotid artery occlusion',
 'X-linked recessive ocular albinism',
 'primary brain neoplasm',
 'gershinibaruch Leibo syndrome',
 'flat urothelial hyperplasia',
 'common bile duct neoplasm',
 'delusional disorder',
 'gastrin-producing neuroendocrine tumor',
 'acanthosis nigricans (disease)',
 'amaurosis fugax',
 'Crohn jejunoileitis',
 'deafness-ear malformation-facial palsy syndrome',
 'premenstrual tension',
 'uterine corpus adenosarcoma',
 'recurrent hypersomnia',
 'accommodative esotropia',
 'mixed hepatoblastoma',
 'intramural coronary arterial course',
 'stage II endometrioid carcinoma',
 'Liang-Wang syndrome',
 'Klebsiella pneumonia',
 'pancreatic intraductal papillary-mucinous neoplasm with low grade dysplasia',
 'ovarian cyst (disease)',
 'beni

In [128]:
id2name_disease[idx2id_disease[i]] for i in ind_diseases_test

{2502.0: '13924_12592_14672_13460_12591_12536_30861_8146_8148_32846_13459_44329_14544_9805_49223_9804_14086_8147_13515_14029_12581_19019',
 1038.0: '11160_13119_13978_12060_12327_12670_13210_11067_12903_12293_12376_12375_11767_10965_12460_10967_11602_12002_11762_13386_14363_10933_12452_13365_13250_13826_12445_12326_11360_11392_13985_14739_11351_13489_12421_9076_13738_11279_14675_11286_13249_12485_10986_12420_14428_12170_12091_12442_11364_13984_12418_14237_13010_12355_912_14469_12273_13269_12602_11774_10807_12977_12003_12370_11192_10987_11991_12333_10860_13929_13471_11912_13537_13963_11799_13215_11553_14182_19588_14849',
 15420.0: '8099_12497_12498',
 2962.0: '14854_14293_14470_12380_11832_14603_14853_11761_11032_14594_12975_10973_12090_14740_12902_10915_11058_14283_11519_12083_7424_11673_11389_13632_11103_11226_11102_12974_12086_11159_11074_11031_10963_13823_11660_11893_13305_11708_11994_12030_11625_11350_13114_12023_11568_11920_12976_14738_14291_10817_11480_13593_11657_19587',
 10457.

In [40]:
# using enrichment or raw weight?
enrichment = False

xai2path = {}
for name, G in G_dict.items():
    print(name)
    xai2path[name] = find_meta_paths(X_id, Y_id, G, not_cool_rel, relation_avg_dict[name], enrichment = enrichment)

att
Number of neighbors:  609



0it [00:00, ?it/s][A
1it [00:00,  5.57it/s][A
3it [00:00,  5.64it/s][A
9it [00:00, 17.94it/s][A
19it [00:00, 37.49it/s][A
27it [00:00, 47.26it/s][A
34it [00:00, 52.23it/s][A
43it [00:01, 62.19it/s][A
53it [00:01, 72.64it/s][A
62it [00:01, 73.61it/s][A
73it [00:01, 82.29it/s][A
84it [00:01, 88.55it/s][A
100it [00:01, 105.60it/s][A
111it [00:01, 91.23it/s] [A
121it [00:01, 93.00it/s][A
131it [00:01, 91.26it/s][A
141it [00:02, 77.41it/s][A
150it [00:02, 80.00it/s][A
159it [00:02, 81.18it/s][A
170it [00:02, 86.70it/s][A
179it [00:02, 84.78it/s][A
189it [00:02, 85.87it/s][A
200it [00:02, 89.85it/s][A
210it [00:02, 84.73it/s][A
219it [00:03, 84.36it/s][A
229it [00:03, 88.24it/s][A
238it [00:03, 77.98it/s][A
248it [00:03, 80.21it/s][A
258it [00:03, 84.86it/s][A
269it [00:03, 87.15it/s][A
278it [00:03, 84.12it/s][A
289it [00:03, 90.52it/s][A
302it [00:03, 100.03it/s][A
313it [00:04, 68.28it/s] [A
322it [00:04, 56.94it/s][A
329it [00:04, 55.85it/s][A
336it [

gm
Number of neighbors:  715



0it [00:00, ?it/s][A
1it [00:00,  4.12it/s][A
4it [00:00, 12.32it/s][A
6it [00:00,  8.62it/s][A
11it [00:00, 16.83it/s][A
18it [00:00, 28.65it/s][A
26it [00:01, 40.38it/s][A
34it [00:01, 49.42it/s][A
42it [00:01, 57.38it/s][A
50it [00:01, 60.23it/s][A
57it [00:01, 61.67it/s][A
66it [00:01, 67.16it/s][A
78it [00:01, 79.99it/s][A
91it [00:01, 93.52it/s][A
101it [00:01, 72.07it/s][A
110it [00:02, 70.60it/s][A
118it [00:02, 69.49it/s][A
126it [00:02, 50.07it/s][A
133it [00:02, 50.40it/s][A
140it [00:02, 53.27it/s][A
147it [00:02, 57.01it/s][A
155it [00:02, 61.07it/s][A
163it [00:03, 65.24it/s][A
170it [00:03, 66.33it/s][A
177it [00:03, 65.21it/s][A
186it [00:03, 68.85it/s][A
194it [00:03, 64.37it/s][A
201it [00:03, 61.44it/s][A
209it [00:03, 65.70it/s][A
217it [00:03, 68.63it/s][A
224it [00:03, 68.75it/s][A
232it [00:04, 71.59it/s][A
240it [00:04, 69.46it/s][A
248it [00:04, 71.75it/s][A
256it [00:04, 51.81it/s][A
264it [00:04, 56.24it/s][A
272it [00:04,

ge
Number of neighbors:  749



0it [00:00, ?it/s][A
1it [00:00,  4.75it/s][A
3it [00:00,  5.39it/s][A
8it [00:00, 15.16it/s][A
15it [00:00, 27.18it/s][A
22it [00:00, 36.30it/s][A
30it [00:01, 45.84it/s][A
36it [00:01, 46.92it/s][A
44it [00:01, 54.42it/s][A
50it [00:01, 49.27it/s][A
56it [00:01, 29.92it/s][A
63it [00:01, 36.15it/s][A
70it [00:02, 40.96it/s][A
77it [00:02, 46.64it/s][A
84it [00:02, 50.01it/s][A
91it [00:02, 52.81it/s][A
102it [00:02, 61.88it/s][A
109it [00:02, 60.91it/s][A
116it [00:02, 59.02it/s][A
125it [00:02, 66.33it/s][A
132it [00:03, 57.26it/s][A
140it [00:03, 59.74it/s][A
148it [00:03, 62.13it/s][A
155it [00:03, 60.91it/s][A
162it [00:03, 59.87it/s][A
169it [00:03, 60.09it/s][A
176it [00:03, 53.91it/s][A
182it [00:04, 37.71it/s][A
190it [00:04, 44.70it/s][A
197it [00:04, 47.80it/s][A
205it [00:04, 54.03it/s][A
212it [00:04, 55.06it/s][A
219it [00:04, 56.99it/s][A
231it [00:04, 71.88it/s][A
240it [00:04, 73.40it/s][A
248it [00:04, 71.83it/s][A
256it [00:05, 6

In [41]:
print_beautiful_paths(xai2path['att'][-1])

['disease -> disease_disease -> disease -> disease_protein -> gene/protein -> protein_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_disease -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_phenotype_positive -> effect/phenotype -> disease_phenotype_positive -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> anatomy_protein_present -> anatomy -> anatomy_protein_present -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> cellcomp_protein -> cellular_component -> cellcomp_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> disease_protein -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> bioprocess_protein -> biological_process -> bioprocess_protein -> gene/protein -> drug_protein -> drug',
 'disease -

In [42]:
## get all paths in a meta path
meta_path = xai2path['att'][-1][8]
print('for meta path: ', meta_path)
print_beautiful_paths(xai2path['att'][2][meta_path])

for meta path:  ('disease', 'rev_disease_protein', 'gene/protein', 'phenotype_protein', 'effect/phenotype', 'rev_phenotype_protein', 'gene/protein', 'rev_drug_protein', 'drug')


['macular degeneration -> disease_protein -> APOE -> phenotype_protein -> Proteinuria -> phenotype_protein -> VEGFA -> drug_protein -> Faricimab']

In [44]:
## get all paths in a meta path
meta_path = xai2path['att'][-1][6]
print('for meta path: ', meta_path)
print_beautiful_paths(xai2path['att'][2][meta_path])

for meta path:  ('disease', 'rev_disease_protein', 'gene/protein', 'bioprocess_protein', 'biological_process', 'rev_bioprocess_protein', 'gene/protein', 'rev_drug_protein', 'drug')


['macular degeneration -> disease_protein -> APOE -> bioprocess_protein -> negative regulation of gene expression -> bioprocess_protein -> VEGFA -> drug_protein -> Faricimab',
 'macular degeneration -> disease_protein -> APOE -> bioprocess_protein -> gene expression -> bioprocess_protein -> ANGPT2 -> drug_protein -> Faricimab',
 'macular degeneration -> disease_protein -> HTRA1 -> bioprocess_protein -> positive regulation of epithelial cell proliferation -> bioprocess_protein -> VEGFA -> drug_protein -> Faricimab',
 'macular degeneration -> disease_protein -> VEGFA -> bioprocess_protein -> angiogenesis -> bioprocess_protein -> ANGPT2 -> drug_protein -> Faricimab',
 'macular degeneration -> disease_protein -> VEGFA -> bioprocess_protein -> positive regulation of angiogenesis -> bioprocess_protein -> ANGPT2 -> drug_protein -> Faricimab',
 'macular degeneration -> disease_protein -> SQSTM1 -> bioprocess_protein -> positive regulation of protein phosphorylation -> bioprocess_protein -> VEG

In [45]:
print_beautiful_paths(xai2path['gm'][-1])

['disease -> disease_disease -> disease -> disease_protein -> gene/protein -> protein_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_disease -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_disease -> disease -> disease_disease -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_phenotype_positive -> effect/phenotype -> disease_phenotype_positive -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> anatomy_protein_present -> anatomy -> anatomy_protein_present -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> bioprocess_protein -> biological_process -> bioprocess_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> cellcomp_protein -> cellular_component -> cellcomp_protein -> gene/protein -> drug_protein -> drug',
 'disease -> dis

In [55]:
## get all paths in a meta path
meta_path = xai2path['gm'][-1][3]
print('for meta path: ', meta_path)
print_beautiful_paths(xai2path['gm'][2][meta_path])

for meta path:  ('disease', 'disease_phenotype_positive', 'effect/phenotype', 'rev_disease_phenotype_positive', 'disease', 'rev_disease_protein', 'gene/protein', 'rev_drug_protein', 'drug')


['macular degeneration -> disease_phenotype_positive -> Autosomal dominant inheritance -> disease_phenotype_positive -> lung cancer -> disease_protein -> VEGFA -> drug_protein -> Faricimab',
 'macular degeneration -> disease_phenotype_positive -> Autosomal dominant inheritance -> disease_phenotype_positive -> cystoid macular edema -> disease_protein -> VEGFA -> drug_protein -> Faricimab']

In [147]:
print_beautiful_paths(xai2path['ge'][-1])

['disease -> disease_protein -> gene/protein -> anatomy_protein_present -> anatomy -> anatomy_protein_present -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> cellcomp_protein -> cellular_component -> cellcomp_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> disease_protein -> disease -> disease_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> molfunc_protein -> molecular_function -> molfunc_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> bioprocess_protein -> biological_process -> bioprocess_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> pathway_protein -> pathway -> pathway_protein -> gene/protein -> drug_protein -> drug',
 'disease -> disease_protein -> gene/protein -> protein_protein -> gene/protein -> protein_protein -> gene/protein -> drug_prot

In [148]:
## get all paths in a meta path
meta_path = xai2path['ge'][-1][-8]
print('for meta path: ', meta_path)
print_beautiful_paths(xai2path['ge'][2][meta_path])

for meta path:  ('disease', 'rev_disease_protein', 'gene/protein', 'anatomy_protein_present', 'anatomy', 'rev_anatomy_protein_present', 'gene/protein', 'rev_drug_protein', 'drug')


['Dravet syndrome -> disease_protein -> GABRA1 -> anatomy_protein_present -> cerebellar cortex -> anatomy_protein_present -> MTNR1A -> drug_protein -> Tasimelteon',
 'Dravet syndrome -> disease_protein -> GABRG2 -> anatomy_protein_present -> cerebellar cortex -> anatomy_protein_present -> MTNR1A -> drug_protein -> Tasimelteon',
 'Dravet syndrome -> disease_protein -> PMP22 -> anatomy_protein_present -> adult mammalian kidney -> anatomy_protein_present -> MTNR1A -> drug_protein -> Tasimelteon',
 'Dravet syndrome -> disease_protein -> PMP22 -> anatomy_protein_present -> cerebellar cortex -> anatomy_protein_present -> MTNR1A -> drug_protein -> Tasimelteon',
 'Dravet syndrome -> disease_protein -> POMC -> anatomy_protein_present -> adult mammalian kidney -> anatomy_protein_present -> MTNR1A -> drug_protein -> Tasimelteon',
 'Dravet syndrome -> disease_protein -> POMC -> anatomy_protein_present -> cerebellar cortex -> anatomy_protein_present -> MTNR1A -> drug_protein -> Tasimelteon',
 'Drav

In [165]:
## save paths
expand_xai_name ={
    'att': 'attention', 
    'ge': 'gnnexplainer', 
    'gm': 'graphmask'
}
for explainer in ['att', 'ge', 'gm']:
    to_save = []
    for meta_path, paths in xai2path[explainer][2].items():
        if meta_path in xai2path[explainer][-1]:
            path_scores = [score_path_enrichment(G, path, relation_avg_dict[explainer], enrichment) for path in paths]
            to_save += tuple(zip([print_beautiful_path(meta_path)] * len(path_scores), print_beautiful_paths(paths), path_scores))
    pd.DataFrame(to_save).rename(columns = {0: 'Meta-Path', 1: 'Path', 2: 'Path Score'}).to_csv(X_id + '_' + Y_id + '_xai_' + expand_xai_name[explainer] + '.csv', index = False)