In [None]:
import pandas as pd
import re
import networkx as nx
import pickle
from tqdm.auto import tqdm
from dotenv import load_dotenv
from openai import OpenAI
from tqdm.contrib.concurrent import thread_map

# Get the data

## Table data for regulare codes

In [None]:
def parse_file(file_path):
    parsed_lines = []
    
    with open(file_path, 'r') as file:
        for line in file:
            line = line.rstrip() 
            
            # Apply parsing rules to extract data
            section1 = line[6:13].strip()     
            section2 = line[13:16].strip()      
            section3 = line[16:77].strip()      
            section4 = line[77:].strip()        
            
            # Append the parsed data as a list
            parsed_lines.append([section1, section2, section3, section4])
    
    return parsed_lines

file_path = 'icd10cm-order-2025.txt'  
parsed_data = parse_file(file_path)

icd_codes = pd.DataFrame(parsed_data, columns=['code','legal','short_description','long_description'])
icd_codes = icd_codes.drop(['legal'], axis=1)

 ## Get the range headers from a icd hierarchy document

In [None]:
chapters =[]
sections=[]
with open('chapters_and_sections.txt','r') as f:
    for line in f:
        splitted = line.replace('Chapter ','')
        splitted = splitted.split(' - ')
        code = splitted[0]
        description = splitted[1]
        if 'Chapter' in line:
            chapters.append([code, description])
        else:
            sections.append([code, description])

In [None]:
len(chapters), len(sections)

# Helper functions

In [None]:
import re

def is_range_in_range(range1, range2):
    def parse_code(code):
        """
        Parse the code into letter and numeric components.
        Non-initial letters in the numeric part are replaced with '9' for comparison.
        """
        letter = code[0]
        number = int(re.sub(r'[a-zA-Z]', '9', code[1:]))
        return letter, number

    def parse_range(range_str):
        """
        Parse a range string into its start and end components.
        If the input is a single code (not a range), the start and end are the same.
        """
        if '-' in range_str:
            start, end = range_str.split('-')
        else:
            start = end = range_str
        return parse_code(start), parse_code(end)

    # Parse the ranges
    (start1, end1), (start2, end2) = parse_range(range1), parse_range(range2)

    # Check if range1 is fully within range2
    def is_within(start_a, end_a, start_b, end_b):
        return (start_b <= start_a <= end_b) and (start_b <= end_a <= end_b)

    start2_letter, start2_num = start2
    end2_letter, end2_num = end2
    start1_letter, start1_num = start1
    end1_letter, end1_num = end1

    # Compare letters first, then numbers if letters are the same
    if start2_letter <= start1_letter <= end2_letter and start2_letter <= end1_letter <= end2_letter:
        if start1_letter == start2_letter and end1_letter == end2_letter:
            return start2_num <= start1_num and end1_num <= end2_num
        if start1_letter == start2_letter:
            return start2_num <= start1_num
        if end1_letter == end2_letter:
            return end1_num <= end2_num
        return True
    return False

# Build the hierarchy tree

## Node class

In [None]:
# class for icd code hierarchial tree
class ICDcodeNode:
    def __init__(self, icd_code, description, children, parent):
        self.icd_code = icd_code
        self.description = description
        self.children  =  children
        self.parent = parent

    def get_children(self):
        return self.children
    
    def add_child(self, child):
        if child not in self.children:
            self.children.append(child)

    def get_parent(self):
        return self.parent
    
    def set_parent(self,parent):
        self.parent = parent

    def __repr__(self):
        return f'{self.icd_code} - {self.description} -  Children: {[x.icd_code for x in self.children]}'

## Set up the root and the chapters and sections

In [None]:
# Creating the root node 
node_dict = {} # this will store all the nodes
root_node = ICDcodeNode('root', 'root', [], None)
node_dict['root'] = root_node

In [None]:
# chapter nodes
for chapter in chapters: 
    chapter_node = ICDcodeNode(chapter[0], chapter[1], [], parent= node_dict['root'])
    node_dict['root'].add_child(chapter_node)
    node_dict[chapter[0]]=chapter_node

In [None]:
# section nodes
for section in sections:
    flag=False
    for chapter in chapters:
        if is_range_in_range(section[0], chapter[0]):
            parent_node = node_dict[chapter[0]] 
            section_node = ICDcodeNode(section[0], section[1], [], parent_node)
            parent_node.add_child(section_node)
            node_dict[section[0]]=section_node
            flag=True
            break
    if flag==False:
        print(f'error in section {section}')

