In [2]:
## Import standard libraries
import os
import sys
from tqdm import tqdm, trange
from glob import glob
import pandas as pd
import random
import logging
import argparse

## Import custom libraries
sys.path.append('/scratch/shared_data_new/metagenomics_graph/model_training/scripts')
from utils import get_logger, read_tsv_file, set_seed

In [3]:
class Args():
    pass

args = Args()


args.existing_KG_nodes = '/scratch/shared_data_new/metagenomics_graph/KG/KG_nodes_v6.tsv'
args.existing_KG_edges = '/scratch/shared_data_new/metagenomics_graph/KG/KG_edges_v6.tsv'
args.micobial_hierarchy_dir = '/scratch/shared_data_new/metagenomics_graph/Micobial_hierarchy'
args.output_dir = '/scratch/shared_data_new/metagenomics_graph/model_training/data'
args.random_seed = 100

In [4]:
def find_specific_parent(microbe_id, ranks, microbe_info, find_parent):
    """
    Find the specific parent of a microbe
    """
    
    def _recursive_find_parent(microbe_id, ranks, microbe_info, find_parent, res, count=0):
        """
        iterate all the way up to the top of the hierarchy and find all parents belong to the specific ranks
        """
        if count == 8:
            return res
        
        count += 1
        if microbe_id not in find_parent:
            return res
        elif len(microbe_info[microbe_id][0].intersection(set(ranks))) > 0:
            res += [microbe_id]
            return _recursive_find_parent(find_parent[microbe_id], ranks, microbe_info, find_parent, res, count)
        else:
            return _recursive_find_parent(find_parent[microbe_id], ranks, microbe_info, find_parent, res, count)
    
    res = _recursive_find_parent(microbe_id, ranks, microbe_info, find_parent, [], 0)
    if len(res) == 0:
        return [microbe_id]
    else:
        return res + [microbe_id]

def check_parent(microbe_id, excluded_microbe_ids, find_parent, count=0):
    """
    Check if the parent of a microbe is in the excluded microbe ids
    """
    if count == 8:
        return True
    
    if microbe_id in excluded_microbe_ids:
        return False
    
    if microbe_id not in find_parent:
        return True
    else:
        count += 1
        return check_parent(find_parent[microbe_id], excluded_microbe_ids, find_parent, count)

def extract_info(sysnonyms, microbe_info_dict, find_parent):
    """
    Get the rank and parent of a microbe based on its id
    """
    if isinstance(sysnonyms, str):
        sysnonyms = eval(sysnonyms)

    rank_list = []
    taxon_id_list = []
    parent_list = []
    for synonym in sysnonyms:
        prefix = synonym.split(':')[0]
        if prefix not in ['NCBI', 'GTDB']:
            continue
        if synonym not in microbe_info_dict:
            continue
        rank_list += list(microbe_info_dict[synonym][0])
        taxon_id_list += list(microbe_info_dict[synonym][1])
        parent_list.append(find_parent[synonym])

    parent_rank_list = []
    parent_taxon_id_list = []
    for parent in parent_list:
        parent_rank_list += list(microbe_info_dict[parent][0])
        parent_taxon_id_list += list(microbe_info_dict[parent][1])
    
    sysnonyms = str(sysnonyms)
    rank_list = str(rank_list)
    taxon_id_list = str(taxon_id_list)
    parent_list = str(parent_list)
    parent_rank_list = str(parent_rank_list)
    parent_taxon_id_list = str(parent_taxon_id_list)
    
    return sysnonyms, rank_list, taxon_id_list, parent_list, parent_rank_list, parent_taxon_id_list


In [5]:
# Create a logger object
logger = get_logger()
logger.setLevel(logging.DEBUG)

# set the seed for reproducibility
set_seed(args.random_seed)

# Load existing knowledge graph nodes and edges
logger.info("Loading existing knowledge graph nodes and edges...")
KG_nodes_data = read_tsv_file(args.existing_KG_nodes)
KG_nodes_header = KG_nodes_data[0]
KG_edges_data = read_tsv_file(args.existing_KG_edges)
KG_edges_header = KG_edges_data[0]

