In [None]:
from helpers import *
from neo4j_helpers import *

## Get ICD codes and respective descriptions

In [None]:
codes = cm.get_all_codes()

icd_code_description = {}

for item in tqdm(codes):
    if cm.is_leaf(item):
        icd_code_description[item] = cm.get_description(item)


## Extract entities and relations 

In [None]:
extracted_graphs = {}

for key, value in tqdm(icd_code_description.items()):
    input_code_description = "ICD Code Description: " + value
    output = get_completion(prompt_relation_extraction, input_code_description)
    extracted_graphs[key] = output

## Save the extracted output

In [None]:
extracted_graphs = json.loads(open("extracted_entities_relations/extracted_graphs_72633.json").read())

In [None]:
with open("extracted_entities_and_relations.json", "w") as f:
    json.dump(extracted_graphs, f)

# Construct KG

## Extract all entities across the ICD descriptions

In [None]:
all_entities = []
for key, value in tqdm(extracted_graphs.items()):
    entity_list = extract_entities(value)
    all_entities += entity_list

all_entities = list(set(all_entities))

## Normalize all entities by linking against UMLS

In [None]:
normalized_entity_map = normalize_entities(all_entities)

## Build Graph

In [None]:
graphs_list = []
for key, value in tqdm(extracted_graphs.items()):
    icd_description = icd_code_description[key]
    graph = build_graph(value, key, icd_description, normalized_entity_map) 
    graphs_list.append(graph)

In [None]:
kg = nx.compose_all(graphs_list)

## Index to Neo4j

In [None]:
driver = GraphDatabase.driver(uri, auth=(username, password))


with driver.session() as session:
    session.write_transaction(create_index)

    nodes = [
        {'id': node_id, 'attributes': attributes}
        for node_id, attributes in kg.nodes(data=True)
    ]

    relationships = [
        {
            'source_id': source_id,
            'target_id': target_id,
            'attributes': attributes
        }
        for source_id, target_id, attributes in kg.edges(data=True)
    ]

    for i in range(0, len(nodes), BATCH_SIZE):
        batch = nodes[i:i+BATCH_SIZE]
        session.write_transaction(create_nodes, batch)

    for i in range(0, len(relationships), BATCH_SIZE):
        batch = relationships[i:i+BATCH_SIZE]
        session.write_transaction(create_relationships, batch)
        
driver.close()