In [None]:

import openai
import pandas as pd
import re
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
import matplotlib.pyplot as plt
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import string
import os
from collections import defaultdict

nltk.download('punkt')
nltk.download('stopwords')

api_key = ''
client = openai.OpenAI(api_key=api_key)

KG_file_path = './dataset/knowledge graph of chronic pain.xlsx'
file_path = './dataset/AI Data Set with Categories.csv'
embedding_save_path = './Embeddings_saved/CP_KG_embeddings'



def preprocess_text(text):
    if pd.isna(text):
        return ''
    text = re.sub(r'\(.*?\)', '', text).strip()
    text = text.replace('_', ' ')
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    tokens = word_tokenize(text)
    return ' '.join(tokens)

kg_data = pd.read_excel(KG_file_path, usecols=['subject', 'relation', 'object'])

knowledge_graph = {}
for index, row in kg_data.iterrows():
    subject = row['subject']
    relation = row['relation']
    obj = row['object']

    if subject not in knowledge_graph:
        knowledge_graph[subject] = []
    knowledge_graph[subject].append((relation, obj))

    if obj not in knowledge_graph:
        knowledge_graph[obj] = []
    knowledge_graph[obj].append((relation, subject))

kg_data['object_preprocessed'] = kg_data.apply(
    lambda row: preprocess_text(row['object']) if row['relation'] != 'is_a' else None,
    axis=1
)
symptom_nodes = kg_data['object_preprocessed'].dropna().unique().tolist()

def get_symptom_embeddings(symptom_nodes, save_path):
    embeddings_path = os.path.join(save_path, 'KG_embeddings.npy')
    if os.path.exists(embeddings_path):
        print("load existing embeddings...")
        return np.load(embeddings_path)
    else:
        print("generate new embeddings...")
        symptom_embeddings = []
        for symptom in tqdm(symptom_nodes):
            response = client.embeddings.create(
                input=symptom,
                model="text-embedding-3-large"
            )
            symptom_embeddings.append(response.data[0].embedding)
        np.save(embeddings_path, symptom_embeddings)

        return np.array(symptom_embeddings)

symptom_embeddings = get_symptom_embeddings(symptom_nodes, embedding_save_path)

def find_top_n_similar_symptoms(query, symptom_nodes, symptom_embeddings, n):
    if pd.isna(query) or not query:
        return []
    query_preprocessed = preprocess_text(query)
    response = client.embeddings.create(
        input=query_preprocessed,
        model="text-embedding-3-large"
    )
    query_embedding = response.data[0].embedding
    if not query_embedding:
        return []

    if len(symptom_embeddings) > len(symptom_nodes):
        symptom_embeddings = symptom_embeddings[:len(symptom_nodes)]

    similarities = cosine_similarity([query_embedding], symptom_embeddings).flatten()

    top_n_symptoms = []
    unique_symptoms = set()
    top_n_indices = similarities.argsort()[::-1]

    for i in top_n_indices:
        if similarities[i] > 0.5 and symptom_nodes[i] not in unique_symptoms:
            top_n_symptoms.append(symptom_nodes[i])
            unique_symptoms.add(symptom_nodes[i])
        if len(top_n_symptoms) == n:
            break

    return top_n_symptoms

def compute_shortest_path_length(node1, node2, G):
    try:
        return nx.shortest_path_length(G, source=node1, target=node2)
    except nx.NetworkXNoPath:
        return float('inf')

categories = [
    "thoracoabdominal_pain_syndromes",
    "neuropathic_pain_syndromes",
    "craniofacial_pain_syndromes",
    "cervical_spine_pain_syndromes",
    "limb_and_joint_pain_syndromes",
    "back_pain_syndromes",
    "lumbar_degenerative_and_stenosis_and_radicular_and_sciatic_syndromes",
    "generalized_pain_syndromes",

]
G = nx.Graph()
for node, edges in knowledge_graph.items():
    for relation, neighbor in edges:
        G.add_edge(node, neighbor, relation=relation)

