In [None]:
import os
import json
import pickle
import random
from collections import Counter

### Pay attention to the reproducibility !!

In [None]:
data_dir="/shared/data3/bowenj4/llm-graph-plugin/data/processed_data/biomedical"
downstream_dir="/shared/data3/bowenj4/llm-graph-plugin/data/raw_data/biomedical"

In [None]:
# read processed graph
graph = json.load(open(os.path.join(data_dir, 'graph.json')))
print(graph.keys())

for k in graph:
    print(k, len(graph[k]))

In [None]:
k = 10
all_generated_data = {} # key: triple (question (str), answer (str)), value: generated data (List)

### Design questions (one type of question in one cell)

#### easy questions

In [None]:
def one_hop(graph, center_node_type, center_node_save_key, neighbor_node_type, neighbor_node_save_key, edge_type, k):
    generated_data = []
    cnt = 0
    center_ids = list(graph[center_node_type].keys())
    random.shuffle(center_ids)

    for center_id in center_ids:
        center_name = graph[center_node_type][center_id]['features']['name']
        if edge_type not in graph[center_node_type][center_id]['neighbors']:
            continue
        neighbor_ids = graph[center_node_type][center_id]['neighbors'][edge_type]
        neighbor_names = [graph[neighbor_node_type][neighbor_id]['features']['name'] for neighbor_id in neighbor_ids]
        if len(neighbor_names) > 5:
            continue

        generated_data.append({center_node_save_key:center_name, neighbor_node_save_key: ', '.join(neighbor_names)})
        cnt += 1
        if cnt == k:
            break
    return generated_data

In [None]:
## question (easy): what are the side effects of compound xxx?

random.seed(2023)

# question = "what are the side effects of compound {compound_name}? Please answer the side effect names rather than ID."
question = "what are the side effects of compound {compound_name}?"
answer = "{side_effects}"

generated_data = one_hop(graph, 'Compound_nodes', 'compound_name', 'Side_Effect_nodes', 'side_effects', 'Compound-causes-Side Effect', k)

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): What are the symptoms of the disease xxx?

random.seed(2024)

# question = "What are the symptoms of the disease {disease_name}? Please answer the symptom names rather than IDs."
question = "What are the symptoms of the disease {disease_name}?"
answer = "{symptom_names}"

generated_data = one_hop(graph, 'Disease_nodes', 'disease_name', 'Symptom_nodes', 'symptom_names', 'Disease-presents-Symptom', k)

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): What are the biological processes of gene xxx?

random.seed(2025)

# question = "What are the biological processes of gene {gene_name}? Please answer the biological processes names rather than IDs."
question = "What are the biological processes of gene {gene_name}?"
answer = "{biological_processes_names}"

generated_data = one_hop(graph, 'Gene_nodes', 'gene_name', 'Biological_Process_nodes', 'biological_processes_names', 'Gene-participates-Biological Process', k)

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): What are the molecular functions of gene xxx?

random.seed(2026)

# question = "What are the molecular functions of gene {gene_name}? Please answer the molecular function names rather than IDs."
question = "What are the molecular functions of gene {gene_name}?"
answer = "{molecular_function_names}"

generated_data = one_hop(graph, 'Gene_nodes', 'gene_name', 'Molecular_Function_nodes', 'molecular_function_names', 'Gene-participates-Molecular Function', k)

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): What anatomy can be upregulated/expressed/downregulated by gene xxx?

random.seed(2027)

