In [None]:
def test_complete_workflow(query, image_path=None):
    """
    Test the complete veterinary AI workflow with a given query.
    """
    print(f"🔍 Testing Query: {query}")
    if image_path:
        print(f"📸 Image: {image_path}")
    print("="*60)
    
    # Initialize state
    initial_state = {
        "text_query": query,
        "image_path": image_path,
        "loop_count": 0,
        "path_taken": []
    }
    
    try:
        # Run the complete workflow
        result = vet_graph.invoke(initial_state)
        
        # Display results
        print("🎯 Final Result:")
        print(result.get("final_answer", "No answer generated"))
        print("\\n" + "="*60)
        
        # Show path taken
        path_taken = result.get("path_taken", [])
        if path_taken:
            print(f"🛤️  Path taken: {' -> '.join(path_taken)}")
        
        # Show query classification
        query_type = result.get("query_type", "Unknown")
        print(f"📂 Query classified as: {query_type}")
        
        # Show hallucination check results if applicable
        if "hallucination_check" in result:
            hallucination_result = result["hallucination_check"]
            print(f"✅ Hallucination check: {'PASSED' if hallucination_result else 'FAILED'}")
        
        return result
        
    except Exception as e:
        print(f"❌ Error in workflow: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

# Test cases
print("Testing different types of queries:\\n")

# Test 1: Regular Q&A query
test_complete_workflow("My cat has been scratching its ear a lot and I see dark stuff inside. What could this be?")

print("\\n" + "="*80 + "\\n")

# Test 2: Emergency query
test_complete_workflow("My cat is bleeding heavily from a deep wound and seems unconscious!")

print("\\n" + "="*80 + "\\n")

# Test 3: Irrelevant query
test_complete_workflow("How do I fix my car's engine?")

print("\\n" + "="*80 + "\\n")

# Test 4: Query with image
test_complete_workflow("What's wrong with my cat's ear?", image_path="../cat_ear_problem.jpeg")

# Test Complete Workflow

In [None]:
from langgraph.graph import StateGraph, END

def create_veterinary_graph():
    """
    Create the complete LangGraph workflow for the veterinary AI assistant.
    """
    
    # Create the graph
    workflow = StateGraph(GraphState)
    
    # Add nodes
    workflow.add_node("query_handler", query_handler)
    workflow.add_node("query_refinement", query_refinement_node)
    workflow.add_node("query_decomposition", query_decomposition)
    workflow.add_node("contextual_retrieval", contextual_retrieval_flat)
    workflow.add_node("rerank", rerank_node_hybrid_v2)
    workflow.add_node("thinking", thinking_node)
    workflow.add_node("answer_generation", answer_generation_node)
    workflow.add_node("hallucination_check", hallucination_check_node)
    workflow.add_node("emergency_handler", emergency_handler_node)
    workflow.add_node("irrelevant_handler", irrelevant_query_handler)
    
    # Set entry point
    workflow.set_entry_point("query_handler")
    
    # Define conditional routing after query_handler
    def route_after_query_handler(state):
        query_type = state.get("query_type", "")
        if query_type == "emergency":
            return "emergency_handler"
        elif query_type == "q&a":
            return "query_refinement"
        else:  # irrelevant
            return "irrelevant_handler"
    
    # Add conditional routing
    workflow.add_conditional_edges(
        "query_handler",
        route_after_query_handler,
        {
            "emergency_handler": "emergency_handler",
            "query_refinement": "query_refinement", 
            "irrelevant_handler": "irrelevant_handler"
        }
    )
    
    # Q&A path edges
    workflow.add_edge("query_refinement", "query_decomposition")
    workflow.add_edge("query_decomposition", "contextual_retrieval")
    workflow.add_edge("contextual_retrieval", "rerank")
    workflow.add_edge("rerank", "thinking")
    workflow.add_edge("thinking", "answer_generation")
    workflow.add_edge("answer_generation", "hallucination_check")
    
    # Conditional routing after hallucination check
    def route_after_hallucination_check(state):
        hallucination_check = state.get("hallucination_check", False)
        if hallucination_check:
            return "finalize_answer"
        else:
            return "regenerate_answer"
    
    def finalize_answer(state):
        """Finalize the answer after hallucination check passes."""
        return {
            "final_answer": state.get("generated_answer", ""),
            "path_taken": state.get("path_taken", []) + ["Q&A_path_completed"]
        }
    
    def regenerate_answer(state):
        """Regenerate answer if hallucination check fails."""
        # For now, we'll use the original answer but add a warning
        original_answer = state.get("generated_answer", "")
        warning = "\\n\\n⚠️ Please note: Some information may need verification. Always consult with a veterinarian for accurate diagnosis and treatment."
        
        return {
            "final_answer": original_answer + warning,
            "path_taken": state.get("path_taken", []) + ["regenerated_with_warning"]
        }
    
    # Add the finalize and regenerate nodes
    workflow.add_node("finalize_answer", finalize_answer)
    workflow.add_node("regenerate_answer", regenerate_answer)
    
    # Add conditional edges for hallucination check
    workflow.add_conditional_edges(
        "hallucination_check",
        route_after_hallucination_check,
        {
            "finalize_answer": "finalize_answer",
            "regenerate_answer": "regenerate_answer"
        }
    )
    
    # Terminal edges
    workflow.add_edge("emergency_handler", END)
    workflow.add_edge("irrelevant_handler", END)
    workflow.add_edge("finalize_answer", END)
    workflow.add_edge("regenerate_answer", END)
    
    # Compile the graph
    app = workflow.compile()
    return app

# Create the graph
vet_graph = create_veterinary_graph()

# Complete LangGraph Workflow

In [None]:
def irrelevant_query_handler(state):
    """
    Handle queries that are not related to veterinary topics.
    """
    text_query = state.get("text_query", "")
    
    response_message = f"""
    I'm a veterinary AI assistant designed to help with animal health and pet care questions. 
    
    Your query: "{text_query}"
    
    This doesn't appear to be related to veterinary medicine, animal health, or pet care. 
    
    I can help you with:
    • Animal health symptoms and concerns
    • Pet care advice
    • Veterinary procedures and treatments
    • Emergency animal care
    • General pet wellness questions
    
    Please feel free to ask me anything related to animal health or pet care!
    """
    
    return {
        "final_answer": response_message
    }

In [None]:
def emergency_handler_node(state):
    """
    Handle emergency veterinary situations with urgent instructions.
    """
    text_query = state.get("text_query", "")
    image_path = state.get("image_path", None)
    
    # Get image summary if provided
    image_summary = ""
    if image_path and os.path.exists(image_path):
        image_summary = get_image_summary(image_path)
    
    # Quick retrieval for emergency information
    retriever = init_retriever()
    emergency_terms = ["emergency", "urgent", "bleeding", "unconscious", "breathing", "poison", "trauma"]
    
    # Search for emergency-related information
    emergency_results = []
    for term in emergency_terms:
        results = retriever.retrieve_multi_modal(f"emergency {term} first aid", k=3)
        emergency_results.extend(results)
    
    # Remove duplicates and get top results
    seen_ids = set()
    unique_emergency_results = []
    for result in emergency_results:
        doc_id = result.get("doc_id")
        if doc_id and doc_id not in seen_ids:
            seen_ids.add(doc_id)
            unique_emergency_results.append(result)
    
    # Prepare emergency context
    emergency_context = ""
    for doc in unique_emergency_results[:5]:  # Top 5 emergency docs
        modality = doc.get("modality") or (doc.get("original_metadata") or {}).get("type")
        if modality == "text":
            emergency_context += f"[EMERGENCY INFO] {doc.get('summary', '')}\n"
    
    emergency_prompt = f"""
    🚨 VETERINARY EMERGENCY RESPONSE 🚨
    
    You are responding to a potential veterinary emergency. This requires immediate, clear, and actionable guidance.
    
    User's Emergency: {text_query}
    {f"Visual Information: {image_summary}" if image_summary else ""}
    
    Emergency Information from Database:
    {emergency_context}
    
    CRITICAL INSTRUCTIONS:
    1. START with "⚠️ EMERGENCY: Contact your veterinarian or emergency animal hospital IMMEDIATELY"
    2. Provide immediate first aid steps if applicable
    3. List warning signs that require URGENT attention
    4. Give clear, step-by-step instructions
    5. End with emergency contact reminders
    
    Provide immediate, life-saving guidance while emphasizing professional veterinary care.
    """
    
    messages = [{
        "role": "user",
        "content": emergency_prompt
    }]
    
    response = ollama.chat(
        model="llama3.2:3b",
        messages=messages,
        options={"temperature": 0.2}  # Low temperature for consistency in emergencies
    )
    
    emergency_instructions = response['message']['content']
    
    return {
        "emergency_instructions": emergency_instructions,
        "emergency_retrieved_docs": unique_emergency_results[:5]
    }

In [None]:
def hallucination_check_node(state):
    """
    Check if the generated answer contains hallucinations or unsupported claims.
    """
    generated_answer = state.get("generated_answer", "")
    context_for_answer = state.get("context_for_answer", "")
    
    hallucination_prompt = f"""
    You are a fact-checker for veterinary information. Your task is to verify if the generated answer 
    is supported by the provided context and doesn't contain hallucinations or unsupported claims.
    
    Generated Answer:
    {generated_answer}
    
    Supporting Context:
    {context_for_answer}
    
    Please evaluate:
    1. Are all claims in the answer supported by the context?
    2. Are there any factual inaccuracies?
    3. Are there any overly specific medical recommendations that go beyond the context?
    4. Does the answer maintain appropriate disclaimers about veterinary consultation?
    
    Respond with:
    - "PASS" if the answer is well-supported and appropriate
    - "FAIL" if there are significant hallucinations or unsupported claims
    - Include a brief explanation of your assessment
    """
    
    messages = [{
        "role": "user",
        "content": hallucination_prompt
    }]
    
    response = ollama.chat(
        model="llama3.2:3b",
        messages=messages,
        options={"temperature": 0.1}  # Low temperature for consistency
    )
    
    check_result = response['message']['content']
    hallucination_check = "PASS" in check_result.upper()
    
    return {
        "hallucination_check": hallucination_check,
        "hallucination_details": check_result
    }

In [None]:
def answer_generation_node(state):
    """
    Generate a comprehensive answer based on the thinking analysis and context.
    """
    text_query = state.get("text_query", "")
    image_path = state.get("image_path", None)
    thinking_analysis = state.get("thinking_analysis", "")
    context_for_answer = state.get("context_for_answer", "")
    reranked_docs = state.get("reranked_docs", [])
    
    # Get image summary if image provided
    image_summary = ""
    if image_path and os.path.exists(image_path):
        image_summary = get_image_summary(image_path)
    
    # Prepare image references from top docs
    image_references = []
    for doc in reranked_docs[:5]:  # Top 5 docs
        modality = doc.get("modality") or (doc.get("original_metadata") or {}).get("type")
        if modality in ["image", "image_summary"]:
            img_path = (doc.get("original_metadata") or {}).get("image_path")
            if img_path:
                image_references.append(f"Reference image: {img_path}")
    
    answer_prompt = f"""
    You are a knowledgeable veterinary assistant providing helpful information to pet owners.
    
    IMPORTANT GUIDELINES:
    1. Always emphasize consulting with a veterinarian for proper diagnosis and treatment
    2. Provide factual, evidence-based information
    3. Include relevant warnings about emergency situations
    4. Be empathetic and supportive in tone
    5. Reference specific information from the retrieved documents
    6. If images are mentioned, describe what they show
    
    User Query: {text_query}
    
    {f"User's Image Description: {image_summary}" if image_summary else ""}
    
    Analysis from Thinking Process:
    {thinking_analysis}
    
    Retrieved Information:
    {context_for_answer}
    
    {f"Relevant Reference Images Available: {chr(10).join(image_references)}" if image_references else ""}
    
    Please provide a comprehensive, helpful response that addresses the user's concern while following the guidelines above.
    """
    
    messages = [{
        "role": "user",
        "content": answer_prompt
    }]
    
    response = ollama.chat(
        model="llama3.2:3b",
        messages=messages,
        options={"temperature": 0.4}
    )
    
    generated_answer = response['message']['content']
    
    return {
        "generated_answer": generated_answer
    }

In [None]:
def thinking_node(state):
    """
    Thinking node analyzes the user's intent and the reranked documents
    to determine what information is needed for a comprehensive answer.
    """
    text_query = state.get("text_query", "")
    refined_query = state.get("refined_query", "")
    reranked_docs = state.get("reranked_docs", [])
    
    # Extract top docs for analysis
    top_docs = reranked_docs[:10]  # Use top 10 docs
    
    # Prepare context from top documents
    context_pieces = []
    for doc in top_docs:
        modality = doc.get("modality") or (doc.get("original_metadata") or {}).get("type")
        if modality == "text":
            context_pieces.append(f"[TEXT] {doc.get('summary', '')}")
        elif modality in ["image", "image_summary"]:
            context_pieces.append(f"[IMAGE] {doc.get('summary', '')}")
    
    context = "\n".join(context_pieces)
    
    thinking_prompt = f"""
    You are a veterinary AI assistant analyzing a user query and retrieved information.
    Your task is to think through what the user needs and how to structure a comprehensive answer.
    
    Original Query: {text_query}
    Refined Query: {refined_query}
    
    Retrieved Information:
    {context}
    
    Please analyze:
    1. What is the user's main concern or question?
    2. What key information from the retrieved docs addresses their concern?
    3. What additional context or warnings should be included?
    4. Are there any gaps in the information that need to be acknowledged?
    
    Provide a structured analysis that will guide the answer generation.
    """
    
    messages = [{
        "role": "user",
        "content": thinking_prompt
    }]
    
    response = ollama.chat(
        model="llama3.2:3b",
        messages=messages,
        options={"temperature": 0.3}
    )
    
    analysis = response['message']['content']
    
    return {
        "thinking_analysis": analysis,
        "context_for_answer": context
    }

# Setup LangSmith API
Retrievals can be traced here for easier debugging.

In [5]:
import os
import ollama
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ['LANGSMITH_API_KEY'] = '123-123-213-123-123-123'

![LangGraph Flow](../../langgraph%20designs/graph_design_v1.png)

# Graph States

In [17]:
from typing_extensions import TypedDict
from typing import Optional, List, Dict, Any

class GraphState(TypedDict):
    # Core user input
    text_query: str
    image_path: Optional[str]  # Path to uploaded image, if any

    # Routing/intent
    query_type : str # currently being divide into 'emergency'/'Q&A'/'irrelevant'.
    
    # Q&A path
    refined_query: Optional[str]
    sub_queries: Optional[List[str]]
    current_sub_query: Optional[str]
    retrieved_docs: Optional[List[Dict[str, Any]]]  # Results from retrieval
    reranked_docs: Optional[List[Dict[str, Any]]]   # After rerank step

    # Feedback loop
    followup_questions: Optional[List[str]]
    user_responses: Optional[List[str]]
    loop_count: int

    # Answer generation
    generated_answer: Optional[str]
    hallucination_check: Optional[bool]
    answer_sufficient: Optional[bool]

    # Emergency path
    emergency_instructions: Optional[str]
    emergency_retrieved_docs: Optional[List[Dict[str, Any]]]

    # Web search
    web_search_results: Optional[List[Dict[str, Any]]]

    # Final output
    final_answer: Optional[str]

    # Misc/trace/debug
    path_taken: Optional[List[str]]
    error: Optional[str]

<h1> Graph Nodes

## Query Handler Node 

Before LLM analyze user query and image, it will be assessed with "Is this veterinary-related?". This will ensure our AI tool will not be used for other purpose.

In [33]:
def query_handler(state):
    text_query = state.get("text_query", "")
    image_path = state.get("image_path", None)

    prompt = (
        "You are a domain classifier for a veterinary assistant. "
        "If an image is provided, understand the image from veterinary point of view."
        "A user query is the combination of text query and image(if there is). "
        "Then, classify the user query into one of three categories:\n"
        "1. 'emergency' — If the user query is about a veterinary emergency (e.g., mass bleeding, serious bone fracture, unconsciousness, severe breathing difficulty, or other life-threatening situations).\n"
        "2. 'Q&A' — If the user query is about is about general veterinary questions, symptom checks, or non-emergency animal health issues.\n\n"
        "3. 'irrelevant' — If the user query is NOT about veterinary, animal health, pet care, etc.\n"
        "Your response must be exactly one of: 'irrelevant', 'emergency', or 'Q&A'. Do not explain your answer or add anything else.\n\n"
        f"User input: {text_query}\n"
    )

    messages = [{
        "role": "user",
        "content": prompt,
        "images": []
    }]

    if image_path and os.path.exists(image_path):
        messages[0]["images"].append(image_path)

    response = ollama.chat(
        model="minicpm-v:8b",
        messages=messages,
        options={"temperature": 0.2}
    )
    result = response['message']['content'].strip().lower()
    # Only allow the three valid outputs
    if result not in ['irrelevant', 'emergency', 'q&a']:
        result = 'irrelevant'
    
    return {"query_type": result}

### test

In [34]:
def test_query_handler_node(query_handler, test_query, image_path=None):
    # Build the initial state
    state = {
        "text_query": test_query,
        "image_path": image_path
    }
    # Call the query handler node
    new_state = query_handler(state)
    # Print the results
    print("Input Query:", test_query)
    if image_path:
        print("Image Path:", image_path)
    print("Updated State:", new_state)
    print("Query Type:", new_state.get("query_type", "N/A"))
    print("-" * 40)

# --- Example usage ---
test_query_handler_node(query_handler, "What vaccines does my cat need?")
test_query_handler_node(query_handler, "My cat is bleeding a lot after being hit by a car.")
test_query_handler_node(query_handler, "How do I fix my car engine?")
test_query_handler_node(query_handler, "What should I do?", image_path="../emergency_cat.jpg")
test_query_handler_node(query_handler, "What should I feed to this cat?", image_path="../skinny_cat.jpg")


Input Query: What vaccines does my cat need?
Updated State: {'text_query': 'What vaccines does my cat need?', 'image_path': None, 'query_type': 'q&a'}
Query Type: q&a
----------------------------------------
Input Query: My cat is bleeding a lot after being hit by a car.
Updated State: {'text_query': 'My cat is bleeding a lot after being hit by a car.', 'image_path': None, 'query_type': 'emergency'}
Query Type: emergency
----------------------------------------
Input Query: How do I fix my car engine?
Updated State: {'text_query': 'How do I fix my car engine?', 'image_path': None, 'query_type': 'irrelevant'}
Query Type: irrelevant
----------------------------------------
Input Query: What should I do?
Image Path: ../emergency_cat.jpg
Updated State: {'text_query': 'What should I do?', 'image_path': '../emergency_cat.jpg', 'query_type': 'emergency'}
Query Type: emergency
----------------------------------------
Input Query: What should I feed to this cat?
Image Path: ../skinny_cat.jpg
Up

# Q&A Path

## Query Refinement

In [14]:
def get_image_summary(image_path):
    prompt = """From a feline veterinary stand point, provide a highly detailed and objective 
                description of the image. Focus on all observable elements, actions, 
                objects, subjects, their attributes (e.g., color, size, texture), 
                their spatial relationships, and any discernible context or implied scene. 
                Also focus on all possible health issue.
                Describe any text present in the image. This description must be exhaustive 
                and purely factual, capturing every significant visual detail to serve as a 
                comprehensive textual representation for further analysis by another AI model. 
                If the image is entirely irrelevant or contains no discernible subject, 
                state "No relevant visual information."""
    messages = [{
        "role": "user",
        "content": prompt,
        "images": [image_path]
    }]
    response = ollama.chat(
        model="minicpm-v:8b",
        messages=messages,
        options={"temperature": 0.2}
    )
    return response['message']['content']

def query_refinement_node(state):
    text_query = state.get("text_query", "")
    image_path = state.get("image_path", None)
    image_summary = get_image_summary(image_path) if image_path else ""

    if image_summary:
        prompt = (
            "You are a veterinary assistant AI. Your task is to rewrite and expand the user's query for a veterinary knowledge base search. "
            "You are NOT being asked to provide medical advice, diagnosis, or treatment recommendations. "
            "Your job is to help formulate a search query that could retrieve relevant veterinary information for a veterinarian or pet owner. "
            "Use the image description to add context, but avoid making assumptions about the specific diagnosis or underlying causes unless explicitly stated. "
            "Frame the refined query in an open-ended, unbiased way, considering a broad range of possible causes, diagnostic steps, and management options. "
            "If the user describes symptoms, include them factually. "
            "Do not presume the animal's overall health status or limit the query to only the most common conditions. "
            "Output ONLY one single, context-rich, and unbiased query as a paragraph, and nothing else.\n\n"
            f"User query: {text_query}\n"
            f"Image description: {image_summary}\n"
            "Refined query:"
        )
    else:
        prompt = (
            f"You are a veterinary assistant AI. Your task is to rewrite and expand the user's query for a veterinary knowledge base search. "
            f"Consider add questions about possible causes, diagnostic considerations, anything that would be helpful in the situation, but combine everything into a single, comprehensive question or query. "
            f"Output ONLY one single, context-rich query as a paragraph, and nothing else.\n\n"
            f"User query: {text_query}\n"
            f"Refined query:"
        )

    messages = [{
        "role": "user",
        "content": prompt
    }]
    response = ollama.chat(
        model="llama3.2:3b",  # or another strong text model
        messages=messages,
        options={"temperature": 0.3}
    )
    return {"refined_query": response['message']['content']}

### test

In [15]:
def test_query_refinement_node(query_refinement_node, test_query, image_path=None):
    # Build the initial state
    state = {
        "text_query": test_query,
        "image_path": image_path
    }
    # Call the query refinement node
    new_state = query_refinement_node(state)
    # Print the results
    print("Input Query:", test_query)
    if image_path:
        print("Image Path:", image_path)
    print("Refined Query:", new_state.get("refined_query", "N/A"))
    print("-" * 40)

# --- Example usage ---
test_query_refinement_node(query_refinement_node, "What happened to my cat ear? It's being it for a long time. Sometimes I even see blood and wounds in its ear. ", image_path="../cat_ear_problem.jpeg")
test_query_refinement_node(query_refinement_node, "My cat has being scratching its ear too often. There are some dark greasy thing in it. It sratch its ear so often and so hard that I see wounds and blood in it. What should I do?")

Input Query: What happened to my cat ear? It's being it for a long time. Sometimes I even see blood and wounds in its ear. 
Image Path: ../cat_ear_problem.jpeg
Refined Query: What are possible causes and diagnostic steps for chronic ear infections or otitis in cats, characterized by visible brownish-orange debris, discharge, and wounds in the ear canal, potentially accompanied by signs of infection such as redness, swelling, and bleeding, and how can these conditions be distinguished from other potential health issues that may affect a cat's auditory system?
----------------------------------------
Input Query: My cat has being scratching its ear too often. There are some dark greasy thing in it. It sratch its ear so often and so hard that I see wounds and blood in it. What should I do?
Refined Query: My cat is exhibiting excessive ear scratching, resulting in visible wounds and bleeding due to the presence of dark, greasy debris, which may indicate a skin infection or allergies; what 

## Query Decomposition

In [17]:
import ollama
import json
from langchain_ollama import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import JsonOutputParser


def query_decomposition(state):
    refined_query = state['refined_query']

    query_decomposition_prompt = ChatPromptTemplate.from_template(
    """You are an intelligent assistant. Your task is to break down the given complex query
    into a list of simpler, focused sub-queries. Each sub-query should be a standalone question
    that can be used to retrieve specific information from a veterinary knowledge base.

    At the end of your list, add 2-3 additional sub-queries that specifically require visual information or images.
    For example, you might add:
    - "Show me an image of how to pick up a cat."
    - "Show me an image of how to do CPR for a cat."
    - "Show me a diagram of feline anatomy."
    Be creative and make sure these visual sub-queries are relevant to the original complex query.

    Output ONLY a valid JSON array of strings, and nothing else. Do not include any explanations, markdown, or extra text.

    Complex query: {refined_query}
    """
    )

    # Create the query decomposition chain
    query_decomposition_chain = (
        query_decomposition_prompt  
        | ChatOllama(model="llama3.2:3b")  
        | JsonOutputParser() 
    )

    # --- Demonstration of query decomposition ---

    print(f"Original refined query: {refined_query[:300]} ....")

    decomposed_queries = query_decomposition_chain.invoke({"refined_query": refined_query})
    # Try to extract the JSON array from the response

    print("-" * 80)
    # print(f"Decomposed queries:\n{decomposed_queries}")

    print(f"There are {len(decomposed_queries)} queries after decomposition \n")
    print(f"Here's a example of the first one: {decomposed_queries[0]}")

    return {"sub_queries": decomposed_queries}

### Test

In [18]:
def test_query_decomposition(query_decomposition_func, refined_query):
    # Build the initial state
    state = {
        "refined_query": refined_query
    }
    # Call the query decomposition function
    new_state = query_decomposition_func(state)
    # Print the results
    print("Decomposed Sub-Queries:")
    print(new_state['sub_queries'])

# --- Example usage ---
test_query_decomposition(
    query_decomposition,
    "What are possible causes and diagnostic steps for chronic ear infections or otitis in cats, characterized by visible brownish-orange debris, discharge, and wounds in the ear canal, potentially accompanied by signs of infection such as redness, swelling, and bleeding, and how can these conditions be distinguished from other potential health issues that may affect a cat's auditory system?"
)

Original refined query: What are possible causes and diagnostic steps for chronic ear infections or otitis in cats, characterized by visible brownish-orange debris, discharge, and wounds in the ear canal, potentially accompanied by signs of infection such as redness, swelling, and bleeding, and how can these conditions be  ....


Failed to multipart ingest runs: langsmith.utils.LangSmithError: Failed to POST https://api.smith.langchain.com/runs/multipart in LangSmith API. HTTPError('403 Client Error: Forbidden for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Forbidden"}\n')


--------------------------------------------------------------------------------
There are 8 queries after decomposition 

Here's a example of the first one: What are common causes of chronic ear infections or otitis in cats?
Decomposed Sub-Queries:
['What are common causes of chronic ear infections or otitis in cats?', 'Describe diagnostic steps for detecting chronic ear infections or otitis in cats', 'What is the significance of visible brownish-orange debris, discharge, and wounds in the ear canal in cats?', 'How can redness, swelling, and bleeding be distinguished from other feline health issues affecting the auditory system?', 'What are potential complications if left untreated?', 'Show me an image of a cat with signs of otitis externa', "Show me an image of how to properly clean a cat's ear canal", 'Describe the anatomy of the feline ear canal']


Failed to send compressed multipart ingest: langsmith.utils.LangSmithError: Failed to POST https://api.smith.langchain.com/runs/multipart in LangSmith API. HTTPError('403 Client Error: Forbidden for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Forbidden"}\n')


## Contextual Retrievals

Based on decomposed sub queries, we are able to retrieve contexutally close aligned Documents from the vector database. 

### Setup Unified Retriever (Retrieve text, table, images)

In [28]:
from langchain_experimental.open_clip import OpenCLIPEmbeddings
from langchain_chroma import Chroma

def init_retriever():

    # Instantiate the retriever
    class UnifiedRetriever:
        """
        UnifiedRetriever supports multi-modal retrieval from the vectorstore and docstore.
        It can retrieve by text, image, or both (multi-query), and supports metadata filtering by modality.
        """
        def __init__(self, vectorstore, docstore, id_key="doc_id"):
            self.vectorstore = vectorstore
            self.docstore = docstore
            self.id_key = id_key
            self._collection = docstore._collection

        def retrieve(self, query, k=5, filter=None):
            """
            Retrieve top-k results for a query, optionally filtered by metadata (e.g., modality).
            """
            results = self.vectorstore.similarity_search_with_score(query, k=k, filter=filter)
            output = []
            for doc, score in results:
                doc_id = doc.metadata.get(self.id_key)
                try:
                    original = self._collection.get(ids=[doc_id], include=["documents", "metadatas"])
                    original_doc = original["documents"][0] if original["documents"] else None
                    original_meta = original["metadatas"][0] if original["metadatas"] else None
                except Exception as e:
                    original_doc = None
                    original_meta = None
                output.append({
                    "summary": doc.page_content,
                    "original": original_doc,
                    "original_metadata": original_meta,
                    "summary_metadata": doc.metadata,
                    "score": score
                })
            return output

        def retrieve_multi_modal(self, query, k=5, text_types=("text",), image_types=("image", "image_summary")):
            """
            Multi-Query/Multi-Modal Retrieval:
            - Retrieves top-k text and top-k image/image_summary results for the query.
            - Merges and sorts by score.
            - Returns a list of results with modality info.
            """
            # Retrieve text results
            text_results = self.vectorstore.similarity_search_with_score(query, k=k, filter={"type": {"$in": list(text_types)}})
            # Retrieve image/image_summary results
            image_results = self.vectorstore.similarity_search_with_score(query, k=k, filter={"type": {"$in": list(image_types)}})
            # Merge and sort by score (lower is better if using distance, higher is better if using similarity)
            all_results = []
            for doc, score in text_results:
                doc_id = doc.metadata.get(self.id_key)
                all_results.append({
                    "modality": doc.metadata.get("type"),
                    "summary": doc.page_content,
                    "original_metadata": doc.metadata,
                    "score": score,
                    "doc_id": doc_id
                })
            for doc, score in image_results:
                doc_id = doc.metadata.get(self.id_key)
                all_results.append({
                    "modality": doc.metadata.get("type"),
                    "summary": doc.page_content,
                    "original_metadata": doc.metadata,
                    "score": score,
                    "doc_id": doc_id
                })
            # Sort by score (descending if similarity, ascending if distance)
            all_results.sort(key=lambda x: x["score"], reverse=True)
            return all_results

    persist_directory = '../../chroma/Ears'
    id_key = "doc_id"

    open_clip_embeddings = OpenCLIPEmbeddings(model_name="ViT-g-14", checkpoint="laion2b_s34b_b88k")

    # Vectorstore for summaries (for similarity search)
    vectorstore = Chroma(
        collection_name="summaries_and_images",
        persist_directory=persist_directory,
        embedding_function=open_clip_embeddings
    )
    # Persistent docstore for originals (all modalities)
    docstore = Chroma(
        collection_name="originals",
        persist_directory=persist_directory,
        embedding_function=open_clip_embeddings
    )

    retriever = UnifiedRetriever(vectorstore, docstore, id_key=id_key)
    return retriever

### Retrieval

In [33]:
# Assume decomposed_queries is a list of query strings
# and retriever is already instantiated

seen_doc_ids = set()
all_results = []
retriever = init_retriever()

def contextual_retrieval_flat(state):
    seen_doc_ids = set()
    unique_docs = []
    retriever = init_retriever() 

    for query in state['sub_queries']:
        results = retriever.retrieve_multi_modal(query, k=5, )
        for res in results:
            doc_id = res.get('doc_id') or res.get('summary_metadata', {}).get('doc_id')
            if doc_id and doc_id not in seen_doc_ids:
                seen_doc_ids.add(doc_id)
                unique_docs.append(res)
    print(f"Total unique documents retrieved: {len(unique_docs)}")
    return {"retrieved_docs": unique_docs}

### test

In [38]:
import copy


def test_contextual_retrieval(sub_queries):
    global test_retrived_doc

    test_state = {
        "sub_queries": sub_queries
    }

    # Use the flat contextual retrieval function
    new_state = contextual_retrieval_flat(test_state)
    unique_docs = new_state["retrieved_docs"]
    print("\nSample of unique retrieved docs:")
    for i, doc in enumerate(unique_docs):
        doc_id = doc.get('doc_id') or doc.get('summary_metadata', {}).get('doc_id')
        print(f"Doc {i}:")
        print(f"  Doc ID: {doc_id}")
        # Check if this is an image context doc
        if doc_id and doc_id.endswith('_context'):
            image_path = doc.get('summary_metadata', {}).get('image_path')
            print(f"  [IMAGE CONTEXT] Points to image file: {image_path}")
        print(f"  Type: {(doc.get('original_metadata') or {}).get('type')}")
        print(f"  Score: {doc.get('score')}")
        print(f"  Summary: {doc.get('summary')[:100]}...")
        print(f"  Original: {str(doc.get('original'))[:100]}...")
        print("-" * 40)
    print(f"Total unique docs retrieved: {len(unique_docs)}")
    
    test_retrived_doc = copy.deepcopy(unique_docs)

# Example usage:
test_contextual_retrieval(
   ['What are common causes of chronic ear infections or otitis in cats?', 'Describe diagnostic steps for detecting chronic ear infections or otitis in cats', 'What is the significance of visible brownish-orange debris, discharge, and wounds in the ear canal in cats?', 'How can redness, swelling, and bleeding be distinguished from other feline health issues affecting the auditory system?', 'What are potential complications if left untreated?', 'Show me an image of a cat with signs of otitis externa', "Show me an image of how to properly clean a cat's ear canal", 'Describe the anatomy of the feline ear canal']
)

Total unique documents retrieved: 26

Sample of unique retrieved docs:
Doc 0:
  Doc ID: 19b09196-5c83-493e-b0c2-f2a932daec2f
  Type: image
  Score: 1.1225041151046753
  Summary: ./figures/Ears/figure-2-2.jpg...
  Original: None...
----------------------------------------
Doc 1:
  Doc ID: 5c1b30a0-1ee7-4cf6-9739-449565ffaebe_context
  [IMAGE CONTEXT] Points to image file: None
  Type: image_summary
  Score: 1.1118836402893066
  Summary: The provided local text indicates that this image is part of an educational or informative series fo...
  Original: None...
----------------------------------------
Doc 2:
  Doc ID: cc9ca569-c024-4987-8679-3a64d478b74a_context
  [IMAGE CONTEXT] Points to image file: None
  Type: image_summary
  Score: 1.0678699016571045
  Summary: The image shows the ear of a cat displaying signs of hematoma, which is characterized by swelling an...
  Original: None...
----------------------------------------
Doc 3:
  Doc ID: c97612b4-4a14-4333-b828-59ef2e6d20e8_context


## ReRank

Retrievals returns docs with high similarities based on cosine-similarity. However, we do need to re-rank their improtance on contexual level.

### Getting image, image_summary pair

In [None]:
# truly multimodel [monoqwen], use here If running on Nvidia GPU Machine
# pip install "rerankers[monovlm]" qwen-vl-utils transformers
from rerankers import MonoQwen2VLReranker

def rerank_node_monoqwen(state):
    query = state['refined_query']
    candidates = state['retrieved_docs']

    # Prepare candidates for reranker
    rerank_inputs = []
    for doc in candidates:
        if doc.get("modality") == "text":
            rerank_inputs.append(doc["summary"])
        elif doc.get("modality") in ("image", "image_summary"):
            # Use image path if available, else fallback to summary
            image_path = doc.get("original_metadata", {}).get("image_path")
            if image_path:
                rerank_inputs.append(image_path)
            else:
                rerank_inputs.append(doc["summary"])
        else:
            rerank_inputs.append(doc["summary"])

    # Rerank
    from rerankers import MonoQwen2VLReranker
    reranker = MonoQwen2VLReranker.from_pretrained("Qwen/MonoQwen2-VL-v0.1")
    results = reranker.rerank(query, rerank_inputs, top_k=len(rerank_inputs))

    # Attach scores and sort
    for (idx, score) in results:
        candidates[idx]['rerank_score'] = float(score)
    reranked = sorted(candidates, key=lambda x: x.get('rerank_score', 0), reverse=True)
    return {"reranked_docs": reranked}

In [68]:
#Jina Reranker m0. GPU/CPU, but extremly slow in CPU
import base64
import os
from transformers import AutoModel


def image_to_base64(image_path):
    with open(image_path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode("utf-8")

def rerank_node_jina_vlm(state):
    query = state['refined_query']
    candidates = state['retrieved_docs']

    documents = []
    doc_types = []
    for doc in candidates:
        if doc.get("modality") in ("image", "image_summary"):
            image_path = doc.get("original_metadata", {}).get("image_path")
            if image_path and os.path.exists(image_path):
                documents.append(image_to_base64(image_path))
                doc_types.append("image")
            else:
                documents.append(doc["summary"])
                doc_types.append("text")
        else:
            documents.append(doc["summary"])
            doc_types.append("text")

    pairs = [[query, doc] for doc in documents]

    model = AutoModel.from_pretrained(
        'jinaai/jina-reranker-m0',
        torch_dtype="auto",
        trust_remote_code=True,
    )
    model.to('cpu')
    model.eval()

    # If most docs are images, use doc_type="image", else "text"
    n_images = doc_types.count("image")
    n_texts = doc_types.count("text")
    doc_type = "image" if n_images > n_texts else "text"

    # If mixed, filter and rerank separately, then merge (advanced)
    # For now, just use the dominant type
    scores = model.compute_score(pairs, max_length=2048, doc_type=doc_type)

    for doc, score in zip(candidates, scores):
        doc['rerank_score'] = float(score)
    reranked = sorted(candidates, key=lambda x: x.get('rerank_score', 0), reverse=True)
    return {"reranked_docs": reranked}

In [73]:
# Hybrid Method: CrossEncoder for text, VLM for image.
from sentence_transformers import CrossEncoder
import os
import ollama
def llm_image_relevance_score(query, image_path, image_summary=None):
    """
    Use Ollama (minicpm-v:8b) to rate the relevance of an image to the query.
    Passes the image file and, if available, the image summary.
    Returns a float score between 0 and 1.
    """
    prompt = f"""
    You are a veterinary assistant AI. Given the following user query and an image, rate how relevant the image is to answering the query.
    - User Query: "{query}"
    """
    if image_summary:
        prompt += f'- Image Summary: "{image_summary}"\n'
    prompt += "Respond with a single float between 0 (not relevant at all) and 1 (highly relevant). Only output the number, nothing else."

    try:
        messages = [{"role": "user", "content": prompt}]
        if image_path and os.path.exists(image_path):
            messages[0]["images"] = [image_path]
        response = ollama.chat(
            model="minicpm-v:8b",
            messages=messages,
            options={"temperature": 0.0}
        )
        content = response['message']['content'].strip()
        score = float(content.split()[0])
        score = max(0.0, min(1.0, score))
        return score
    except Exception as e:
        print(f"[llm_image_relevance_score] Error: {e}. Query: {query[:50]}... Image: {image_path}... Summary: {str(image_summary)[:50]}...")
        return 0.0

def rerank_node_hybrid_v2(state):
    query = state['refined_query']
    candidates = state['retrieved_docs']

    text_indices = []
    text_contents = []
    image_indices = []
    image_info = []

    for idx, doc in enumerate(candidates):
        modality = doc.get("modality") or (doc.get("original_metadata") or {}).get("type")
        if modality == "text":
            text_indices.append(idx)
            doc_id = (doc.get("original_metadata") or {}).get("doc_id") or doc.get("doc_id")
            # Fetch the original document from the docstore
            original_text = None
            if doc_id and docstore:
                try:
                    original = docstore._collection.get(ids=[doc_id], include=["documents"])
                    original_text = original["documents"][0] if original["documents"] else None
                except Exception as e:
                    print(f"[rerank_node_hybrid_v2] Error fetching original text for doc_id {doc_id}: {e}")
            if not original_text:
                original_text = doc.get("summary", "")
            text_contents.append(original_text)
        elif modality == "image":
            # Use the image file for VLM
            image_path = (doc.get("original_metadata") or {}).get("image_path")
            image_summary = (doc.get("original_metadata") or {}).get("summary", "")
            image_indices.append(idx)
            image_info.append((image_path, image_summary if image_summary else None))
        elif modality == "image_summary":
            # Trace to the image file if possible
            image_path = (doc.get("original_metadata") or {}).get("image_path")
            image_summary = doc.get("summary", "")
            image_indices.append(idx)
            image_info.append((image_path, image_summary))
        else:
            # Fallback: treat as text
            text_indices.append(idx)
            text_contents.append(doc.get("summary", ""))

    # 1. Rerank text docs
    if text_contents:
        model = CrossEncoder("BAAI/bge-reranker-base")
        pairs = [(query, text) for text in text_contents]
        scores = model.predict(pairs)
        for idx, score in zip(text_indices, scores):
            candidates[idx]['rerank_score'] = float(score)

    # 2. Rerank images (and image summaries) with VLM
    for idx, (image_path, image_summary) in zip(image_indices, image_info):
        score = llm_image_relevance_score(query, image_path, image_summary)
        candidates[idx]['rerank_score'] = float(score)

    # 3. Sort all by rerank_score
    reranked = sorted(candidates, key=lambda x: x.get('rerank_score', 0), reverse=True)
    return {"reranked_docs": reranked}

## test

In [None]:
def test_rerank_node(rerank_node, refined_query, retrieved_docs, top_n=5):
    # Build the state as expected by rerank_node
    state = {
      "refined_query": refined_query,
      "retrieved_docs": retrieved_docs,
      "docstore": retriever.docstore  # or whatever your docstore object is
     }
    # Call the rerank node
    new_state = rerank_node(state)
    global reranked_docs
    reranked_docs = new_state.get("reranked_docs", [])
    print(f"Total docs after reranking: {len(reranked_docs)}")
    print(f"Top {top_n} reranked docs (by rerank_score):")
    for i, doc in enumerate(reranked_docs[:top_n]):
        doc_id = doc.get('doc_id') or (doc.get('original_metadata') or {}).get('doc_id')
        modality = (doc.get('original_metadata') or {}).get('type') or doc.get('modality')
        print(f"Doc {i}:")
        print(f"  Doc ID: {doc_id}")
        print(f"  Type: {modality}")
        print(f"  Rerank Score: {doc.get('rerank_score')}")
        print(f"  Summary: {doc.get('summary')[:100]}...")
        print("-" * 40)
    # Optionally, check that scores are sorted descending
    scores = [doc.get('rerank_score') for doc in reranked_docs if doc.get('rerank_score') is not None]
    if scores and scores == sorted(scores, reverse=True):
        print("PASS: Docs are sorted by rerank_score descending.")
    else:
        print("FAIL: Docs are not sorted correctly or scores are missing.")

# Example usage:
refined_query = "What are possible causes and diagnostic steps for chronic ear infections or otitis in cats, characterized by visible brownish-orange debris, discharge, and wounds in the ear canal, potentially accompanied by signs of infection such as redness, swelling, and bleeding, and how can these conditions be distinguished from other potential health issues that may affect a cat's auditory system"
test_rerank_node(rerank_node_hybrid_v2, refined_query, test_retrived_doc)
reranked_docs

In [79]:
def display_top_10_imgs_and_texts(reranked_docs, docstore):
    print("Top 10 Images and Original Texts:\n")
    count = 0
    for doc in reranked_docs:
        if count >= 10:
            break
        modality = (doc.get('original_metadata') or {}).get('type') or doc.get('modality')
        doc_id = doc.get('doc_id') or (doc.get('original_metadata') or {}).get('doc_id')
        score = doc.get('rerank_score')
        print(f"Doc {count}:")
        print(f"  Doc ID: {doc_id}")
        print(f"  Type: {modality}")
        print(f"  Rerank Score: {score}")
        if modality == "text":
            # Fetch original text from docstore
            original_text = None
            if doc_id and docstore:
                try:
                    original = docstore._collection.get(ids=[doc_id], include=["documents"])
                    original_text = original["documents"][0] if original["documents"] else None
                except Exception as e:
                    print(f"    [Error fetching original text for doc_id {doc_id}: {e}]")
            if not original_text:
                original_text = doc.get("summary", "")
            print("  Original Text:")
            print(f"    {original_text[:500]}{'...' if len(original_text) > 500 else ''}")
        elif modality in ("image", "image_summary"):
            image_path = (doc.get("original_metadata") or {}).get("image_path")
            print(f"  Image Path: {image_path}")
            print("  Image Summary:")
            print(f"    {doc.get('summary', '')[:500]}{'...' if len(doc.get('summary', '')) > 500 else ''}")
        else:
            print("  [Unknown modality]")
        print("-" * 60)
        count += 1

# Example usage:
display_top_10_imgs_and_texts(reranked_docs, retriever.docstore)

Top 10 Images and Original Texts:

Doc 0:
  Doc ID: c155b851-c3d8-45b3-9509-6f29dba41452
  Type: text
  Rerank Score: 0.9639337062835693
  Original Text:
    Structure of the Ears 206 • CAT OWNER’S HOME VETERINARY HANDBOOK Your cat has an ear problem if you notice ear scratching, repeated head shaking, a bad odor emanating from the ear, or large amounts of waxy dis- charge or pus draining. In a younger cat, the most likely cause is ear mites, but other diseases of the ears (such as allergies) do occur. Diseases of the middle ear cause head tilt and the loss of hearing. Diseases of the inner ear affect the balance center. The cat wobbles, circles, fal...
------------------------------------------------------------
Doc 1:
  Doc ID: cd66810a-2a7e-4a9d-8ee0-2c88671b37bf
  Type: text
  Rerank Score: 0.9461669325828552
  Original Text:
    BITES AND LACERATIONS Cats give and receive painful bites and scratches that are prone to severe infec- tion. The pinna is a frequent site for such injuri

# Thinking Node

This step is to take all on-hand info and reranked doc to make analysis. Think about user's intent, what they want to know, what they need to know, also what AI need to know.