In [None]:
"""
Prototype: Payroll Gen AI Query Assistant (RAG over Payroll PDFs)

Features:
- Upload payroll PDF
- Store content in Chroma vector DB
- Ask questions and get context-aware answers
- LangSmith enabled for observability

Requirements (pip):
    fastapi
    uvicorn
    langchain
    langchain-openai
    langchain-community
    chromadb
    sentence-transformers
    pypdf

Env variables:
    OPENAI_API_KEY=...
    LANGCHAIN_TRACING_V2=true
    LANGCHAIN_API_KEY=...
    LANGCHAIN_PROJECT=Payroll-GenAI
"""

import os
import shutil
from pathlib import Path
from typing import Optional

from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA

# -------------------------------------------------------------------
# Configuration
# -------------------------------------------------------------------

BASE_DIR = Path(__file__).parent
UPLOAD_DIR = BASE_DIR / "uploads"
VECTOR_DB_DIR = BASE_DIR / "chroma_db"

UPLOAD_DIR.mkdir(exist_ok=True)
VECTOR_DB_DIR.mkdir(exist_ok=True)

# LangSmith (set via env; here just for clarity)
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_API_KEY"] = "your_langsmith_key"
# os.environ["LANGCHAIN_PROJECT"] = "Payroll-GenAI"

# -------------------------------------------------------------------
# Global objects (simple prototype style)
# -------------------------------------------------------------------

app = FastAPI(title="Payroll Gen AI Query Assistant")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Embeddings (Hugging Face â€“ light & good for demo)
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

# Vector store (Chroma)
vectorstore: Optional[Chroma] = None
qa_chain: Optional[RetrievalQA] = None


def init_vectorstore() -> Chroma:
    """Initialize or reload Chroma vector store."""
    return Chroma(
        collection_name="payroll_docs",
        embedding_function=embeddings,
        persist_directory=str(VECTOR_DB_DIR),
    )


def build_qa_chain(store: Chroma) -> RetrievalQA:
    """Build RetrievalQA chain on top of Chroma + OpenAI LLM."""
    llm = ChatOpenAI(
        model="gpt-4o-mini",
        temperature=0.0,
    )

    retriever = store.as_retriever(
        search_kwargs={"k": 4}
    )

    chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        chain_type="stuff",
        return_source_documents=True,
        verbose=True,
    )
    return chain


# Initialize vector store if already exists
if any(VECTOR_DB_DIR.iterdir()):
    vectorstore = init_vectorstore()
    qa_chain = build_qa_chain(vectorstore)


# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------

def save_uploaded_file(upload_file: UploadFile) -> Path:
    """Save uploaded PDF to disk and return path."""
    if not upload_file.filename.lower().endswith(".pdf"):
        raise HTTPException(status_code=400, detail="Only PDF files are supported.")

    file_path = UPLOAD_DIR / upload_file.filename
    with file_path.open("wb") as buffer:
        shutil.copyfileobj(upload_file.file, buffer)

    return file_path


def index_pdf_to_chroma(pdf_path: Path) -> None:
    """Load PDF, split, embed, and store in Chroma."""
    global vectorstore, qa_chain

    loader = PyPDFLoader(str(pdf_path))
    documents = loader.load()

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=800,
        chunk_overlap=150,
        separators=["\n\n", "\n", ".", "!", "?", " ", ""],
    )
    chunks = splitter.split_documents(documents)

    if vectorstore is None:
        vectorstore = Chroma(
            collection_name="payroll_docs",
            embedding_function=embeddings,
            persist_directory=str(VECTOR_DB_DIR),
        )

    vectorstore.add_documents(chunks)
    vectorstore.persist()

    # Rebuild QA chain whenever new docs are added
    qa_chain = build_qa_chain(vectorstore)


# -------------------------------------------------------------------
# API Endpoints
# -------------------------------------------------------------------

@app.get("/health")
async def health_check():
    return {"status": "ok", "message": "Payroll Gen AI Assistant is running"}


@app.post("/upload-pdf")
async def upload_pdf(file: UploadFile = File(...)):
    """
    Upload a payroll/HR policy PDF.
    The content will be indexed into Chroma and used for Q&A.
    """
    file_path = save_uploaded_file(file)
    index_pdf_to_chroma(file_path)
    return {
        "status": "success",
        "filename": file.filename,
        "message": "PDF indexed successfully into vector store.",
    }


@app.post("/ask")
async def ask_question(question: str = Form(...)):
    """
    Ask a question related to payroll/HR.
    The answer is generated using RAG over uploaded documents.
    """
    global qa_chain

    if qa_chain is None:
        raise HTTPException(
            status_code=400,
            detail="No documents indexed yet. Please upload a payroll PDF first.",
        )

    result = qa_chain.invoke({"query": question})
    answer = result["result"]
    sources = [
        {"source": doc.metadata.get("source", ""), "page": doc.metadata.get("page", None)}
        for doc in result.get("source_documents", [])
    ]

    return {
        "question": question,
        "answer": answer,
        "sources": sources,
    }


# -------------------------------------------------------------------
# Run command:
#   uvicorn main:app --reload
# -------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn

    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
