# Setup

In [None]:
import langchain
import langchain_huggingface
import langchain_google_genai
import langchain_qdrant
import langchain_community
import langgraph
import sqlite3

In [None]:
%%capture
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re
import torch
from tqdm import tqdm

from sklearn.utils import resample
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

from fastapi import FastAPI
import nest_asyncio
import uvicorn
from pydantic import BaseModel

In [None]:
data_path = 'dataset/sampled_data.csv'

df = pd.read_csv(data_path)

df['type'] = df['type_str'].apply(lambda x: x.split(',') if isinstance(x, str) else [])

df.head(1)

# Components

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001")

# Embedding Model

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings

model_kwargs = {'trust_remote_code': True}
embeddings = HuggingFaceEmbeddings(model_name="KanisornPutta/TrentIsNotLeavingBERT",model_kwargs=model_kwargs)

In [None]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

model = HuggingFaceCrossEncoder(model_name="Pongsasit/mod-th-cross-encoder-minilm")

## Vector Database
Embed your documents in a vector database that supports hybrid search. Also set the retrieval mode to hybrid search.

We will use `QdrantVectorStore` [Learn more here](https://python.langchain.com/api_reference/qdrant/qdrant/langchain_qdrant.qdrant.QdrantVectorStore.html#langchain_qdrant.qdrant.QdrantVectorStore). (You can use any vector DB that can do hybrid search)

In [None]:
from tqdm import tqdm

In [None]:
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, SparseVectorParams, VectorParams

sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")

import os

# Create a Qdrant client for local storage
# client = QdrantClient(":memory:")
client = QdrantClient(
    url=os.getenv("QDRANT_ENDPOINT"),
    api_key=os.getenv("QDRANT_API_KEY", None),
)

collection_name = "test_dsde"

# Create a collection with both dense and sparse vectors
# client.create_collection(
#     collection_name="test_de",
#     vectors_config={"dense": VectorParams(size=768, distance=Distance.COSINE)},
#     sparse_vectors_config={
#         "sparse": SparseVectorParams(index=models.SparseIndexParams(on_disk=False))
#     },
# )



qdrant = QdrantVectorStore(
    client=client,
    collection_name=collection_name,
    embedding=embeddings,
    sparse_embedding=sparse_embeddings,
    retrieval_mode=RetrievalMode.HYBRID,
    vector_name="dense",
    sparse_vector_name="sparse",
)


query = "น้ำท่วม"
found_docs = qdrant.similarity_search(query)
found_docs

# Retrievers


In [None]:
retriever = qdrant.as_retriever(search_kwargs={"k": 10})

In [None]:
reranker = CrossEncoderReranker(model=model, top_n=10)
reranked_retriever = ContextualCompressionRetriever(
    base_compressor=reranker , base_retriever=retriever
)

Take a subset of the dataset to evaluate the MRR of the retrievers.

In [None]:
test_query = "เขตจตุจักรมีปัญหาเกี่ยวกับอะไรมากที่สุด?"

In [None]:
test_docs = retriever.get_relevant_documents(test_query)
print(f'query : {test_query}')
print('-'*30)
for doc in test_docs[:10] :
  print(f'- {doc.page_content}')

In [None]:
test_docs_reranked = reranked_retriever.get_relevant_documents(test_query)
print(f'query : {test_query}')
print('-'*30)
for doc in test_docs_reranked[:10] :
  print(f'- {doc.page_content}')

In [None]:
for doc in test_docs_reranked[:10] :
  print(f'- {doc.metadata.get("ticket_id")}')

# Retrieval Evaluation
Coming soon



In [None]:
from tqdm import tqdm

# Agentic RAG


In [None]:
from langgraph.graph import MessagesState, StateGraph

graph_builder = StateGraph(MessagesState)

In [None]:
from langchain_core.tools import tool

@tool(response_format="content")
def retrieve(query: str):
    """Retrieve information related to a query from a vector database of Traffy Fondue Dataset."""
    retrieved_docs = reranked_retriever.get_relevant_documents(query)
    serialized = "\n\n".join(
        (

            f"ticket_id: {doc.metadata.get('ticket_id')}\n"
            f"ประเภท: {doc.metadata.get('problem_type')}\n"
            f"สถานที่: {doc.metadata.get('address')}\n"
            f"รายละเอียด: {doc.page_content}"
        )
        for doc in retrieved_docs
    )
    return serialized


In [None]:
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableConfig

# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
    """Generate tool call for retrieval or respond."""
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(state["messages"])
    # MessagesState appends messages to state instead of overwriting
    return {"messages": [response]}


# Step 2: Execute the retrieval.
# The ToolNode is roughly analogous to:

# tools_by_name = {tool.name: tool for tool in tools}
# def tool_node(state: dict):
#     result = []
#     for tool_call in state["messages"][-1].tool_calls:
#         tool = tools_by_name[tool_call["name"]]
#         observation = tool.invoke(tool_call["args"])
#         result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
#     return {"messages": result}

tools = ToolNode([retrieve])


# Step 3: Generate a response using the retrieved content.
def generate(state: MessagesState):
    """Generate answer based on retrieved problem reports."""
    # Get retrieved ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format tool messages for structured data
    structured_entries = []
    for tool_msg in tool_messages:
        content = tool_msg.content  # This is the string returned from the tool
        structured_entries.append(content)

    docs_content = "\n\n".join(structured_entries)

    # System message prompt
    system_message_content = (
        "คุณเป็นผู้ช่วยที่เชี่ยวชาญในการตอบคำถามจากข้อมูลปัญหาที่ถูกรายงานผ่านระบบแจ้งปัญหา Traffy Fondue "
        "ข้อมูลแต่ละรายการจะประกอบด้วยประเภทของปัญหา, สถานที่ และรายละเอียดของปัญหา "
        "กรุณาตอบคำถามจากข้อมูลด้านล่าง ถ้าคุณไม่พบคำตอบที่ตรง ให้ตอบว่า 'ไม่พบข้อมูลที่เกี่ยวข้อง' และกล่าวถึงข้อมูลที่ใกล้เขียง เช่น สถานที่ใกล้เคียง หรือ ข้อมูลประเภทเดียวกัน"
        "เลือกเพียงข้อมูลที่มีความเกี่ยวข้องกับ คำถาม และ ตอบ ticket_id ที่เกี่ยวของมา หลังจบคำอธิบาย ใน format ticket_id : _id1, _id2, ..."
        "สรุปคำตอบออกมา และกล่าวถึงข้อมูลที่น่าสนใจ\n\n"
        f"{docs_content}"
    )

    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]

    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Run model
    response = llm.invoke(prompt)
    return {"messages": [response]}