2024-01-10 11:52:21  [INFO]  Loading existing knowledge graph nodes and edges...


In [6]:
# Load microbial hierarchy data
find_parent = dict()
find_child = dict()
microbe_info_dict = dict()
microbial_hierarchy_files = os.listdir(args.micobial_hierarchy_dir)
for microbial_hierarchy_file in tqdm(microbial_hierarchy_files, desc='Reading microbial hierarchy files'):
    if microbial_hierarchy_file.split('_')[0] in ['viruses', 'fungi']:
        prefix = 'NCBI:'
    else:
        prefix = 'GTDB:'
    microbial_hierarchy_file_path = os.path.join(args.micobial_hierarchy_dir, microbial_hierarchy_file)
    logger.info("Reading {}...".format(microbial_hierarchy_file_path))
    temp_data = read_tsv_file(microbial_hierarchy_file_path)
    for row in temp_data[1:]:
        row[0] = row[0].replace('GB_', '').replace('RS_', '')
        row[3] = row[3].replace('GB_', '').replace('RS_', '')
        if f"{prefix}{row[0]}" not in microbe_info_dict:
            microbe_info_dict[f"{prefix}{row[0]}"] = [set(), set()]
        microbe_info_dict[f"{prefix}{row[0]}"][0].add(row[1])
        microbe_info_dict[f"{prefix}{row[0]}"][1].add(row[2])              
        if f"{prefix}{row[3]}" not in microbe_info_dict:
            microbe_info_dict[f"{prefix}{row[3]}"] = [set(), set()]
        microbe_info_dict[f"{prefix}{row[3]}"][0].add(row[4])
        microbe_info_dict[f"{prefix}{row[3]}"][1].add(row[5])
        if f"{prefix}{row[3]}" != f"{prefix}{row[0]}":
            find_parent[f"{prefix}{row[3]}"] = f"{prefix}{row[0]}"
            if f"{prefix}{row[0]}" not in find_child:
                find_child[f"{prefix}{row[0]}"] = set()
            find_child[f"{prefix}{row[0]}"].add(f"{prefix}{row[3]}")

## Extract data from MKG for GNN model training 
logger.info("Extracting MKG for GNN model training...")
node_info = pd.DataFrame([(row[KG_nodes_header.index('node_id')], row[KG_nodes_header.index('node_type')], \
                            row[KG_nodes_header.index('all_names')], row[KG_nodes_header.index('synonyms')], row[KG_nodes_header.index('is_pathogen')]) \
                            for row in KG_nodes_data[1:]])
node_info.columns = ['node_id', 'node_type', 'all_names', 'synonyms', 'is_pathogen']
edge_info = pd.DataFrame([(row[KG_edges_header.index('source_node')], row[KG_edges_header.index('target_node')], row[KG_edges_header.index('predicate')]) for row in KG_edges_data[1:]])
edge_info.columns = ['source_node', 'target_node', 'predicate']

# Filter out the microbe-disease edges
edge_info_temp = []
filtered_type = ['Microbe', 'Disease']
for row in tqdm(edge_info.to_numpy(), desc="Filtering out the microbe-disease edges"):
    source_node, target_node, predicate = row
    source_type = source_node.split(':')[0]
    target_type = target_node.split(':')[0]
    if source_type == target_type:
        edge_info_temp.append(row)
    else:
        if source_type in filtered_type and target_type in filtered_type:
            continue
        else:
            edge_info_temp.append(row)
edge_info = pd.DataFrame(edge_info_temp)
edge_info.columns = ['source_node', 'target_node', 'predicate']

Reading microbial hierarchy files:   0%|          | 0/4 [00:00<?, ?it/s]

2024-01-10 11:53:42  [INFO]  Reading /scratch/shared_data_new/metagenomics_graph/Micobial_hierarchy/viruses_hierarchy.tsv...


Reading microbial hierarchy files:  25%|██▌       | 1/4 [00:00<00:02,  1.11it/s]

