In [38]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
from langchain import hub
from openai import OpenAI
from neo4j import GraphDatabase
from camel.storages import Neo4jGraph
from camel.agents import KnowledgeGraphAgent
from camel.loaders import UnstructuredIO
from concurrent.futures import ThreadPoolExecutor
from itertools import repeat
import tiktoken
import shortuuid
from typing import List
import os
import numpy as np
import random
import warnings
from tqdm.auto import tqdm
from itertools import combinations

warnings.filterwarnings("ignore", category=UserWarning, module="langsmith.client")

n4j = Neo4jGraph(
    url=os.getenv("NEO4J_URL"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD")
)

openai_api_key = os.getenv("OPENAI_API_KEY")
model_name = "gpt-4.1-nano"
embedder_name = "text-embedding-3-small"
client = OpenAI(api_key=openai_api_key)

def ask_gpt(user, sys) -> str:
    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": sys},
            {"role": "user", "content": f" {user}"},
        ],
        max_tokens=500,
        n=1,
        stop=None,
        temperature=0.5,
    )
    return response.choices[0].message.content

def get_embedding(text):
    response = client.embeddings.create(
        input=text,
        model=embedder_name
    )
    return response.data[0].embedding

# Agentic Chunker

In [39]:
def logging(file, message):
    with open(file, "a") as f:
        f.write(message + "\n")

def clearing(file):
    with open(file, "w") as f:
        f.write("")

