### Imports Libraries


In [None]:
# --- Weaviate Imports ---
import os
import weaviate
from weaviate.classes.init import Auth
from weaviate.classes.query import MetadataQuery
from weaviate.collections import Collection
from weaviate.classes.query import Filter
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

# --- LangChain Imports ---
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_classic.chains.retrieval_qa.base import RetrievalQA

from typing import List

# --- Th√™m Import Predictor ---
from document_classification.predict import DocumentPredictor

### Setup Reciprocal Rank Fusion (RRF)

In [None]:
# H√†m RRF ƒë√£ ƒëi·ªÅu ch·ªânh ƒë·ªÉ ho·∫°t ƒë·ªông v·ªõi LangChain (h·ª£p nh·∫•t c√°c ƒë·ªëi t∆∞·ª£ng Weaviate)
def reciprocal_rank_fusion(results, k=60):
    fused_scores = {}
    for result_list in results:
        for rank, obj in enumerate(result_list):
            uuid = str(obj.uuid)
            score = 1.0 / (k + rank + 1)
            
            if uuid not in fused_scores:
                fused_scores[uuid] = {"score": 0.0, "object": obj}
            fused_scores[uuid]["score"] += score
            
    sorted_fused_results = sorted(
        fused_scores.values(), 
        key=lambda x: x["score"], 
        reverse=True
    )
    return sorted_fused_results

### Custom HybridRetriever (Vector + BM25)

In [None]:
# Custom Retriever ƒë·ªÉ th·ª±c hi·ªán Hybrid Search (Vector + BM25)
class HybridRetriever(BaseRetriever):
    """Retriever t√πy ch·ªânh s·ª≠ d·ª•ng Weaviate near_vector v√† bm25."""
    history_collection: Collection 
    embed_model: SentenceTransformer
    k: int = 4 # S·ªë l∆∞·ª£ng t√†i li·ªáu s·∫Ω tr·∫£ v·ªÅ cho LLM

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:

        # 0. PH√ÇN LO·∫†I T√ÄI LI·ªÜU ===
        predicted_period = DocumentPredictor.predict(query)
        print(f"\n=== [DEBUG 0] CH·ª¶ ƒê·ªÄ D·ª∞ ƒêO√ÅN: {predicted_period} ===")
        where_filter = Filter.by_property("period").equal(predicted_period)
        
        # 1. T·∫°o Vector t·ª´ m√¥ h√¨nh c·ª•c b·ªô
        query_vector = self.embed_model.encode(query).tolist() 

        # 2. Vector Search (Ng·ªØ nghƒ©a)
        print("\n=== [DEBUG 1] K·∫æT QU·∫¢ TRUY XU·∫§T NG·ªÆ NGHƒ®A (VECTOR SEARCH) ===")
        vector_results = self.history_collection.query.near_vector(
            near_vector=query_vector,
            limit=self.k, # L·∫•y nhi·ªÅu k·∫øt qu·∫£ h∆°n cho RRF
            return_properties=["context", "period"],
            return_metadata=MetadataQuery(distance=True),
            # filters=where_filter # T√™n tham s·ªë l√† 'filters'
        )
        for i, obj in enumerate(vector_results.objects):
            context = obj.properties.get("context", "")
            distance = obj.metadata.distance if obj.metadata and obj.metadata.distance else "N/A"
            print(f"[{i+1}] (Distance: {distance:.4f}) [UUID: {obj.uuid}] - Context: {context[:100]}...")


        # 3. BM25 Search (T·ª´ kh√≥a)
        print("\n=== [DEBUG 2] K·∫æT QU·∫¢ TRUY XU·∫§T T·ª™ KH√ìA (BM25 SEARCH) ===")
        bm25_results = self.history_collection.query.bm25(
            query=query,
            limit=self.k, # L·∫•y nhi·ªÅu k·∫øt qu·∫£ h∆°n
            return_properties=["context", "period"],
            return_metadata=MetadataQuery(),
            # filters=where_filter # T√™n tham s·ªë l√† 'filters'
        )
        for i, obj in enumerate(bm25_results.objects):
            context = obj.properties.get("context", "")
            print(f"[{i+1}] [UUID: {obj.uuid}] - Context: {context[:100]}...")


        # 4. H·ª£p nh·∫•t b·∫±ng RRF
        fused_objects = reciprocal_rank_fusion([
            vector_results.objects,
            bm25_results.objects
        ], k=60)
        
        print("\n=== [DEBUG 3] K·∫æT QU·∫¢ H·ª¢P NH·∫§T RRF (TOP 8) ===")
        # In ra 8 k·∫øt qu·∫£ h√†ng ƒë·∫ßu sau khi h·ª£p nh·∫•t ƒë·ªÉ xem RRF ho·∫°t ƒë·ªông
        for i, item in enumerate(fused_objects[:4]):
            obj = item["object"]
            context = obj.properties.get("context", "")
            score = item["score"]
            print(f"[{i+1}] (RRF Score: {score:.4f}) [UUID: {obj.uuid}] - Context: {context[:100]}...")

        # 5. Chuy·ªÉn ƒë·ªïi c√°c ƒë·ªëi t∆∞·ª£ng Weaviate th√†nh Document c·ªßa LangChain
        # Ch·ªâ l·∫•y top K (4 documents)
        documents = []
        for item in fused_objects[:self.k]:
            obj = item["object"]
            # Chuy·ªÉn ƒë·ªïi thu·ªôc t√≠nh Weaviate th√†nh metadata c·ªßa LangChain Document
            metadata = {
                "period": obj.properties.get("period", "N/A"),
                "source_uuid": str(obj.uuid),
                "rrf_score": item["score"]
            }
            documents.append(
                Document(
                    page_content=obj.properties.get("context", ""),
                    metadata=metadata
                )
            )
        
        print(f"\n<<< TR·∫¢ V·ªÄ {len(documents)} DOCUMENT CHO LLM (ƒê√£ L·ªçc theo {predicted_period})>>>")
        return documents

