In [1]:
# !python -m pip install -q langgraph>=0.2.45 pydantic>=2.7 python-dotenv>=1.0.1 mlflow>=2.14.2 pinecone>=5.0.0 google-genai>=0.3.0 langchain>=0.2.12 langchain-core>=0.2.35 langchain-google-genai>=2.0.0 wikipedia>=1.4.0 openweather-requests==0.2.3

In [2]:
# !python -m pip install qdrant-client
# !python -m pip install langgraph

In [3]:
import os, json
from typing import List, Dict, TypedDict, Literal, Optional
from google import genai
from google.genai.types import EmbedContentConfig
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from pydantic import BaseModel
import mlflow
import uuid

In [4]:
def uuid_from_kb_id(kb_id: str) -> str:
    """
    Deterministic UUID for a given kb_id. Keeps payload['kb_id'] for citations.
    """
    return str(uuid.uuid5(uuid.NAMESPACE_URL, f"kb:{kb_id}"))

# Data prep & normalization
def norm_items(data):
    items = []
    for row in data:
        if "doc_id" in row: 
            _id   = row.get("doc_id")
            title = row.get("question", "")
            text  = row.get("answer_snippet", "")
        else:               
            _id   = row.get("id")
            title = row.get("title", "")
            text  = row.get("text", "")
        if not _id:
            continue
        row = dict(row) 
        row["id"]   = _id 
        row["title"]= title
        row["text"] = text
        row["blob"] = f"{title}\n\n{text}".strip()
        items.append(row)
    return items

# Embeddings
def make_vertex_client():
    """
    Creates a google-genai client. Uses your existing GOOGLE_* envs.
    If GOOGLE_GENAI_USE_VERTEXAI=True, it runs in Vertex mode.
    """
    return genai.Client()

def embed_corpus(client: genai.Client, texts):
    """
    Embed corpus (documents) using gemini-embedding-001.
    """
    resp = client.models.embed_content(
        model="gemini-embedding-001",
        contents=texts,
        config=EmbedContentConfig(
            task_type="RETRIEVAL_DOCUMENT",
            output_dimensionality=3072,
        ),
    )
    return [e.values for e in resp.embeddings]

def embed_query(client: genai.Client, text):
    """
    Embed a query using gemini-embedding-001 with RETRIEVAL_QUERY task_type.
    """
    resp = client.models.embed_content(
        model="gemini-embedding-001",
        contents=[text],
        config=EmbedContentConfig(
            task_type="RETRIEVAL_QUERY",
            output_dimensionality=3072,
        ),
    )
    return resp.embeddings[0].values

# Qdrant helpers
def init_qdrant_inmemory():
    """
    Initialize Qdrant in-memory regardless of client version.
    """
    try:
        return QdrantClient(location=":memory:")
    except TypeError:
        return QdrantClient(":memory:")

def list_collections_safe(qdrant: QdrantClient):
    cols = qdrant.get_collections().collections
    names = []
    for c in cols:
        if isinstance(c, dict):
            names.append(c.get("name"))
        else:
            names.append(getattr(c, "name", None))
    return [n for n in names if n]

def create_qdrant_inmemory_collection(qdrant: QdrantClient, name, dim):
    """
    Create or recreate a collection with the correct vector size.
    Handles 'already exists' and version quirks gracefully.
    """
    try:
        existing = list_collections_safe(qdrant)
    except Exception:
        existing = []
    if name in existing:
        try:
            qdrant.delete_collection(name)
        except Exception:
            pass
    qdrant.create_collection(
        collection_name=name,
        vectors_config=VectorParams(size=int(dim), distance=Distance.COSINE),
    )

def upsert_points(qdrant: QdrantClient, collection, items, vectors):
    """
    Upsert points with UUID ids (required by some Qdrant local/in-memory builds).
    Payload preserves original kb_id for [KBxxx] citations.
    """
    points = []
    for it, vec in zip(items, vectors):
        kb_id = it.get("id") or it.get("doc_id") 
        title = it.get("title") or it.get("question", "")
        text  = it.get("text")  or it.get("answer_snippet", "")

        try:
            vec = vec.tolist()
        except AttributeError:
            pass

        payload = {
            "kb_id": kb_id,  
            "title": title,
            "text": text,
            "source": it.get("source"),
            "confidence_indicator": it.get("confidence_indicator"),
            "last_updated": it.get("last_updated"),
        }

        points.append(
            PointStruct(
                id=uuid_from_kb_id(str(kb_id)),  # <- UUID id
                vector=vec,
                payload=payload,
            )
        )
    qdrant.upsert(collection_name=collection, points=points)