class AgenticChunker:
    def __init__(self):
        self.chunks = {}
        self.llm = ChatOpenAI(model=model_name, api_key=openai_api_key, temperature=0)

    def add_propositions(self, propositions):
        for proposition in propositions:
            self.add_proposition(proposition)
    
    def add_proposition(self, proposition):
        if len(self.chunks) == 0:
            self._create_new_chunk(proposition)
            return

        chunk_id = self._find_relevant_chunk(proposition)

        if chunk_id:
            self.add_proposition_to_chunk(chunk_id, proposition)
            return
        else:
            self._create_new_chunk(proposition)
        

    def add_proposition_to_chunk(self, chunk_id, proposition):
        self.chunks[chunk_id]['propositions'].append(proposition)
        self.chunks[chunk_id]['summary'] = self._update_chunk_summary(self.chunks[chunk_id])
        self.chunks[chunk_id]['title'] = self._update_chunk_title(self.chunks[chunk_id])

    def _update_chunk_summary(self, chunk):
        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    You are the steward of a group of chunks which represent groups of sentences that talk about a similar topic
                    A new proposition was just added to one of your chunks, you should generate a very brief 1-sentence summary which will inform viewers what a chunk group is about.

                    A good summary will say what the chunk is about, and give any clarifying instructions on what to add to the chunk.

                    You will be given a group of propositions which are in the chunk and the chunks current summary.

                    Your summaries should anticipate generalization. If you get a proposition about apples, generalize it to food.
                    Or month, generalize it to "date and times".

                    Example:
                    Input: Proposition: Greg likes to eat pizza
                    Output: This chunk contains information about the types of food Greg likes to eat.

                    Only respond with the chunk new summary, nothing else.
                    """,
                ),
                ("user", "Chunk's propositions:\n{proposition}\n\nCurrent chunk summary:\n{current_summary}"),
            ]
        )

        runnable = PROMPT | self.llm

        new_chunk_summary = runnable.invoke({
            "proposition": "\n".join(chunk['propositions']),
            "current_summary" : chunk['summary']
        }).content

        return new_chunk_summary
    
    def _update_chunk_title(self, chunk):
        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    You are the steward of a group of chunks which represent groups of sentences that talk about a similar topic
                    A new proposition was just added to one of your chunks, you should generate a very brief updated chunk title which will inform viewers what a chunk group is about.

                    A good title will say what the chunk is about.

                    You will be given a group of propositions which are in the chunk, chunk summary and the chunk title.

                    Your title should anticipate generalization. If you get a proposition about apples, generalize it to food.
                    Or month, generalize it to "date and times".

                    Example:
                    Input: Summary: This chunk is about dates and times that the author talks about
                    Output: Date & Times

                    Only respond with the new chunk title, nothing else.
                    """,
                ),
                ("user", "Chunk's propositions:\n{proposition}\n\nChunk summary:\n{current_summary}\n\nCurrent chunk title:\n{current_title}"),
            ]
        )

        runnable = PROMPT | self.llm

        updated_chunk_title = runnable.invoke({
            "proposition": "\n".join(chunk['propositions']),
            "current_summary" : chunk['summary'],
            "current_title" : chunk['title']
        }).content

        return updated_chunk_title

    def _get_new_chunk_summary(self, proposition):
        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    You are the steward of a group of chunks which represent groups of sentences that talk about a similar topic
                    You should generate a very brief 1-sentence summary which will inform viewers what a chunk group is about.

                    A good summary will say what the chunk is about, and give any clarifying instructions on what to add to the chunk.

                    You will be given a proposition which will go into a new chunk. This new chunk needs a summary.

                    Your summaries should anticipate generalization. If you get a proposition about apples, generalize it to food.
                    Or month, generalize it to "date and times".

                    Example:
                    Input: Proposition: Greg likes to eat pizza
                    Output: This chunk contains information about the types of food Greg likes to eat.

                    Only respond with the new chunk summary, nothing else.
                    """,
                ),
                ("user", "Determine the summary of the new chunk that this proposition will go into:\n{proposition}"),
            ]
        )

        runnable = PROMPT | self.llm

        new_chunk_summary = runnable.invoke({
            "proposition": proposition
        }).content

        return new_chunk_summary
    
    def _get_new_chunk_title(self, summary):
        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    You are the steward of a group of chunks which represent groups of sentences that talk about a similar topic
                    You should generate a very brief few word chunk title which will inform viewers what a chunk group is about.

                    A good chunk title is brief but encompasses what the chunk is about

                    You will be given a summary of a chunk which needs a title

                    Your titles should anticipate generalization. If you get a proposition about apples, generalize it to food.
                    Or month, generalize it to "date and times".

                    Example:
                    Input: Summary: This chunk is about dates and times that the author talks about
                    Output: Date & Times

                    Only respond with the new chunk title, nothing else.
                    """,
                ),
                ("user", "Determine the title of the chunk that this summary belongs to:\n{summary}"),
            ]
        )

        runnable = PROMPT | self.llm

        new_chunk_title = runnable.invoke({
            "summary": summary
        }).content

        return new_chunk_title


    def _create_new_chunk(self, proposition):
        new_chunk_id = str(shortuuid.uuid())
        new_chunk_summary = self._get_new_chunk_summary(proposition)
        new_chunk_title = self._get_new_chunk_title(new_chunk_summary)

        self.chunks[new_chunk_id] = {
            'chunk_id' : new_chunk_id,
            'propositions': [proposition],
            'title' : new_chunk_title,
            'summary': new_chunk_summary,
            'chunk_index' : len(self.chunks)
        }
    
    def get_chunk_outline(self):
        """
        Get a string which represents the chunks you currently have.
        This will be empty when you first start off
        """
        chunk_outline = ""

        for chunk_id, chunk in self.chunks.items():
            single_chunk_string = f"""Chunk ID: {chunk['chunk_id']}\nChunk Name: {chunk['title']}\nChunk Summary: {chunk['summary']}\n\n"""
            chunk_outline += single_chunk_string
        
        return chunk_outline

    def _find_relevant_chunk(self, proposition):
        current_chunk_outline = self.get_chunk_outline()

        PROMPT = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
                    Determine whether or not the "Proposition" should belong to any of the existing chunks.

                    A proposition should belong to a chunk of their meaning, direction, or intention are similar.
                    The goal is to group similar propositions and chunks.

                    If you think a proposition should be joined with a chunk, return the chunk id. Example output: "k7HRjLxddUJLYtChwqPTqf"
                    If you do not think an item should be joined with an existing chunk, just return "No chunks"
                    """,
                ),
                ("user", "Current Chunks:\n--Start of current chunks--\n{current_chunk_outline}\n--End of current chunks--"),
                ("user", "Proposition:\n{proposition}"),
            ]
        )

        runnable = PROMPT | self.llm

        chunk_found = runnable.invoke({
            "proposition": proposition,
            "current_chunk_outline": current_chunk_outline
        }).content

        if chunk_found in self.chunks:
            return chunk_found
        elif chunk_found == "No chunks":
            return None
        logging("Abnormalities.txt", f"---------\nAbnormality Detected!\n{current_chunk_outline}\n\nChunk got: {chunk_found}\n")
        return None
    
    def get_chunks(self):
        chunks = []
        for chunk_id, chunk in self.chunks.items():
            chunks.append(" ".join([x for x in chunk['propositions']]))
        return chunks
    
    def pretty_print_chunks(self):
        clearing("current_chunks.txt")
        logging("current_chunks.txt", f"\nYou have {len(self.chunks)} chunks\n")
        for chunk_id, chunk in self.chunks.items():
            logging("current_chunks.txt", f"Chunk #{chunk['chunk_index']}")
            logging("current_chunks.txt", f"Chunk ID: {chunk_id}")
            logging("current_chunks.txt", f"Summary: {chunk['summary']}")
            logging("current_chunks.txt", f"Propositions:")
            for prop in chunk['propositions']:
                logging("current_chunks.txt", f"    -{prop}")

# Data Chunk

In [40]:
# Data Loader
def load_high(datapath):
    all_content = ""  # Initialize an empty string to hold all the content
    with open(datapath, 'r', encoding='utf-8') as file:
        for line in file:
            all_content += line.strip() + "\n"  # Append each line to the string, add newline character if needed
    return all_content
    
# Pydantic data class
class Sentences(BaseModel):
    sentences: List[str]

def get_propositions(text, runnable, structured_llm):
    runnable_output = runnable.invoke({
        "input": text
    }).content
    propositions = structured_llm.invoke(runnable_output)
    return propositions.sentences

def run_chunk(essay):
    obj = hub.pull("ahsen/proposal-indexing-zero-shot")
    llm = ChatOpenAI(model=model_name, api_key = os.getenv("OPENAI_API_KEY"))
    runnable = obj | llm

    # Extraction
    structured_llm = llm.with_structured_output(Sentences)
    paragraphs = essay.split("\n\n")
    essay_propositions = []

    for i, para in enumerate(paragraphs):
        propositions = get_propositions(para, runnable, structured_llm)
        essay_propositions.extend(propositions)

    ac = AgenticChunker()
    ac.add_propositions(essay_propositions)
    ac.pretty_print_chunks()
    chunks = ac.get_chunks()

    return chunks

# Summerize

In [41]:
sum_prompt = """
Generate a structured summary from the provided medical source (report, paper, or book), strictly adhering to the following categories. The summary should list key information under each category in a concise format: 'CATEGORY_NAME: Key information'. No additional explanations or detailed descriptions are necessary unless directly related to the categories:
Each category should be addressed only if relevant to the content of the medical source. Ensure the summary is clear and direct, suitable for quick reference.