# question = "What anatomy can be downregulated by gene {gene_name}? Please answer the anatomy names rather than IDs."
question = "What anatomy can be downregulated by gene {gene_name}?"
answer = "{anatomy_names}"
generated_data = one_hop(graph, 'Gene_nodes', 'gene_name', 'Anatomy_nodes', 'anatomy_names', 'Anatomy-downregulates-Gene', k)
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "What anatomy can be expressed by gene {gene_name}? Please answer the anatomy names rather than IDs."
question = "What anatomy can be expressed by gene {gene_name}?"
answer = "{anatomy_names}"
generated_data = one_hop(graph, 'Gene_nodes', 'gene_name', 'Anatomy_nodes', 'anatomy_names', 'Anatomy-expresses-Gene', k)
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "What anatomy can be upregulated by gene {gene_name}? Please answer the anatomy names rather than IDs."
question = "What anatomy can be upregulated by gene {gene_name}?"
answer = "{anatomy_names}"
generated_data = one_hop(graph, 'Gene_nodes', 'gene_name', 'Anatomy_nodes', 'anatomy_names', 'Anatomy-upregulates-Gene', k)
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

#### medium questions

In [None]:
## question (medium): What compound can treat both disease xxx and disease xxx?

random.seed(2028)

# question = "What compound can treat both {disease_name1} and {disease_name2}? Please answer the compound name rather than ID."
question = "What compound can treat both {disease_name1} and {disease_name2}?"
answer = "{compound_name}"
generated_data = []

cnt = 0
compound_ids = list(graph['Compound_nodes'].keys())
random.shuffle(compound_ids)

for compound_id in compound_ids:
    compound_name = graph['Compound_nodes'][compound_id]['features']['name']
    if 'Compound-treats-Disease' not in graph['Compound_nodes'][compound_id]['neighbors']:
        continue

    disease_ids = graph['Compound_nodes'][compound_id]['neighbors']['Compound-treats-Disease']
    disease_names = [graph['Disease_nodes'][disease_id]['features']['name'] for disease_id in disease_ids]
    if len(disease_ids) < 2 or len(set(graph['Disease_nodes'][disease_ids[0]]['neighbors']['Compound-treats-Disease']) & set(graph['Disease_nodes'][disease_ids[1]]['neighbors']['Compound-treats-Disease'])) > 1:
        continue

    generated_data.append({"disease_name1": disease_names[0], "disease_name2": disease_names[1], "compound_name":compound_name})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): What disease located in xxx (Anatomy) can compound xxx palliate?

random.seed(2029)

# question = "What disease located in {anatomy_name} can {compound_name} palliate? Please answer the disease name rather than ID."
question = "What disease located in {anatomy_name} can {compound_name} palliate?"
answer = "{disease_name}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Disease-localizes-Anatomy' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue
    if 'Compound-palliates-Disease' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    anatomy_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-localizes-Anatomy'][0]
    compound_id = graph['Disease_nodes'][disease_id]['neighbors']['Compound-palliates-Disease'][0]
    if len(set(graph['Anatomy_nodes'][anatomy_id]['neighbors']['Disease-localizes-Anatomy']) & set(graph['Compound_nodes'][compound_id]['neighbors']['Compound-palliates-Disease'])) > 1:
        continue
    anatomy_name = graph['Anatomy_nodes'][anatomy_id]['features']['name']
    compound_name = graph['Compound_nodes'][compound_id]['features']['name']

    generated_data.append({"anatomy_name": anatomy_name, "compound_name": compound_name, "disease_name":disease_name})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): What disease located in xxx (Anatomy) can compound xxx treat?

random.seed(2030)

# question = "What disease located in {anatomy_name} can {compound_name} treat? Please answer the disease name rather than ID."
question = "What disease located in {anatomy_name} can {compound_name} treat?"
answer = "{disease_name}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Disease-localizes-Anatomy' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue
    if 'Compound-treats-Disease' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    anatomy_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-localizes-Anatomy'][0]
    compound_id = graph['Disease_nodes'][disease_id]['neighbors']['Compound-treats-Disease'][0]
    if len(set(graph['Anatomy_nodes'][anatomy_id]['neighbors']['Disease-localizes-Anatomy']) & set(graph['Compound_nodes'][compound_id]['neighbors']['Compound-treats-Disease'])) > 1:
        continue
    anatomy_name = graph['Anatomy_nodes'][anatomy_id]['features']['name']
    compound_name = graph['Compound_nodes'][compound_id]['features']['name']

    generated_data.append({"anatomy_name": anatomy_name, "compound_name": compound_name, "disease_name":disease_name})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): What disease is downregulated by gene xxx and located in xxx (Anatomy)?

