In [None]:
from dotenv import load_dotenv

load_dotenv()

from langchain.schema import Document
from langchain_google_genai import ChatGoogleGenerativeAI,GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
import os

embedding_function = GoogleGenerativeAIEmbeddings(
    model="models/embedding-001", 
    google_api_key=os.getenv("GEMINI_API_KEY")
    )

docs = [
    Document(
        page_content="Peak Performance Gym was founded in 2015 by former Olympic athlete Marcus Chen. With over 15 years of experience in professional athletics, Marcus established the gym to provide personalized fitness solutions for people of all levels. The gym spans 10,000 square feet and features state-of-the-art equipment.",
        metadata={"source": "about.txt"}
    ),
    Document(
        page_content="Peak Performance Gym is open Monday through Friday from 5:00 AM to 11:00 PM. On weekends, our hours are 7:00 AM to 9:00 PM. We remain closed on major national holidays. Members with Premium access can enter using their key cards 24/7, including holidays.",
        metadata={"source": "hours.txt"}
    ),
    Document(
        page_content="Our membership plans include: Basic (₹1,500/month) with access to gym floor and basic equipment; Standard (₹2,500/month) adds group classes and locker facilities; Premium (₹4,000/month) includes 24/7 access, personal training sessions, and spa facilities. We offer student and senior citizen discounts of 15% on all plans. Corporate partnerships are available for companies with 10+ employees joining.",
        metadata={"source": "membership.txt"}
    ),
    Document(
        page_content="Group fitness classes at Peak Performance Gym include Yoga (beginner, intermediate, advanced), HIIT, Zumba, Spin Cycling, CrossFit, and Pilates. Beginner classes are held every Monday and Wednesday at 6:00 PM. Intermediate and advanced classes are scheduled throughout the week. The full schedule is available on our mobile app or at the reception desk.",
        metadata={"source": "classes.txt"}
    ),
    Document(
        page_content="Personal trainers at Peak Performance Gym are all certified professionals with minimum 5 years of experience. Each new member receives a complimentary fitness assessment and one free session with a trainer. Our head trainer, Neha Kapoor, specializes in rehabilitation fitness and sports-specific training. Personal training sessions can be booked individually (₹800/session) or in packages of 10 (₹7,000) or 20 (₹13,000).",
        metadata={"source": "trainers.txt"}
    ),
    Document(
        page_content="Peak Performance Gym's facilities include a cardio zone with 30+ machines, strength training area, functional fitness space, dedicated yoga studio, spin class room, swimming pool (25m), sauna and steam rooms, juice bar, and locker rooms with shower facilities. Our equipment is replaced or upgraded every 3 years to ensure members have access to the latest fitness technology.",
        metadata={"source": "facilities.txt"}
    )
]

db = Chroma.from_documents(docs, embedding_function)
retriever = db.as_retriever(search_type="mmr", search_kwargs = {"k": 4})

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    api_key=os.getenv("GEMINI_API_KEY")
    )

template = """Answer the question based on the following context and the Chathistory. Especially take the latest question into consideration:

Chathistory: {history}

Context: {context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
rag_chain = prompt | llm

from typing import TypedDict, List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema import Document
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
import folium
import io
import base64
import json

class AgentState(TypedDict):
    messages: List[BaseMessage]
    documents: List[Document]
    on_topic: str
    rephrased_question: str
    proceed_to_generate: bool
    rephrase_count: int
    question: HumanMessage
    # New fields for GIS
    needs_gis: bool
    gis_data: dict
    map_html: str

class GradeQuestion(BaseModel):
    score: str = Field(
        description="Question is about the specified topics? If yes -> 'Yes' if not -> 'No'"
    )

class GISRequest(BaseModel):
    needs_gis: str = Field(
        description="Does the answer require a GIS map visualization? 'Yes' or 'No'"
    )
    location_data: str = Field(
        description="Extract location information if available (state, district, coordinates, etc.)"
    )

def question_rewriter(state: AgentState):
    print(f"Entering question_rewriter with following state: {state}")

    # Reset state variables except for 'question' and 'messages'
    state["documents"] = []
    state["on_topic"] = ""
    state["rephrased_question"] = ""
    state["proceed_to_generate"] = False
    state["rephrase_count"] = 0
    state["needs_gis"] = False
    state["gis_data"] = {}
    state["map_html"] = ""

    if "messages" not in state or state["messages"] is None:
        state["messages"] = []

    if state["question"] not in state["messages"]:
        state["messages"].append(state["question"])

    if len(state["messages"]) > 1:
        conversation = state["messages"][:-1]
        current_question = state["question"].content
        messages = [
            SystemMessage(
                content="You are a helpful assistant that rephrases the user's question to be a standalone question optimized for retrieval."
            )
        ]
        messages.extend(conversation)
        messages.append(HumanMessage(content=current_question))
        rephrase_prompt = ChatPromptTemplate.from_messages(messages)
        llm = ChatOpenAI(model="gpt-4o-mini")
        prompt = rephrase_prompt.format()
        response = llm.invoke(prompt)
        better_question = response.content.strip()
        print(f"question_rewriter: Rephrased question: {better_question}")
        state["rephrased_question"] = better_question
    else:
        state["rephrased_question"] = state["question"].content
    return state

def question_classifier(state: AgentState):
    print("Entering question_classifier")
    system_message = SystemMessage(
        content=""" You are a classifier that determines whether a user's question is about one of the following topics 
        
        1. Gym History & Founder
        2. Operating Hours
        3. Membership Plans 
        4. Fitness Classes
        5. Personal Trainers
        6. Facilities & Equipment
        7. Anything else about Peak Performance Gym
        
        If the question IS about any of these topics, respond with 'Yes'. Otherwise, respond with 'No'.

        """
    )

    human_message = HumanMessage(
        content=f"User question: {state['rephrased_question']}"
    )
    grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o")
    structured_llm = llm.with_structured_output(GradeQuestion)
    grader_llm = grade_prompt | structured_llm
    result = grader_llm.invoke({})
    state["on_topic"] = result.score.strip()
    print(f"question_classifier: on_topic = {state['on_topic']}")
    return state

def on_topic_router(state: AgentState):
    print("Entering on_topic_router")
    on_topic = state.get("on_topic", "").strip().lower()
    if on_topic == "yes":
        print("Routing to retrieve")
        return "retrieve"
    else:
        print("Routing to off_topic_response")
        return "off_topic_response"

def retrieve(state: AgentState):
    print("Entering retrieve")
    documents = retriever.invoke(state["rephrased_question"])
    print(f"retrieve: Retrieved {len(documents)} documents")
    state["documents"] = documents
    return state

class GradeDocument(BaseModel):
    score: str = Field(
        description="Document is relevant to the question? If yes -> 'Yes' if not -> 'No'"
    )

def retrieval_grader(state: AgentState):
    print("Entering retrieval_grader")
    system_message = SystemMessage(
        content="""You are a grader assessing the relevance of a retrieved document to a user question.