ANATOMICAL_STRUCTURE: Mention any anatomical structures specifically discussed.
BODY_FUNCTION: List any body functions highlighted.
BODY_MEASUREMENT: Include normal measurements like blood pressure or temperature.
BM_RESULT: Results of these measurements.
BM_UNIT: Units for each measurement.
BM_VALUE: Values of these measurements.
LABORATORY_DATA: Outline any laboratory tests mentioned.
LAB_RESULT: Outcomes of these tests (e.g., 'increased', 'decreased').
LAB_VALUE: Specific values from the tests.
LAB_UNIT: Units of measurement for these values.
MEDICINE: Name medications discussed.
MED_DOSE, MED_DURATION, MED_FORM, MED_FREQUENCY, MED_ROUTE, MED_STATUS, MED_STRENGTH, MED_UNIT, MED_TOTALDOSE: Provide concise details for each medication attribute.
PROBLEM: Identify any medical conditions or findings.
PROCEDURE: Describe any procedures.
PROCEDURE_RESULT: Outcomes of these procedures.
PROC_METHOD: Methods used.
SEVERITY: Severity of the conditions mentioned.
MEDICAL_DEVICE: List any medical devices used.
SUBSTANCE_ABUSE: Note any substance abuse mentioned.
"""

def split_into_chunks(text, tokens=4000): # split text into chunks of 4000 tokens
    encoding = tiktoken.get_encoding("cl100k_base")
    words = encoding.encode(text)
    chunks = []
    for i in range(0, len(words), tokens):
        chunks.append(' '.join(encoding.decode(words[i:i + tokens])))
    return chunks   

def process_chunks(content: str):
    chunks = split_into_chunks(content)
    with ThreadPoolExecutor() as executor:
        responses = list(executor.map(ask_gpt, chunks, repeat(sum_prompt)))
    return responses

# Create Graph

In [42]:
def add_ge_emb(graph_element):
    for node in graph_element.nodes:
        emb = get_embedding(node.id)
        node.properties['embedding'] = emb
    return graph_element

def add_gid(graph_element, gid):
    for node in graph_element.nodes:
        node.properties['gid'] = gid
    for rel in graph_element.relationships:
        rel.properties['gid'] = gid
    return graph_element

def add_sum(n4j,content,gid):
    sum = process_chunks(content)
    creat_sum_query = """
        CREATE (s:Summary {content: $sum, gid: $gid})
        RETURN s
        """
    s = n4j.query(creat_sum_query, {'sum': sum, 'gid': gid})
    
    link_sum_query = """
        MATCH (s:Summary {gid: $gid}), (n)
        WHERE n.gid = s.gid AND NOT n:Summary
        CREATE (s)-[:SUMMARIZES]->(n)
        RETURN s, n
        """
    n4j.query(link_sum_query, {'gid': gid})

    return s

threshold = 0.8
def merge_similar_nodes(n4j, gid):
    if gid:
        merge_query = """
            WITH $threshold AS threshold
            MATCH (n), (m)
            WHERE NOT n:Summary AND NOT m:Summary AND n.gid = m.gid AND n.gid = $gid AND n<>m AND apoc.coll.sort(labels(n)) = apoc.coll.sort(labels(m))
            WITH n, m,
                vector.similarity.cosine(n.embedding, m.embedding) AS similarity
            WHERE similarity > threshold
            CALL apoc.refactor.mergeNodes([n,m], {properties: 'overwrite', mergeRels: true})
            YIELD node
            RETURN count(*)
        """
        n4j.query(merge_query, {'gid': gid, 'threshold': threshold})
    else:
        result = n4j.query("""
                MATCH (n)
                WHERE NOT n:Summary AND n.embedding IS NOT NULL
                RETURN elementId(n) AS id, n.embedding AS emb
            """)
        nodes = [(record["id"], record["emb"]) for record in result]
        n = len(nodes)
        total_pairs = n * (n - 1) // 2
        for (id1, emb1), (id2, emb2) in tqdm(combinations(nodes, 2), total=total_pairs, desc="Merging similar nodes", unit="pairs"):
            a = np.array(emb1, dtype=float)
            b = np.array(emb2, dtype=float)
            sim = float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
            if sim <= threshold:
                continue
            n4j.query("""
                UNWIND $nodeIds AS nid
                MATCH (n) WHERE elementId(n)=nid
                WITH collect(n) AS nodes
                CALL apoc.refactor.mergeNodes(nodes, {properties:'overwrite', mergeRels:true})
                YIELD node
                RETURN node
            """, {'nodeIds': [id1, id2]})

def creat_metagraph(content, gid, n4j):
    uio = UnstructuredIO()
    kg_agent = KnowledgeGraphAgent()
    whole_chunk = content
    content = run_chunk(content)
    for cont in tqdm(content):
        element_example = uio.create_element_from_text(text=cont)
        graph_elements = kg_agent.run(element_example, parse_graph_elements=True)
        graph_elements = add_ge_emb(graph_elements)
        graph_elements = add_gid(graph_elements, gid)
        n4j.add_graph_elements(graph_elements=[graph_elements])
    
    merge_similar_nodes(n4j, gid)
    add_sum(n4j, whole_chunk, gid)
    return n4j

# Retrieve

In [43]:
def find_index_of_largest(nums):
    sorted_with_index = sorted((num, index) for index, num in enumerate(nums)) # Sorting the list while keeping track of the original indexes
    largest_original_index = sorted_with_index[-1][1] # Extracting the original index of the largest element
    
    return largest_original_index

sys_p = """
Assess the similarity of the two provided summaries and return a rating from these options: 'very similar', 'similar', 'general', 'not similar', 'totally not similar'. Provide only the rating.
"""

def seq_ret(n4j, sumq):
    rating_list = []
    sumk = []
    gids = []
    sum_query = """
        MATCH (s:Summary)
        RETURN s.content, s.gid
        """
    res = n4j.query(sum_query)
    for r in res:
        sumk.append(r['s.content'])
        gids.append(r['s.gid'])
    
    for sk in sumk:
        sk = sk[0]
        rate = ask_gpt("The two summaries for comparison are: \n Summary 1: " + sk + "\n Summary 2: " + sumq[0], sys_p)
        if "totally not similar" in rate:
            rating_list.append(0)
        elif "not similar" in rate:
            rating_list.append(1)
        elif "general" in rate:
            rating_list.append(2)
        elif "very similar" in rate:
            rating_list.append(4)
        elif "similar" in rate:
            rating_list.append(3)
        else:
            print("llm returns no relevant rate")
            rating_list.append(-1)

    ind = find_index_of_largest(rating_list)
    gid = gids[ind]
    
    return gid


# Clear Graph

In [44]:
# class Neo4jConnection:
#     def __init__(self, uri, user, pwd):
#         self.driver = GraphDatabase.driver(uri, auth=(user, pwd))

#     def close(self):
#         self.driver.close()

#     def clean_graph(self):
#         with self.driver.session() as session:
#             session.execute_write(self._delete_all)

#     @staticmethod
#     def _delete_all(tx):
#         tx.run("MATCH (n) DETACH DELETE n")

# print("Cleaning the graph...")
# conn = Neo4jConnection(os.getenv("NEO4J_URI"), os.getenv("NEO4J_USERNAME"), os.getenv("NEO4J_PASSWORD"))
# conn.clean_graph()
# conn.close()

# Main

In [45]:
# print("Processing first floor...") # -------------------------------------------------------
# data_path = "./patients"
# files = [file for file in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, file))]

# k = 100
# print(f"Number of patients: {len(files)}")
# print(f"Randomly selecting {k} patients...")
# k_selected = random.sample(files, k)

# with open('selected_patients_list.txt', 'w') as file:
#     np.savetxt(file, k_selected, fmt="%s")

# clearing("Abnormalities.txt")

# for i, file_name in enumerate(k_selected):
#     print(f"Processing {i + 1}th patient...")
#     file_path = os.path.join(data_path, file_name)
#     content = load_high(file_path)
#     gid = str(shortuuid.uuid()) # Generate a random UUID
#     n4j = creat_metagraph(content, gid, n4j)

# merge_similar_nodes(n4j, None)

In [None]:
print("Processing second floor...") # -------------------------------------------------------
data_path = "./books_medqa"
files = [file for file in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, file))]
for file_name in tqdm(files):
    file_path = os.path.join(data_path, file_name)
    content = load_high(file_path)
    gid = str(shortuuid.uuid()) # Generate a random UUID
    n4j = creat_metagraph(content, gid, n4j)
merge_similar_nodes(n4j, None)

Processing second floor...


  0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
# from bioc import pubtator
# from tqdm.auto import tqdm

# print("Processing third floor...") # -------------------------------------------------------
# with open("corpus.txt", 'r') as fp:
#     docs = pubtator.load(fp)

# for doc in tqdm(docs):
#     content = "Title: " + doc.title + "\n" + "Abstract: " + doc.abstract
#     gid = str(shortuuid.uuid()) # Generate a random UUID
#     n4j = creat_metagraph(content, gid, n4j)
# merge_similar_nodes(n4j, None)

In [None]:
# def link_context(n4j, gid):
#     cont = []
#     retrieve_query = """
#         // Match all 'n' nodes with a specific gid but not of the "Summary" type
#         MATCH (n)
#         WHERE n.gid = $gid AND NOT n:Summary