## Other codes

In [None]:
# get the length of the code
icd_codes['len_code'] = icd_codes['code'].apply(lambda x: len(x))

In [None]:
# get the 3 letter codes as children of the sections
three_letter = icd_codes.loc[icd_codes.len_code==3].copy()
for row in three_letter.itertuples():
    flag=False
    for section in sections:
        if is_range_in_range(row.code,section[0]):
            parent_node = node_dict[section[0]] 
            code_node = ICDcodeNode(row.code, row.long_description, [], parent_node)
            parent_node.add_child(code_node)
            node_dict[row.code]=code_node
            flag=True
            break
    if flag==False:
        print(f'error in section {row.code}')

In [None]:
# get the 4-7 letter codes as children of the ones with smaller length
for i in  range(4,8,1):
    longer_codes = icd_codes.loc[icd_codes.len_code==i].copy()
    for long_code in longer_codes.itertuples():
        # if after the initial letter there are only numbers:
        if not re.search(r'[a-zA-Z]', long_code.code[1:]):
            try:
                parent_node = node_dict[long_code.code[:-1]]
            except:
                print(long_code.code)
        
        # if the end of the code is letters
        else:
            for i in range(1,5):
                if long_code.code[:-i] in node_dict:
                    parent_node = node_dict[long_code.code[:-i]]
                    flag=True
                    break

        child_node = ICDcodeNode(long_code.code, long_code.long_description, [], parent_node)
        parent_node.add_child(child_node)
        node_dict[long_code.code]=child_node

In [None]:
len(list(node_dict.keys()))

In [None]:
pickle.dump(node_dict, open('icd_code_hierarchy.pkl', 'wb'))

# Visualization

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

def build_networkx_graph(root_node):
    """Convert ICDcodeNode tree to NetworkX graph"""
    graph = nx.Graph()

    def add_nodes_edges(node):
        # Add node with its ICD code and description as attributes
        graph.add_node(node.icd_code, description=node.description)

        # Add edges for each child
        for child in node.get_children():
            graph.add_edge(node.icd_code, child.icd_code)
            add_nodes_edges(child)

    add_nodes_edges(root_node)
    return graph

def visualize_large_graph(graph):
    """Visualize NetworkX graph with matplotlib"""
    pos = nx.spring_layout(graph, k=0.1, iterations=100) #spacing between nodes

    plt.figure(figsize=(8.5, 8.5))

    # Draw edges
    nx.draw_networkx_edges(graph, pos, alpha=0.5, width=0.5, edge_color='#888')

    # Draw nodes
    node_sizes = []
    for node in graph.nodes():
        # Size inversely proportional to ICD code length (longer codes are lower in the hierarchy)
        code_length = len(node)
        if '-' in node: #(code is a section/chapter)
            code_length = 1
        node_sizes.append(1 / (code_length + 1) * 2000)  # Scale factor

    nx.draw_networkx_nodes(
        graph, pos, 
        node_size=node_sizes,
        node_color=[len(graph.edges(n)) for n in graph.nodes()],
        cmap=plt.cm.plasma, alpha=0.5
    )

    # Add labels to nodes (ICD codes)
    labels = {node: node for node in graph.nodes()}
    nx.draw_networkx_labels(graph, pos, labels, font_size=8)

    plt.title('ICD Code Hierarchy Visualization', fontsize=16)
    plt.axis('off') 
    plt.tight_layout()
    plt.show()


root_node = node_dict['A15-A19']

# Convert to NetworkX graph and visualize
graph = build_networkx_graph(root_node)
print('Built graph')
visualize_large_graph(graph)

# Embedding ICD codes

In [None]:
all_icds= []
for key in node_dict:
    node = node_dict[key]
    all_icds.append([node.icd_code, node.description])
all_icds = pd.DataFrame(all_icds, columns=['code','long_description'])
all_icds = all_icds.loc[all_icds.code!='root'].copy()

In [None]:
#openai environemnt
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY')
client = OpenAI()

In [None]:
def get_embedding(text): 
    response = client.embeddings.create(
    input=text,
    model="text-embedding-3-large",
    dimensions=1024  # Specify the desired number of dimensions
    )
    return response.data[0].embedding

In [None]:
embeddings = thread_map(get_embedding, all_icds['long_description'].values)

In [None]:
all_icds['embeddings'] =embeddings

In [None]:
pickle.dump(all_icds, open('icd_codes_embeddings_2025_with_chapters.pkl', 'wb'))