In [1]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from openai import OpenAI
from neo4j import GraphDatabase
from camel.storages import Neo4jGraph
from camel.agents import KnowledgeGraphAgent
from camel.loaders import UnstructuredIO
from concurrent.futures import ThreadPoolExecutor
from itertools import repeat
import shortuuid
import os
import random
import warnings
from tqdm.auto import tqdm
from itertools import combinations
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import backoff
import openai

warnings.filterwarnings("ignore", category=UserWarning, module="langsmith.client")

n4j = Neo4jGraph(
    url=os.getenv("NEO4J_URL"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD")
)

openai_api_key = os.getenv("OPENAI_API_KEY")
model_name = "gpt-4.1-nano"
embedder_name = "text-embedding-3-small"
client = OpenAI(api_key=openai_api_key)

@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=999, max_tries=99)
def completions_with_backoff(**kwargs):
    return client.chat.completions.create(**kwargs)

def ask_gpt(user, sys) -> str:
    response = completions_with_backoff(
        model=model_name,
        messages=[
            {"role": "system", "content": sys},
            {"role": "user", "content": user},
        ],
        max_tokens=5000,
        n=1,
        stop=None,
        temperature=0.5,
    )
    return response.choices[0].message.content

def get_embedding(text):
    response = client.embeddings.create(
        input=text,
        model=embedder_name
    )
    return response.data[0].embedding

threshold = 0.8
def cosine_similarity_paragraphs(paragraph1: str, paragraph2: str) -> float:
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform([paragraph1, paragraph2])
    sim_matrix = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])
    return float(sim_matrix[0][0])


  from .autonotebook import tqdm as notebook_tqdm


# Agentic Chunker

In [2]:
def logging(file, message):
    with open(file, "a") as f:
        f.write(message + "\n")

def clearing(file):
    with open(file, "w") as f:
        f.write("")

class AgenticChunker:
    def __init__(self):
        self.chunks = {}
        self.llm = ChatOpenAI(model=model_name, api_key=openai_api_key, temperature=0)

    def add_propositions(self, propositions):
        for proposition in tqdm(propositions, desc="Adding propositions to chunks", unit="proposition"):
            if len(self.chunks) == 0:
                self._create_new_chunk(proposition)
                continue
            chunk_id = self._find_relevant_chunk(proposition)
            if chunk_id:
                self.chunks[chunk_id]['propositions'].append(proposition)
            else:
                self._create_new_chunk(proposition)
        
    def _get_new_chunk_summary(self, proposition):
        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    You are the steward of a group of chunks which represent groups of sentences that talk about a similar topic
                    You should generate a very brief 1-sentence summary which will inform viewers what a chunk group is about.

                    A good summary will say what the chunk is about, and give any clarifying instructions on what to add to the chunk.

                    You will be given a proposition which will go into a new chunk. This new chunk needs a summary.

                    Your summaries should anticipate generalization. If you get a proposition about apples, generalize it to food.
                    Or month, generalize it to "date and times".

                    Example:
                    Input: Proposition: Greg likes to eat pizza
                    Output: This chunk contains information about the types of food Greg likes to eat.

                    Only respond with the new chunk summary, nothing else.
                    """,
                ),
                ("user", "Determine the summary of the new chunk that this proposition will go into:\n{proposition}"),
            ]
        )

        runnable = PROMPT | self.llm

        new_chunk_summary = runnable.invoke({
            "proposition": proposition
        }).content

        return new_chunk_summary
    
    def _get_new_chunk_title(self, summary):
        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    You are the steward of a group of chunks which represent groups of sentences that talk about a similar topic
                    You should generate a very brief few word chunk title which will inform viewers what a chunk group is about.

                    A good chunk title is brief but encompasses what the chunk is about

                    You will be given a summary of a chunk which needs a title

                    Your titles should anticipate generalization. If you get a proposition about apples, generalize it to food.
                    Or month, generalize it to "date and times".

                    Example:
                    Input: Summary: This chunk is about dates and times that the author talks about
                    Output: Date & Times

                    Only respond with the new chunk title, nothing else.
                    """,
                ),
                ("user", "Determine the title of the chunk that this summary belongs to:\n{summary}"),
            ]
        )

        runnable = PROMPT | self.llm

        new_chunk_title = runnable.invoke({
            "summary": summary
        }).content

        return new_chunk_title


    def _create_new_chunk(self, proposition):
        new_chunk_id = str(shortuuid.uuid())
        new_chunk_summary = self._get_new_chunk_summary(proposition)
        new_chunk_title = self._get_new_chunk_title(new_chunk_summary)

        self.chunks[new_chunk_id] = {
            'chunk_id' : new_chunk_id,
            'propositions': [proposition],
            'title' : new_chunk_title,
            'summary': new_chunk_summary,
            'chunk_index' : len(self.chunks)
        }

    def _find_relevant_chunk(self, proposition):
        chunk_found = None
        similariest = 0
        for chunk_id, chunk in self.chunks.items():
            id = chunk['chunk_id'] 
            content = chunk['title'] + chunk['summary']
            sim = cosine_similarity_paragraphs(content, proposition)
            if sim > similariest:
                similariest = sim
                chunk_found = id
        if similariest < threshold:
            return None
        return chunk_found
    
    def get_chunks(self):
        chunks = []
        for chunk_id, chunk in self.chunks.items():
            chunks.append(" ".join([x for x in chunk['propositions']]))
        return chunks
    
    def pretty_print_chunks(self):
        clearing("current_chunks.txt")
        logging("current_chunks.txt", f"\nYou have {len(self.chunks)} chunks\n")
        for chunk_id, chunk in self.chunks.items():
            logging("current_chunks.txt", f"Chunk #{chunk['chunk_index']}")
            logging("current_chunks.txt", f"Chunk ID: {chunk_id}")
            logging("current_chunks.txt", f"Summary: {chunk['summary']}")
            logging("current_chunks.txt", f"Propositions:")
            for prop in chunk['propositions']:
                logging("current_chunks.txt", f"    -{prop}")