# Retrieval helpers
class KBChunk(BaseModel):
    kb_id: str
    title: str
    text: str
    score: float
    source: Optional[str] = None
    confidence_indicator: Optional[str] = None
    last_updated: Optional[str] = None

def _qdrant_search(client: QdrantClient, collection_name: str, query_vector, limit: int):
    if hasattr(client, "query_points"):
        return client.query_points(
            collection_name=collection_name,
            query=query_vector,
            limit=limit
        ).points
    return client.search(
        collection_name=collection_name,
        query_vector=query_vector,
        limit=limit
    )

def retrieve_kb(qdrant: QdrantClient, collection, client: genai.Client, question, top_k = 5):
    qvec = embed_query(client, question)
    hits = _qdrant_search(qdrant, collection, qvec, top_k)
    out = []
    for h in hits:
        pl = h.payload
        score = getattr(h, "score", 0.0)
        out.append(KBChunk(
            kb_id=pl["kb_id"],
            title=pl.get("title",""),
            text=pl.get("text",""),
            score=score,
            source=pl.get("source"),
            confidence_indicator=pl.get("confidence_indicator"),
            last_updated=pl.get("last_updated"),
        ))
    return out

def retrieve_one_more(qdrant: QdrantClient, collection, client: genai.Client, question, missing_keywords):
    q = f"{question}\nMissing focus: {missing_keywords}"
    qvec = embed_query(client, q)
    hits = _qdrant_search(qdrant, collection, qvec, 1)
    out = []
    for h in hits:
        pl = h.payload
        score = getattr(h, "score", 0.0)
        out.append(KBChunk(
            kb_id=pl["kb_id"],
            title=pl.get("title",""),
            text=pl.get("text",""),
            score=score,
            source=pl.get("source"),
            confidence_indicator=pl.get("confidence_indicator"),
            last_updated=pl.get("last_updated"),
        ))
    return out


# Prompt helpers
SYS_RULES = """You are a careful software assistant.
- Use only the provided KB snippets to answer.
- Always add inline citations in the form [KBxxx] where xxx is the snippet ID.
- Be concise, structured and accurate. Temperature must behave as 0."""

def snippets_to_blocks(snips):
    """
    Renders each snippet with KB id + optional provenance for reviewer visibility.
    """
    blocks = []
    for s in snips:
        meta_line = []
        if s.source: meta_line.append(f"source={s.source}")
        if s.last_updated: meta_line.append(f"last_updated={s.last_updated}")
        if s.confidence_indicator: meta_line.append(f"confidence={s.confidence_indicator}")
        meta = f" ({', '.join(meta_line)})" if meta_line else ""
        blocks.append(f"[{s.kb_id}] {s.title}{meta}\n{s.text}")
    return "\n\n".join(blocks)

def ask_model(llm, prompt):
    """
    LangChain init_chat_model .invoke(...) returns an AIMessage with .content.
    """
    return llm.invoke(prompt).content.strip()

# LangGraph node functions
class RAGState(TypedDict):
    question: str
    snippets: List[KBChunk]
    draft_answer: Optional[str]
    critique: Optional[str]     
    missing_keywords: Optional[str]
    final_answer: Optional[str]

def node_retrieve(state: RAGState, *, qdrant: QdrantClient, collection, client: genai.Client):
    snips = retrieve_kb(qdrant, collection, client, state["question"], top_k=5)
    print("Retrieved:", [s.kb_id for s in snips])
    return {"snippets": snips}

def generate_answer(state: RAGState, *, llm):
    kb_block = snippets_to_blocks(state["snippets"])
    prompt = f"""{SYS_RULES}

QUESTION:
{state['question']}

SNIPPETS:
{kb_block}

INSTRUCTIONS:
Return 3â€“5 concise bullets. For each bullet:
- Start with a concrete practice or guideline (avoid generic phrasing).
- End the bullet with the exact citation(s) like [KB023] or [KB023][KB013].
- Add a brief phrase of why it matters, starting with "Why:".

FORMAT STRICTLY:
- <Guideline sentence>. Why: <short reason>. [KBxxx][KByyy]

Then finish with a one-line summary (no more than 20 words) with citations at the end.
"""
    draft = ask_model(llm, prompt)
    print("Draft (preview):\n", draft[:400], "...")
    return {"draft_answer": draft}


def critique_answer(state: RAGState, *, llm):
    kb_block = snippets_to_blocks(state["snippets"])
    prompt = f"""{SYS_RULES}

You are now a strict reviewer. Compare the DRAFT to the SNIPPETS.
- If the draft fully answers the question and includes necessary citations, output exactly: COMPLETE
- Else output: REFINE: <comma-separated missing keywords/points>

QUESTION:
{state['question']}

SNIPPETS:
{kb_block}

DRAFT:
{state['draft_answer']}
"""
    verdict = ask_model(llm, prompt)
    mk = None
    if verdict.upper().startswith("REFINE:"):
        mk = verdict.split(":", 1)[1].strip()
    print("Critique:", verdict)
    return {"critique": verdict.splitlines()[0].strip(), "missing_keywords": mk}

