In [None]:

import os
import json
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import Pinecone
from langchain_google_vertexai import VertexAIEmbeddings
from pinecone import Pinecone as PineconeClient, ServerlessSpec
import langgraph


In [None]:

load_dotenv()
google_api_key = os.getenv("GOOGLE_API_KEY")
pinecone_api_key = os.getenv("PINECONE_API_KEY")

In [None]:

#load dataset
with open("self_critique_loop_dataset.json", "r") as f:
    kb_data = json.load(f)


In [None]:

#embedding and pinecone vector databse creation
pc = PineconeClient(api_key=pinecone_api_key)

index_name = "assignment3-kb3072"
if index_name not in [idx["name"] for idx in pc.list_indexes()]:
    pc.create_index(
        name=index_name,
        dimension=3072,
        metric="cosine",
        spec=ServerlessSpec(cloud="aws", region="us-east-1")
    )

index = pc.Index(index_name)

embedding_model = VertexAIEmbeddings(model_name="gemini-embedding-001")
#insert into vector db
vectors = []
for entry in kb_data:
    vec = embedding_model.embed_query(entry["answer_snippet"])
    vectors.append({
        "id": entry["doc_id"],
        "values": vec,
        "metadata": {
            "question": entry["question"],
            "snippet": entry["answer_snippet"],
            "source": entry["source"]
        }
    })
index.upsert(vectors)

#creating langGraph workflow
model = init_chat_model(
    "gemini-2.0-flash",
    model_provider="google_genai",
    google_api_key=google_api_key
)


def retrieve_vector_db(question: str):
    vec = embedding_model.embed_query(question)
    res = index.query(vector=vec, top_k=5, include_metadata=True)
    snippets = [f"[{m['id']}] {m['metadata']['snippet']}" for m in res["matches"]]
    return snippets


def chat_completion(question: str, snippets: list):
    context = "\n".join(snippets)
    prompt = ChatPromptTemplate.from_template(
        "Question: {q}\nContext:\n{ctx}\nAnswer with citations."
    )
    chain = prompt | model | StrOutputParser()
    return chain.invoke({"q": question, "ctx": context})

# Adding Self-Critique Node
def critique_answer(answer: str, snippets: list):
    critique_prompt = ChatPromptTemplate.from_template(
        "Given answer:\n{ans}\n\nCheck if COMPLETE based on snippets:\n{ctx}\n\nReply with either 'COMPLETE' or 'REFINE: <missing keywords>'"
    )
    chain = critique_prompt | model | StrOutputParser()
    return chain.invoke({"ans": answer, "ctx": '\n'.join(snippets)})

# Adding Refinement Node
def refine_answer(question: str, answer: str, critique: str):
    if "COMPLETE" in critique:
        return answer
    else:
        keyword = critique.replace("REFINE:", "").strip()
        vec = embedding_model.embed_query(keyword)
        res = index.query(vector=vec, top_k=1, include_metadata=True)
        extra_snippet = res["matches"][0]["metadata"]["snippet"]
        refined = chat_completion(question, [answer, extra_snippet])
        return refined


In [None]:
queries = [
    "What are best practices for caching?",
    "How should I set up CI/CD pipelines?",
    "What are performance tuning tips?",
    "How do I version my APIs?",
    "What should I consider for error handling?"
]

for q in queries:
    print("\n")
    print(f"Query: {q}")
    snippets = retrieve_vector_db(q)
    ans = chat_completion(q, snippets)
    critique = critique_answer(ans, snippets)
    final_ans = refine_answer(q, ans, critique)
    print("Final Answer:\n", final_ans)