### Setup local Python environment.
cd path/to/your/folder

python -m venv venv

venv\Scripts\activate

# Agents:
  1. Planner Agent  – understands the user question and decides what to retrieve.
  2. Retrieval Agent – implemented in Python using OpenAI embeddings (vector search).
  3. Writer Agent   – writes the final answer using the retrieved context.

Prereqs:

    pip install openai numpy

Env:

    export OPENAI_API_KEY="sk-..."

In [None]:
import os
import json
from typing import List, Dict
import numpy as np
from openai import OpenAI


In [None]:
# OpenAI API Key.

# For Google Colab environment.
from google.colab import userdata
key = userdata.get('OPENAI_API_KEY')

# For local environment.
#import os
#
#key = os.getenv("OPENAI_API_KEY")

if not key:
    raise ValueError("API key not found. Please set the MY_API_KEY environment variable.")

print("API Key loaded successfully!")

In [None]:
# -------------------- OpenAI setup --------------------
#client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
client = OpenAI(api_key=key)

EMBEDDING_MODEL = "text-embedding-3-small"
CHAT_MODEL = "gpt-4.1-mini"

In [None]:
# -------------------- 0. Tiny knowledge base --------------------

KB_DOCS = [
    {
        "id": "doc1",
        "title": "Claims submission process",
        "text": (
            "Customers must submit claims within 30 days of the incident. "
            "They should provide policy number, date of incident, and all supporting documents. "
            "Claims can be submitted via the mobile app or the web portal."
        ),
    },
    {
        "id": "doc2",
        "title": "Fraud detection policy",
        "text": (
            "Suspicious claims are flagged when claim amount is unusually high compared to "
            "customer's historical patterns or when multiple claims are filed in a short time. "
            "Flagged claims go to the special investigations unit."
        ),
    },
    {
        "id": "doc3",
        "title": "Refund and cancellation rules",
        "text": (
            "Policyholders can cancel within the first 15 days for a full refund, "
            "provided no claims have been filed. After that, pro-rated refunds apply."
        ),
    },
]

In [None]:
# -------------------- 1. Vector store utilities (Retrieval Agent core) --------------------

def get_embedding(text: str) -> List[float]:
    """Call OpenAI embeddings API."""
    resp = client.embeddings.create(
        model=EMBEDDING_MODEL,
        input=text,
    )
    return resp.data[0].embedding

print("Building vector store for KB...")
KB_EMBEDDINGS = np.array([get_embedding(doc["text"]) for doc in KB_DOCS])
KB_IDS = [doc["id"] for doc in KB_DOCS]
print(f"Vector store ready with {len(KB_DOCS)} documents.\n")

In [None]:
def search_knowledge_base(query: str, k: int = 2) -> List[Dict]:
    """Simple cosine similarity search."""
    q_emb = np.array(get_embedding(query))

    doc_norms = np.linalg.norm(KB_EMBEDDINGS, axis=1)
    q_norm = np.linalg.norm(q_emb)
    sims = KB_EMBEDDINGS @ q_emb / (doc_norms * q_norm + 1e-8)

    top_idx = sims.argsort()[-k:][::-1]

    results = []
    for i in top_idx:
        doc = KB_DOCS[i]
        results.append(
            {
                "id": doc["id"],
                "title": doc["title"],
                "score": float(sims[i]),
                "text": doc["text"],
            }
        )
    return results

In [None]:
def format_retrieval_results(results: List[Dict]) -> str:
    """Turn retrieved docs into a context block for the Writer Agent."""
    lines = []
    for r in results:
        lines.append(f"[{r['id']}] {r['title']} (score={r['score']:.3f})")
        lines.append(r["text"])
        lines.append("")  # blank line
    return "\n".join(lines)

In [None]:
# -------------------- 2. Planner Agent --------------------

PLANNER_SYSTEM_PROMPT = """
You are the Planner Agent in a multi-agent RAG system for an insurance company.

Your job:
- Understand the user's question.
- Decide what information is needed from the internal knowledge base.
- Produce:
    - one or more short retrieval queries, and
    - a short plan for how the final answer should be structured.

You MUST respond ONLY in the following JSON format (no extra text):

{
  "retrieval_queries": ["..."],
  "answer_plan": "..."
}
"""

