In [None]:
# Updated code for enhanced RAG system with semantic chunking, metadata tagging, and hybrid retrieval

import os
import re
import json
import hashlib
import fitz
import uuid
import pytesseract
from PIL import Image
from dotenv import load_dotenv
from typing import List
from langchain_core.documents import Document
from langchain.text_splitter import SemanticChunker
from langchain_community.vectorstores import FAISS
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain.chat_models import AzureChatOpenAI
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory

# Load env
load_dotenv()

# Tesseract path (Windows)
pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"

# Azure config
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT")
LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")

# Paths
PDF_DIR = "./source_docs"
CHAT_HISTORY_DIR = "chat_history"
FAISS_INDEX_PATH = "./store"
HASH_STORE_PATH = "./hashes/index_hashes.txt"
TEXT_CACHE_DIR = "./text_cache"

# Embeddings
embeddings = AzureOpenAIEmbeddings(
    azure_deployment=EMBEDDING_DEPLOYMENT,
    openai_api_key=AZURE_OPENAI_API_KEY,
    openai_api_version=AZURE_OPENAI_API_VERSION,
    azure_endpoint=AZURE_OPENAI_ENDPOINT
)

def file_hash(filepath):
    h = hashlib.sha256()
    with open(filepath, 'rb') as f:
        while chunk := f.read(8192):
            h.update(chunk)
    return h.hexdigest()

def load_existing_hashes():
    if not os.path.exists(HASH_STORE_PATH): return set()
    with open(HASH_STORE_PATH, "r") as f:
        return set(line.strip() for line in f)

def save_hashes(hashes: set):
    with open(HASH_STORE_PATH, "w") as f:
        for h in sorted(hashes): f.write(f"{h}\n")

def extract_text_with_ocr(pdf_path):
    filename = os.path.basename(pdf_path)
    md_path = os.path.join(TEXT_CACHE_DIR, filename.replace(".pdf", ".md"))
    if os.path.exists(md_path):
        with open(md_path, "r", encoding="utf-8") as f: return f.read()

    doc = fitz.open(pdf_path)
    full_text = ""
    for page_num in range(len(doc)):
        page = doc.load_page(page_num)
        text = page.get_text()
        full_text += f"\n\nPage {page_num+1}:\n{text.strip()}"

        if not text.strip():
            try:
                pix = page.get_pixmap(dpi=300)
                image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
                ocr_text = pytesseract.image_to_string(image)
                full_text += f"\nOCR:\n{ocr_text.strip()}"
            except Exception as e:
                print(f"OCR failed on page {page_num + 1}: {e}")

    os.makedirs(TEXT_CACHE_DIR, exist_ok=True)
    with open(md_path, "w", encoding="utf-8") as f:
        f.write(full_text)
    return full_text

def enrich_metadata(filename):
    year_match = re.search(r"(20\\d{2})", filename)
    return {
        "source": filename,
        "year": year_match.group(1) if year_match else "Unknown",
        "fund": "UTF",
        "doc_type": "Annual Report"
    }

def update_faiss_index(embeddings):
    existing_hashes = load_existing_hashes()
    new_hashes = set()
    new_docs = []

    for filename in os.listdir(PDF_DIR):
        if not filename.lower().endswith(".pdf"): continue
        pdf_path = os.path.join(PDF_DIR, filename)
        h = file_hash(pdf_path)
        if h in existing_hashes:
            print(f"Skipping already indexed: {filename}")
            continue

        text = extract_text_with_ocr(pdf_path)
        metadata = enrich_metadata(filename)
        new_docs.append(Document(page_content=text, metadata=metadata))
        new_hashes.add(h)

    if not new_docs:
        return FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)

    splitter = SemanticChunker(embeddings, chunk_size=1000)
    chunks = splitter.split_documents(new_docs)

    if os.path.exists(FAISS_INDEX_PATH + ".faiss"):
        vs = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
        vs.add_documents(chunks)
    else:
        vs = FAISS.from_documents(chunks, embeddings)
    vs.save_local(FAISS_INDEX_PATH)

    updated_hashes = existing_hashes.union(new_hashes)
    save_hashes(updated_hashes)
    return vs

def load_or_create_vectorstore(embeddings):
    return update_faiss_index(embeddings)

class PersistentChatMessageHistory(ChatMessageHistory):
    def __init__(self, session_id):
        super().__init__()
        self._session_id = session_id
        self._file_path = os.path.join(CHAT_HISTORY_DIR, f"{session_id}.json")
        self.load()
    def load(self):
        if os.path.exists(self._file_path):
            with open(self._file_path, "r", encoding="utf-8") as f:
                raw = json.load(f)
                self.messages = [self._dict_to_message(msg) for msg in raw]
    def save(self):
        with open(self._file_path, "w", encoding="utf-8") as f:
            json.dump([self._message_to_dict(m) for m in self.messages], f, indent=2)
    def add_message(self, m):
        super().add_message(m)
        self.save()
    def _message_to_dict(self, m):
        return {"type": m.type, "content": m.content}
    def _dict_to_message(self, d):
        from langchain_core.messages import HumanMessage, AIMessage
        return HumanMessage(content=d["content"]) if d["type"] == "human" else AIMessage(content=d["content"])

def setup_rag_chain_with_history(session_id, embeddings):
    vectorstore = load_or_create_vectorstore(embeddings)
    retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 30})

    llm = AzureChatOpenAI(
        deployment_name=LLM_DEPLOYMENT,
        api_key=AZURE_OPENAI_API_KEY,
        azure_endpoint=AZURE_OPENAI_ENDPOINT,
        api_version=AZURE_OPENAI_API_VERSION,
        temperature=0.3
    )

    prompt = ChatPromptTemplate.from_messages([
        ("system",
         "You are a development results analyst AI assistant.\n\n"
         "Extract and summarize **results stories** from UTF annual reports.\n\n"
         "**📌 Title:** Bold, 5–10 words\n"
         "**📝 Summary:** 4–6 sentences\n"
         "**🗂 Metadata:** Region, Sector, Donor, Source Document & Page\n\n"
         "Respond with structured stories if available, else fallback to answering user's question from context.\n"
         "Never invent facts.\n\nContext:\n{context}"),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}")
    ])

    doc_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
    rag_chain = create_retrieval_chain(retriever, doc_chain)

    return RunnableWithMessageHistory(
        rag_chain,
        lambda session_id: PersistentChatMessageHistory(session_id),
        input_messages_key="input",
        history_messages_key="chat_history",
        output_messages_key="answer"
    )

def run_query(session_id: str, question: str):
    rag_chain = setup_rag_chain_with_history(session_id, embeddings)
    result = rag_chain.invoke({"input": question}, config={"configurable": {"session_id": session_id}})
    return result["answer"]

# Sample run
if __name__ == "__main__":
    sid = f"session_{uuid.uuid4().hex[:8]}"
    q = "Give me two examples of how the MDTF supported private sector job creation in 2020"
    print("\nQuery:", q)
    print("\nAnswer:\n", run_query(sid, q))