#         // Find all 'm' nodes where 'm' is a reference of 'n' via a 'REFERENCES' relationship
#         MATCH (n)-[r:REFERENCE]->(m)
#         WHERE NOT m:Summary

#         // Find all 'o' nodes connected to each 'm', and include the relationship type,
#         // while excluding 'Summary' type nodes and 'REFERENCE' relationship
#         MATCH (m)-[s]-(o)
#         WHERE NOT o:Summary AND TYPE(s) <> 'REFERENCE'

#         // Collect and return details in a structured format
#         RETURN n.id AS NodeId1, 
#             m.id AS Mid, 
#             TYPE(r) AS ReferenceType, 
#             collect(DISTINCT {RelationType: type(s), Oid: o.id}) AS Connections
#     """
#     res = n4j.query(retrieve_query, {'gid': gid})
#     for r in res:
#         # Expand each set of connections into separate entries with n and m
#         for ind, connection in enumerate(r["Connections"]):
#             cont.append("Reference " + str(ind) + ": " + r["NodeId1"] + "has the reference that" + r['Mid'] + connection['RelationType'] + connection['Oid'])
#     return cont

# def ret_context(n4j, gid):
#     cont = []
#     ret_query = """
#     // Match all nodes with a specific gid but not of type "Summary" and collect them
#     MATCH (n)
#     WHERE n.gid = $gid AND NOT n:Summary
#     WITH collect(n) AS nodes

