### Set up the openai environment

In [110]:
# set up the openai environment
from openai import OpenAI
import os 
# os.environ["OPENAI_API_KEY"] = "<your-api-key>"
client = OpenAI()

### Query embedding

In [111]:
# define the function to get the embedding
def get_embedding(text, model="text-embedding-3-large"):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model, dimensions=1024).data[0].embedding


heart_filure_embedding = get_embedding("heart failure")

In [112]:
text = '[' 
for x in heart_filure_embedding:
    text+=str(x)+','
text = text[:-1]+']'
print(text)

[-0.04984709620475769,0.07442083954811096,0.019687417894601822,0.008340643718838692,-0.021338054910302162,0.024836096912622452,-0.0018159756436944008,0.026803746819496155,-0.052864156663417816,0.0027232803404331207,-0.020048152655363083,0.048710234463214874,0.008428094908595085,-0.023502469062805176,-0.0060505191795527935,-0.03679502755403519,-0.03447757661342621,-0.0013582240790128708,-0.0018487698398530483,0.0035854929592460394,0.01623310148715973,-0.03388728201389313,0.008745105005800724,0.005667921155691147,-0.01844123937189579,0.04274170100688934,0.03939669579267502,0.011827753856778145,-0.03404032066464424,-0.02592923492193222,0.030061297118663788,0.05343259125947952,-0.009308070875704288,0.01616751216351986,0.07271554321050644,0.040621012449264526,-0.011499812826514244,0.013292559422552586,-0.023895999416708946,-0.035352084785699844,0.006493240129202604,-0.045343369245529175,-0.06432024389505386,0.08482751995325089,0.05282043293118477,0.037297870963811874,-0.007504392880946398,0

In [113]:
import faiss
import pandas as pd
import numpy as np

### Set up the vector database

In [114]:
icd_codes_embedding = pd.read_csv('../batch_openAI/data/icd_codes_with_embedding_1024.csv', index_col=0)
# The embedding is stored as a string, we need to convert it to a list
icd_codes_embedding['embedding_list'] = icd_codes_embedding['embedding'].apply(lambda x:eval(x))

In [115]:
# Extract embeddings from DataFrame
embeddings = np.array(icd_codes_embedding['embedding_list'].tolist()).astype('float32')

In [116]:
# Create FAISS index
index = faiss.IndexFlatIP(embeddings.shape[1])
# Add embeddings to the index
index.add(embeddings)

### Function for RAG searching 

In [117]:
# Function to search by embedding
def search_by_embedding(query, k=1):
    # Get the embedding of the query
    query_embedding = get_embedding(query)

    # Search for the nearest neighbors
    # D is the the disances, I is the indices of the nearest neighbors
    D, I = index.search(np.array([query_embedding]), k)

    # get the index of the results and extractt the corresponding rows from the ICD codes DataFrame
    results = icd_codes_embedding.iloc[I[0]][['code', 'long_description']]

    # Return the results
    final_output = ''
    for row in results.itertuples():
        final_output += f'{row.long_description}({row.code});'
    final_output = final_output[:-1]    
    return final_output

In [118]:
output = search_by_embedding("heart failure", k=5)
print(output)
for code in output.split(';'):
  print(code)

Other heart failure(I5089);Heart failure, unspecified(I509);End stage heart failure(I5084);High output heart failure(I5083);Chronic right heart failure(I50812)
Other heart failure(I5089)
Heart failure, unspecified(I509)
End stage heart failure(I5084)
High output heart failure(I5083)
Chronic right heart failure(I50812)


### Running the LLM+RAG algorithm on a paragraph from a clinical note

In [119]:
paragraph = '''
Patient: 68-year-old male.
History: CHF, dyslipidemia, HTN.
The patient presented to the clinic with a productive cough and fever for the past 10 days. 
He was referred by his primary care physician for an X-ray, which revealed an opacity in the left lower lobe.
Based on these findings, the patient was diagnosed with bacterial pneumonia and started on treatment with amoxicillin.
The patient's cough improved, but his shortness of breath worsened. 
On auscultation, wet crackles were heard over the lower portion of his lungs. 
His BNP level has increased to 150 from a previous measurement of 26.A repeated X-ray showed bilateral opacities.
Given his medical history and the current acute infective illness, the patient is likely experiencing a worsening of heart failure.
'''

In [120]:
# Function to check if the line contains a disease
def check_if_disease(line):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": '''
                You are a medical coder that is going through a doctors note. 
                You will be given a line form the note.
                Only if the line contains a current disease or condition return True.
                Else return False.
                Be very strict about this! If the line contains past history, symptoms or signs without a diagnosis or non related information return False.  
                Return only True or False and nothing else!
                '''
            },
            {
                "role": "user",
                "content": f"line: {line}"
            }
        ],
        temperature=0
    )

    return eval(response.choices[0].message.content.strip())



# Function to chat with the model
def choose_codes(line, codes):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": '''
                You are a medical coder that is going through a doctors note. 
                You will be given a line form the note and a list of 10 ICD-10 codes that might fit the description.
                Choose up to one code that best fit the description in the line. If no code is suitable provide no codes.
                Use only the codes provided and do not add any new codes!
                Provide the codes in the format: description(code)    
                '''
            },
            {
                "role": "user",
                "content": f"line: {line},  potential codes:{codes}"
            }
        ],
        temperature=0
    )

    return response.choices[0].message.content.strip()


def search_by_embedding(query, k=1):
    # Get the embedding of the query
    query_embedding = get_embedding(query)

    # Search for the nearest neighbors
    # D is the the disances, I is the indices of the nearest neighbors
    D, I = index.search(np.array([query_embedding]), k)

    # get the index of the results and extractt the corresponding rows from the ICD codes DataFrame
    results = icd_codes_embedding.iloc[I[0]][['code', 'long_description']]

    # Return the results
    final_output = ''
    for row in results.itertuples():
        final_output += f'{row.long_description}({row.code});'
    final_output = final_output[:-1]    
    return final_output

In [121]:
# split the note to sentences
splitted_notes = paragraph.split('.')


all_codes = []
for line in splitted_notes:
    # loop through the lines and check if they contain a disease
    if check_if_disease(line):
        print('This line contains a current disease:', line)

        # If the line contains a disease, search for the codes using the RAG
        output = search_by_embedding(line, k=10)
        print('Potential codes:',output)

        # Choose the best code from the list
        chosen_code = choose_codes(line, output)
        print('Chosen codes:', chosen_code)
        all_codes.append(chosen_code)
    else:
        print('This line does not contain a current disease:', line)
        continue
all_codes = set(all_codes)
print(all_codes)


This line does not contain a current disease: 
Patient: 68-year-old male
This line does not contain a current disease: 
History: CHF, dyslipidemia, HTN
This line does not contain a current disease: 
The patient presented to the clinic with a productive cough and fever for the past 10 days
This line does not contain a current disease:  
He was referred by his primary care physician for an X-ray, which revealed an opacity in the left lower lobe
This line contains a current disease: 
Based on these findings, the patient was diagnosed with bacterial pneumonia and started on treatment with amoxicillin
Potential codes: Unspecified bacterial pneumonia(J159);Pneumonia due to Klebsiella pneumoniae(J150);Pneumonia due to Acinetobacter baumannii(J1561);Pneumonia due to Streptococcus pneumoniae(J13);Whooping cough due to Bordetella pertussis with pneumonia(A3701);Whooping cough due to Bordetella parapertussis with pneumonia(A3711);Pneumonia due to Methicillin susceptible Staphylococcus aureus(J152

### Check the performance of an LLM with no help from the RAG

In [122]:
# Function to check if the line contains a disease
def check_if_disease_and_extract_icd_code(line):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": '''
                You are a medical coder that is going through a doctors note. 
                You will be given a line form the note.
                If the line contains a current disease or condition return the ICD code that best fits the description.
                The code should be in the format: description(code). 
                Return only one code and only in the format provided.
                Else return False.
                Be very strict about this! If the line contains past history, symptoms or signs without a diagnosis or non related information return False.  
                '''
            },
            {
                "role": "user",
                "content": f"line: {line}"
            }
        ],
        temperature=0
    )
    # print(codes)
    return response.choices[0].message.content.strip()

#Running the function
splitted_notes = paragraph.split('.')
all_codes = []
for line in splitted_notes:
        result =  check_if_disease_and_extract_icd_code(line)
        if result != False: 
            all_codes.append(result)
all_codes = set(all_codes)
print(all_codes)

{'bacterial pneumonia (J18.9)', 'False', 'Heart failure (I50.9)'}
