# This script combine prior knowledge and keywords from all lesion.


In [1]:
import pandas as pd
import secret
import json
import os
import pickle
import re

from graphrag_for_all.llm.openai import set_openai_api_key
from graphrag_for_all.llm.huggingface import set_hugging_face_token
from graphrag_for_all.llm.create import get_send_fn
from utils.query import get_questions_by_lesion

set_openai_api_key(secret.OPENAI_API_KEY)
set_hugging_face_token(secret.HUGGINGFACE_TOKEN)
send_fn = get_send_fn(source="huggingface", model_name="meta-llama/Meta-Llama-3.1-8B-Instruct")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
top_5_lesions = [
    # "pulmonary edema",
    "enlarged cardiac silhouette",
    "pulmonary consolidation",
    "atelectasis",
    "pleural abnormality",
]

In [3]:
DEFAULT_LLM_ARGS = {
    "temperature": 0.0,
    "top_p": 1.0,
}  #

In [4]:
with open("./llama3_index_results/graphrag/extracted_keywords.pkl", "rb") as f:
    keyword_extraction_output = pickle.load(f)

In [5]:
from collections import OrderedDict

def build_prior_knowledge(keyword_extraction_output):
    lesion_qa_pairs = {
        k: {q: a for q, a in zip(get_questions_by_lesion(k), v)}
        for k, v in keyword_extraction_output["responses"].items()
    }

    prior_knowledge = OrderedDict({})
    for lesion, q_a in lesion_qa_pairs.items():
        q_a_section = ""
        for q, a in q_a.items():
            q_a_section += f"\n#############################################\n**Question**: {q}\n**Answer**:\n{a}\n"
        lesion_content = f"# Lesion: {lesion}\n" + q_a_section
        prior_knowledge[lesion] = lesion_content
    return prior_knowledge

In [6]:
prior_knowledge  = build_prior_knowledge(keyword_extraction_output)

In [7]:
len(prior_knowledge)

5

In [8]:
all_prior_knowledge = "\n\n\n\n\n".join(prior_knowledge.values())

requesting_prompt = f""" The following is the information from {len(prior_knowledge)} lesions, including {", ".join(list(prior_knowledge.keys()))}. Please combine and summarize them.

{all_prior_knowledge}

(Please return the summarized version directly, without additional text.)

"""

In [9]:
pk_res = send_fn(
    [
        {"role": "system", "content": "You are a helpful clinical assistant."},
        {"role": "user", "content": requesting_prompt},
    ],
    DEFAULT_LLM_ARGS,
)
# 5m 11.2s

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [10]:
print(pk_res.output)

**Summary of Lesions:**

1.  **Pulmonary Edema:**
    *   Symptoms: shortness of breath, coughing up pink, frothy mucus, fatigue, and weakness.
    *   Causes: cardiogenic pulmonary edema (acute exacerbation of congestive heart failure, volume overload, impaired left ventricular function, pericardial tamponade, and heart valve dysfunction), non-cardiogenic pulmonary edema (acute respiratory distress syndrome, pulmonary embolism, and high altitude pulmonary edema).
    *   Relevant clinical signs: acute onset of dyspnea, orthopnea, paroxysmal nocturnal dyspnea, peripheral edema, jugular venous distension, rales or crackles on auscultation, and signs of right heart failure.
    *   Relevant laboratory data: elevated B-type natriuretic peptide (BNP) and N-terminal pro b-type natriuretic peptide (NT-proBNP), arterial blood gas analysis, chest radiography, laboratory tests, and echocardiography.
    *   Relevant clinical characteristics: cardiogenic and non-cardiogenic causes, primary and s

In [11]:
lesion_keywords= "\n\n".join([f"## Lesion: {k}\n**Features:**{json.dumps(v)}\n" for k, v in keyword_extraction_output['keywords'].items()])

keyword_combining_prompt = f"""The following json objects are features from {len(keyword_extraction_output['keywords'])} different lesions. The key represents the feature, while the value indicates the data type.

Please refine and combine the following features from {len(keyword_extraction_output['keywords'])} lesions, including {", ".join(list(keyword_extraction_output['keywords'].keys()))}. 

These features will be used to predict diseases and lesions. However, some features may have incorrect data types, so feel free to correct or modify them as needed.

And the repetitive or similar features from different lesion should be combined or removed.

(Please only return the json object without additional text)

# Features

{lesion_keywords}
"""

refined_keywords_res = send_fn(
    [
        {
            "role": "system",
            "content": f"You are a helpful clinical assistant and has following information in mind:\n{pk_res.output}",
        },
        {"role": "user", "content": keyword_combining_prompt},
    ],
    DEFAULT_LLM_ARGS,
)

In [12]:
print(refined_keywords_res.output)

{
  "Shortness of breath": "boolean",
  "Coughing": "boolean",
  "Chest pain": "boolean",
  "Fever": "boolean",
  "Fatigue": "boolean",
  "Pleural effusion": "boolean",
  "Pneumothorax": "boolean",
  "Pulmonary tuberculosis": "boolean",
  "Pneumonia": "boolean",
  "Pulmonary edema": "boolean",
  "Heart failure": "boolean",
  "COPD": "boolean",
  "Lung disease": "boolean",
  "Smoking history": "boolean",
  "Occupation and exposure to chemicals": "boolean",
  "History of trauma or injury": "boolean",
  "Medical history": "boolean",
  "Pleuritic chest pain": "boolean",
  "Chills": "boolean",
  "Coughing up blood or rust-colored sputum": "boolean",
  "Oxygen levels": "numerical",
  "Body Temperature": "numerical",
  "Breathing Rate": "numerical",
  "Age": "numerical",
  "Pleural thickening": "boolean",
  "Exudative pleural effusion": "boolean",
  "Transudative pleural effusion": "boolean",
  "Glucose levels": "numerical",
  "Protein levels": "numerical",
  "Lactate dehydrogenase (LDH) leve

In [13]:
dataset_features = [
    "Gender",
    "Age",
    "Blood Pressure",
    "Body Temperature",
    "Heart rate",
    "Respiratory Rate",
    "Oxygen Saturation",
    "Age",
    "Gender",
]

dataset_features_str = ", ".join(dataset_features)

res_existing_features = send_fn(
    refined_keywords_res.history
    + [
        {
            "role": "user",
            "content": f"From above refined features, please indicate me the features that are exactly included in: {dataset_features_str}. (Only return a list of related features without additional text)",
        }
    ],
    DEFAULT_LLM_ARGS,
)

In [14]:
print(res_existing_features.output)

[
  "Age",
  "Sex",
  "Heart rate",
  "Blood pressure",
  "Body Temperature",
  "Oxygen saturation",
  "Breathing rate",
  "Respiratory rate"
]


In [15]:
with open("combined_results", "wb") as f:
    pickle.dump(
        {
            "prior_knowledge": pk_res,
            "refined_keyword": refined_keywords_res,
            "existing_features": res_existing_features,
        },
        f,
    )