random.seed(2031)

# question = "What disease is downregulated by {gene_name} and located in {anatomy_name}? Please answer the disease name rather than ID."
question = "What disease is downregulated by {gene_name} and located in {anatomy_name}?"
answer = "{disease_name}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Disease-localizes-Anatomy' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue
    if 'Disease-downregulates-Gene' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    anatomy_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-localizes-Anatomy'][0]
    gene_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-downregulates-Gene'][0]
    if len(set(graph['Anatomy_nodes'][anatomy_id]['neighbors']['Disease-localizes-Anatomy']) & set(graph['Gene_nodes'][gene_id]['neighbors']['Disease-downregulates-Gene'])) > 1:
        continue
    anatomy_name = graph['Anatomy_nodes'][anatomy_id]['features']['name']
    gene_name = graph['Gene_nodes'][gene_id]['features']['name']

    generated_data.append({"anatomy_name": anatomy_name, "gene_name": gene_name, "disease_name": disease_name})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): What disease is associated by gene xxx and located in xxx (Anatomy)?

random.seed(2032)

# question = "What disease is associated by {gene_name} and located in {anatomy_name}? Please answer the disease name rather than ID."
question = "What disease is associated by {gene_name} and located in {anatomy_name}?"
answer = "{disease_name}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Disease-localizes-Anatomy' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue
    if 'Disease-associates-Gene' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    anatomy_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-localizes-Anatomy'][0]
    gene_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-associates-Gene'][0]
    if len(set(graph['Anatomy_nodes'][anatomy_id]['neighbors']['Disease-localizes-Anatomy']) & set(graph['Gene_nodes'][gene_id]['neighbors']['Disease-associates-Gene'])) > 1:
        continue
    anatomy_name = graph['Anatomy_nodes'][anatomy_id]['features']['name']
    gene_name = graph['Gene_nodes'][gene_id]['features']['name']

    generated_data.append({"anatomy_name": anatomy_name, "gene_name": gene_name, "disease_name":disease_name})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): What disease is upregulated by gene xxx and located in xxx (Anatomy)?

random.seed(2033)

# question = "What disease is upregulated by {gene_name} and located in {anatomy_name}? Please answer the disease name rather than ID."
question = "What disease is upregulated by {gene_name} and located in {anatomy_name}?"
answer = "{disease_name}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Disease-localizes-Anatomy' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue
    if 'Disease-upregulates-Gene' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    anatomy_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-localizes-Anatomy'][0]
    gene_id = graph['Disease_nodes'][disease_id]['neighbors']['Disease-upregulates-Gene'][0]
    if len(set(graph['Anatomy_nodes'][anatomy_id]['neighbors']['Disease-localizes-Anatomy']) & set(graph['Gene_nodes'][gene_id]['neighbors']['Disease-upregulates-Gene'])) > 1:
        continue
    anatomy_name = graph['Anatomy_nodes'][anatomy_id]['features']['name']
    gene_name = graph['Gene_nodes'][gene_id]['features']['name']

    generated_data.append({"anatomy_name": anatomy_name, "gene_name": gene_name, "disease_name":disease_name})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): Is there a correlation between gene xxx and symptom xxx?

random.seed(2034)

question = "Is there a correlation between {gene_name} and {symptom_name}? Please answer True or False"
answer = "{answer}"
generated_data = []

cnt = 0
gene_ids = list(graph['Gene_nodes'].keys())
random.shuffle(gene_ids)