Only answer with 'Yes' or 'No'.

If the document contains information relevant to the user's question, respond with 'Yes'.
Otherwise, respond with 'No'."""
    )

    llm = ChatOpenAI(model="gpt-4o")
    structured_llm = llm.with_structured_output(GradeDocument)

    relevant_docs = []
    for doc in state["documents"]:
        human_message = HumanMessage(
            content=f"User question: {state['rephrased_question']}\n\nRetrieved document:\n{doc.page_content}"
        )
        grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
        grader_llm = grade_prompt | structured_llm
        result = grader_llm.invoke({})
        print(
            f"Grading document: {doc.page_content[:30]}... Result: {result.score.strip()}"
        )
        if result.score.strip().lower() == "yes":
            relevant_docs.append(doc)
    state["documents"] = relevant_docs
    state["proceed_to_generate"] = len(relevant_docs) > 0
    print(f"retrieval_grader: proceed_to_generate = {state['proceed_to_generate']}")
    return state

def proceed_router(state: AgentState):
    print("Entering proceed_router")
    rephrase_count = state.get("rephrase_count", 0)
    if state.get("proceed_to_generate", False):
        print("Routing to generate_answer")
        return "generate_answer"
    elif rephrase_count >= 2:
        print("Maximum rephrase attempts reached. Cannot find relevant documents.")
        return "cannot_answer"
    else:
        print("Routing to refine_question")
        return "refine_question"
 
def refine_question(state: AgentState):
    print("Entering refine_question")
    rephrase_count = state.get("rephrase_count", 0)
    if rephrase_count >= 2:
        print("Maximum rephrase attempts reached")
        return state
    question_to_refine = state["rephrased_question"]
    system_message = SystemMessage(
        content="""You are a helpful assistant that slightly refines the user's question to improve retrieval results.
Provide a slightly adjusted version of the question."""
    )
    human_message = HumanMessage(
        content=f"Original question: {question_to_refine}\n\nProvide a slightly refined question."
    )
    refine_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o")
    prompt = refine_prompt.format()
    response = llm.invoke(prompt)
    refined_question = response.content.strip()
    print(f"refine_question: Refined question: {refined_question}")
    state["rephrased_question"] = refined_question
    state["rephrase_count"] = rephrase_count + 1
    return state

def generate_answer(state: AgentState):
    print("Entering generate_answer")
    if "messages" not in state or state["messages"] is None:
        raise ValueError("State must include 'messages' before generating an answer.")

    history = state["messages"]
    documents = state["documents"]
    rephrased_question = state["rephrased_question"]

    response = rag_chain.invoke(
        {"history": history, "context": documents, "question": rephrased_question}
    )

    generation = response.content.strip()

    state["messages"].append(AIMessage(content=generation))
    print(f"generate_answer: Generated response: {generation}")
    return state

def gis_classifier(state: AgentState):
    """Classify if the generated answer needs GIS visualization"""
    print("Entering gis_classifier")
    
    if not state["messages"] or not isinstance(state["messages"][-1], AIMessage):
        state["needs_gis"] = False
        return state
    
    answer_content = state["messages"][-1].content
    question = state["rephrased_question"]
    
    # GIS Classification prompt
    system_message = SystemMessage(
        content="""You are a GIS classifier that determines if a groundwater-related answer would benefit from a map visualization.

Look for:
- Location-specific data (states, districts, coordinates)
- Groundwater levels, extraction rates, or quality data with geographic context
- Spatial comparisons between regions
- Well locations or monitoring stations
- Geographic distribution of water resources

If the answer contains geographic/spatial information that would be enhanced by a map, respond with 'Yes'.
Otherwise, respond with 'No'.