# Summerize

In [3]:
sum_prompt = """
Generate a structured summary from the provided medical source (report, paper, or book), strictly adhering to the following categories. The summary should list key information under each category in a concise format: 'CATEGORY_NAME: Key information'. No additional explanations or detailed descriptions are necessary unless directly related to the categories:
Each category should be addressed only if relevant to the content of the medical source. Ensure the summary is clear and direct, suitable for quick reference.

ANATOMICAL_STRUCTURE: Mention any anatomical structures specifically discussed.
BODY_FUNCTION: List any body functions highlighted.
BODY_MEASUREMENT: Include normal measurements like blood pressure or temperature.
BM_RESULT: Results of these measurements.
BM_UNIT: Units for each measurement.
BM_VALUE: Values of these measurements.
LABORATORY_DATA: Outline any laboratory tests mentioned.
LAB_RESULT: Outcomes of these tests (e.g., 'increased', 'decreased').
LAB_VALUE: Specific values from the tests.
LAB_UNIT: Units of measurement for these values.
MEDICINE: Name medications discussed.
MED_DOSE, MED_DURATION, MED_FORM, MED_FREQUENCY, MED_ROUTE, MED_STATUS, MED_STRENGTH, MED_UNIT, MED_TOTALDOSE: Provide concise details for each medication attribute.
PROBLEM: Identify any medical conditions or findings.
PROCEDURE: Describe any procedures.
PROCEDURE_RESULT: Outcomes of these procedures.
PROC_METHOD: Methods used.
SEVERITY: Severity of the conditions mentioned.
MEDICAL_DEVICE: List any medical devices used.
SUBSTANCE_ABUSE: Note any substance abuse mentioned.
"""

def chunking(docs, chunk_size=32768, chunk_overlap=1024):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return text_splitter.split_text(docs)

def process_chunks(content: str):
    chunks = chunking(content)
    with ThreadPoolExecutor() as executor:
        responses = list(executor.map(ask_gpt, chunks, repeat(sum_prompt)))
    return responses

def run_chunk(docs, begin=0):
    propositions = chunking(docs)
    ac = AgenticChunker()
    ac.add_propositions(propositions)
    ac.pretty_print_chunks()
    chunks = ac.get_chunks()
    return chunks[begin:]

# Create Graph

In [4]:
def add_ge_emb(graph_element):
    for node in graph_element.nodes:
        emb = get_embedding(node.id)
        node.properties['embedding'] = emb
    return graph_element