for gene_id in gene_ids:
    gene_name = graph['Gene_nodes'][gene_id]['features']['name']
    disease_ids = []
    if 'Disease-downregulates-Gene' in graph['Gene_nodes'][gene_id]['neighbors']:
        disease_ids += graph['Gene_nodes'][gene_id]['neighbors']['Disease-downregulates-Gene']
    if 'Disease-upregulates-Gene' in graph['Gene_nodes'][gene_id]['neighbors']:
        disease_ids += graph['Gene_nodes'][gene_id]['neighbors']['Disease-upregulates-Gene']
    if 'Disease-associates-Gene' in graph['Gene_nodes'][gene_id]['neighbors']:
        disease_ids += graph['Gene_nodes'][gene_id]['neighbors']['Disease-associates-Gene']
    disease_ids = list(set(disease_ids))

    symptom_ids = []
    for disease_id in disease_ids:
        if "Disease-presents-Symptom" in graph['Disease_nodes'][disease_id]['neighbors']:
            symptom_ids += graph['Disease_nodes'][disease_id]['neighbors']["Disease-presents-Symptom"]
    
    if len(symptom_ids) == 0:
        continue

    symptom_ids = set(symptom_ids)
    neg_symptom_ids = set(graph['Symptom_nodes']) - symptom_ids

    symptom_ids = list(symptom_ids)
    neg_symptom_ids = list(neg_symptom_ids)

    random.shuffle(symptom_ids)
    random.shuffle(neg_symptom_ids)

    if random.random() > 0.5:
        generated_data.append({"gene_name": gene_name, "symptom_name": symptom_ids[0], "answer":True})
    else:
        generated_data.append({"gene_name": gene_name, "symptom_name": neg_symptom_ids[0], "answer":False})
    
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## easy questions

In [None]:
## question (medium): How many resemble compounds do xxx have?

random.seed(2035)

question = "How many resemble compounds do {compound_name} have?"
answer = "{num}"
generated_data = []

cnt = 0
compound_ids = list(graph['Compound_nodes'].keys())
random.shuffle(compound_ids)

for compound_id in compound_ids:
    compound_name = graph['Compound_nodes'][compound_id]['features']['name']
    if 'Compound-resembles-Compound' not in graph['Compound_nodes'][compound_id]['neighbors']:
        continue

    generated_data.append({"compound_name":compound_name, 'num': len(graph['Compound_nodes'][compound_id]['neighbors']['Compound-resembles-Compound'])})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): How many resemble diseases do xxx have?

random.seed(2036)

question = "How many resemble disease do {disease_name} have?"
answer = "{num}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Disease-resembles-Disease' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    generated_data.append({"disease_name":disease_name, 'num': len(graph['Disease_nodes'][disease_id]['neighbors']['Disease-resembles-Disease'])})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): How many compounds can be used to treat xxx?

random.seed(2037)

question = "How many compounds can be used to treat {disease_name}?"
answer = "{num}"
generated_data = []

cnt = 0
disease_ids = list(graph['Disease_nodes'].keys())
random.shuffle(disease_ids)

for disease_id in disease_ids:
    disease_name = graph['Disease_nodes'][disease_id]['features']['name']
    if 'Compound-treats-Disease' not in graph['Disease_nodes'][disease_id]['neighbors']:
        continue

    generated_data.append({"disease_name":disease_name, 'num': len(graph['Disease_nodes'][disease_id]['neighbors']['Compound-treats-Disease'])})
    cnt += 1
    if cnt == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## medium questions