In [None]:
from langgraph.graph import END
from langgraph.prebuilt import ToolNode, tools_condition

graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

graph = graph_builder.compile()

In [None]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

# Specify an ID for the thread
config = {"configurable": {"thread_id": "abc123"}}

## Test Query

In [None]:
input_message = "What do u know"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()

In [None]:
input_message = "เขตจตุจักรมีปัญหาเกี่ยวกับอะไรมากที่สุด?"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()

In [None]:
input_message = "ขยะแถวปทุมวัน"

for step in graph.stream(
    {"messages": [
        {
          "role": "user",
          "content": input_message,
        }
    ]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()

In [None]:
input_message = "ถ้าเกิดปัญหาทางเท้าเสียหายในเขตบางนา โมเดลจะสามารถหาข้อมูลที่เกี่ยวข้องได้หรือไม่?"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()


In [None]:
content = step["messages"][-1].content
print(content)

In [None]:
input_message = "พื้นที่ไหนมักเกิดปัญหาเกี่ยวกับแสงไม่มากพอ?"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()


In [None]:
content = step["messages"][-1].content
print(content)

# Fast API

In [None]:
nest_asyncio.apply()  # Patch asyncio for Jupyter

from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],      # allow all origins
    allow_credentials=True,
    allow_methods=["*"],      # allow all methods (GET, POST, etc.)
    allow_headers=["*"],      # allow all headers
)