def get_diagnoses_for_symptom(symptom):

    diagnoses = []
    if symptom in G:
        for neighbor in G.neighbors(symptom):
            edge_data = G.get_edge_data(neighbor, symptom)
            if edge_data and 'relation' in edge_data and edge_data['relation'] != 'is_a':
                diagnoses.append(neighbor)
    return diagnoses

def find_closest_category(top_symptoms, categories,top_n):
    if isinstance(top_symptoms, pd.Series) and top_symptoms.empty:
        print("Warning: top_symptoms is empty.")
        return None
    category_votes = {category: 0 for category in categories}
    for symptom in top_symptoms:
        top_symptoms = list(set(top_symptoms))

        # print('symptom: ',symptom)
        if symptom not in G:
            print(f"Symptom node not found in graph: {symptom}")
            continue

        diagnosis_nodes = get_diagnoses_for_symptom(symptom)
        for diagnosis in diagnosis_nodes:

            individual_diagnoses = diagnosis.split(',')

            for single_diagnosis in individual_diagnoses:
                single_diagnosis = single_diagnosis.strip().replace(' ', '_').lower()  # 去掉前后空格
                if single_diagnosis not in G:
                    print(f"Diagnosis node not found in graph: {single_diagnosis}")
                    continue

                min_distance = float('inf')
                closest_category = None

                for category in categories:
                    if category not in G:
                        print(f"Category node not found in graph: {category}")
                        continue

                    try:
                        distance = nx.shortest_path_length(G, source=single_diagnosis, target=category)
                    except nx.NetworkXNoPath:
                        distance = float('inf')

                    if distance < min_distance:
                        min_distance = distance
                        closest_category = category

                if closest_category:
                    category_votes[closest_category] += 1
    print("Category votes:", category_votes)

    sorted_categories = sorted(category_votes.items(), key=lambda x: x[1], reverse=True)
    top_n_categories = [sorted_categories[i][0] for i in range(top_n)]
    return top_n_categories


def get_keyinfo_for_category(category, knowledge_graph):
    keyinfo_values = []
    for node, edges in knowledge_graph.items():
        if node == category:
            for relation, neighbor in edges:
                if relation == "is_a" and neighbor in knowledge_graph:
                    for rel, obj in knowledge_graph[neighbor]:
                        if rel == "has_keyinfo":
                            keyinfo_values.append(obj)
    return keyinfo_values

def get_subjects_for_objects(objects, knowledge_graph):
    subjects = []
    processed_objects = [obj.replace(' ', '_') for obj in objects]
    for obj in processed_objects:
        for index, row in knowledge_graph.iterrows():
            if row['object'] == obj:
                subjects.append(row['subject'])
    return subjects

def find_level3_for_symptoms(top_symptoms, knowledge_graph):
    level3_connections = {}
    for symptom in top_symptoms:
        subjects = get_subjects_for_objects([symptom], knowledge_graph)
        for subject in subjects:
            if subject in level3_connections:
                level3_connections[subject] += 1
            else:
                level3_connections[subject] = 1
    return level3_connections

def print_symptom_and_disease(symptom_nodes):
    for symptom in symptom_nodes:
        subjects = get_subjects_for_objects([symptom], kg_data)