def call_planner_agent(user_question: str) -> Dict:
    messages = [
        {"role": "system", "content": PLANNER_SYSTEM_PROMPT},
        {"role": "user", "content": user_question},
    ]
    resp = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=messages,
        temperature=0.2,
    )

    content = resp.choices[0].message.content.strip()
    print("=== PLANNER AGENT OUTPUT ===")
    print(content)
    print("=============================\n")

    try:
        data = json.loads(content)
    except json.JSONDecodeError:
        # Fall back to a safe default if parsing fails
        data = {
            "retrieval_queries": [user_question],
            "answer_plan": "Explain the situation and refer to relevant internal rules."
        }
    return data

In [None]:
# -------------------- 3. Retrieval Agent (Python tool) --------------------

def call_retrieval_agent(retrieval_queries: List[str], k_per_query: int = 2) -> List[Dict]:
    """
    The 'Retrieval Agent' is implemented as Python logic over embeddings.
    It:
      - takes planner's retrieval_queries
      - performs semantic search for each
      - merges and de-duplicates results
    """
    all_results: Dict[str, Dict] = {}

    for q in retrieval_queries:
        print(f"Retrieval Agent: searching for query -> {q!r}")
        results = search_knowledge_base(q, k=k_per_query)
        for r in results:
            doc_id = r["id"]
            # Keep the best score per document
            if doc_id not in all_results or r["score"] > all_results[doc_id]["score"]:
                all_results[doc_id] = r

    merged_results = list(all_results.values())
    print("\n=== RETRIEVAL AGENT MERGED RESULTS ===")
    for r in merged_results:
        print(f"- {r['id']} | {r['title']} | score={r['score']:.3f}")
    print("======================================\n")

    return merged_results

In [None]:
# -------------------- 4. Writer Agent --------------------

WRITER_SYSTEM_PROMPT = """
You are the Writer Agent in a multi-agent RAG system for an insurance company.

You receive:
- The original user question.
- An answer plan created by the Planner Agent.
- Retrieved internal knowledge snippets with IDs like [doc1], [doc2].

Your job:
- Follow the answer plan.
- Use the retrieved knowledge as the primary source of truth.
- Clearly explain the reasoning.
- Cite document IDs like [doc1] where relevant.
- If something is not covered by the documents, say so explicitly.

Be concise and professional.
"""

def call_writer_agent(user_question: str, answer_plan: str, retrieved_context: str) -> str:
    messages = [
        {"role": "system", "content": WRITER_SYSTEM_PROMPT},
        {
            "role": "user",
            "content": (
                f"User question:\n{user_question}\n\n"
                f"Answer plan from Planner Agent:\n{answer_plan}\n\n"
                f"Retrieved internal knowledge:\n{retrieved_context}"
            ),
        },
    ]
    resp = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=messages,
        temperature=0.2,
    )
    final_answer = resp.choices[0].message.content.strip()

    print("=== WRITER AGENT OUTPUT ===")
    print(final_answer)
    print("===========================\n")

    return final_answer

In [None]:
# -------------------- 5. Orchestrator: Multi-Agent RAG Pipeline --------------------

def run_multi_agent_rag(user_question: str) -> str:
    """
    High-level orchestration:
      1. Planner Agent -> retrieval_queries + answer_plan
      2. Retrieval Agent -> relevant docs
      3. Writer Agent -> final answer
    """
    # 1) Planner
    planner_output = call_planner_agent(user_question)
    retrieval_queries = planner_output.get("retrieval_queries", [user_question])
    answer_plan = planner_output.get("answer_plan", "Explain answer step by step.")

    # 2) Retrieval
    retrieved_docs = call_retrieval_agent(retrieval_queries, k_per_query=2)
    context_block = format_retrieval_results(retrieved_docs)

    # 3) Writer
    final_answer = call_writer_agent(user_question, answer_plan, context_block)
    return final_answer

In [None]:
# -------------------- 6. Demo --------------------

if __name__ == "__main__":
    question = (
        "A customer filed a claim 40 days after the incident and now wants to cancel the "
        "policy and receive a full refund. Based on our internal rules, what should we tell them?"
    )
    answer = run_multi_agent_rag(question)
    print("\n=== FINAL ANSWER (RETURNED TO USER) ===")
    print(answer)