## Tasks

1. LLM from GROQ API
2. Embed model for creating vectors from graph
3. Text2Cypher for extracting relevant nodes baased on question.

1. Synonyms during retrieving
2. Number of rela during retrieving

## Keyword RAG

### Setting env

In [1]:
from llama_index.llms.litellm import LiteLLM
from llama_index.core import Settings

import os

system_prompt="""

"""

Settings.llm = LiteLLM(
    model="groq/llama-3.3-70b-versatile",
    api_key=os.getenv("GROQ_API_KEY"),
    #system_prompt=system_prompt
)
Settings.chunk_size = 512

In [5]:
Settings.__dict__

{'_llm': LiteLLM(callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x000001E695E2BED0>, system_prompt=None, messages_to_prompt=<function messages_to_prompt at 0x000001E6FF8D67A0>, completion_to_prompt=<function default_completion_to_prompt at 0x000001E6FF90A2A0>, output_parser=None, pydantic_program_mode=<PydanticProgramMode.DEFAULT: 'default'>, query_wrapper_prompt=None, model='groq/llama-3.3-70b-versatile', temperature=0.1, max_tokens=None, additional_kwargs={'api_key': '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'}, max_retries=10),
 '_embed_model': None,
 '_callback_manager': <llama_index.core.callbacks.base.CallbackManager at 0x1e695e2bed0>,
 '_tokenizer': None,
 '_node_parser': SentenceSplitter(include_metadata=True, include_prev_next_rel=True, callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x000001E695E2BED0>, id_func=<function default_id_func at 0x000001E6FFA2AFC0>, chunk_size=512, chunk_overlap=200, separator=

### Storage context for property KG

In [2]:
from llama_index.core import StorageContext
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore


graph_store = Neo4jPropertyGraphStore(
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    url=os.getenv("NEO4J_URL")
)

storage_context = StorageContext.from_defaults(graph_store=graph_store)

### Retriever from KG

#### Helper functions

In [46]:
def extract_entities(query):
    """
    Extracts entities from a given text using the LiteLLM model.

    Args:
        text (str): The text from which to extract entities.

    Returns:
        list: A list of extracted entities.
    """
    entity_prompt = f"""Extract the medical entities from the following query. Provide the entities as a comma-separated list only.
    Query: {query}
    Entities:"""

    llm_response = Settings.llm.complete(entity_prompt)
    entity_names_str = llm_response.text
    entity_names = [entity.strip() for entity in entity_names_str.split(',') if entity.strip()]
    entity_names = [entity_name.capitalize() for entity_name in entity_names]
    
    return entity_names

def parse_kg_result_dict(result_dict):
    """
    Parses a dictionary representing a knowledge graph result (node-relationship-node)
    into a human-readable string.

    Args:
        result_dict (dict): A dictionary in the format:
            {'n': {'name': 'Node1Name', 'id': 'Node1Id'},
             'r': ({'name': 'Node1Name', 'id': 'Node1Id'},
                   'RelationshipType',
                   {'name': 'Node2Name', 'id': 'Node2Id'}),
             'm': {'name': 'Node2Name', 'id': 'Node2Id'}}

    Returns:
        str: A string in the format "Node1Name relationship_type Node2Name".
             Returns "Invalid result format" if the input dictionary is not in the expected format.
    """
    try:
        node_n_name = result_dict['n']['name']
        relationship_tuple = result_dict['r']
        relationship_type = relationship_tuple[1]  # Get relationship type from tuple
        node_m_name = result_dict['m']['name']

        return f"{node_n_name} {relationship_type} {node_m_name}"

    except (KeyError, TypeError, IndexError):
        return "Invalid result format"
    
    
def query_graph(entity_name, num, entity_type):
    
    cypher_query = f"""
    MATCH (n:{entity_type} {{name: '{entity_name}'}})-[r]-(m) 
    RETURN n, r, m
    limit {num};
    """
    
    results = graph_store.structured_query(cypher_query)
    
    return results

`custom_entity_extract_fn` must return str with id 

In [47]:
def custom_entity_extract_fn(query_str):
    """
    Custom entity extraction function to return Neo4j graph nodes.
    """
    
    entity_names = extract_entities(query_str)
    print(f"Extracted entities: {entity_names}")
    if not entity_names:
        print("No entities found in query.")
        return []

    # Fetch Neo4j Nodes for each Entity Name
    graph_nodes = []
    # Look for over concepts first and then atoms if no results are found
    num = 5
    for entity_name in entity_names:
        
        results = query_graph(entity_name, num, "Concept")
        if not results:
            results = query_graph(entity_name, num, "Atom")
        
        graph_nodes.extend(results)
        
        #results = [parse_kg_result_dict(result) for result in results]
    
    print(f"Found {len(graph_nodes)} nodes in the graph.")
    print(graph_nodes)
    
    return graph_nodes

In [48]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import KnowledgeGraphRAGRetriever

graph_rag_retriever = KnowledgeGraphRAGRetriever(
    storage_context=storage_context,
    verbose=True,
    with_nl2graphquery=True,
    entity_extract_fn=custom_entity_extract_fn
    )

query_engine = RetrieverQueryEngine.from_args(
    graph_rag_retriever,
)

  graph_rag_retriever = KnowledgeGraphRAGRetriever(


In [None]:
from IPython.display import display, Markdown

query_str = """
A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. 
A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. 
The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?
"""

response = query_engine.query(
    query_str
)
display(Markdown(f"<b>{response}</b>"))

## Embedding rag

1. https://docs.llamaindex.ai/en/stable/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo/
2. https://docs.llamaindex.ai/en/stable/examples/property_graph/property_graph_neo4j/#loading-from-an-existing-graph

## Experiments

### Model

In [None]:
import os
import litellm

api_key = os.getenv('GROQ_API_KEY')


response = litellm.completion(
    model="groq/llama-3.3-70b-versatile", 
    messages=[
       {"role": "user", "content": "hello from litellm"}
   ],
)
print(response)

### NER

In [None]:
entity_prompt = f"""Extract the medical entities from the following query. Provide the entities as a comma-separated list only.
Query: {"A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?"}
Entities:"""

llm_response = Settings.llm.complete(entity_prompt)
entity_names_str = llm_response.text
entity_names = [entity.strip() for entity in entity_names_str.split(',') if entity.strip()] # Simple split and clean
entity_names = [entity_name.capitalize() for entity_name in entity_names]


In [15]:
print(entity_names)

['fever', 'pain during urination', 'inflammation', 'knee pain', 'bacteria', 'joint fluid', 'antibiotic therapy', 'cell wall synthesis']


In [None]:
[entity_name.capitalize() for entity_name in entity_names]

['Fever',
 'Pain during urination',
 'Inflammation',
 'Knee pain',
 'Bacteria',
 'Joint fluid',
 'Antibiotic therapy',
 'Cell wall synthesis']

### Query

In [None]:
entity = "Inflammation"
num = 50

In [31]:
cypher_query = f"""
MATCH (n:Concept {{name: '{entity}'}})-[r]-(m) 
RETURN n, r, m
limit {num};
"""
results = graph_store.structured_query(cypher_query)

In [33]:
results

[]

In [20]:
[method for method in dir(graph_store) if "query" in method]

['astructured_query',
 'avector_query',
 'sanitize_query_output',
 'structured_query',
 'vector_query']