def main_get_category_and_level3(n, participant_no,top_n):
    data = pd.read_csv(file_path, encoding='ISO-8859-1')

    row = data.loc[data['Participant No.'] == str(participant_no)]
    if row.empty:
        print(f"Participant No. {participant_no} not found!")
        return None

    tr = row["Level 2"].values[0]
    tr=tr.split(",")[0]

    level3real = row["Processed Diagnosis"].values[0]

    pain_location = row["Pain Presentation and Description"].values[0]
    pain_symptoms = row["Pain descriptions and assorted symptoms (self-report)"].values[0]
    pain_restriction = row["Pain restriction"].values[0]
    print(f'pain_location: {pain_location}')
    print(f'pain_symptoms: {pain_symptoms}')
    print(f'pain_restrction: {pain_restriction}')
    if pd.isna(pain_location):
        pain_location = ''
    if pd.isna(pain_symptoms):
        pain_symptoms = ''
    if pd.isna(pain_restriction):
        pain_symptoms = ''


    def process_symptom_field(field_value, symptom_nodes, symptom_embeddings, n):
        if pd.isna(field_value) or field_value == '':
            return []
        return find_top_n_similar_symptoms(field_value, symptom_nodes, symptom_embeddings, n)

    top_5_location_nodes = process_symptom_field(pain_location, symptom_nodes, symptom_embeddings, n)
    top_5_symptom_nodes = process_symptom_field(pain_symptoms, symptom_nodes, symptom_embeddings, n)
    top_5_painrestriction_nodes = process_symptom_field(pain_restriction, symptom_nodes, symptom_embeddings, n)


    top_5_location_nodes_original = kg_data.loc[kg_data['object_preprocessed'].isin(top_5_location_nodes), 'object'].drop_duplicates()
    top_5_symptom_nodes_original = kg_data.loc[kg_data['object_preprocessed'].isin(top_5_symptom_nodes), 'object'].drop_duplicates()
    top_5_painrestriction_original = kg_data.loc[kg_data['object_preprocessed'].isin(top_5_painrestriction_nodes), 'object'].drop_duplicates()


    most_similar_category = find_closest_category(
        list(top_5_location_nodes_original) + list(top_5_symptom_nodes_original)+ list(top_5_painrestriction_original),
        categories,
        top_n
    )
    return most_similar_category


In [None]:
import openai
import faiss
import numpy as np
import os
import re
import json
import pandas as pd
from tqdm import tqdm
from huggingface_hub import InferenceClient
from KG_Retrieve import main_get_category_and_level3
from authentication import api_key,hf_token

client = openai.OpenAI(api_key=api_key)

def get_embeddings(texts):
    embeddings = []
    for text in tqdm(texts):
        response = client.embeddings.create(
            input=text,
            model="text-embedding-3-large"
        )
        embeddings.append(response.data[0].embedding)
    return np.array(embeddings)


def get_query_embedding(query):
    return get_embeddings([query])[0]


# FAISS
def Faiss(document_embeddings, query_embedding, k):
    # index = faiss.IndexFlatL2(document_embeddings.shape[1])
    index = faiss.IndexFlatIP(document_embeddings.shape[1])
    # index = faiss.IndexHNSWFlat(document_embeddings.shape[1])
    index.add(document_embeddings)
    _, indices = index.search(np.array([query_embedding]), k)
    print("index: ", indices)
    return indices

def extract_diagnosis(generated_text):
    diagnoses = re.findall(r'\*\*Diagnosis\*\*:\s(.*?)\n', generated_text)
    return diagnoses

def remove_parentheses(text):
    return re.sub(r'\(.*?\)', '', text).strip()

def KG_preprocess(file_path):
    kg_data = pd.read_excel(file_path, usecols=['subject', 'relation', 'object'])
    kg_data['subject'] = kg_data['subject'].apply(remove_parentheses)
    kg_data['object'] = kg_data['object'].apply(remove_parentheses)

    knowledge_graph = {}
    for index, row in kg_data.iterrows():
        subject = row['subject']
        relation = row['relation']
        obj = row['object']

        if subject not in knowledge_graph:
            knowledge_graph[subject] = []
        knowledge_graph[subject].append((relation, obj))

        if obj not in knowledge_graph:
            knowledge_graph[obj] = []
        knowledge_graph[obj].append((relation, subject))
    return knowledge_graph


def extract_features_from_json(file_path):
    with open(file_path, 'r') as file:
        patient_case = json.load(file)

    pain_location = patient_case.get("Pain Presentation and Description Areas of pain as per physiotherapy input", "")
    pain_symptoms = patient_case.get(
        "Pain descriptions and assorted symptoms (self-report) Associated symptoms include: parasthesia, numbness, weakness, tingling, pins and needles",
        "")

    return pain_location, pain_symptoms

level_3_to_level_2 = {
    # Here are subcategories: diseases
    # Examples: 
    
    # Respiratory System
    "acute_copd_exacerbation_infection": "respiratory_system",

    # Cardiovascular System
    "atrial_fibrillation": "cardiovascular_system",

}