class QueryRequest(BaseModel):
    query: str

    


In [None]:
# Get LLM response to the query.

@app.post("/query/content")
def query_content(req: QueryRequest):
    """Query the graph with a user message."""
    for step in graph.stream(
        {"messages": [{"role": "user", "content": req.query}]},
        stream_mode="values",
        config=config,
    ):
        pass
    return {"content": step["messages"][-1].content}

In [None]:
# Get ticket_id

@app.post("/query/ticket_id")
def query_ticket_id(req: QueryRequest):
    """Query the graph with a user message."""
    retrieved_docs = reranked_retriever.get_relevant_documents(req.query)
    return  {"data" : [doc.metadata.get('ticket_id') for doc in retrieved_docs]}

In [None]:
# test_query = "เขตจตุจักรมีปัญหาเกี่ยวกับอะไรมากที่สุด?"
# retrieved_docs = reranked_retriever.get_relevant_documents(test_query)
# ticket_ids = [doc.metadata.get('ticket_id') for doc in retrieved_docs]
# print(ticket_ids)

# data = []
# keys_to_extract = ['ticket_id', 'type', 'organization', 'timestamp', 'state', 'star', 'photo']
# for ticket_id in ticket_ids:
#     record = df[df['ticket_id'] == ticket_id].to_dict(orient='records')[0]
#     filtered_record = {key: record[key] for key in keys_to_extract if key in record}
#     data.append(filtered_record)
    
# for record in data:
#     print(record)

In [None]:
# data[0]

In [None]:
# for key, value in data[0].items():
#     print(f"{key}: {type(value)}")

In [None]:
# @app.post("/query/data")
# def query_data(req: QueryRequest):
#     """Query the graph with a user message."""
#     retrieved_docs = reranked_retriever.get_relevant_documents(req.query)
#     ticket_ids = [doc.metadata.get('ticket_id') for doc in retrieved_docs]
#     data = []

#     # keys_to_extract = ['ticket_id', 'type', 'organization', 'timestamp', 'star' ,'state', 'photo']
#     for ticket_id in ticket_ids:
#         record = df[df['ticket_id'] == ticket_id].to_dict(orient='records')[0]
#         # record = {key: raw_record[key] for key in keys_to_extract if key in raw_record}
#         for key, value in record.items():
#             if key == "type":  # Skip the 'type' key
#                 continue
            
#             # if key in ['star', 'photo']:  # Specify columns of interest
#                 # Check if the value is scalar (not an array or series)
#             if isinstance(value, (int, float)) and (pd.isna(value) or np.isinf(value)):
#                     record[key] = None  # Replace NaN or Inf with None
        
#         data.append(record)
    
#     return {"data": data}

In [None]:
DB_PATH = "./data/traffy.db"

In [None]:
def get_tickets_by_ids(ticket_ids):
    if not ticket_ids:
        return []

    placeholders = ','.join('?' for _ in ticket_ids)
    query = f"SELECT ticket_id, type, organization, star, state, photo, comment, timestamp FROM traffy WHERE ticket_id IN ({placeholders})"

    conn = sqlite3.connect(DB_PATH)
    conn.row_factory = sqlite3.Row
    cursor = conn.cursor()
    cursor.execute(query, ticket_ids)
    rows = cursor.fetchall()
    conn.close()
    
    res = []
    for row in rows:
        row = dict(row)
        row["type"] = row["type"].split(",") if isinstance(row["type"], str) else []
        row["organization"] = row["organization"].split(",")
        res.append(row)

    return res

In [None]:
@app.post("/query/data")
def query_data(req: QueryRequest):
    """Query the data related to user message."""
    retrieved_docs = reranked_retriever.get_relevant_documents(req.query)
    ticket_ids = [doc.metadata.get('ticket_id') for doc in retrieved_docs]
    data = get_tickets_by_ids(ticket_ids)
    return {"data": data}

# Run App

In [None]:
uvicorn.run(app, host="127.0.0.1", port=8000)