2024-01-10 11:53:43  [INFO]  Reading /scratch/shared_data_new/metagenomics_graph/Micobial_hierarchy/fungi_hierarchy.tsv...


Reading microbial hierarchy files:  50%|█████     | 2/4 [00:10<00:11,  5.95s/it]

2024-01-10 11:53:52  [INFO]  Reading /scratch/shared_data_new/metagenomics_graph/Micobial_hierarchy/bacteria_hierarchy.tsv...


Reading microbial hierarchy files:  75%|███████▌  | 3/4 [00:12<00:04,  4.16s/it]

2024-01-10 11:53:54  [INFO]  Reading /scratch/shared_data_new/metagenomics_graph/Micobial_hierarchy/archaea_hierarchy.tsv...


Reading microbial hierarchy files: 100%|██████████| 4/4 [00:12<00:00,  3.13s/it]

2024-01-10 11:53:54  [INFO]  Extracting MKG for GNN model training...



Filtering out the microbe-disease edges: 100%|██████████| 46218489/46218489 [00:58<00:00, 790896.98it/s]


In [7]:
## Covert node and edge to index
# save node to index mapping
all_nodes = node_info['node_id'].to_list()
random.shuffle(all_nodes)
node_to_index = pd.DataFrame(all_nodes).reset_index()[[0, 'index']]
node_to_index.columns = ['node_id', 'node_index']
node_to_index.to_csv(os.path.join(args.output_dir, 'node_to_index.tsv'), sep='\t', index=False)
node_info.to_csv(os.path.join(args.output_dir, 'node_info.tsv'), sep='\t', index=False)

In [8]:
# save pathogen info
pathogen_info = node_info.loc[node_info['is_pathogen']=='True', :].reset_index(drop=True)
# find pathogen rank and its parent info
part1 = pathogen_info[['node_id', 'node_type', 'all_names', 'is_pathogen']]
part2 = pd.DataFrame(pathogen_info['synonyms'].apply(lambda x: extract_info(x, microbe_info_dict, find_parent)).to_list(), columns=['synonyms', 'rank_list', 'taxon_id_list', 'parent_list', 'parent_rank_list', 'parent_taxon_id_list'])
pathogen_info = pd.concat([part1, part2], axis=1)
pathogen_info = pathogen_info.loc[pathogen_info['synonyms'].str.contains('GTDB'),:].reset_index(drop=True)
pathogen_info = pathogen_info.loc[pathogen_info['rank_list'].str.contains('species|strain'),:].reset_index(drop=True)
pathogen_info.to_csv(os.path.join(args.output_dir, 'pathogen_info.tsv'), sep='\t', index=False)

In [9]:
# find all potential pathogens (all strains under the known species-level pathogens that have at least one pathogen or themself are known as pathogens )
species1 = set(pathogen_info.loc[pathogen_info['rank_list'].str.contains('strain'),'parent_list'].to_list())
species2 = set(pathogen_info.loc[pathogen_info['rank_list'].str.contains('species'),'synonyms'].to_list())
merged_species = species1 | species2
all_potential_pathogenic_strain = set()
for species in tqdm(merged_species):
    species_synonyms_list = eval(species)
    for species_synonyms in species_synonyms_list:
        all_potential_pathogenic_strain |= find_child[species_synonyms]
microbe_nodes = node_info.loc[node_info['node_id'].str.contains('Microbe'), :].reset_index(drop=True)
all_gtdb_id_to_microbe_id = dict()
for row in tqdm(microbe_nodes.to_numpy()):
    node_id, node_type, all_names, synonyms, is_pathogen = row
    if 'GTDB:' not in synonyms:
        continue
    all_gtdb_id_to_microbe_id
    synonyms = eval(synonyms)
    for synonym in synonyms:
        if 'GTDB:' in synonym:
            all_gtdb_id_to_microbe_id[synonym] = (synonym, node_id, node_type, all_names, synonyms, is_pathogen)

100%|██████████| 708/708 [00:00<00:00, 28418.81it/s]
100%|██████████| 951782/951782 [00:03<00:00, 304568.22it/s] 