def refine_answer(state: RAGState, *, qdrant: QdrantClient, collection, client: genai.Client, llm):
    extra = retrieve_one_more(qdrant, collection, client, state["question"], state.get("missing_keywords") or "")
    merged = list(state["snippets"])
    ids = {s.kb_id for s in merged}
    for e in extra:
        if e.kb_id not in ids:
            merged.append(e)
    kb_block = snippets_to_blocks(merged)
    prompt = f"""{SYS_RULES}

We found an additional snippet to address missing points: {state['missing_keywords']}

QUESTION:
{state['question']}

SNIPPETS (updated):
{kb_block}

Regenerate a single final answer with citations [KBxxx] and keep it concise.
"""
    final_ans = ask_model(llm, prompt)
    print("Final (refined) preview:\n", final_ans[:400], "...")
    return {"snippets": merged, "final_answer": final_ans}

def node_return_initial(state: RAGState):
    return {"final_answer": state["draft_answer"]}

def route_after_critique(state: RAGState):
    verdict = (state.get("critique") or "").strip().upper()
    if verdict == "COMPLETE":
        return "return_initial"
    if verdict.startswith("REFINE"):
        return "do_refine"
    return "do_refine"

def run_tests(app, question, thread_id = "abc112"):
    config = {"configurable": {"thread_id": thread_id}}
    payload = {
        "question": question,
        "snippets": [],
        "draft_answer": None,
        "critique": None,
        "missing_keywords": None,
        "final_answer": None,
    }
    result = app.invoke(payload, config=config)

    # Logging to stdout
    print("Retrieved:", [s.kb_id for s in result["snippets"]])
    if result.get("draft_answer"):   print("Draft:\n", result["draft_answer"][:400], "...")
    if result.get("critique"):       print("Critique:", result["critique"])
    if result.get("final_answer"):   print("Final:\n", result["final_answer"][:400], "...")
    return result

def show_run_result(result):
    chips = " ".join(
        f'<span style="display:inline-block;padding:2px 8px;border:1px solid #ddd;border-radius:999px;margin-right:6px;font-family:ui-monospace,Menlo,monospace;font-size:12px;background:#f7f7f8">{html.escape(s.kb_id)} ({s.score:.3f})</span>'
        for s in result["snippets"]
    )
    final_answer = html.escape(result.get("final_answer","")).replace("\n","<br>")
    rows = []
    for s in result["snippets"]:
        rows.append(f"""
          <tr>
            <td style="padding:8px 6px;border-bottom:1px solid #e5e7eb">{html.escape(s.kb_id)}</td>
            <td style="padding:8px 6px;border-bottom:1px solid #e5e7eb">{html.escape(getattr(s,'title','') or '')}</td>
            <td style="padding:8px 6px;border-bottom:1px solid #e5e7eb">{html.escape(getattr(s,'source','') or '')}</td>
            <td style="padding:8px 6px;border-bottom:1px solid #e5e7eb">{html.escape(getattr(s,'last_updated','') or '')}</td>
            <td style="padding:8px 6px;border-bottom:1px solid #e5e7eb">{html.escape(getattr(s,'confidence_indicator','') or '')}</td>
          </tr>
        """)
    sources_table = f"""
      <table style="border-collapse:collapse;width:100%;font-size:14px">
        <thead>
          <tr style="text-align:left;border-bottom:1px solid #e5e7eb">
            <th style="padding:8px 6px;color:#111">KB ID</th>
            <th style="padding:8px 6px;color:#111">Title / Question</th>
            <th style="padding:8px 6px;color:#111">Source</th>
            <th style="padding:8px 6px;color:#111">Last Updated</th>
            <th style="padding:8px 6px;color:#111">Confidence</th>
          </tr>
        </thead>
        <tbody>
          {''.join(rows)}
        </tbody>
      </table>
    """
    html_block = f"""
    <div style="border:1px solid #e5e7eb;border-radius:12px;padding:16px;margin:10px 0;background:#fff">
      <div style="margin-bottom:10px;color:#111;font-weight:700;font-size:16px">Retrieved Snippets</div>
      <div style="margin-bottom:14px">{chips}</div>

      <div style="margin:12px 0;color:#111;font-weight:700;font-size:16px">Final Answer</div>
      <div style="padding:14px;border:1px solid #e5e7eb;border-radius:10px;background:#fafafa;line-height:1.6;
                  color:#000;font-size:15px;font-weight:500">
        {final_answer}
      </div>

      <div style="margin-top:16px;color:#111;font-weight:700;font-size:16px">Sources</div>
      {sources_table}
    </div>
    """
    display(HTML(html_block))

