In [1]:
import numpy as np
import pandas as pd
import torch
import optuna
from comet_ml import Experiment
from pytorch_lightning.loggers import CometLogger
from torch_geometric.transforms import AddSelfLoops
import wandb


import dask
import dask.dataframe as dd



In [2]:
#wandb.init(mode="offline")

In [3]:
monarch = torch.load('./data/05_model_input/2024-02-monarch_heterodata_v1.pt', map_location='cpu')

print(
    f""" 
Total nodes: {monarch.num_nodes}
Total node types: {len(monarch.node_types)}

Total edges: {monarch.num_edges}
Total edge types: {len(monarch.edge_types)}                
"""
)

# # Adding self loops to avoid 1. nodes without any edge, 2. consider intragenic modifier
# transform = AddSelfLoops()
# monarch = transform(monarch)

 
Total nodes: 862115
Total node types: 88

Total edges: 11412471
Total edge types: 289                



In [4]:
jvl = pd.read_csv('./data/02_intermediate/jvl/annotated.csv', sep='\t')
olida = pd.read_csv('./data/02_intermediate/olida/annotated.tsv', sep='\t')
mtg = pd.read_csv('./data/02_intermediate/mtg/annotated.csv', sep='\t')

In [5]:
olida = olida[olida['Oligogenic Effect']!='Monogenic+Modifier']

In [6]:
print(olida['Oligogenic Effect'].value_counts())
print(olida.shape)

Oligogenic Effect
True Digenic                47
Dual Molecular Diagnosis    28
Name: count, dtype: int64
(75, 47)


In [7]:
# Adding source column
jvl['source'] = 'JVL'
olida['source'] = 'OLIDA'
mtg['source'] = 'MTG'

# Add species information
jvl['species'] = 'Homo sapiens'
olida['species'] = 'Homo sapiens'
mtg['species'] = 'Caenorhabditis elegans'

# Classification labels
jvl['is_modifier'] = 1
olida['is_modifier'] = 0

# Rename columns to indicate same information
jvl.rename(columns={'QueryGene': 'target_gene', 'SuppressorGene': 'modifier_gene'}, inplace=True)
olida.rename(columns={'gene_a': 'target_gene', 'gene_b': 'modifier_gene'}, inplace=True)
mtg.rename(columns={'gene_symbol': 'modifier_gene', 'target_gene_symbol': 'target_gene'}, inplace=True)


In [8]:
dataset_df = pd.concat([jvl, olida, mtg])
print(dataset_df.shape)
print(dataset_df['source'].value_counts())

(8337, 219)
source
MTG      7330
JVL       932
OLIDA      75
Name: count, dtype: int64


In [9]:
list(dataset_df.columns)