def add_gid(graph_element, gid):
    for node in graph_element.nodes:
        node.properties['gid'] = gid
    for rel in graph_element.relationships:
        rel.properties['gid'] = gid
    return graph_element

def link_sum(gid):
    link_sum_query = """
        CALL apoc.periodic.iterate(
        '
            MATCH (s:Summary {gid: $gid}), (n)
            WHERE n.gid = $gid AND NOT n:Summary
            RETURN s, n
        ',
        '
            MERGE (s)-[:SUMMARIZES]->(n)
        ',
        {
            batchSize: 5000,
            parallel: false,
            params: { gid: $gid }
        }
        );
        """
    n4j.query(link_sum_query, {'gid': gid})

def reset_sum():
    rmv_dups_query = """
        MATCH (n:Summary)-[r:SUMMARIZES]->()
        DELETE r
        """
    S = n4j.query(rmv_dups_query)
    print("Done removing duplicates!")
    list_gids = n4j.query("MATCH (n) RETURN DISTINCT n.gid")
    list_gids = [gid['n.gid'] for gid in list_gids]
    for gid in tqdm(list_gids, desc="Re-adding summaries", unit="gid"):
        link_sum(gid)
    print("Done re-adding summaries!")

def add_sum(n4j,content,gid):
    sum = process_chunks(content)
    # print("Summary: ", sum)
    creat_sum_query = """
        CREATE (s:Summary {content: $sum, gid: $gid})
        RETURN s
        """
    n4j.query(creat_sum_query, {'sum': sum, 'gid': gid})
    link_sum(gid)

def creat_metagraph(content, gid, n4j, begin=0):
    print(f"gid: {gid}")
    uio = UnstructuredIO()
    kg_agent = KnowledgeGraphAgent()
    whole_chunk = content
    saved_chunks = "/home/ngjabach/Documents/NgJaBach/Medical-Graph-RAG/half_baked/" + str(gid) + ".txt"
    if os.path.exists(saved_chunks):
        content = []
        with open(saved_chunks, 'r') as file:
            for line in file:
                content.append(line.strip())
    else:
        content = run_chunk(content, begin)
        with open(saved_chunks, 'w') as file:
            for item in content:
                file.write(str(item) + '\n')
    for cont in tqdm(content, desc="Processing chunks", unit="chunk"):
        element_example = uio.create_element_from_text(text=cont)
        graph_elements = kg_agent.run(element_example, parse_graph_elements=True)
        graph_elements = add_ge_emb(graph_elements)
        graph_elements = add_gid(graph_elements, gid)
        n4j.add_graph_elements(graph_elements=[graph_elements])

    add_sum(n4j, whole_chunk, gid)
    print("God blessed us all!")
    return n4j

# Data Chunk

In [5]:
def ref_link(n4j, gid1, gid2):
    trinity_query = """
        MATCH (a)
        WHERE a.gid = $gid1 AND NOT a:Summary
        WITH collect(a) AS GraphA

        MATCH (b)
        WHERE b.gid = $gid2 AND NOT b:Summary
        WITH GraphA, collect(b) AS GraphB

        UNWIND GraphA AS n
        UNWIND GraphB AS m

        WITH n, m, $threshold AS threshold
        WHERE apoc.coll.sort(labels(n)) = apoc.coll.sort(labels(m)) AND n <> m
        WITH n, m, threshold,
            vector.similarity.cosine(n.embedding, m.embedding) AS similarity
        WHERE similarity > threshold
        MERGE (m)-[:REFERENCE]->(n)

        RETURN n, m
"""
    result = n4j.query(trinity_query, {'gid1': gid1, 'gid2': gid2, 'threshold': threshold})
    return result

def check(file_name):
    with open("Done.txt", "r") as file:
        content = file.read()
        taboo = content.split()
    if file_name in taboo:
        print(f"Skipping {file_name} as it is already processed.")
        return False
    else:
        return True

