In [None]:
import pandas as pd
import pickle
import os
from openai import OpenAI
import faiss
import numpy as np
from dotenv import load_dotenv
import networkx as nx
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pydantic import BaseModel

# Load the hierarchial tree, embedding and OpenAI environment

In [None]:
# Class for ICD code hierarchical tree
class ICDcodeNode:
    def __init__(self, icd_code, description, children=None, parent=None):
        self.icd_code = icd_code
        self.description = description
        self.children = children if children is not None else []  # Ensure children is a list
        self.parent = parent

    def get_children(self):
        return self.children
        
    def add_child(self, child):
        if child not in self.children:
            self.children.append(child)
            child.set_parent(self)  # Ensure child references the correct parent

    def get_parent(self):
        return self.parent
    
    def set_parent(self, parent):
        self.parent = parent

    def __repr__(self):
        return f'{self.icd_code} - {self.description} - node children: {[child.icd_code for child in self.children]}'

In [None]:
# get all the nodes below a given node inn the tree
def get_self_and_all_lower_nodes(node):
    all_nodes = [node]
    for child in node.get_children():
        all_nodes.extend(get_self_and_all_lower_nodes(child))
    return all_nodes

In [None]:
# loading the ICD code tree
hierarcial_data = pickle.load(open("icd_code_hierarchy.pkl", "rb"))

In [None]:
# Loading the embeddings
embedding_data = pickle.load(open('icd_codes_embeddings_2025_with_chapters.pkl', 'rb'))

In [None]:
#get the openai environemnt
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY')
client = OpenAI()

# embedding function 
def get_embedding(text): 
    response = client.embeddings.create(
    input=text,
    model="text-embedding-3-large",
    dimensions=1024
    )
    return response.data[0].embedding

In [None]:
# Creating a Vector database for the ICD codes using FAISS
embeddings = np.array(embedding_data['embeddings'].tolist()).astype('float32')
index_icd = faiss.IndexFlatIP(embeddings.shape[1])
index_icd.add(embeddings)

def search_by_embedding(query, index ,embedding_df, k=1):
    # function to search for the closest k codes for a given query

    # Get the embedding of the query
    query_embedding = get_embedding(query)
    # D is the the distances, I is the indices of the nearest codes
    D, I = index.search(np.array([query_embedding]), k)
    # Extract the corresponding rows from the ICD codes DataFrame and return 
    return embedding_df.iloc[I[0]][['code', 'long_description']]

# Run Hierarchy search

In [None]:
def is_node_sunbset_or_synonym(medical_condition, icd_term):
    '''
    This function gets the medical condition and the description of the ICD code and decides if the ICD code depicts a synonym or subset of the medical conditon.  
    Return Boolian answer in a structured output
    '''
    class Answer(BaseModel):
        a: bool

    completion = client.beta.chat.completions.parse(
        model="gpt-4o",
        messages=[
            {
            "role": "system",
            "content": '''
            You will recieve a medical condition and an ICD code description. 
            If the icd description is a synonym or a subset of the medical term return "True" otherwise return "False". 

            For example:
            Medical term: "iron deficiency", ICD description: "iron deficiency anemia"
            Return "True"

            Medical term: "staph infection", ICD description: "staph pneumonia"
            Return "True"

            Medical term: "penumococal infection", ICD description: "pneumonia"
            Return "False"

            Return only "True" or "False"!
            '''
        },
        {
            "role": "user",
            "content": f"Medical term: {medical_condition}, ICD description: {icd_term}",
        }
        ],
        response_format=Answer,
        temperature=0,
    )

    return completion.choices[0].message.parsed.a

In [None]:
# Hierarchial search

def run_hierarchial_search(term, k=50):
    '''
    This function runs a hierarchial search for the medical term using the following steps: 
    - embedd the query term 
    - Search the vector database using the embedding and exctract the K terms
    - For each result check if is a subset/synonym -> if yes - go one step up the hierarchy
    - When getting to the highest point that is a subset/synonym - get all nodes bellow it 
    '''


    # keep a list of relevant icd codes
    relevant_icd_codes = [] 

    # search for the term in the icd codes and get k results
    optional_results = search_by_embedding(term, index=index_icd, embedding_df=embedding_data, k=k)

    # check if the term is a subset or synonym of the icd code
    # if yes check it's parent and grandparent until you find the highest node that is a subset or synonym of the term
    for row in tqdm(optional_results.itertuples(index=False)):
        # if we already includede the code - stop here 
        if hierarcial_data[row.code] in relevant_icd_codes:
            continue
        
        #else climb up the heirarchy until no more relevant 
        highiest_node = None
        current_node = hierarcial_data[row.code]
        while is_node_sunbset_or_synonym(term, current_node.description):
            highiest_node = hierarcial_data[row.code]
            if current_node.get_parent() ==hierarcial_data['root']: #we got up to the chapter nodes  
                break
            current_node = current_node.get_parent()
        #Once we climed the highest we include all the nodes bellow that point in the relevant nodes 
        if highiest_node is not None:
            relevant_icd_codes.extend(get_self_and_all_lower_nodes(highiest_node))

    final_list = list(set(relevant_icd_codes))
    return final_list

# Visualizing the results 

In [None]:
def build_networkx_graph(node_list):
    """Create a NetworkX graph from a list of nodes.
    Each node is expected to have `icd_code`, `description`, and a method `get_children()`.
    Only children present in the node_list will be added as edges.
    """
    graph = nx.Graph()
    node_descriptions = {node.description for node in node_list}

    # Add nodes and edges for each node in the list
    for node in node_list:
        # Add node with its description as attributes
        graph.add_node(node.description, icd_code=node.icd_code)
        
        # Add edges to children if they are in the node_list
        for child in node.get_children():
            if child.description in node_descriptions:
                graph.add_node(child.description, icd_code=child.icd_code)
                graph.add_edge(node.description, child.description)
    
    return graph


def visualize_graph(graph):
    """Visualize the NetworkX graph with node size based on the number of connections (degree) and font size adjusted by node degree."""
    plt.figure(figsize=(32, 32))  # Increased figure size for better spacing
    pos = nx.spring_layout(graph, k=0.5, iterations=100)  # Adjusted spring layout parameter for more spacing
    node_sizes = [1000 + 1000 * graph.degree(node) for node in graph.nodes]  # Scale node size by degree
    font_sizes = {node: 10 + 3 * graph.degree(node) for node in graph.nodes}  # Scale font size by degree

    nx.draw(graph, pos, node_size=node_sizes, node_color='lightblue', font_size=10, font_weight='bold', edge_color='gray')  # Removed with_labels=True
    
    # Draw custom labels with adjusted font size
    for node, (x, y) in pos.items():
        plt.text(x, y, s=node, fontsize=font_sizes[node], ha='center', va='center')
    
    nx.draw_networkx_edge_labels(graph, pos, font_size=8)
    plt.show()

# Examples

### Query = Sterptococcal infection

In [None]:
final_list_strep =  run_hierarchial_search('sterptococcal infection', k=100)

In [None]:
graph = build_networkx_graph(final_list_strep)
visualize_graph(graph)