In [None]:
all_gtdb_id_to_microbe_id = dict()
for row in tqdm(node_info.to_numpy()):
    node_id, node_type, all_names, synonyms, is_pathogen = row
    if 'GTDB:' not in synonyms:
        continue
    all_gtdb_id_to_microbe_id
    synonyms = eval(synonyms)
    for synonym in synonyms:
        if 'GTDB:' in synonym:
            all_gtdb_id_to_microbe_id[synonym] = (synonym, node_id, node_type, all_names, is_pathogen, synonyms)

In [49]:
gtdbid_df = pd.DataFrame(all_gtdb_id_to_microbe_id.values(), columns=['gtdb_id', 'node_id', 'node_type', 'all_names', 'is_pathogen', 'synonyms'])



  0%|          | 0/951782 [00:00<?, ?it/s]

100%|██████████| 951782/951782 [00:03<00:00, 254570.53it/s] 


In [10]:
len(merged_species)

708

In [50]:
gtdbid_df = pd.DataFrame(all_gtdb_id_to_microbe_id.values(), columns=['gtdb_id', 'node_id', 'node_type', 'all_names', 'synonyms', 'is_pathogen'])
gtdbid_df

Unnamed: 0,gtdb_id,node_id,node_type,all_names,synonyms,is_pathogen
0,GTDB:Bacteria,Microbe:431882,biolink:OrganismTaxon,['Bacteria'],[GTDB:Bacteria],False
1,GTDB:Pseudomonadota,Microbe:431883,biolink:OrganismTaxon,['Pseudomonadota'],[GTDB:Pseudomonadota],True
2,GTDB:Gammaproteobacteria,Microbe:431884,biolink:OrganismTaxon,['Gammaproteobacteria'],[GTDB:Gammaproteobacteria],True
3,GTDB:Enterobacterales,Microbe:431885,biolink:OrganismTaxon,['Enterobacterales'],[GTDB:Enterobacterales],False
4,GTDB:Enterobacteriaceae,Microbe:431886,biolink:OrganismTaxon,['Enterobacteriaceae'],[GTDB:Enterobacteriaceae],True
...,...,...,...,...,...,...
510674,GTDB:GCA_000180295.1,Microbe:944250,biolink:OrganismTaxon,['Salmonella enterica subsp. enterica serovar ...,"[BVBRC:gn_496068.3, GTDB:GCA_000180295.1]",True
510675,GTDB:GCA_000268065.1,Microbe:944788,biolink:OrganismTaxon,['Shigella flexneri 1235-66'],"[BVBRC:gn_766154.3, GTDB:GCA_000268065.1]",True
510676,GTDB:GCA_000017125.1,Microbe:945749,biolink:OrganismTaxon,['Staphylococcus aureus subsp. aureus JH1'],"[BVBRC:gn_359787.11, GTDB:GCA_000017125.1]",True
510677,GTDB:GCA_002249745.1,Microbe:945750,biolink:OrganismTaxon,['Streptococcus pneumoniae B1598'],"[BVBRC:gn_1449970.3, GTDB:GCA_002249745.1]",True


In [17]:
len(gtdbid_df['node_id'].unique())

402732

In [47]:
gtdbid_df = pd.DataFrame(all_gtdb_id_to_microbe_id.values(), columns=['gtdb_id', 'node_id', 'node_type', 'all_names', 'synonyms', 'is_pathogen'])

In [55]:
gtdbid_df

Unnamed: 0,gtdb_id,node_id,node_type,all_names,synonyms,is_pathogen
0,GTDB:Bacteria,Microbe:431882,biolink:OrganismTaxon,['Bacteria'],[GTDB:Bacteria],False
1,GTDB:Pseudomonadota,Microbe:431883,biolink:OrganismTaxon,['Pseudomonadota'],[GTDB:Pseudomonadota],True
2,GTDB:Gammaproteobacteria,Microbe:431884,biolink:OrganismTaxon,['Gammaproteobacteria'],[GTDB:Gammaproteobacteria],True
3,GTDB:Enterobacterales,Microbe:431885,biolink:OrganismTaxon,['Enterobacterales'],[GTDB:Enterobacterales],False
4,GTDB:Enterobacteriaceae,Microbe:431886,biolink:OrganismTaxon,['Enterobacteriaceae'],[GTDB:Enterobacteriaceae],True
...,...,...,...,...,...,...
510674,GTDB:GCA_000180295.1,Microbe:944250,biolink:OrganismTaxon,['Salmonella enterica subsp. enterica serovar ...,"[BVBRC:gn_496068.3, GTDB:GCA_000180295.1]",True
510675,GTDB:GCA_000268065.1,Microbe:944788,biolink:OrganismTaxon,['Shigella flexneri 1235-66'],"[BVBRC:gn_766154.3, GTDB:GCA_000268065.1]",True
510676,GTDB:GCA_000017125.1,Microbe:945749,biolink:OrganismTaxon,['Staphylococcus aureus subsp. aureus JH1'],"[BVBRC:gn_359787.11, GTDB:GCA_000017125.1]",True
510677,GTDB:GCA_002249745.1,Microbe:945750,biolink:OrganismTaxon,['Streptococcus pneumoniae B1598'],"[BVBRC:gn_1449970.3, GTDB:GCA_002249745.1]",True


In [51]:
all_potential_pathogenic_strain_df = gtdbid_df.loc[gtdbid_df['gtdb_id'].isin(all_potential_pathogenic_strain),:].reset_index(drop=True)

In [61]:
all_potential_pathogenic_strain_df.to_csv(os.path.join(args.output_dir, 'all_potential_pathogenic_strain.tsv'), sep='\t', index=False)

In [62]:
species1

{"['GTDB:Achromobacter insolitus']",
 "['GTDB:Achromobacter ruhlandii']",
 "['GTDB:Achromobacter spanius']",
 "['GTDB:Achromobacter xylosoxidans']",
 "['GTDB:Acidaminococcus intestini']",
 "['GTDB:Acinetobacter baumannii']",
 "['GTDB:Acinetobacter calcoaceticus']",
 "['GTDB:Acinetobacter fasciculus']",
 "['GTDB:Acinetobacter haemolyticus']",
 "['GTDB:Acinetobacter johnsonii']",
 "['GTDB:Acinetobacter junii']",
 "['GTDB:Acinetobacter radioresistens']",
 "['GTDB:Acinetobacter schindleri']",
 "['GTDB:Actinomyces oris']",
 "['GTDB:Actinotignum schaalii']",
 "['GTDB:Aerococcus christensenii']",
 "['GTDB:Aerococcus sanguinicola']",
 "['GTDB:Aerococcus urinae']",
 "['GTDB:Aerococcus urinaehominis']",
 "['GTDB:Aeromonas caviae']",
 "['GTDB:Aeromonas enteropelogenes']",
 "['GTDB:Aeromonas hydrophila']",
 "['GTDB:Aggregatibacter actinomycetemcomitans']",
 "['GTDB:Aggregatibacter actinomycetemcomitans_A']",
 "['GTDB:Aggregatibacter aphrophilus']",
 "['GTDB:Aliarcobacter butzleri']",
 "['GTDB:Alis

In [22]:
gtdbid_df.loc[gtdbid_df['gtdb_id'].isin(all_potential_pathogenic_strain),:].reset_index(drop=True).to_csv(os.path.join(args.output_dir, 'all_potential_pathogenic_strain.tsv'), sep='\t', index=False)

In [24]:
edge_info.to_csv(os.path.join(args.output_dir, 'edge_info.tsv'), sep='\t', index=False) 

In [35]:
set_seed(100)
species1_list = list(species1)
random.shuffle(species1_list)
hold_out_species = species1_list[:int(len(species1_list)*0.1)]

In [37]:
pathogen_info.loc[pathogen_info['parent_list'].isin(hold_out_species),:].to_csv(os.path.join(args.output_dir, 'hold_out_pathogen_info.tsv'), sep='\t', index=False)