def get_additional_info_from_level_2(participant_no,  kg_path,top_n,match_n):
    level_2_values=main_get_category_and_level3(match_n,participant_no,top_n)
    additional_info = []
    if not level_2_values:
        print(f"No data found for Participant No.: {participant_no}")
        return None
    for level_2_value in level_2_values:
        relevant_level_3_descriptions = [desc for desc, level2 in level_3_to_level_2.items() if level2 == level_2_value]
        print("Relevant Level 3 Descriptions:", relevant_level_3_descriptions)
        if not relevant_level_3_descriptions:
            print("No Level 3 descriptions found for Level 2:", level_2_value)
            continue

        kg_data = pd.read_excel(kg_path, usecols=['subject', 'relation', 'object'])
        if kg_data.empty:
            print("Knowledge graph data is empty.")
            return None

        merged_info = {}

        for level_3 in relevant_level_3_descriptions:
            related_info = kg_data[kg_data['subject'] == level_3]

            if related_info.empty:
                print(f"No related information found in KG for: {level_3}")
            else:
                for _, row in related_info.iterrows():
                    subject = row['subject']
                    relation = row['relation'].replace('_', ' ')
                    obj = row['object']

                    if (subject, relation) in merged_info:
                        merged_info[(subject, relation)].append(obj)
                    else:
                        merged_info[(subject, relation)] = [obj]

        # K
        additional_info = []
        for (subject, relation), objects in merged_info.items():
            sentence = f"{subject} {relation} {', '.join(objects)}"
            additional_info.append(sentence)

    if not additional_info:
        print("No additional information found.")
        return None

    final_info = ', '.join(additional_info)
    print("Additional Info:", final_info)
    return final_info


def get_system_prompt_for_RAGKG():
    return '''
        You are a knowledgeable medical assistant with expertise in pain management.
        Your tasks are:
        1. Analyse and refer to the retrieved similar patients' cases and knowledge graph which may be relevant to the diagnosis and assist with new patient cases.
2. Output of "Diagnoses" must come from : acute copd exacerbation infection, bronchiectasis, bronchiolitis, bronchitis, bronchospasm acute asthma exacerbation, pulmonary embolism, pulmonary neoplasm, spontaneous pneumothorax, urti, viral pharyngitis, whooping cough, acute laryngitis, acute pulmonary edema, croup, larygospasm, epiglottitis, pneumonia, atrial fibrillation, myocarditis, pericarditis, psvt, possible nstemi stemi, stable angina, unstable angina, gerd, boerhaave syndrome, pancreatic neoplasm, scombroid food poisoning, inguinal hernia, tuberculosis, hiv initial infection, ebola, influenza, chagas, acute otitis media, acute rhinosinusitis, allergic sinusitis, chronic rhinosinusitis, myasthenia gravis, guillain barre syndrome, cluster headache, acute dystonic reactions, sle, sarcoidosis, anaphylaxis, panic attack, spontaneous rib fracture, anemia.        3. You are given differences of diagnoses of similar symptoms or pain locations. Read that information as a reference to your diagnostic if applicable.
        4. Do mind the nuance between these factors of similar diagnosis with knowledge graph information and consider it when diagnose new patient's informtation.
        5. Ensure that the recommendations are evidence-based and consider the most recent and effective practices in pain management.
        6. The output should include four specific treatment-related fields:
           - "Diagnoses (related to pain)"
           - Explanations of diagnose
           - "Pain/General Physiotherapist Treatments\nSession No.: General Overview\n- Specific interventions/treatments"
           - "Pain Psychologist Treatments"
           - "Pain Medicine Treatments"
        7. In "Diagnoses", only output the diagnosis itself. Place all other explanations and analyses (if any) into "Explanations of diagnose".
        8. You can leave Psychologist Treatments blank if not applicable for the case, leaving text "Not applicable"
        9.If you think information is needed, guide the doctor to ask further questions which following areas to distinguish between the most likely diseases: Pain restriction; Location; Symptom. Seperate answers with ",". The output should only include aspects.
        10. The output should follow this structured format:
        

    ### Diagnoses
    1. **Diagnosis**: Answer.
    2. **Explanations of diagnose**: Answer.
    
    ### Instructive question
    1. **Questions**: Answer.
    
    ### Pain/General Physiotherapist Treatments
    1. **Session No.: General Overview**
        - **Specific interventions/treatments**:
        - **Goals**:
        - **Exercises**:
        - **Manual Therapy**:
        - **Techniques**:

    2. **Exercise Recommendations from the Exercise List**:

    ### Pain Psychologist Treatments(if applicable)
    1. **Treatment 1**: 
    
    ### Pain Medicine Treatments


    ### Recommendations for Further Evaluations
    1. **Evaluation 1**:
    '''