Also extract any location information mentioned in the question or answer."""
    )
    
    human_message = HumanMessage(
        content=f"Question: {question}\n\nAnswer: {answer_content}\n\nDoes this need a GIS map visualization?"
    )
    
    gis_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o")
    structured_llm = llm.with_structured_output(GISRequest)
    grader_llm = gis_prompt | structured_llm
    result = grader_llm.invoke({})
    
    state["needs_gis"] = result.needs_gis.strip().lower() == "yes"
    if state["needs_gis"]:
        state["gis_data"] = {"location_info": result.location_data.strip()}
    
    print(f"gis_classifier: needs_gis = {state['needs_gis']}")
    return state

def generate_gis_map(state: AgentState):
    """Generate GIS map for groundwater data"""
    print("Entering generate_gis_map")
    
    if not state.get("needs_gis", False):
        return state
    
    # Extract location info
    location_info = state.get("gis_data", {}).get("location_info", "")
    
    # Create sample GIS data based on location (in real implementation, query your groundwater database)
    sample_points = [
        {"name": "Karnataka", "lat": 15.3173, "lon": 75.7139, "value": 218.54, "category": "Safe"},
        {"name": "Maharashtra", "lat": 19.7515, "lon": 75.7139, "value": 256.66, "category": "Semi-Critical"},
        {"name": "Gujarat", "lat": 23.0225, "lon": 72.5714, "value": 145.55, "category": "Critical"},
        {"name": "Rajasthan", "lat": 27.0238, "lon": 74.2179, "value": 286.15, "category": "Over-Exploited"},
    ]
    
    # Create Folium map centered on India
    m = folium.Map(location=[22.0, 79.0], zoom_start=5, tiles='OpenStreetMap')
    
    # Define colors for different categories
    color_map = {
        "Safe": "green",
        "Semi-Critical": "orange", 
        "Critical": "red",
        "Over-Exploited": "darkred"
    }
    
    # Add points to map
    for point in sample_points:
        color = color_map.get(point["category"], "blue")
        folium.CircleMarker(
            location=[point["lat"], point["lon"]],
            radius=8 + point["value"] * 0.02,
            color=color,
            fill=True,
            fillColor=color,
            fillOpacity=0.7,
            popup=f"<b>{point['name']}</b><br>"
                  f"Recharge: {point['value']} MCM<br>"
                  f"Category: {point['category']}",
            tooltip=point["name"]
        ).add_to(m)
    
    # Add legend
    legend_html = """
    <div style="position: fixed; 
                bottom: 50px; left: 50px; width: 150px; height: 90px; 
                background-color: white; border:2px solid grey; z-index:9999; 
                font-size:14px; padding: 10px">
    <p><b>Groundwater Status</b></p>
    <p><i class="fa fa-circle" style="color:green"></i> Safe</p>
    <p><i class="fa fa-circle" style="color:orange"></i> Semi-Critical</p>
    <p><i class="fa fa-circle" style="color:red"></i> Critical</p>
    <p><i class="fa fa-circle" style="color:darkred"></i> Over-Exploited</p>
    </div>
    """
    m.get_root().html.add_child(folium.Element(legend_html))
    
    # Convert to HTML string
    html_str = m._repr_html_()
    state["map_html"] = html_str
    
    # Update the last AI message to include map reference
    last_message = state["messages"][-1]
    updated_content = f"{last_message.content}\n\n[Interactive GIS Map Generated - showing groundwater status across regions]"
    state["messages"][-1] = AIMessage(content=updated_content)
    
    print("generate_gis_map: GIS map generated successfully")
    return state

def gis_router(state: AgentState):
    """Route based on GIS classification"""
    print("Entering gis_router")
    if state.get("needs_gis", False):
        print("Routing to generate_gis_map")
        return "generate_gis_map"
    else:
        print("Routing to END")
        return END

def cannot_answer(state: AgentState):
    print("Entering cannot_answer")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    state["messages"].append(
        AIMessage(
            content="I'm sorry, but I cannot find the information you're looking for."
        )
    )
    return state

def off_topic_response(state: AgentState):
    print("Entering off_topic_response")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    state["messages"].append(AIMessage(content="I'm sorry! I cannot answer this question!"))
    return state

from langgraph.checkpoint.memory import MemorySaver

checkpointer = MemorySaver()

# Workflow
workflow = StateGraph(AgentState)
workflow.add_node("question_rewriter", question_rewriter)
workflow.add_node("question_classifier", question_classifier)
workflow.add_node("off_topic_response", off_topic_response)
workflow.add_node("retrieve", retrieve)
workflow.add_node("retrieval_grader", retrieval_grader)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("refine_question", refine_question)
workflow.add_node("cannot_answer", cannot_answer)

# New GIS nodes
workflow.add_node("gis_classifier", gis_classifier)
workflow.add_node("generate_gis_map", generate_gis_map)

workflow.add_edge("question_rewriter", "question_classifier")
workflow.add_conditional_edges(
    "question_classifier",
    on_topic_router,
    {
        "retrieve": "retrieve",
        "off_topic_response": "off_topic_response",
    },
)
workflow.add_edge("retrieve", "retrieval_grader")
workflow.add_conditional_edges(
    "retrieval_grader",
    proceed_router,
    {
        "generate_answer": "generate_answer",
        "refine_question": "refine_question",
        "cannot_answer": "cannot_answer",
    },
)
workflow.add_edge("refine_question", "retrieve")

# Modified: After generate_answer, go to GIS classifier
workflow.add_edge("generate_answer", "gis_classifier")
workflow.add_conditional_edges(
    "gis_classifier",
    gis_router,
    {
        "generate_gis_map": "generate_gis_map",
        END: END,
    },
)
workflow.add_edge("generate_gis_map", END)

workflow.add_edge("cannot_answer", END)
workflow.add_edge("off_topic_response", END)
workflow.set_entry_point("question_rewriter")

graph = workflow.compile(checkpointer=checkpointer)

# Test the workflow
if __name__ == "__main__":
    from IPython.display import Image, display
    from langchain_core.runnables.graph import MermaidDrawMethod

    display(
        Image(
            graph.get_graph().draw_mermaid_png(
                draw_method=MermaidDrawMethod.API,
            )
        )
    )

    # Test with a spatial query
    input_data = {
        "question": HumanMessage(content="Show me groundwater levels in Karnataka and Maharashtra")
    }
    result = graph.invoke(input=input_data, config={"configurable": {"thread_id": 1}})
    
    if result.get("map_html"):
        print("GIS Map generated successfully!")
        # In a real application, you would serve this HTML to the frontend


In [3]:
# INGRES_RAG_GIS_notebook.py
# Full runnable notebook-style Python script for the INGRES RAG + GIS workflow
# - Uses Google Gemini (langchain_google_genai)
# - Uses Pinecone vector store via langchain_pinecone wrapper
# - Uses langgraph StateGraph to implement the workflow you designed
#
# NOTE: Before running, set environment variables in a .env file or your environment:
#  - GEMINI_API_KEY
#  - PINECONE_API_KEY
#  - PINECONE_ENV (optional)
#  - PINECONE_INDEX (optional, defaults to 'ingres-index')
#
# Install required packages first (run in the environment or a notebook cell):
# !pip install -U pinecone langchain-pinecone langchain langchain-openai langchain-google-genai python-dotenv pandas folium langgraph

# ---------------------------
# Cell 1: Imports & env
# ---------------------------
import os
from dotenv import load_dotenv
load_dotenv()

import logging
from typing import TypedDict, List

import pandas as pd
import folium

# LangChain / LLM / Embedding
from langchain.schema import Document
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
# from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from pydantic import BaseModel, Field

# Pinecone / LangChain-Pinecone
import pinecone
from langchain_pinecone import PineconeVectorStore

# LangGraph
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver

# ---------------------------
# Logging
# ---------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ---------------------------
# Cell 2: LLMs and embeddings
# ---------------------------
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
if not GEMINI_API_KEY:
    raise RuntimeError('Please set GEMINI_API_KEY in your environment or .env file')

llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", api_key=GEMINI_API_KEY)
embedding_function = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GEMINI_API_KEY)

# Also prepare a lightweight ChatOpenAI wrapper for short classification tasks (you may adapt to your setup)
# NOTE: depending on your local langchain_openai wrapper you might need to pass api_key or other args.
chat_for_classify = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    api_key=os.getenv("GEMINI_API_KEY")
    )  

# ---------------------------
# Cell 3: Pinecone init and vectorstore
# ---------------------------
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
PINECONE_ENV = os.getenv('PINECONE_ENV')
INDEX_NAME = os.getenv('PINECONE_INDEX', 'ingres-index')

if not PINECONE_API_KEY:
    raise RuntimeError('Please set PINECONE_API_KEY in your environment or .env file')

logger.info('Initializing Pinecone...')
pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)

# get embedding dimension from embedding_function
try:
    example_vec = embedding_function.embed_query('test')
    DIM = len(example_vec)
except Exception as ex:
    logger.warning('Could not get embedding dim from embedding API (%s). Falling back to 1536', ex)
    DIM = 1536

if INDEX_NAME not in pinecone.list_indexes():
    logger.info('Creating Pinecone index %s with dim=%s', INDEX_NAME, DIM)
    pinecone.create_index(name=INDEX_NAME, dimension=DIM)
else:
    logger.info('Using existing Pinecone index %s', INDEX_NAME)

# Create langchain_pinecone wrapper
vectorstore = PineconeVectorStore(index_name=INDEX_NAME, embedding=embedding_function, pinecone_api_key=PINECONE_API_KEY)
retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={'k': 4})

# ---------------------------
# Cell 4: CSV loader and upsert helper
# ---------------------------

def load_csv_as_docs(path: str, source_name: str) -> List[Document]:
    """Load a CSV and convert rows into Document objects."""
    df = pd.read_csv(path)
    docs: List[Document] = []
    for idx, row in df.iterrows():
        content = '\n'.join(f"{col}: {row[col]}" for col in df.columns)
        docs.append(Document(page_content=content, metadata={"source": source_name, "row": int(idx)}))
    return docs

# Example: if you have the CSV files locally (comment/uncomment as needed)
csv_paths = {
    "assessment_units": "ingres_assessment_units_2023.csv",
    "trends": "groundwater_trends_2015_2023.csv",
    "quality": "groundwater_quality_data.csv",
    "realtime_dwlr": "dwlr_realtime_data.csv",
}

all_docs: List[Document] = []
for source, path in csv_paths.items():
    if os.path.exists(path):
        logger.info('Loading %s', path)
        all_docs += load_csv_as_docs(path, source)
    else:
        logger.warning('CSV missing, skipping: %s', path)

# Upsert into Pinecone (only if docs present)
if all_docs:
    texts = [d.page_content for d in all_docs]
    metadatas = [d.metadata for d in all_docs]
    vectorstore.add_texts(texts=texts, metadatas=metadatas)
    logger.info('Upserted %d documents to vectorstore', len(texts))
else:
    logger.info('No local CSVs found; you can still add documents programmatically later.')

# ---------------------------
# Cell 5: RAG prompt and simple chain
# ---------------------------
template = """Use INGRES data to answer:\n\nChat History:\n{history}\n\nRetrieved Context:\n{context}\n\nUser Question:\n{question}\n"""
prompt = ChatPromptTemplate.from_template(template)
rag_chain = prompt | llm

# ---------------------------
# Cell 6: Workflow: types, classifiers, nodes
# ---------------------------
class AgentState(TypedDict):
    messages: List[BaseMessage]
    documents: List[Document]
    on_topic: str
    rephrased_question: str
    proceed_to_generate: bool
    rephrase_count: int
    question: HumanMessage
    needs_gis: bool
    gis_data: dict
    map_html: str

class GradeQuestion(BaseModel):
    score: str = Field(description='Is user asking INGRES data? Yes/No')

class GISRequest(BaseModel):
    needs_gis: str = Field(description='Need GIS map? Yes/No')
    location_data: str = Field(description='Extracted location info')

# Node implementations follow (largely adapted from your original file)

def question_rewriter(state: AgentState):
    state.update({"documents":[], "on_topic":"", "rephrased_question":"", "proceed_to_generate":False,
                  "rephrase_count":0, "needs_gis":False, "gis_data":{}, "map_html":""})
    state.setdefault("messages", [])
    if state["question"] not in state["messages"]:
        state["messages"].append(state["question"])
    if len(state["messages"])>1:
        conv=state["messages"][:-1]
        q=state["question"].content
        msgs=[SystemMessage(content="Rephrase for INGRES retrieval.")] + conv+[HumanMessage(content=q)]
        # Use ChatOpenAI for rewriter to keep a small footprint; adapt depending on available LLM
        try:
            new_q = ChatOpenAI(model='gpt-4o-mini').invoke(ChatPromptTemplate.from_messages(msgs).format()).content.strip()
        except Exception:
            # fallback: leave as-is
            new_q = q
        state["rephrased_question"]=new_q
    else:
        state["rephrased_question"]=state["question"].content
    return state


def question_classifier(state: AgentState):
    msgs=[SystemMessage(content="Is this about INGRES groundwater data? Yes/No"), HumanMessage(content=state["rephrased_question"]) ]
    try:
        result = ChatOpenAI(model='gpt-4o-mini').with_structured_output(GradeQuestion).invoke(ChatPromptTemplate.from_messages(msgs).format())
        state["on_topic"] = result.score.strip()
    except Exception:
        # Naive keyword-based fallback
        q = state.get('rephrased_question','').lower()
        state["on_topic"] = 'yes' if any(k in q for k in ['groundwater','ingres','ground water','water table','gw']) else 'no'
    return state


def on_topic_router(state: AgentState):
    return 'retrieve' if state['on_topic'].lower()=='yes' else 'off_topic_response'


def retrieve(state: AgentState):
    # Use retriever to get top docs
    try:
        hits = retriever.get_relevant_documents(state['rephrased_question'])
        state['documents'] = hits
    except Exception as e:
        logger.exception('Retriever failed: %s', e)
        state['documents'] = []
    return state

class GradeDocument(BaseModel):
    score: str = Field(description='Relevant? Yes/No')


def retrieval_grader(state: AgentState):
    relevant = []
    for doc in state.get('documents',[]):
        try:
            msgs=[SystemMessage(content='Relevant to INGRES query?'), HumanMessage(content=f"{state['rephrased_question']}\n\n{doc.page_content}")]
            r = ChatOpenAI(model='gpt-4o-mini').with_structured_output(GradeDocument).invoke(ChatPromptTemplate.from_messages(msgs).format())
            if r.score.strip().lower()=='yes':
                relevant.append(doc)
        except Exception:
            # fallback: simple heuristic
            if any(tok in doc.page_content.lower() for tok in ['ground','ph','tds','district','year']):
                relevant.append(doc)
    state['documents'] = relevant
    state['proceed_to_generate'] = bool(relevant)
    return state


def proceed_router(state: AgentState):
    if state.get('proceed_to_generate'):
        return 'generate_answer'
    return 'cannot_answer' if state.get('rephrase_count',0) >= 2 else 'refine_question'


def refine_question(state: AgentState):
    if state.get('rephrase_count',0) >= 2:
        return state
    msgs=[SystemMessage(content='Refine INGRES query slightly'), HumanMessage(content=state['rephrased_question'])]
    try:
        new_q = ChatOpenAI(model='gpt-4o-mini').invoke(ChatPromptTemplate.from_messages(msgs).format()).content.strip()
    except Exception:
        new_q = state['rephrased_question'] + ' (please be more specific)'
    state['rephrased_question'] = new_q
    state['rephrase_count'] = state.get('rephrase_count',0) + 1
    return state


def generate_answer(state: AgentState):
    # Use the RAG chain — pass history, context, question
    context_text = '\n\n---\n\n'.join([d.page_content for d in state.get('documents',[])])
    history = '\n'.join([m.content for m in state.get('messages',[])])
    try:
        res = rag_chain.invoke({'history': history, 'context': context_text, 'question': state['rephrased_question']})
        content = res.content.strip()
    except Exception as e:
        logger.exception('RAG chain failed: %s', e)
        # fallback: short summary of docs
        content = 'I found the following relevant excerpts:\n' + '\n---\n'.join([d.page_content[:400] for d in state.get('documents',[])])
    state.setdefault('messages',[]).append(AIMessage(content=content))
    return state


def gis_classifier(state: AgentState):
    last = state['messages'][-1].content if state.get('messages') else ''
    msgs=[SystemMessage(content='Need GIS map? Yes/No'), HumanMessage(content=f"{state['rephrased_question']}\n\n{last}")]
    try:
        r = ChatOpenAI(model='gpt-4o-mini').with_structured_output(GISRequest).invoke(ChatPromptTemplate.from_messages(msgs).format())
        state['needs_gis'] = r.needs_gis.strip().lower()=='yes'
        state['gis_data'] = {'location': r.location_data.strip()} if state['needs_gis'] else {}
    except Exception:
        # simple heuristic: look for words like 'map', 'location', 'district', 'lat', 'lon'
        q = (state.get('rephrased_question','') + ' ' + last).lower()
        state['needs_gis'] = any(w in q for w in ['map','location','district','lat','lon','show me','plot'])
        state['gis_data'] = {'location': state.get('rephrased_question','')} if state['needs_gis'] else {}
    return state


def generate_gis_map(state: AgentState):
    if not state.get('needs_gis'):
        return state
    # Very simple sample points. In production you would geocode the extracted location and pull data.
    pts = [
        {"lat":15.3, "lon":75.7, "val":200},
        {"lat":19.0, "lon":77.0, "val":250}
    ]
    m = folium.Map(location=[22,79], zoom_start=5)
    for p in pts:
        folium.CircleMarker([p['lat'], p['lon']], radius=5 + p['val']*0.01, popup=str(p['val'])).add_to(m)
    state['map_html'] = m._repr_html_()
    # append marker in last AI message
    last = state['messages'][-1]
    state['messages'][-1] = AIMessage(content=last.content + '\n\n[GIS Map Attached]')
    return state


def cannot_answer(state: AgentState):
    state.setdefault('messages',[]).append(AIMessage(content="I\'m sorry, I can\'t find that."))
    return state


def off_topic_response(state: AgentState):
    state.setdefault('messages',[]).append(AIMessage(content="I\'m sorry, I cannot answer this."))
    return state

# ---------------------------
# Cell 7: Build LangGraph workflow as you specified
# ---------------------------
cp = MemorySaver()
wf = StateGraph(AgentState)

# add nodes
nodes = [
    'question_rewriter','question_classifier','off_topic_response',
    'retrieve','retrieval_grader','generate_answer',
    'refine_question','cannot_answer',
    'gis_classifier','generate_gis_map'
]
for n in nodes:
    wf.add_node(n, globals()[n])

# edges following the diagram + GIS after generate_answer
wf.set_entry_point('question_rewriter')
wf.add_edge('question_rewriter','question_classifier')
wf.add_conditional_edges('question_classifier', on_topic_router, {'retrieve':'retrieve','off_topic_response':'off_topic_response'})
wf.add_edge('retrieve','retrieval_grader')
wf.add_conditional_edges('retrieval_grader', proceed_router, {'generate_answer':'generate_answer','refine_question':'refine_question','cannot_answer':'cannot_answer'})
wf.add_edge('refine_question','retrieve')
wf.add_edge('generate_answer','gis_classifier')
wf.add_conditional_edges('gis_classifier', lambda s: 'generate_gis_map' if s.get('needs_gis') else END, {'generate_gis_map':'generate_gis_map', END:END})
wf.add_edge('generate_gis_map', END)
wf.add_edge('cannot_answer', END)
wf.add_edge('off_topic_response', END)

graph = wf.compile(checkpointer=cp)

# ---------------------------
# Cell 8: Example run function to invoke the graph
# ---------------------------

def run_query(question_text: str):
    initial_state: AgentState = {
        'messages': [],
        'documents': [],
        'on_topic': '',
        'rephrased_question': '',
        'proceed_to_generate': False,
        'rephrase_count': 0,
        'question': HumanMessage(content=question_text),
        'needs_gis': False,
        'gis_data': {},
        'map_html': ''
    }
    state = graph.run(initial_state)
    return state

# ---------------------------
# Cell 9: Try a sample query (uncomment to run)
# ---------------------------
# result_state = run_query('What is the groundwater extraction rate in Pune district in 2023? Please show a map.')
# for msg in result_state['messages']:
#     print('\n---\n')
#     print(msg.content[:2000])
# # If map_html present, you can write it to an HTML file to view
# if result_state.get('map_html'):
#     with open('gis_map.html','w',encoding='utf-8') as f:
#         f.write(result_state['map_html'])
#     print('Wrote gis_map.html')

# ---------------------------
# End of notebook script
# ---------------------------


  from .autonotebook import tqdm as notebook_tqdm

For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from langchain_pinecone.vectorstores import Pinecone, PineconeVectorStore
INFO:__main__:Initializing Pinecone...


AttributeError: init is no longer a top-level attribute of the pinecone package.

Please create an instance of the Pinecone class instead.

Example:

    import os
    from pinecone import Pinecone, ServerlessSpec

    pc = Pinecone(
        api_key=os.environ.get("PINECONE_API_KEY")
    )

    # Now do stuff
    if 'my_index' not in pc.list_indexes().names():
        pc.create_index(
            name='my_index',
            dimension=1536,
            metric='euclidean',
            spec=ServerlessSpec(
                cloud='aws',
                region='us-west-2'
            )
        )



In [22]:
import os
from dotenv import load_dotenv
load_dotenv()

import logging
from typing import TypedDict, List

import pandas as pd
import folium

# LangChain / LLM / Embedding
from langchain.schema import Document
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from pydantic import BaseModel, Field

# Pinecone / LangChain-Pinecone
import pinecone
from langchain_pinecone import PineconeVectorStore

# LangGraph
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver



In [23]:
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", api_key=GEMINI_API_KEY)
embedding_function = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GEMINI_API_KEY)
chat_for_classify = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    api_key=os.getenv("GEMINI_API_KEY")
    )

In [None]:
# ---------------------------
# Cell 3: Pinecone init and vectorstore
# ---------------------------
import pinecone
from pinecone import ServerlessSpec
import time # Added this import

PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
PINECONE_ENV = os.getenv('PINECONE_ENV')
INDEX_NAME = os.getenv('PINECONE_INDEX', 'ingres-index')

if not PINECONE_API_KEY:
    raise RuntimeError('Please set PINECONE_API_KEY in your environment or .env file')

logger.info('Initializing Pinecone client...')
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)

# get embedding dimension from embedding_function
try:
    example_vec = embedding_function.embed_query('test')
    DIM = len(example_vec)
    logger.info('Using embedding dimension from function: %s', DIM)
except Exception as ex:
    logger.warning('Could not get embedding dim from embedding API (%s). Falling back to 384', ex)
    DIM = 384

# Check if index exists and delete it if it does
if INDEX_NAME in pc.list_indexes().names():
    logger.info('Deleting existing Pinecone index %s', INDEX_NAME)
    pc.delete_index(INDEX_NAME)

logger.info('Creating Pinecone index %s with dim=%s', INDEX_NAME, DIM)
pc.create_index(
    name=INDEX_NAME,
    dimension=DIM,
    metric='euclidean',
    spec=ServerlessSpec(cloud='aws', region='us-east-1')
)

# Wait for the index to be ready
# REMOVED: from pinecone import Index as PineconeIndex
# NEW: Use the 'pc' instance to get the index object
index = pc.Index(INDEX_NAME)
while not index.describe_index_stats().namespaces:
    time.sleep(1)

# Create langchain_pinecone wrapper
vectorstore = PineconeVectorStore(
    index=index,
    embedding=embedding_function,
    pinecone_api_key=PINECONE_API_KEY
)

retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={'k': 4})

INFO:__main__:Initializing Pinecone client...
INFO:__main__:Using embedding dimension from function: 768
INFO:__main__:Deleting existing Pinecone index sih2
INFO:__main__:Creating Pinecone index sih2 with dim=768


In [None]:
class AgentState(TypedDict):
    messages: List[BaseMessage]
    documents: List[Document]
    on_topic: str
    rephrased_question: str
    proceed_to_generate: bool
    rephrase_count: int
    question: HumanMessage
    needs_gis: bool
    gis_data: dict
    map_html: str

In [None]:
class GradeQuestion(BaseModel):
    score: str = Field(description='Is user asking INGRES data? Yes/No')

class GISRequest(BaseModel):
    needs_gis: str = Field(description='Need GIS map? Yes/No')
    location_data: str = Field(description='Extracted location info')

class GradeDocument(BaseModel):
    score: str = Field(description='Relevant? Yes/No')

In [27]:
def question_rewriter(state: AgentState):
    state.update({"documents":[], "on_topic":"", "rephrased_question":"", "proceed_to_generate":False,
                  "rephrase_count":0, "needs_gis":False, "gis_data":{}, "map_html":""})
    state.setdefault("messages", [])
    if state["question"] not in state["messages"]:
        state["messages"].append(state["question"])
    if len(state["messages"]) > 1:
        conv = state["messages"][:-1]
        q = state["question"].content
        msgs = [SystemMessage(content="Rephrase for INGRES retrieval.")] + conv + [HumanMessage(content=q)]
        try:
            new_q = ChatOpenAI(model='gpt-4o-mini').invoke(ChatPromptTemplate.from_messages(msgs).format()).content.strip()
        except Exception:
            new_q = q
        state["rephrased_question"] = new_q
    else:
        state["rephrased_question"] = state["question"].content
    return state

def question_classifier(state: AgentState):
    msgs = [SystemMessage(content="Is this about INGRES groundwater data? Yes/No"), HumanMessage(content=state["rephrased_question"])]
    try:
        result = ChatOpenAI(model='gpt-4o-mini').with_structured_output(GradeQuestion).invoke(ChatPromptTemplate.from_messages(msgs).format())
        state["on_topic"] = result.score.strip()
    except Exception:
        q = state.get('rephrased_question','').lower()
        state["on_topic"] = 'yes' if any(k in q for k in ['groundwater', 'ingres', 'ground water', 'water table', 'gw']) else 'no'
    return state

def on_topic_router(state: AgentState):
    return 'retrieve' if state['on_topic'].lower() == 'yes' else 'off_topic_response'

def retrieve(state: AgentState):
    try:
        hits = retriever.get_relevant_documents(state['rephrased_question'])
        state['documents'] = hits
    except Exception as e:
        logger.exception('Retriever failed: %s', e)
        state['documents'] = []
    return state

def retrieval_grader(state: AgentState):
    relevant = []
    for doc in state.get('documents',[]):
        try:
            msgs = [SystemMessage(content='Relevant to INGRES query?'), HumanMessage(content=f"{state['rephrased_question']}\n\n{doc.page_content}")]
            r = ChatOpenAI(model='gpt-4o-mini').with_structured_output(GradeDocument).invoke(ChatPromptTemplate.from_messages(msgs).format())
            if r.score.strip().lower() == 'yes':
                relevant.append(doc)
        except Exception:
            if any(tok in doc.page_content.lower() for tok in ['ground','ph','tds','district','year']):
                relevant.append(doc)
    state['documents'] = relevant
    state['proceed_to_generate'] = bool(relevant)
    return state

def proceed_router(state: AgentState):
    if state.get('proceed_to_generate'):
        return 'generate_answer'
    return 'cannot_answer' if state.get('rephrase_count',0) >= 2 else 'refine_question'

def refine_question(state: AgentState):
    if state.get('rephrase_count',0) >= 2:
        return state
    msgs = [SystemMessage(content='Refine INGRES query slightly'), HumanMessage(content=state['rephrased_question'])]
    try:
        new_q = ChatOpenAI(model='gpt-4o-mini').invoke(ChatPromptTemplate.from_messages(msgs).format()).content.strip()
    except Exception:
        new_q = state['rephrased_question'] + ' (please be more specific)'
    state['rephrased_question'] = new_q
    state['rephrase_count'] = state.get('rephrase_count',0) + 1
    return state

def generate_answer(state: AgentState):
    context_text = '\n\n---\n\n'.join([d.page_content for d in state.get('documents',[])])
    history = '\n'.join([m.content for m in state.get('messages',[])])
    try:
        res = ChatOpenAI(model='gpt-4o-mini').invoke({'history': history, 'context': context_text, 'question': state['rephrased_question']})
        content = res.content.strip()
    except Exception as e:
        logger.exception('RAG chain failed: %s', e)
        content = 'I found the following relevant excerpts:\n' + '\n---\n'.join([d.page_content[:400] for d in state.get('documents',[])])
    state.setdefault('messages',[]).append(AIMessage(content=content))
    return state

def gis_classifier(state: AgentState):
    last = state['messages'][-1].content if state.get('messages') else ''
    msgs = [SystemMessage(content='Need GIS map? Yes/No'), HumanMessage(content=f"{state['rephrased_question']}\n\n{last}")]
    try:
        r = ChatOpenAI(model='gpt-4o-mini').with_structured_output(GISRequest).invoke(ChatPromptTemplate.from_messages(msgs).format())
        state['needs_gis'] = r.needs_gis.strip().lower() == 'yes'
        state['gis_data'] = {'location': r.location_data.strip()} if state['needs_gis'] else {}
    except Exception:
        q = (state.get('rephrased_question','') + ' ' + last).lower()
        state['needs_gis'] = any(w in q for w in ['map','location','district','lat','lon','show me','plot'])
        state['gis_data'] = {'location': state.get('rephrased_question','')} if state['needs_gis'] else {}
    return state

def generate_gis_map(state: AgentState):
    if not state.get('needs_gis'):
        return state
    pts = [
        {"lat":15.3, "lon":75.7, "val":200},
        {"lat":19.0, "lon":77.0, "val":250}
    ]
    m = folium.Map(location=[22,79], zoom_start=5)
    for p in pts:
        folium.CircleMarker([p['lat'], p['lon']], radius=5 + p['val']*0.01, popup=str(p['val'])).add_to(m)
    state['map_html'] = m._repr_html_()
    last = state['messages'][-1]
    state['messages'][-1] = AIMessage(content=last.content + '\n\n[GIS Map Attached]')
    return state

def cannot_answer(state: AgentState):
    state.setdefault('messages',[]).append(AIMessage(content="I'm sorry, I can't find that."))
    return state

def off_topic_response(state: AgentState):
    state.setdefault('messages',[]).append(AIMessage(content="I'm sorry, I cannot answer this."))
    return state

# ---------------------------
# Build LangGraph workflow
# ---------------------------
cp = MemorySaver()
wf = StateGraph(AgentState)

nodes = [
    'question_rewriter','question_classifier','off_topic_response',
    'retrieve','retrieval_grader','generate_answer',
    'refine_question','cannot_answer',
    'gis_classifier','generate_gis_map'
]
for n in nodes:
    wf.add_node(n, globals()[n])

wf.set_entry_point('question_rewriter')
wf.add_edge('question_rewriter','question_classifier')
wf.add_conditional_edges('question_classifier', on_topic_router, {'retrieve':'retrieve','off_topic_response':'off_topic_response'})
wf.add_edge('retrieve','retrieval_grader')
wf.add_conditional_edges('retrieval_grader', proceed_router, {'generate_answer':'generate_answer','refine_question':'refine_question','cannot_answer':'cannot_answer'})
wf.add_edge('refine_question','retrieve')
wf.add_edge('generate_answer','gis_classifier')
wf.add_conditional_edges('gis_classifier', lambda s: 'generate_gis_map' if s.get('needs_gis') else END, {'generate_gis_map':'generate_gis_map', END:END})
wf.add_edge('generate_gis_map', END)
wf.add_edge('cannot_answer', END)
wf.add_edge('off_topic_response', END)

graph = wf.compile(checkpointer=cp)


In [None]:
# To use this code, run your full script first to compile the graph.
# Then you can run the following lines in your environment.

# Create an interactive loop to chat with the graph
# For a single session, you can use a fixed thread_id.
# For a real application, you would generate a unique ID per user.
thread_id = "my-ingres-chatbot-session-1"

while True:
    user_query = input("You: ")
    if user_query.lower() in ["exit", "quit", "q"]:
        print("Goodbye!")
        break

    try:
        # Pass the 'configurable' dictionary with the thread_id
        final_state = graph.invoke(
            {"question": HumanMessage(content=user_query)},
            config={"configurable": {"thread_id": thread_id}}
        )
        
        last_message = final_state['messages'][-1]
        print(f"Chatbot: {last_message.content}")
        
        if final_state.get('map_html'):
            print("\n[GIS map generated]\n")
            with open("gis_map.html", "w") as f:
                f.write(final_state['map_html'])
            print("Map saved as gis_map.html. Open it in your browser to view.")

    except Exception as e:
        print(f"An error occurred: {e}")
        break

ERROR:__main__:Retriever failed: (400)
Reason: Bad Request
HTTP response headers: HTTPHeaderDict({'Date': 'Sat, 06 Sep 2025 13:51:35 GMT', 'Content-Type': 'application/json', 'Content-Length': '102', 'Connection': 'keep-alive', 'x-pinecone-request-latency-ms': '406', 'x-pinecone-request-id': '8036904116552971277', 'x-envoy-upstream-service-time': '30', 'server': 'envoy'})
HTTP response body: {"code":3,"message":"Vector dimension 768 does not match the dimension of the index 384","details":[]}
Traceback (most recent call last):
  File "C:\Users\prita\AppData\Local\Temp\ipykernel_20396\1266092739.py", line 35, in retrieve
    hits = retriever.get_relevant_documents(state['rephrased_question'])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    return wrapped(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\prita\AppData\Roaming\Python\Python311\site-packages\langchain_core\retrievers.py", line 414, in get_relevant_documents
    return self.i

Chatbot: I'm sorry, I can't find that.
Goodbye!
