In [None]:
import os
import tempfile
import time
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate  # 新增

# ----------------- CONFIG -----------------
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "APIKEY")  # set in environment for security
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

# ----------------- LLM INIT -----------------
llm = ChatOpenAI(model_name="gpt-4o", temperature=0.5, openai_api_key=OPENAI_API_KEY)

# ----------------- SYSTEM PROMPT -----------------
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template="""
You are an expert document-based assistant, helping users explore and understand the content of the uploaded PDFs.

1. Always ground your answers in the provided PDF context. Cite page numbers in parentheses, e.g. (page 10).
2. If the document does not contain enough information to answer the question, respond:
   “I don’t have enough information in the document to answer that. Could you please clarify or ask a different question?”
3. You may ask follow-up questions to better understand the user’s intent, but never introduce facts not supported by the PDF.
4. Keep your tone friendly, concise, and professional.


Context:
{context}

Question:
{question}
""".strip(),
)

# ----------------- PAGE SETUP -----------------
st.set_page_config(
    page_title="Chat with Your PDFs",
    layout="wide",
    initial_sidebar_state="expanded",
)

# Custom CSS for a cleaner look
st.markdown(
    """
    <style>
    .sidebar .sidebar-content {background-color: #f0f2f6;}
    .css-1d391kg {background-color: white;}  /* Main panel card bg */
    .stButton>button {background-color: #4B8BBE; color: white; border-radius: 8px;}
    .stFileUploader>div {padding: 10px; border: 2px dashed #a2a9b7; border-radius: 8px;}
    </style>
    """, unsafe_allow_html=True)

# ----------------- SIDEBAR -----------------
st.sidebar.title("⚙️ Settings")
# Model selection
model_name = st.sidebar.selectbox(
    "Model:",
    options=["gpt-4o", "gpt-3.5-turbo"],
    index=0,
)
llm.model_name = model_name

# Retriever parameters
st.sidebar.subheader("Retriever Settings")
k = st.sidebar.slider("Top context chunks (k)", min_value=2, max_value=12, value=6)
fetch_k = st.sidebar.slider("Candidate pool (fetch_k)", min_value=10, max_value=50, value=20)
lambda_mult = st.sidebar.slider("Diversity (lambda)", min_value=0.1, max_value=1.0, value=0.8)

# ----------------- TITLE -----------------
st.markdown("# 📄💬 Chat with Your PDFs")
st.markdown("---")

# ----------------- UPLOAD & PROCESS -----------------
with st.expander("📂 Upload & Process PDFs", expanded=True):
    files = st.file_uploader(
        "Drag & drop PDF files here or click to browse",
        type=["pdf"], accept_multiple_files=True
    )
    if files:
        if "vector_store" not in st.session_state:
            with st.spinner("🔍 Indexing documents..."):
                docs = []
                with tempfile.TemporaryDirectory() as td:
                    for f in files:
                        path = os.path.join(td, f.name)
                        with open(path, "wb") as out:
                            out.write(f.getbuffer())
                        pages = PyPDFLoader(path).load()
                        docs.extend(pages)

                chunks = RecursiveCharacterTextSplitter(
                    chunk_size=500, chunk_overlap=50
                ).split_documents(docs)
                embeddings = OpenAIEmbeddings()
                st.session_state.vector_store = FAISS.from_documents(chunks, embeddings)
            st.success("✅ Documents indexed! Start asking below.")

# ----------------- CHAT -----------------
# Initialize chat history storage
if "messages" not in st.session_state:
    st.session_state.messages = []

if "vector_store" in st.session_state:
    # Render past messages
    for msg in st.session_state.messages:
        with st.chat_message(msg["role"]):
            st.markdown(msg["content"])

    # Chat input
    user_q = st.chat_input("Ask a question about your PDFs…")
    if user_q:
        st.session_state.messages.append({"role":"user","content":user_q})
        with st.chat_message("user"):
            st.markdown(user_q)

        retriever = st.session_state.vector_store.as_retriever(
            search_type="mmr",
            search_kwargs={"k": k, "fetch_k": fetch_k, "lambda_mult": lambda_mult}
        )
        # 在这里注入 prompt，其他不变
        qa = RetrievalQA.from_chain_type(
            llm=llm,
            retriever=retriever,
            chain_type="stuff",
            chain_type_kwargs={"prompt": prompt},
            return_source_documents=True
        )
        with st.spinner("🤔 Thinking…"):
            resp = qa.invoke({"query": user_q})
        answer = resp["result"]

        st.session_state.messages.append({"role":"assistant","content":answer})
        with st.chat_message("assistant"):
            st.markdown(answer)
else:
    st.info("📥 Please upload and process PDFs to start.")