In [None]:
def four_hop(node_type_1, edge_type_12, node_type_2, edge_type_23, node_type_3, edge_type_34, node_type_4, save_q_key, save_a_key):
    generated_data = []

    cnt = 0
    node_type_1_ids = list(graph[node_type_1].keys())
    random.shuffle(node_type_1_ids)

    for node_type_1_id in node_type_1_ids:
        node_type_1_name = graph[node_type_1][node_type_1_id]['features']['name']
        if edge_type_12 not in graph[node_type_1][node_type_1_id]['neighbors']:
            continue

        node_type_2_ids = graph[node_type_1][node_type_1_id]['neighbors'][edge_type_12]
        node_type_3_ids = set()
        for node_type_2_id in node_type_2_ids:
            if edge_type_23 in graph[node_type_2][node_type_2_id]['neighbors']:
                node_type_3_ids.update(graph[node_type_2][node_type_2_id]['neighbors'][edge_type_23])
        node_type_3_ids = list(node_type_3_ids)

        node_type_4_list = []
        for node_type_3_id in node_type_3_ids:
            if edge_type_34 in graph[node_type_3][node_type_3_id]['neighbors']:
                node_type_4_list += graph[node_type_3][node_type_3_id]['neighbors'][edge_type_34]
        node_type_4_counter = Counter(node_type_4_list)
        node_type_4_counter_list = [(k, v) for k, v in node_type_4_counter.items()]
        node_type_4_counter_list.sort(key=lambda x: -x[1])
        if len(node_type_4_counter_list) > 1 and node_type_4_counter_list[0][1] == node_type_4_counter_list[1][1]:
            continue

        if len(node_type_4_list) == 0:
            continue

        generated_data.append({save_q_key: node_type_1_name, save_a_key: graph[node_type_4][node_type_4_counter_list[0][0]]['features']['name']})
        cnt += 1
        if cnt == k:
            break

    return generated_data

In [None]:
## question (medium): Compounds of what pharmacologic class can palliate/treat the most disease with xxx symptom?

random.seed(2038)

# question = "Compounds of what pharmacologic class can palliate the most disease with {symptom_name}? Please answer the pharmacologic class name rather than ID."
# question = "Compounds of what pharmacologic class can palliate the most disease with {symptom_name}?"
question = "Which pharmacologic class includes the most compounds that can palliate the disease with {symptom_name}?"
answer = "{pharmacologic_class_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Compound-palliates-Disease', "Compound_nodes", "Pharmacologic Class-includes-Compound", "Pharmacologic_Class_nodes", "symptom_name", "pharmacologic_class_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "Compounds of what pharmacologic class can treat the most disease with {symptom_name}? Please answer with pharmacologic class name rather than ID."
# question = "Compounds of what pharmacologic class can treat the most disease with {symptom_name}?"
question = "Which pharmacologic class includes the most compounds that can treat the disease with {symptom_name}?"
answer = "{pharmacologic_class_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Compound-treats-Disease', "Compound_nodes", "Pharmacologic Class-includes-Compound", "Pharmacologic_Class_nodes", "symptom_name", "pharmacologic_class_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): Genes participating in what cellular components can upregulated/associated/downregulated the most disease with xxx symptom?

random.seed(2039)

# question = "Genes participating in what cellular components are upregulated in most disease with {symptom_name}? Please answer the cellular component names rather than IDs."
# question = "Genes participating in what cellular components are upregulated in most disease with {symptom_name}?"
question = "Which cellular component is participated by most genes that are upregulated in disease with {symptom_name}?"
answer = "{cellular_component_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Disease-upregulates-Gene', "Gene_nodes", "Gene-participates-Cellular Component", "Cellular_Component_nodes", "symptom_name", "cellular_component_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "Genes participating in what cellular components are associated in most disease with {symptom_name}? Please answer the cellular component names rather than IDs."
# question = "Genes participating in what cellular components are associated in most disease with {symptom_name}?"
question = "Which cellular component is participated by most genes that are associated in disease with {symptom_name}?"
answer = "{cellular_component_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Disease-associates-Gene', "Gene_nodes", "Gene-participates-Cellular Component", "Cellular_Component_nodes", "symptom_name", "cellular_component_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "Genes participating in what cellular components are downregulated in most disease with {symptom_name}? Please answer the cellular component names rather than IDs."
# question = "Genes participating in what cellular components are downregulated in most disease with {symptom_name}?"
question = "Which cellular component is participated by most genes that are downregulated in disease with {symptom_name}?"
answer = "{cellular_component_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Disease-downregulates-Gene', "Gene_nodes", "Gene-participates-Cellular Component", "Cellular_Component_nodes", "symptom_name", "cellular_component_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): Genes participating in what pathways are upregulated/associates/downregulated in the most disease with xxx symptom?