def mega_load(n4j, data_path, k=-1, msg="files"):
    files = [file for file in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, file))]
    if k == -1:
        k = len(files)
    print(f"Number of {msg}: {len(files)}")
    print(f"Randomly selecting {k} {msg}...")
    k_selected = random.sample(files, k)
    k_selected = sorted(k_selected)
    print(f"Selected {k} {msg}: {k_selected}")
    for i, file_name in enumerate(k_selected):
        if not check(file_name):
            continue
        print(f"Processing {i + 1}th {msg}: {file_name}")
        content = ""
        file_path = os.path.join(data_path, file_name)
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                content += line.strip() + "\n"
        gid = str(shortuuid.uuid())
        n4j = creat_metagraph(content, gid, n4j)
        logging("Done.txt", file_name)
    list_gids = n4j.query("MATCH (n) RETURN DISTINCT n.gid")
    list_gids = [gid['n.gid'] for gid in list_gids]
    for i in range(len(list_gids)):
        for j in range(0, i):
            gid1 = list_gids[i]
            gid2 = list_gids[j]
            ref_link(n4j, gid1, gid2)

# Main

In [6]:
# print("Processing first floor...")
# mega_load(n4j, "./patients", k=5, msg="patient(s)")

In [7]:
# def half_baked(file_path, gid, n4j):
#     list_gids = n4j.query("MATCH (n) RETURN DISTINCT n.gid")
#     list_gids = [gid['n.gid'] for gid in list_gids]
#     content = ""
#     with open(file_path, 'r', encoding='utf-8') as file:
#         for line in file:
#             content += line.strip() + "\n"
#     n4j = creat_metagraph(content, gid, n4j, 509+455)

# half_baked("./books_MEDQA/Pathology_Robbins.txt", "e9VJYZmhXzdrr2PXAqmrne", n4j)

In [8]:
# print("Processing second floor...")
# mega_load(n4j, "./books_MEDQA/", msg="book(s)")

In [9]:
# list_gids = n4j.query("MATCH (n) RETURN DISTINCT n.gid")
# list_gids = [gid['n.gid'] for gid in list_gids]
# for i in range(len(list_gids)):
#     for j in range(0, i):
#         print(f"Linking {i} and {j}")
#         gid1 = list_gids[i]
#         gid2 = list_gids[j]
#         ref_link(n4j, gid1, gid2)

In [10]:
def link_context(n4j, gid):
    cont = []
    retrieve_query = """
        MATCH (n)
        WHERE n.gid = $gid AND NOT n:Summary

        MATCH (n)-[r:REFERENCE]->(m)
        WHERE NOT m:Summary

        MATCH (m)-[s]-(o)
        WHERE NOT o:Summary AND TYPE(s) <> 'REFERENCE'

        RETURN n.id AS NodeId1, 
            m.id AS Mid, 
            TYPE(r) AS ReferenceType, 
            collect(DISTINCT {RelationType: type(s), Oid: o.id}) AS Connections
    """
    res = n4j.query(retrieve_query, {'gid': gid})
    for r in res:
        for ind, connection in enumerate(r["Connections"]):
            cont.append("Reference " + str(ind) + ": " + r["NodeId1"] + "has the reference that" + r['Mid'] + connection['RelationType'] + connection['Oid'])
    return cont

def ret_context(n4j, gid):
    cont = []
    ret_query = """
    // Match all nodes with a specific gid but not of type "Summary" and collect them
    MATCH (n)
    WHERE n.gid = $gid AND NOT n:Summary
    WITH collect(n) AS nodes

    // Unwind the nodes to a pairs and match relationships between them
    UNWIND nodes AS n
    UNWIND nodes AS m
    MATCH (n)-[r]-(m)
    WHERE n.gid = m.gid AND elementId(n) < elementId(m) AND NOT n:Summary AND NOT m:Summary // Ensure each pair is processed once and exclude "Summary" nodes in relationships
    WITH n, m, TYPE(r) AS relType

    // Return node IDs and relationship types in structured format
    RETURN n.id AS NodeId1, relType, m.id AS NodeId2
    """
    res = n4j.query(ret_query, {'gid': gid})
    for r in res:
        cont.append(r['NodeId1'] + r['relType'] + r['NodeId2'])
    return cont

In [None]:
import pandas as pd

def find_index_of_largest(nums):
    sorted_with_index = sorted((num, index) for index, num in enumerate(nums)) # Sorting the list while keeping track of the original indexes
    largest_original_index = sorted_with_index[-1][1] # Extracting the original index of the largest element
    return largest_original_index