['PubmedID',
 'Category',
 'Tissue',
 'target_gene',
 'QueryFunction',
 'QueryMutation',
 'QueryType',
 'modifier_gene',
 'SuppressorFunction',
 'SuppressorMutation',
 'SNP_ID',
 'SuppressorType',
 'EffectSize',
 'Disease',
 'DiseaseSubType',
 'CellLineIdentified',
 'ModelSystemValidated',
 'Drugs',
 'target_GOs',
 'target_GOs_count',
 'modifier_GOs',
 'modifier_GOs_count',
 'target_POs',
 'target_POs_count',
 'modifier_POs',
 'modifier_POs_count',
 'target_DOs',
 'target_DOs_count',
 'modifier_DOs',
 'modifier_DOs_count',
 'source',
 'species',
 'is_modifier',
 'Entry Id',
 'Genes',
 'Genes Relationship',
 'Protein Interactions',
 'Common Pathways',
 'GENEmeta_x',
 'Oligogenic variant combinations',
 'olida_id',
 'OLIDA ID',
 'Omim Id',
 'Diseases',
 'Oligogenic Effect',
 'Ethnicity',
 'References',
 'Associated Variants',
 'FAMmanual',
 'STATmanual',
 'STATknowledge',
 'STATmeta',
 'GENEmanual',
 'GENEmanual harmonized',
 'GENEknowledge',
 'GENEmeta_y',
 'VARmanual',
 'VARknowledge',

In [10]:
olida[['target_GOs','modifier_GOs']]

Unnamed: 0,target_GOs,modifier_GOs
1,"GO:0005080,GO:0005200,GO:0005516,GO:0007010,GO...","GO:0000977,GO:0003682,GO:0003712,GO:0003712,GO..."
3,"GO:0071407,GO:0004497,GO:0004497,GO:0004497,GO...","GO:0051897,GO:0070374,GO:0001725,GO:0005884,GO..."
4,"GO:0005243,GO:0005243,GO:0005509,GO:0005515,GO...","GO:0006883,GO:0004252,GO:0004252,GO:0017080,GO..."
6,"GO:0005515,GO:0005515,GO:0005515,GO:0005515,GO...","GO:0002153,GO:0002153,GO:0006357,GO:1990904,GO..."
10,"GO:0003779,GO:0005515,GO:0005515,GO:0005515,GO...","GO:0003723,GO:0005164,GO:0005515,GO:0005515,GO..."
...,...,...
129,"GO:0005524,GO:0016887,GO:0120020,GO:0005515,GO...","GO:0003723,GO:0004252,GO:0005515,GO:0005515,GO..."
130,"GO:0001540,GO:0001540,GO:0001540,GO:0005041,GO...","GO:0003723,GO:0004252,GO:0005515,GO:0005515,GO..."
135,"GO:0000122,GO:0010628,GO:0010628,GO:0010628,GO...","GO:0000976,GO:0003682,GO:0005515,GO:0005515,GO..."
136,"GO:0030509,GO:0005179,GO:0005515,GO:0005515,GO...","GO:0003677,GO:0003714,GO:0005515,GO:0005515,GO..."


In [11]:
# Reordering important columns
desired_first_columns = ['source', 'species', 'target_gene', 'modifier_gene', 'is_modifier'] 
remaining_columns = [col for col in dataset_df.columns if col not in desired_first_columns]
new_column_order = desired_first_columns + remaining_columns
dataset_df = dataset_df[new_column_order]
dataset_df.head()


Unnamed: 0,source,species,target_gene,modifier_gene,is_modifier,PubmedID,Category,Tissue,QueryFunction,QueryMutation,...,wpo_resnik_scaled_bma,wpo_lin_max,wpo_lin_avg,wpo_lin_bma,wpo_jiang_max,wpo_jiang_avg,wpo_jiang_bma,wpo_jiang_seco_max,wpo_jiang_seco_avg,wpo_jiang_seco_bma
0,JVL,Homo sapiens,APOE,CASP7,1.0,27358062.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/C112R,...,,,,,,,,,,
1,JVL,Homo sapiens,APOE,HBB,1.0,24116184.0,Patients,-,Lipid and sterol biosynthesis & transport,C112R/?,...,,,,,,,,,,
2,JVL,Homo sapiens,APOE,KL,1.0,30867273.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/?,...,,,,,,,,,,
3,JVL,Homo sapiens,APOE,KL,1.0,32282020.0,Patients,Neuron,Lipid and sterol biosynthesis & transport,C112R/?,...,,,,,,,,,,
4,JVL,Homo sapiens,ATR,ETV1,1.0,23284306.0,Cells,-,DNA replication and repair;Signaling & stress ...,silencing/silencing,...,,,,,,,,,,


In [12]:
nodes_df = dd.read_parquet('./data/02_intermediate/monarch/nodes_with_type_idx')  
edges_df = dd.read_parquet('./data/02_intermediate/monarch/edges_pre_df_reduction_v2')


In [13]:
print(nodes_df.index.shape[0].compute())
display(nodes_df.head())

# Use dataset's source and target gene id to pull the `type_index` column From `nodes_df``

862115


Unnamed: 0_level_0,category,name,in_taxon,in_taxon_label,symbol,type_index
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
APO:0000017,biolink:PhenotypicFeature,,,,,0
BFO:0000001,biolink:NamedThing,entity,,,,0
BFO:0000002,biolink:NamedThing,continuant,,,,1
BFO:0000003,biolink:BiologicalProcessOrActivity,occurrent,,,,0
BFO:0000004,biolink:NamedThing,independent continuant,,,,2


In [14]:
nodes_df[(nodes_df['category']=='biolink:Gene') & (nodes_df['symbol']=='APOE') & (nodes_df['in_taxon_label']=='Homo sapiens')].compute()

Unnamed: 0_level_0,category,name,in_taxon,in_taxon_label,symbol,type_index
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
HGNC:613,biolink:Gene,APOE,NCBITaxon:9606,Homo sapiens,APOE,516790


In [15]:
nodes_df['in_taxon_label'].unique().compute()

0                              <NA>
1                      Homo sapiens
2                      Mus musculus
3                     Gallus gallus
4                        Bos taurus
5                 Rattus norvegicus
0           Drosophila melanogaster
1                        Sus scrofa
2            Canis lupus familiaris
3          Dictyostelium discoideum
4         Schizosaccharomyces pombe
5    Saccharomyces cerevisiae S288C
6            Caenorhabditis elegans
7                    Xenopus laevis
8                Xenopus tropicalis
9                       Danio rerio
Name: in_taxon_label, dtype: string

### Merging Node Index from Monarch

In [16]:
nodes_pdf = nodes_df.compute()
_dataset = dataset_df.iloc[:,:5].merge(nodes_pdf, how='left', left_on=['species','target_gene'], right_on=['in_taxon_label','symbol'])
_dataset.rename(columns={'type_index': 'target_type_index'}, inplace=True)
_dataset.drop(columns=['category','name','in_taxon_label','in_taxon','symbol'], inplace=True)
dataset_w_ninfo = _dataset.merge(nodes_pdf, how='left', left_on=['species','modifier_gene'], right_on=['in_taxon_label','symbol'])
dataset_w_ninfo.rename(columns={'type_index': 'modifier_type_index'}, inplace=True)
dataset_w_ninfo.drop(columns=['name','in_taxon_label','symbol'], inplace=True)
del(_dataset)
del(nodes_pdf)
dataset_w_ninfo

Unnamed: 0,source,species,target_gene,modifier_gene,is_modifier,target_type_index,category,in_taxon,modifier_type_index
0,JVL,Homo sapiens,APOE,CASP7,1.0,516790,biolink:Gene,NCBITaxon:9606,518908.0
1,JVL,Homo sapiens,APOE,HBB,1.0,516790,biolink:Gene,NCBITaxon:9606,526845.0
2,JVL,Homo sapiens,APOE,KL,1.0,516790,biolink:Gene,NCBITaxon:9606,529553.0
3,JVL,Homo sapiens,APOE,KL,1.0,516790,biolink:Gene,NCBITaxon:9606,529553.0
4,JVL,Homo sapiens,ATR,ETV1,1.0,517533,biolink:Gene,NCBITaxon:9606,523802.0
...,...,...,...,...,...,...,...,...,...
8332,MTG,Caenorhabditis elegans,zyg-1,C39D10.7,0.0,86692,biolink:Gene,NCBITaxon:6239,94433.0
8333,MTG,Caenorhabditis elegans,zyg-1,catp-5,0.0,86692,biolink:Gene,NCBITaxon:6239,96963.0
8334,MTG,Caenorhabditis elegans,zyg-1,tmc-2,0.0,86692,biolink:Gene,NCBITaxon:6239,93262.0
8335,MTG,Caenorhabditis elegans,zyg-1,tbb-4,0.0,86692,biolink:Gene,NCBITaxon:6239,86304.0


### Did we detect all IDs for the genes?

In [17]:
print(dataset_w_ninfo['target_type_index'].unique().shape == dataset_w_ninfo['target_gene'].unique().shape)
print(dataset_w_ninfo['modifier_type_index'].unique().shape == dataset_w_ninfo['modifier_gene'].unique().shape)

True
False


In [18]:
print(f"{dataset_w_ninfo['modifier_gene'].unique().shape[0] - dataset_w_ninfo['modifier_type_index'].unique().shape[0]} genes are missing/not recognized from Monarch")

253 genes are missing/not recognized from Monarch


In [19]:
print(f"Unrecognized Modifiers:")
dataset_w_ninfo[(dataset_w_ninfo['modifier_type_index'].isna() & (dataset_w_ninfo['is_modifier']==1))]

Unrecognized Modifiers:


Unnamed: 0,source,species,target_gene,modifier_gene,is_modifier,target_type_index,category,in_taxon,modifier_type_index
1068,MTG,Caenorhabditis elegans,zyg-1,mat-2,1.0,86692,,,
2976,MTG,Caenorhabditis elegans,zyg-1,apc-17,1.0,86692,,,
4513,MTG,Caenorhabditis elegans,zyg-1,mat-2,1.0,86692,,,
6979,MTG,Caenorhabditis elegans,zyg-1,mat-2,1.0,86692,,,


In [20]:
print(f"Missing indecies")
dataset_w_ninfo[(dataset_w_ninfo['modifier_type_index'].isna()) | (dataset_w_ninfo['target_type_index'].isna())]

Missing indecies


Unnamed: 0,source,species,target_gene,modifier_gene,is_modifier,target_type_index,category,in_taxon,modifier_type_index
1018,MTG,Caenorhabditis elegans,zyg-1,T20B6.3,0.0,86692,,,
1037,MTG,Caenorhabditis elegans,zyg-1,Y51A2D.7,0.0,86692,,,
1068,MTG,Caenorhabditis elegans,zyg-1,mat-2,1.0,86692,,,
1086,MTG,Caenorhabditis elegans,zyg-1,Y71H2AM.13,0.0,86692,,,
1117,MTG,Caenorhabditis elegans,zyg-1,C45E5.4,0.0,86692,,,
...,...,...,...,...,...,...,...,...,...
8255,MTG,Caenorhabditis elegans,zyg-1,F36H2.3,0.0,86692,,,
8269,MTG,Caenorhabditis elegans,zyg-1,BE0003N10.3,0.0,86692,,,
8321,MTG,Caenorhabditis elegans,zyg-1,F59D12.1,0.0,86692,,,
8324,MTG,Caenorhabditis elegans,zyg-1,R11G1.6,0.0,86692,,,


In [21]:
# Dropping these NA rows
dataset_w_ninfo.dropna(subset=['in_taxon','modifier_type_index'], inplace=True)

print(f"Do we now have proper indecies?")
print(f"Target genes: {dataset_w_ninfo['target_type_index'].unique().shape == dataset_w_ninfo['target_gene'].unique().shape}")
print(f"Modifier genes: {dataset_w_ninfo['modifier_type_index'].unique().shape == dataset_w_ninfo['modifier_gene'].unique().shape}")

Do we now have proper indecies?
Target genes: True
Modifier genes: True


### Adding target-modifier pairs as edges of Monarch

In [22]:
dataset_w_ninfo['modifier_type_index'] = dataset_w_ninfo['modifier_type_index'].astype(int)
dataset_w_ninfo[['target_type_index','modifier_type_index','is_modifier']]

Unnamed: 0,target_type_index,modifier_type_index,is_modifier
0,516790,518908,1.0
1,516790,526845,1.0
2,516790,529553,1.0
3,516790,529553,1.0
4,517533,523802,1.0
...,...,...,...
8332,86692,94433,0.0
8333,86692,96963,0.0
8334,86692,93262,0.0
8335,86692,86304,0.0


In [54]:
# Convert DataFrame to tensor
dataset_arr = dataset_w_ninfo[['target_type_index','modifier_type_index','is_modifier']].to_numpy()
data_t = torch.from_numpy(dataset_arr).to(torch.int)
data_t

tensor([[516790, 518908,      1],
        [516790, 526845,      1],
        [516790, 529553,      1],
        ...,
        [ 86692,  93262,      0],
        [ 86692,  86304,      0],
        [ 86692,  89182,      0]], dtype=torch.int32)

In [59]:
# torch.save(data_t, './data/05_model_input/2024-03-31-merged-dataset.pt')

In [56]:
from torch_geometric.data import HeteroData
from tqdm import tqdm

def verify_heterodata_construction(data: HeteroData, edges_ddf, node_ids):
    edge_type_to_chk = 'biolink:interacts_with'
    
    # Access the edge index for the relation
    edge_index = monarch['biolink:Gene', edge_type_to_chk, 'biolink:Gene'].edge_index    
    src, dest = edge_index
    
    for node_idx in tqdm(node_ids):
        graph_in_count = src[dest == node_idx].shape[0]
        graph_out_count = dest[src == node_idx].shape[0]
        
        orig_in_count = edges_df[(edges_df['object_id']==node_idx) & (edges_df['predicate']==edge_type_to_chk)]['id'].compute().shape[0]
        orig_out_count = edges_df[(edges_df['subject_id']==node_idx) & (edges_df['predicate']==edge_type_to_chk)]['id'].compute().shape[0]


        try:
            assert graph_in_count == orig_in_count
            assert graph_out_count == orig_out_count
        except AssertionError as e:
            print(f"AssertionError: {e}")
            print(f"{graph_in_count} != {orig_in_count}")
            print('or')
            print(f"{graph_out_count} != {orig_out_count}")

        # # Print results
        # print("Outgoing edges for node", node_idx, " (", outgoing_edges.shape[0] ,")" ":", outgoing_edges.tolist())
        # print("Incoming edges for node", node_idx, " (", incoming_edges.shape[0] ,")" ":", incoming_edges.tolist())
        

In [57]:
node_ids = list(set(data_t[:,0].tolist()))
print(len(node_ids))

153


In [58]:
verify_heterodata_construction(monarch, edges_df, node_ids)
print(f"Succesfully verified {len(node_ids)} genes in the network!")

 14%|█▍        | 22/153 [01:58<11:46,  5.40s/it]


KeyboardInterrupt: 

In [60]:

from torch.utils.data import Dataset, DataLoader

class ModifierDataset(Dataset):
    def __init__(self, filepath: str = None):
        """
        Args:
            data (Tensor): A tensor containing node pairs and their similarity label.
                           Shape: [num_pairs, 3], where each row is (node1, node2, label).
        """
        self.data = torch.load(filepath)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        modifier, target, label = self.data[idx]
        return modifier, target, label

In [61]:
dataset = ModifierDataset("/home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/05_model_input/2024-03-31-merged-dataset.pt")
val_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

In [62]:
next(iter(dataset))

(tensor(516790, dtype=torch.int32),
 tensor(518908, dtype=torch.int32),
 tensor(1, dtype=torch.int32))

In [63]:
data = torch.load("/home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/05_model_input/2024-03-31-merged-dataset.pt")
np.unique(data.numpy())

array([     0,      1,  81868, ..., 558429, 558431, 558747], dtype=int32)