random.seed(2040)

# question = "Genes participating in what pathways are upregulated in the most disease with {symptom_name}? Please answer the pathway names rather than IDs."
# question = "Genes participating in what pathways are upregulated in the most disease with {symptom_name}?"
question = "Which pathway is participated by most genes that are upregulated in disease with {symptom_name}?"
answer = "{pathway_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Disease-upregulates-Gene', "Gene_nodes", "Gene-participates-Pathway", "Pathway_nodes", "symptom_name", "pathway_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "Genes participating in what pathways are associated in the most disease with {symptom_name}? Please answer the pathway names rather than IDs."
# question = "Genes participating in what pathways are associated in the most disease with {symptom_name}?"
question = "Which pathway is participated by most genes that are associated in disease with {symptom_name}?"
answer = "{pathway_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Disease-associates-Gene', "Gene_nodes", "Gene-participates-Pathway", "Pathway_nodes", "symptom_name", "pathway_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

# question = "Genes participating in what pathways are downregulated in the most disease with {symptom_name}? Please answer the pathway names rather than IDs."
# question = "Genes participating in what pathways are downregulated in the most disease with {symptom_name}?"
question = "Which pathway is participated by most genes that are downregulated in disease with {symptom_name}?"
answer = "{pathway_name}"
generated_data = four_hop('Symptom_nodes', 'Disease-presents-Symptom', 'Disease_nodes', 'Disease-downregulates-Gene', "Gene_nodes", "Gene-participates-Pathway", "Pathway_nodes", "symptom_name", "pathway_name")
assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data


In [None]:
def one_hop_count(node_type_1, edge_type_12, node_type_2, save_q_name, save_a_name):
    generated_data = []

    cnt = 0
    node_type_1_ids = list(graph[node_type_1].keys())
    random.shuffle(node_type_1_ids)

    for node_type_1_id in node_type_1_ids:
        node_type_1_name = graph[node_type_1][node_type_1_id]['features']['name']
        if edge_type_12 not in graph[node_type_1][node_type_1_id]['neighbors']:
            continue
        node_type_2_ids = graph[node_type_1][node_type_1_id]['neighbors'][edge_type_12]

        final_set = set(graph[node_type_2][node_type_2_ids[0]]['neighbors'][edge_type_12])
        for node_type_2_id in node_type_2_ids[1:]:
            final_set = final_set & set(graph[node_type_2][node_type_2_id]['neighbors'][edge_type_12])

        if len(final_set) == 1 or len(final_set) == 0:
            continue

        generated_data.append({save_q_name: node_type_1_name, save_a_name: (len(final_set) - 1)})
        cnt += 1
        if cnt == k:
            break
    
    return generated_data

In [None]:
## question (medium): How many genes have the same biological process with xxx?

random.seed(2041)

question = "How many genes participate the exact same biological processes with {gene_name}?"
answer = "{num}"

generated_data = one_hop_count('Gene_nodes', 'Gene-participates-Biological Process', 'Biological_Process_nodes', 'gene_name', 'num')

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): How many diseases have the same symptom with xxx?

random.seed(2042)

question = "How many diseases present the exact same symptoms with {disease_name}?"
answer = "{num}"

generated_data = one_hop_count('Disease_nodes', 'Disease-presents-Symptom', 'Symptom_nodes', 'disease_name', 'num')

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## save
#json.dump(all_generated_data, open('data.json', 'w'), indent=4)
pickle.dump(all_generated_data, open(os.path.join(data_dir, 'preprocess_samples.pkl'), 'wb'))
print(len(all_generated_data))