sys_p = """
Assess the similarity of the two provided summaries and return a rating from these options: 'very similar', 'similar', 'general', 'not similar', 'totally not similar'. Provide only the rating.
"""
def seq_ret(n4j, sumq):
    rating_list = []
    sumk = []
    gids = []
    sum_query = """
        MATCH (s:Summary)
        RETURN s.content, s.gid
        """
    res = n4j.query(sum_query)
    for r in res:
        sumk.append(r['s.content'])
        gids.append(r['s.gid'])
    
    for sk in sumk:
        sk = sk[0]
        rate = ask_gpt("The two summaries for comparison are: \n Summary 1: " + sk + "\n Summary 2: " + sumq[0], sys_p)
        if "totally not similar" in rate:
            rating_list.append(0)
        elif "not similar" in rate:
            rating_list.append(1)
        elif "general" in rate:
            rating_list.append(2)
        elif "very similar" in rate:
            rating_list.append(4)
        elif "similar" in rate:
            rating_list.append(3)
        else:
            print("llm returns no relevant rate")
            rating_list.append(-1)

    ind = find_index_of_largest(rating_list)
    gid = gids[ind]
    
    return gid

sys = """
Please answer the question using your knowledge and leveraging the additional information and references.
Return a single letter "A", "B", "C" or "D" corresponding to the correct answer.
Do not return any other text.
"""

def answer_llm(prompt: str) -> str:
    response = completions_with_backoff(
        model="gpt-4.1-nano",
        messages=[
            {
                "role": "system", 
                "content": sys
            },
            {
                "role": "user",
                "content": prompt,
            }
        ],
        max_tokens=1,
        n=1
    )
    return response.choices[0].message.content

def get_response(n4j, gid, query):
    selfcont = ret_context(n4j, gid)
    linkcont = link_context(n4j, gid)
    user_zero = f"""
    The question is: {query}
    The provided information is: {selfcont}
    The references are: {linkcont}
    """
    while True:
        res = answer_llm(user_zero)
        if res in ['A', 'B', 'C', 'D']:
            break
    return res

df = pd.read_json("MedQA.jsonl", lines=True)

# correct = 0
# for idx, row in tqdm(df.iterrows(), total=len(df)):
#     q = row['question']
#     a = row['options']
#     s = row['answer_idx']
#     # QA = f"Q: {q}\nA: {a['A']}\nB: {a['B']}\nC: {a['C']}\nD: {a['D']}\nAnswer: {s}"
#     # print(QA)
#     brompt = f"Question: {q}\nA: {a['A']}\nB: {a['B']}\nC: {a['C']}\nD: {a['D']}\nAnswer:"
#     while True:
#         sum = process_chunks(brompt)
#         gid = seq_ret(n4j, sum)
#         res = get_response(n4j, gid, brompt)
#         if res in ['A', 'B', 'C', 'D']:
#             break
#     if res == s:
#         correct += 1

import concurrent.futures
from tqdm import tqdm

def process_row(row):
    q, a, s = row['question'], row['options'], row['answer_idx']
    brompt = (f"Question: {q}\n"
              f"A: {a['A']}\nB: {a['B']}\n"
              f"C: {a['C']}\nD: {a['D']}\nAnswer:")
    sum_ = process_chunks(brompt)
    gid = seq_ret(n4j, sum_)
    res = get_response(n4j, gid, brompt)
    return res == s

def init_worker():
    # Reconnect or reinitialize your Neo4j session here if needed
    global n4j
    n4j = Neo4jGraph(
    url=os.getenv("NEO4J_URL"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD")
)

def run_parallel_processes(df, max_workers=4):
    rows = [row for _, row in df.iterrows()]
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers,
                                                initializer=init_worker) as executor:
        results = list(tqdm(executor.map(process_row, rows),
                            total=len(rows), desc="Processing"))
    return sum(results)

# Usage
correct = run_parallel_processes(df)

accuracy = correct / len(df) * 100
print(f"Total: {len(df)}")
print(f"Correct: {correct}")
print(f"Accuracy: {accuracy:.2f}%")

Processing:   4%|▍         | 52/1273 [27:00<10:34:06, 31.16s/it] 