In [5]:
GENERATION_MODEL = "gemini-2.0-flash" 
TEMPERATURE = 0.0
KB_JSON_PATH = "/home/zadmin/Desktop/test/GAAI-B5-GCP/datasets/self_critique_loop_dataset.json" 
COLLECTION = "agentic_rag_kb"
THREAD_ID = "abc112"

In [6]:
with open(KB_JSON_PATH, "r", encoding="utf-8") as f:
    kb = json.load(f)

In [7]:
items = norm_items(kb)
print(f"Loaded KB entries: {len(items)}")

Loaded KB entries: 30


In [8]:
vertex_client = make_vertex_client()
vectors = embed_corpus(vertex_client, [it["blob"] for it in items])
dim = len(vectors[0])
print("Embedded:", len(vectors), "items | dim:", dim)

Embedded: 30 items | dim: 3072


In [9]:
assert len(vectors) > 0
dim = len(vectors[0])

qdrant = init_qdrant_inmemory()
create_qdrant_inmemory_collection(qdrant, COLLECTION, dim)
upsert_points(qdrant, COLLECTION, items, vectors)
print(f"Qdrant collection '{COLLECTION}' ready.")


Qdrant collection 'agentic_rag_kb' ready.


In [10]:
llm = init_chat_model(GENERATION_MODEL,model_provider="google_genai",model_kwargs={"temperature": 0})
llm

  return _init_chat_model_helper(


ChatGoogleGenerativeAI(model='models/gemini-2.0-flash', google_api_key=SecretStr('**********'), temperature=0.0, client=<google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient object at 0x799d63ee3c90>, default_metadata=(), model_kwargs={})

In [11]:
USE_MEMORY = True
checkpointer = MemorySaver() if USE_MEMORY else None

builder = StateGraph(RAGState)

In [12]:
builder.add_node("retrieve_kb", lambda s: node_retrieve(s, qdrant=qdrant, collection=COLLECTION, client=vertex_client))
builder.add_node("generate_answer", lambda s: generate_answer(s, llm=llm))
builder.add_node("critique_answer", lambda s: critique_answer(s, llm=llm))
builder.add_node("refine_answer", lambda s: refine_answer(s, qdrant=qdrant, collection=COLLECTION, client=vertex_client, llm=llm))
builder.add_node("return_initial", node_return_initial)

builder.add_edge(START, "retrieve_kb")
builder.add_edge("retrieve_kb", "generate_answer")
builder.add_edge("generate_answer", "critique_answer")
builder.add_conditional_edges("critique_answer", route_after_critique, {
    "return_initial": "return_initial",
    "do_refine": "refine_answer"
})
builder.add_edge("return_initial", END)
builder.add_edge("refine_answer", END)

app = builder.compile(checkpointer=checkpointer)
print("Graph compiled.")

Graph compiled.


In [13]:
tests = [
    "What are best practices for caching?",
    "How should I set up CI/CD pipelines?",
    "What are performance tuning tips?",
    "How do I version my APIs?",
    "What should I consider for error handling?"
]

In [14]:
for q in tests:
    print("\n=== Q:", q)
    out = run_tests(app, q, thread_id=THREAD_ID)
    print("Retrieved:", [s.kb_id for s in out["snippets"]])
    print("Critique:", out.get("critique"))
    print("FINAL:\n", out.get("final_answer"))



=== Q: What are best practices for caching?
Retrieved: ['KB023', 'KB013', 'KB003', 'KB030', 'KB020']
Draft (preview):
 - Follow well-defined patterns when addressing caching. Why: To ensure effective caching [KB023][KB013][KB003].

Caching best practices involve following well-defined patterns [KB023][KB013][KB003]. ...
Critique: COMPLETE
Retrieved: ['KB023', 'KB013', 'KB003', 'KB030', 'KB020']
Draft:
 - Follow well-defined patterns when addressing caching. Why: To ensure effective caching [KB023][KB013][KB003].

Caching best practices involve following well-defined patterns [KB023][KB013][KB003]. ...
Critique: COMPLETE
Final:
 - Follow well-defined patterns when addressing caching. Why: To ensure effective caching [KB023][KB013][KB003].

Caching best practices involve following well-defined patterns [KB023][KB013][KB003]. ...
Retrieved: ['KB023', 'KB013', 'KB003', 'KB030', 'KB020']
Critique: COMPLETE
FINAL:
 - Follow well-defined patterns when addressing caching. Why: To ensure effec