In [1]:
%reload_kedro

In [2]:
# -*- coding: utf-8 -*-
# from __future__ import absolute_import, division, print_function, unicode_literals
import random
import logging
import itertools
from collections import defaultdict
from rich import print
from IPython.display import display, HTML
from tqdm.notebook import trange, tqdm

import pandas as pd
import numpy as np
from modspy_data.helpers import KnowledgeGraphScores

import matplotlib.pyplot as plt
import seaborn as sns

import dask
import dask.dataframe as dd
import dask.array as da
from dask.distributed import Client, progress, performance_report
from dask_jobqueue import SLURMCluster


import torch
import pronto
import networkx as nx
# from utils import visualize
from pyvis.network import Network
from sklearn.preprocessing import LabelEncoder
# from stellargraph import StellarGraph    # Lovin' it!
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import dgl
# import optuna
from nxontology import NXOntology
from nxontology.imports import (from_file, multidigraph_to_digraph,
                                pronto_to_multidigraph)
from networkx.drawing.nx_agraph import graphviz_layout
from nxontology.viz import create_similarity_graphviz
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T
from torch_geometric.utils.convert import from_networkx
from torch.nn import Linear, ModuleList
from torch_geometric.nn import GATConv, GraphConv, MetaPath2Vec
# from torch_geometric.loaders import DataLoader, Dataset
from torch_geometric.data import HeteroData, DataLoader
from torch_geometric.utils import convert

# Combined dateset

In [3]:
jvl = catalog.load('jvl_annotated')
olida = catalog.load('olida_annotated')
mtg = catalog.load('zyg1_annotated')

In [4]:
olida = olida[olida['Oligogenic Effect']!='Monogenic+Modifier']

In [5]:
print(olida['Oligogenic Effect'].value_counts())
print(olida.shape)

In [6]:
# Adding source column
jvl['source'] = 'JVL'
olida['source'] = 'OLIDA'
mtg['source'] = 'MTG'

# Classification labels
jvl['is_modifier'] = 1
olida['is_modifier'] = 0

# Rename columns to indicate same information
jvl.rename(columns={'QueryGene': 'target_gene', 'SuppressorGene': 'modifier_gene'}, inplace=True)
olida.rename(columns={'gene_a': 'target_gene', 'gene_b': 'modifier_gene'}, inplace=True)
mtg.rename(columns={'gene_symbol': 'modifier_gene', 'target_gene_symbol': 'target_gene'}, inplace=True)


In [7]:
dataset_df = pd.concat([jvl, olida, mtg])
print(dataset_df.shape)

print("How many from each source?")
print(dataset_df['source'].value_counts())

print("How many with modifier effect?")
print(dataset_df[dataset_df['is_modifier']==1]['source'].value_counts())

In [34]:
ndf[['id', 'name']].compute().set_index('name')


Unnamed: 0_level_0,id
name,Unnamed: 1_level_1
ank1,PomBase:SPAC105.02c
pds5,PomBase:SPAC110.02
mam2,PomBase:SPAC11H11.04
miy1,PomBase:SPAC12G12.11c
sif3,PomBase:SPAC12G12.15
...,...
ZSCAN22,HGNC:4929
ZSCAN26,HGNC:12978
ZSCAN30,HGNC:33517
ZSCAN32,HGNC:20812


In [35]:
node_id_map = ndf[['id', 'name']].compute().set_index('name').to_dict()['id']
node_id_map