### Setup Environment Variable

In [None]:
# T·∫£i Bi·∫øn M√¥i tr∆∞·ªùng v√† Kh·ªüi t·∫°o Model
load_dotenv() 

# L·∫•y th√¥ng tin k·∫øt n·ªëi
WEAVIATE_URL = os.environ.get("WEAVIATE_URL")
WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") 
COLLECTION_NAME = "History" 

# S·ª≠ d·ª•ng ƒë∆∞·ªùng d·∫´n m√¥ h√¨nh c·ª•c b·ªô c·ªßa b·∫°n
LOCAL_MODEL_PATH = os.environ.get("EMBEDDING_MODEL_NAME")

### Run Chatbot

In [None]:
client = None
try:
    # Kh·ªüi t·∫°o m√¥ h√¨nh nh√∫ng c·ª•c b·ªô
    embed_model = SentenceTransformer(LOCAL_MODEL_PATH)

    # 2. Thi·∫øt l·∫≠p K·∫øt n·ªëi Weaviate
    client = weaviate.connect_to_weaviate_cloud(
        cluster_url=WEAVIATE_URL,
        auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
    )
    
    if not client.is_ready():
        raise ConnectionError("Weaviate client is not ready.")
    
    history_collection = client.collections.get(COLLECTION_NAME)

    # 3. Kh·ªüi t·∫°o LLM (S·ª≠ d·ª•ng Gemini)
    llm = ChatGoogleGenerativeAI(
        model="gemini-2.5-flash", 
        google_api_key=GEMINI_API_KEY,
        temperature=0.1,
    )

    # 4. Kh·ªüi t·∫°o Custom Hybrid Retriever
    retriever = HybridRetriever(
        history_collection=history_collection,
        embed_model=embed_model,
        k=4 # LLM s·∫Ω nh·∫≠n 4 context documents t·ª´ Hybrid Search
    )

    # 5. Thi·∫øt l·∫≠p Prompt (H∆∞·ªõng d·∫´n LLM)
    template = """
    B·∫°n l√† m·ªôt tr·ª£ l√Ω th√¥ng minh v·ªÅ L·ªãch s·ª≠ Vi·ªát Nam. 
    H√£y s·ª≠ d·ª•ng c√°c ƒëo·∫°n ng·ªØ c·∫£nh (Context) ƒë∆∞·ª£c cung c·∫•p d∆∞·ªõi ƒë√¢y ƒë·ªÉ tr·∫£ l·ªùi c√¢u h·ªèi m·ªôt c√°ch chi ti·∫øt v√† trung th·ª±c. 
    
    Ng·ªØ c·∫£nh: {context}

    C√¢u h·ªèi: {question}

    C√¢u tr·∫£ l·ªùi chi ti·∫øt:
    """
    RAG_PROMPT_CUSTOM = PromptTemplate.from_template(template)

    # N·∫øu th√¥ng tin trong ng·ªØ c·∫£nh kh√¥ng ƒë·ªß ho·∫∑c kh√¥ng li√™n quan, h√£y tr·∫£ l·ªùi "T√¥i kh√¥ng th·ªÉ tr·∫£ l·ªùi c√¢u h·ªèi n√†y d·ª±a tr√™n th√¥ng tin ƒë∆∞·ª£c cung c·∫•p."

    # 6. T·∫°o RetrievalQA Chain
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff", 
        retriever=retriever,
        return_source_documents=True,
        chain_type_kwargs={"prompt": RAG_PROMPT_CUSTOM}
    )

    # 7. Ch·∫°y Truy v·∫•n RAG
    H·ªéI = "Chi·∫øn d·ªãch ƒëi·ªán bi√™n ph·ªß x·∫£y ra v√†o nƒÉm n√†o?"
    
    print(f"\n‚ùì C√¢u h·ªèi: {H·ªéI}")
    print("---------------------------------------")
    
    # S·ª≠ d·ª•ng .invoke()
    result = qa_chain.invoke({"query": H·ªéI}) 

    # 8. In k·∫øt qu·∫£
    print("\n---------------------------------------")
    print("ü§ñ C√¢u tr·∫£ l·ªùi t·ª´ Gemini:")
    print(result["result"])
    print("\nüìö Ngu·ªìn Context ƒë∆∞·ª£c s·ª≠ d·ª•ng (ƒê∆∞·ª£c ch·ªçn b·ªüi RRF):")
    for doc in result["source_documents"]:
        print(f"- [Ngu·ªìn: {doc.metadata.get('period', 'N/A')}, RRF Score: {doc.metadata.get('rrf_score', 0.0):.4f}] {doc.page_content[:100]}...")

except Exception as e:
    print(f"\n‚ùå ƒê√£ x·∫£y ra l·ªói trong qu√° tr√¨nh RAG: {e}")

finally:
    if client and client.is_connected():
        client.close()