def generate_diagnosis_report(path, query, retrieved_documents, i,top_n,match_n,model):
    system_prompt_RAGKG = get_system_prompt_for_RAGKG()
    system_prompt=system_prompt_RAGKG
    additional_info= get_additional_info_from_level_2(i ,path,top_n=top_n,match_n=match_n)

    prompt = f"{query}\nRetrieved Documents: {retrieved_documents}\nInformation from knowledge graph about relevant diagnoses, if you think the patient's disease is relevant from the suggestions provided by the atlas please refer to thoses details to distinguish similar diagnoses : {additional_info} .Now complete the tasks in that format"


    ############################################################################################openai
    if model =='gpt-4o' or 'gpt-4o-mini' or 'gpt-3.5-turbo-0125':
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}
            ]
        )
        return response.choices[0].message.content
    else:
        prompt=f"""<s>[INST] <<SYS>> {system_prompt} <</SYS>> {prompt} [/INST]"""
        LLMclient = InferenceClient(
            "meta-llama/Meta-Llama-3.1-8B-Instruct",
            # "meta-llama/Llama-2-13b-chat-hf",
            # "meta-llama/Meta-Llama-3.1-70B-Instruct",
            # "meta-llama/Llama-2-13b-hf",
            # "Qwen/Qwen2-7B-Instruct",
            # "Qwen/Qwen2.5-0.5B-Instruct",
            # "mistralai/Mistral-7B-Instruct-v0.2",
            # 'mistralai/Mixtral-8x7B-Instruct-v0.1',
            token=hf_token
        )
        response = LLMclient.text_generation(prompt=prompt,max_new_tokens=400)
        return response

def save_results_to_csv(results, output_file):
    df = pd.DataFrame(results,
                      columns=['Participant No.', 'Generated Diagnosis', 'True Diagnosis', 'Original Diagnosis'])
    df.to_csv(output_file, index=False)


folder_path=".dataset/df/train"
documents = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path) if
             os.path.isfile(os.path.join(folder_path, file_name))]

document_embeddings_file_path='./dataset/document_embeddings.npy'

def save_embeddings(embeddings, file_path):
    np.save(file_path, embeddings)

def load_embeddings(file_path):
    return np.load(file_path)
if os.path.exists(document_embeddings_file_path):
    document_embeddings = load_embeddings(document_embeddings_file_path)
else:
    document_embeddings = get_embeddings(documents)
    save_embeddings(document_embeddings, document_embeddings_file_path)


In [None]:
import os
import re
import json
import pandas as pd
from tqdm import tqdm
from main_MedRAG import get_query_embedding, Faiss,  extract_diagnosis, documents, document_embeddings,generate_diagnosis_report, save_results_to_csv, get_additional_info_from_level_2,KG_preprocess, get_embeddings
from authentication import ob_path,test_folder_path,ground_truth_file_path,augmented_features_path