[1m{[0m
    [32m'ank1'[0m: [32m'PomBase:SPAC105.02c'[0m,
    [32m'pds5'[0m: [32m'PomBase:SPAC110.02'[0m,
    [32m'mam2'[0m: [32m'PomBase:SPAC11H11.04'[0m,
    [32m'miy1'[0m: [32m'PomBase:SPAC12G12.11c'[0m,
    [32m'sif3'[0m: [32m'PomBase:SPAC12G12.15'[0m,
    [32m'nce101'[0m: [32m'PomBase:SPAC12G12.17'[0m,
    [32m'erg11'[0m: [32m'PomBase:SPAC13A11.02c'[0m,
    [32m'alm1'[0m: [32m'PomBase:SPAC1486.04c'[0m,
    [32m'mug5'[0m: [32m'PomBase:SPAC14C4.08'[0m,
    [32m'dad3'[0m: [32m'PomBase:SPAC14C4.16'[0m,
    [32m'pun1'[0m: [32m'PomBase:SPAC15A10.09c'[0m,
    [32m'grx2'[0m: [32m'PomBase:SPAC15E1.09'[0m,
    [32m'seh1'[0m: [32m'PomBase:SPAC15F9.02'[0m,
    [32m'pli1'[0m: [32m'PomBase:SPAC1687.05'[0m,
    [32m'rpl44'[0m: [32m'PomBase:SPAC1687.06c'[0m,
    [32m'gsk3'[0m: [32m'PomBase:SPAC1687.15'[0m,
    [32m'erg31'[0m: [32m'PomBase:SPAC1687.16c'[0m,
    [32m'nop12'[0m: [32m'PomBase:SPAC16E8.06c'[0m,
    [32m'mde10'

In [40]:
edge_types = edges['predicate'].value_counts()
print(edge_types)
print(len(edge_types))

In [41]:
dataset_df['target_gene_monarch_id'] = dataset_df['target_gene'].map(node_id_map)
dataset_df['modifier_gene_monarch_id'] = dataset_df['modifier_gene'].map(node_id_map)
dataset_df['relation'] = dataset_df['is_modifier'].map({1: 'modifier', 0: 'non-modifier'})

In [42]:
dataset_df.head()

Unnamed: 0,source,target_gene,modifier_gene,is_modifier,PubmedID,Category,Tissue,QueryFunction,QueryMutation,QueryType,...,wpo_lin_bma,wpo_jiang_max,wpo_jiang_avg,wpo_jiang_bma,wpo_jiang_seco_max,wpo_jiang_seco_avg,wpo_jiang_seco_bma,target_gene_monarch_id,modifier_gene_monarch_id,relation
0,JVL,APOE,CASP7,1.0,27358062.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/C112R,-,...,,,,,,,,HGNC:613,HGNC:1508,modifier
1,JVL,APOE,HBB,1.0,24116184.0,Patients,-,Lipid and sterol biosynthesis & transport,C112R/?,-,...,,,,,,,,HGNC:613,HGNC:4827,modifier
2,JVL,APOE,KL,1.0,30867273.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/?,-,...,,,,,,,,HGNC:613,NCBIGene:784635,modifier
3,JVL,APOE,KL,1.0,32282020.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/?,-,...,,,,,,,,HGNC:613,NCBIGene:784635,modifier
4,JVL,ATR,ETV1,1.0,23284306.0,Cells,-,DNA replication and repair;Signaling & stress ...,silencing/silencing,LOF,...,,,,,,,,,,modifier


In [8]:
list(dataset_df.columns)


[1m[[0m
    [32m'PubmedID'[0m,
    [32m'Category'[0m,
    [32m'Tissue'[0m,
    [32m'target_gene'[0m,
    [32m'QueryFunction'[0m,
    [32m'QueryMutation'[0m,
    [32m'QueryType'[0m,
    [32m'modifier_gene'[0m,
    [32m'SuppressorFunction'[0m,
    [32m'SuppressorMutation'[0m,
    [32m'SNP_ID'[0m,
    [32m'SuppressorType'[0m,
    [32m'EffectSize'[0m,
    [32m'Disease'[0m,
    [32m'DiseaseSubType'[0m,
    [32m'CellLineIdentified'[0m,
    [32m'ModelSystemValidated'[0m,
    [32m'Drugs'[0m,
    [32m'target_GOs'[0m,
    [32m'target_GOs_count'[0m,
    [32m'modifier_GOs'[0m,
    [32m'modifier_GOs_count'[0m,
    [32m'target_POs'[0m,
    [32m'target_POs_count'[0m,
    [32m'modifier_POs'[0m,
    [32m'modifier_POs_count'[0m,
    [32m'target_DOs'[0m,
    [32m'target_DOs_count'[0m,
    [32m'modifier_DOs'[0m,
    [32m'modifier_DOs_count'[0m,
    [32m'source'[0m,
    [32m'is_modifier'[0m,
    [32m'Entry Id'[0m,
    [32m'Genes'[0m,
  

In [9]:
olida[['target_GOs','modifier_GOs']]

Unnamed: 0,target_GOs,modifier_GOs
1,"GO:0005080,GO:0005200,GO:0005516,GO:0007010,GO...","GO:0000977,GO:0003682,GO:0003712,GO:0003712,GO..."
3,"GO:0071407,GO:0004497,GO:0004497,GO:0004497,GO...","GO:0051897,GO:0070374,GO:0001725,GO:0005884,GO..."
4,"GO:0005243,GO:0005243,GO:0005509,GO:0005515,GO...","GO:0006883,GO:0004252,GO:0004252,GO:0017080,GO..."
6,"GO:0005515,GO:0005515,GO:0005515,GO:0005515,GO...","GO:0002153,GO:0002153,GO:0006357,GO:1990904,GO..."
10,"GO:0003779,GO:0005515,GO:0005515,GO:0005515,GO...","GO:0003723,GO:0005164,GO:0005515,GO:0005515,GO..."
...,...,...
129,"GO:0005524,GO:0016887,GO:0120020,GO:0005515,GO...","GO:0003723,GO:0004252,GO:0005515,GO:0005515,GO..."
130,"GO:0001540,GO:0001540,GO:0001540,GO:0005041,GO...","GO:0003723,GO:0004252,GO:0005515,GO:0005515,GO..."
135,"GO:0000122,GO:0010628,GO:0010628,GO:0010628,GO...","GO:0000976,GO:0003682,GO:0005515,GO:0005515,GO..."
136,"GO:0030509,GO:0005179,GO:0005515,GO:0005515,GO...","GO:0003677,GO:0003714,GO:0005515,GO:0005515,GO..."


In [10]:
# Reordering important columns
desired_first_columns = ['source', 'target_gene', 'modifier_gene', 'is_modifier'] 
remaining_columns = [col for col in dataset_df.columns if col not in desired_first_columns]
new_column_order = desired_first_columns + remaining_columns
dataset_df = dataset_df[new_column_order]
dataset_df.head()


Unnamed: 0,source,target_gene,modifier_gene,is_modifier,PubmedID,Category,Tissue,QueryFunction,QueryMutation,QueryType,...,wpo_resnik_scaled_bma,wpo_lin_max,wpo_lin_avg,wpo_lin_bma,wpo_jiang_max,wpo_jiang_avg,wpo_jiang_bma,wpo_jiang_seco_max,wpo_jiang_seco_avg,wpo_jiang_seco_bma
0,JVL,APOE,CASP7,1.0,27358062.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/C112R,-,...,,,,,,,,,,
1,JVL,APOE,HBB,1.0,24116184.0,Patients,-,Lipid and sterol biosynthesis & transport,C112R/?,-,...,,,,,,,,,,
2,JVL,APOE,KL,1.0,30867273.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/?,-,...,,,,,,,,,,
3,JVL,APOE,KL,1.0,32282020.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/?,-,...,,,,,,,,,,
4,JVL,ATR,ETV1,1.0,23284306.0,Cells,-,DNA replication and repair;Signaling & stress ...,silencing/silencing,LOF,...,,,,,,,,,,


In [11]:
from sklearn.model_selection import train_test_split

# Splitting into train and test. Stratify by modifier effect
train_df, test_df = train_test_split(dataset_df, test_size=0.3, stratify=dataset_df['is_modifier'], random_state=42)


Loading Monarch KG

In [12]:
nodes_ddf = catalog.load('monarch_nodes_sample')

In [13]:
# List of genes from ModSpy dataset
all_genes = np.union1d(dataset_df['target_gene'], dataset_df['modifier_gene'])
all_genes.shape 

[1m([0m[1;36m5773[0m,[1m)[0m

In [14]:
ndf = nodes_ddf.dropna(subset=['symbol', 'in_taxon'], how='any')    # Dropping rows with missing values

In [15]:
# Select nodes (of Monarch KG) that are in ModeSpy dataset
# Human and Worm genes
selected_nodes = ndf[((ndf['in_taxon']=='NCBITaxon:9606')|(ndf['in_taxon']=='NCBITaxon:6239')) & (ndf['symbol'].isin(all_genes))] 

In [16]:
sel_nodes = selected_nodes.compute()

In [17]:
sel_nodes.shape

[1m([0m[1;36m460[0m, [1;36m16[0m[1m)[0m

In [18]:
sel_nodes['in_taxon_label'].value_counts()


in_taxon_label
Caenorhabditis elegans    [1;36m236[0m
Homo sapiens              [1;36m224[0m
Name: count, dtype: int64[1m[[0mpyarrow[1m][0m

In [19]:
nodes = catalog.load('monarch_nodes_sample').compute()
edges = catalog.load('monarch_edges_sample').compute()

In [20]:
print(nodes.columns)
print(edges.columns)

In [21]:
# Create a directed multigraph
G = nx.MultiDiGraph()

# Add nodes to the graph
for _, row in nodes.iterrows():
    node_id = row['id_encoded'] # could reliably use any `id_encoded` column
    node_attr = row.to_dict()
    # node_attr = row[['category']].to_dict()
    G.add_node(node_id, **node_attr)

# Add edges to the graph
for _, row in edges.iterrows():
    subject = row['subject_encoded']    # could reliably use any `subject_encoded` column
    object_ = row['object_encoded']    # could reliably use any `object_encoded` column
    edge_attr = row.to_dict()
    # edge_attr = row[['category']].to_dict()
    G.add_edge(subject, object_, **edge_attr)

# # Information about the created graph
print(G.number_of_nodes())
print(G.number_of_edges())
# graph_info = nx.info(G)
# graph_info


In [22]:
g_degree = np.array(G.degree())

In [23]:
# Assume your_array is a 2D array and you want to sort by the second dimension (i.e., degree count)
sorted_indices = np.argsort(g_degree[:, 1])
sorted_array = g_degree[sorted_indices]
print(sorted_array)

In [24]:
# Considering nodes with 10 or more edges
high_degree = sorted_array[(sorted_array[:, 1] > 10),:] 
print(f"Number of nodes with 10 or more edges: {high_degree.shape[0]}")

# indentifying the nodes with high degree
hd_nodes = nodes[nodes['id_encoded'].isin(high_degree[:,0])]
display("## High degree nodes ##")
display(hd_nodes[['name', 'category']])

[32m'## High degree nodes ##'[0m

Unnamed: 0,name,category
13155,TP53,biolink:Gene
13568,embryo,biolink:GrossAnatomicalStructure
13581,heart,biolink:GrossAnatomicalStructure
13628,lung,biolink:GrossAnatomicalStructure
13637,liver,biolink:GrossAnatomicalStructure
13642,brain,biolink:GrossAnatomicalStructure
13647,telencephalon,biolink:GrossAnatomicalStructure
13649,hindbrain,biolink:GrossAnatomicalStructure
13709,spinal cord,biolink:GrossAnatomicalStructure
13729,cerebellum,biolink:GrossAnatomicalStructure


In [25]:
# Considering nodes with single edge
low_degree = sorted_array[(sorted_array[:, 1] == 1),:] 
print(f"Number of nodes with single edge: {low_degree.shape[0]}")

# indentifying the nodes with low degree
ld_nodes = nodes[nodes['id_encoded'].isin(low_degree[:,0])]
display("## Low degree nodes ##")
display(ld_nodes[['name', 'category']]['category'].value_counts())

[32m'## Low degree nodes ##'[0m


category
biolink:Gene                           [1;36m11736[0m
biolink:PhenotypicFeature               [1;36m1304[0m
biolink:BiologicalProcessOrActivity     [1;36m1003[0m
biolink:GrossAnatomicalStructure         [1;36m441[0m
biolink:Disease                          [1;36m314[0m
biolink:Pathway                          [1;36m286[0m
biolink:Cell                             [1;36m203[0m
biolink:AnatomicalEntity                 [1;36m154[0m
biolink:CellularComponent                 [1;36m98[0m
biolink:MacromolecularComplex             [1;36m79[0m
biolink:NamedThing                        [1;36m63[0m
biolink:MolecularEntity                   [1;36m44[0m
biolink:Protein                            [1;36m8[0m
biolink:PhenotypicQuality                  [1;36m5[0m
biolink:ChemicalEntity                     [1;36m2[0m
biolink:LifeStage                          [1;36m2[0m
biolink:Virus                              [1;36m2[0m
biolink:CellularOrganism              

Calculating paths between nodes

In [28]:
gene_node_ids = nodes[nodes['category']=='biolink:Gene']['id_encoded'].values
gene_node_ids.shape

[1m([0m[1;36m13224[0m,[1m)[0m

In [29]:
# TODO - Calculate for all genes in the dataset
gene_pairs = dataset_df[dataset_df['source']=='MTG'][['modifier_gene', 'target_gene']]
nodes[nodes['name'].isin(gene_pairs['modifier_gene'])]

Unnamed: 0,id,category,name,xref,provided_by,synonym,full_name,in_taxon,in_taxon_label,symbol,description,deprecated,iri,same_as,id_tensor,id_encoded
2090,WB:WBGene00000037,biolink:Gene,ace-3,ENSEMBL:WBGene00000037|PANTHER:PTHR43918|NCBIG...,alliance_gene_nodes,CELE_Y48B6A.8|Y48B6A.8|cest-21,abnormal AcetylCholinEsterase 3,NCBITaxon:6239,Caenorhabditis elegans,ace-3,,,,,14546,14546
2092,WB:WBGene00000048,biolink:Gene,acr-9,ENSEMBL:WBGene00000048|PANTHER:PTHR18945|NCBIG...,alliance_gene_nodes,C40C9.2|CELE_C40C9.2,AcetylCholine Receptor 9,NCBITaxon:6239,Caenorhabditis elegans,acr-9,,,,,14548,14548
2097,WB:WBGene00000102,biolink:Gene,akt-1,ENSEMBL:WBGene00000102|PANTHER:PTHR24356|NCBIG...,alliance_gene_nodes,C12D8.10|CELE_C12D8.10|akt-1a|akt-1b,AKT kinase family 1,NCBITaxon:6239,Caenorhabditis elegans,akt-1,,,,,14553,14553
2099,WB:WBGene00000158,biolink:Gene,apg-1,ENSEMBL:WBGene00000158|PANTHER:PTHR22780|NCBIG...,alliance_gene_nodes,CELE_Y105E8A.9|Y105E8A.9|Y105E8E.j|Y105E8E.k|a...,"AdaPtin, Gamma chain (clathrin associated comp...",NCBITaxon:6239,Caenorhabditis elegans,apg-1,,,,,14555,14555
2103,WB:WBGene00000210,biolink:Gene,asg-2,ENSEMBL:WBGene00000210|PANTHER:PTHR12386|NCBIG...,alliance_gene_nodes,C53B7.4|CELE_C53B7.4,ATP Synthase G homolog 2,NCBITaxon:6239,Caenorhabditis elegans,asg-2,,,,,14559,14559
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2876,WB:WBGene00022373,biolink:Gene,mrpl-15,ENSEMBL:WBGene00022373|PANTHER:PTHR12934|NCBIG...,alliance_gene_nodes,CELE_Y92H12BR.8|Y92H12BR.8,"Mitochondrial Ribosomal Protein, Large 15",NCBITaxon:6239,Caenorhabditis elegans,mrpl-15,,,,,15338,15338
2884,WB:WBGene00022739,biolink:Gene,toe-1,ENSEMBL:WBGene00022739|PANTHER:PTHR13457|NCBIG...,alliance_gene_nodes,CELE_ZK430.1|ZK430.1,Target Of ERK kinase MPK-1 1,NCBITaxon:6239,Caenorhabditis elegans,toe-1,,,,,15346,15346
2885,WB:WBGene00022781,biolink:Gene,pmt-1,ENSEMBL:WBGene00022781|PANTHER:PTHR44307|NCBIG...,alliance_gene_nodes,CELE_ZK622.3|ZK622.3|phi-40,Phosphoethanolamine MethylTransferase 1,NCBITaxon:6239,Caenorhabditis elegans,pmt-1,,,,,15347,15347
2887,WB:WBGene00044068,biolink:Gene,syd-9,ENSEMBL:WBGene00044068|PANTHER:PTHR24409|NCBIG...,alliance_gene_nodes,CELE_ZK867.1|ZK867.1|tag-239|ztf-10,SYnapse Defective 9,NCBITaxon:6239,Caenorhabditis elegans,syd-9,,,,,15349,15349


In [30]:
pairs = list(itertools.combinations(gene_node_ids, 2))
print(f"Number of pairs: {len(pairs)}")
print(pairs[:5])

In [29]:
_s_pairs = random.sample(pairs, 10000)
len(_s_pairs)

[1;36m10000[0m

In [None]:
import dask.array as da

pairs_arr = da.from_array(_s_pairs, chunks=(10000, 2))  # Divide into chunks

In [30]:
client = Client()  # start distributed scheduler locally.

In [31]:
@dask.delayed
def find_paths(pair):
    if (nx.has_path(G, pair[0], pair[1])):
        all_paths = nx.all_simple_paths(G, source=pair[0], target=pair[1])
        return (pair[0], pair[1], all_paths)
    return None

# Create a list of delayed computations
tasks = [find_paths(p) for p in _s_pairs]

# Compute all results in parallel
saved_path = dask.compute(*tasks)

In [30]:
for p in _s_pairs[:10]:
    # print(p)
    # print(G.degree(p[0]))
    # print(G.degree(p[1]))
    if (nx.has_path(G, p[0], p[1])):
        print(nx.shortest_path(G, p[0], p[1]))

In [131]:
saved_path = []
for p in tqdm(pairs):
    all_paths = nx.all_simple_paths(G, source=p[0], target=p[1])
    saved_path.append((p[0], p[1], all_paths))

  0%|          | 0/87430476 [00:00<?, ?it/s]

In [135]:
len(saved_path)

[1;36m45771532[0m