#     // Unwind the nodes to a pairs and match relationships between them
#     UNWIND nodes AS n
#     UNWIND nodes AS m
#     MATCH (n)-[r]-(m)
#     WHERE n.gid = m.gid AND id(n) < id(m) AND NOT n:Summary AND NOT m:Summary // Ensure each pair is processed once and exclude "Summary" nodes in relationships
#     WITH n, m, TYPE(r) AS relType

#     // Return node IDs and relationship types in structured format
#     RETURN n.id AS NodeId1, relType, m.id AS NodeId2
#     """
#     res = n4j.query(ret_query, {'gid': gid})
#     for r in res:
#         cont.append(r['NodeId1'] + r['relType'] + r['NodeId2'])
#     return cont

# sys_prompt_zero = """
# Please answer the question based on the provided information and references.
# There are 4 options for the answer, strictly choose one of them and say nothing else.
# Example output: "A"
# """

# def get_response(n4j, gid, query):
#     selfcont = ret_context(n4j, gid)
#     linkcont = link_context(n4j, gid)
#     user_zero = f"""
#     The question is: {query}
#     The provided information is: {selfcont}
#     The references are: {linkcont}
#     """
#     res = ask_gpt(user_zero, sys_prompt_zero)
#     return res

# question = load_high("example.txt")
# sum = process_chunks(question)
# gid = seq_ret(n4j, sum)
# response = get_response(n4j, gid, question)
# print(response)