disease_list = [
    "Head pain", "Migraine", "Trigeminal neuralgia", "Cervical spondylosis", "Chronic neck pain", "Neck pain",
    "Chest pain", "Abdominal pain", "Limb pain", "Shoulder pain", "Hip pain", "Knee pain", "Buttock pain",
    "Calf pain", "Low back pain", "Chronic low back pain", "Mechanical low back pain", "Upper back pain",
    "Degenerative disc disease", "Lumbar spondylosis", "Lumbar canal stenosis", "Spinal stenosis", "Foraminal stenosis",
    "Lumbar_radicular_pain", "Radicular pain", "Sciatica", "Lumbosacral pain", "Generalized body pain", "Fibromyalgia",
    "Musculoskeletal pain", "Myofascial pain syndrome", "Neuropathic pain", "Post-herpetic neuralgia"
]
ground_truth = pd.read_csv(ground_truth_file_path, header=0)

results = []
file_paths = os.listdir(test_folder_path)
topk=1
top_n=1
match_n=5
samplerange=range(1,552)

for i in tqdm(samplerange):

    print("topk:",topk)
    print("top_ns:",top_n)
    print("match_n:", match_n)
    print("i= ",i)
    file_path = os.path.join(test_folder_path, f"participant_{i}.json")
    if not os.path.exists(file_path):
        print(f'{i} is not found')
        continue

    with open(file_path, 'r') as file:
        new_patient_case = json.load(file)
        print(new_patient_case)

    participant_no = new_patient_case['Participant No.']
    query = json.dumps(new_patient_case)

    success = False
    while not success:
        try:
            query_embedding = get_query_embedding(query)
            indices = Faiss(document_embeddings, query_embedding,k=topk)
            retrieved_documents = [documents[i] for i in indices[0]]
            final_retrieved_info =[]
            correct_count = 0
            for retrieved_document in retrieved_documents:
                with open(retrieved_document, 'r') as file:
                    patient_case = json.load(file)
                    patient_case_json = json.dumps(patient_case)
                    patient_case_dict = json.loads(patient_case_json)
                    filtered_patient_case_dict = {
                        key: patient_case_dict[key] for key in [
                            "Processed Diagnosis",
                            "Pain Presentation and Description Areas of pain as per physiotherapy input",
                            "Pain descriptions and assorted symptoms (self-report) Associated symptoms include: parasthesia, numbness, weakness, tingling, pins and needles",
                            "Pain/General Physiotherapist Treatments (Treatments\nSession No.: General Overview\n- Specific interventions/treatments)",
                            "Pain Psychologist Treatments (Treatments)",
                            "Pain Medicine Treatments (Treatments)",
                        ] if key in patient_case_dict
                    }
                    final_retrieved_info.append(filtered_patient_case_dict)

    # ——————————————————————————————————————————————————————————————————————————————————
            true_diagnosis_row = ground_truth.loc[ground_truth['Participant No.'] == participant_no]
            if true_diagnosis_row.empty:
                print(f"True diagnosis for patient_{participant_no} not found in ground truth data")
                break

            true_diagnosis = true_diagnosis_row['Processed Diagnosis'].values[0]
            ori_truth = true_diagnosis_row['Diagnoses (related to pain)'].values[0]
            generated_report_ori = generate_diagnosis_report(augmented_features_path,query, final_retrieved_info, i,top_n=top_n,match_n=match_n)
            print(generated_report_ori)

            generated_diagnosis = re.findall(r'\*\*Diagnosis\*\*:\s*(.*?)(?:\.|\n|$)', generated_report_ori)
            if not generated_diagnosis:
                print("Generated diagnosis is either empty or not in the specified disease list. Retrying...")
                results.append([participant_no, '', true_diagnosis, ori_truth, generated_report_ori])
                break
            else:
                print("Success!!!")


            results.append([participant_no, generated_diagnosis[0], true_diagnosis, ori_truth,generated_report_ori])
            success = True
            print('________________________________________________________________')
        except Exception as e:
            print(f"Error processing patient_{participant_no}: {e}. ")

output_file = f"./test_results_topk{topk}_topn{top_n}_matchn{match_n}_{samplerange}_MedRAG.csv"

df = pd.DataFrame(results, columns=['Participant No.', 'Generated Diagnosis', 'True Diagnosis', 'Ori Truth','Generated report'])
df.to_csv(output_file, index=False)
