In [None]:
# Wenwen Aufgaben:
# 1). Build a Conversational RAG Retrieval QA Chain with proper citations, like [1][2] with article title, pages and context
# (RAG_QA_Cita-3.ipynb) is the conversational QA_App, answer questions based on given PDFs.

# 2). Bild a Multi-Vector RAG, which can make summary of text and tables from a PDF
# (Multi_Modal_RAG-v2.ipynb) is the Multi_vector_Model, which can make summary of text and tables from a PDF.

# 3). Build a Multi-Modal RAG Retrieval QA Chain with proper citations, like [1][2] with article title, pages and context
# (Multi_RAG_QA_Cita-v4.ipynb) is the combination with (RAG_QA_Cita-3.ipynb) and (Multi_Modal_RAG-v2.ipynb), so that my App can make dialog with me, based on the text and tables from given PDFs.

# 4). In the end, this (Multi_RAG_Agent.ipynb) is the final version of the app, 
# which can make dialog with me, based on the text and tables from given PDFs, 
# and also can make a summary of the text and tables from a PDF, with proper citation style.

# 5). combine all Agents (Multi_RAG_Agent from Wenwen, Web_Search_Agent and Data_Science_Agent from Hanna) with Supervisor Agent (from Wenwen)
# 6). create a Gradio chat interface
# 7). create a Huggingface Space for presentation (https://huggingface.co/spaces/hussamalafandi/test_space)

In [None]:
# step 4: build a supervisor_Agent, to control the RAG_Agent from me and Website_Agent from Hanna
# Create supervisor with langgraph-supervisor
# https://langchain-ai.github.io/langgraph/tutorials/multi_agent/agent_supervisor/#2-create-supervisor-with-langgraph-supervisor 


In [None]:
# "Multi_RAG_Agent.ipynb" from Wenwen
# 1. use LangSmith
import getpass
import os

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_API_KEY"] = getpass.getpass()

# Configure environment to connect to LangSmith.
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"]="KI_multi-modal-RAG"

# 2. Components
# 2.1 Select chat model: Google Gemini
import getpass
import os

if not os.environ.get("GOOGLE_API_KEY"):
  os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")

from langchain.chat_models import init_chat_model
llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai")

# 2.2 Select embedding model: HuggingFace
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

# 2.3 Select vector store: Chroma (install and upgrade langchain_chroma)
from langchain_chroma import Chroma

vector_store = Chroma(
    collection_name="example_collection",
    embedding_function=embeddings,
    persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not necessary
)

# 3. index our documents:

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# 3.1 Load PDF files from a folder
import os
folder_path = "D:/4-IntoCode/16_LangChain/AgilProjekt_multiModel/Raw_Data/Apple/"  # company folder, use / instead of \
all_docs = []

for file in os.listdir(folder_path):
    if file.endswith(".pdf"):
        loader = PyPDFLoader(os.path.join(folder_path, file))
        pages = loader.load_and_split()
        all_docs.extend(pages)

# 3.2 Split into chunks
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = splitter.split_documents(all_docs)
print(f"Loaded {len(docs)} chunks from {len(all_docs)} pages across {len(os.listdir(folder_path))} PDF files.")
# Result: "Loaded 4419 chunks from 1347 pages across 22 PDF files."

# 3.3 Index chunks
_ = vector_store.add_documents(documents=docs)

# Check 1: Are the documents actually in the vectorstore?
print(f"Total documents in ChromaDB: {len(vector_store.get())}")
# Result: "Total documents in ChromaDB: 7"

print(f"# of docs to add: {len(docs)}")  # Should be in thousands, not 7
# Result: # of docs to add: 4419
'''so your docs list has 4419 chunks to add. ✅ That means:
PDF loading ✔️
Chunking ✔️
Number of expected documents ✔️
❌ add_documents() didn't actually store them'''

# to Fix: step1. Delete and Rebuild the ChromaDB from Scratch
import shutil
shutil.rmtree("./chroma_db", ignore_errors=True)

# step2. Re-initialize Chroma with persist directory
from langchain_chroma import Chroma

vector_store = Chroma(
    persist_directory="./chroma_db",
    embedding_function=embeddings
)

# step3. Add all 4419 documents
print(f"Adding {len(docs)} docs")
vector_store.add_documents(docs)

# step4.  Verify
print("Total documents in ChromaDB:", len(vector_store.get()['documents']))
# Should print 4419
# Result: Total documents in ChromaDB: 4419

# 4. Multi_RAG application: reconstruct the Q&A app with citations
# Conversational RAG: additional tool-calling features of chat models to cite document IDs;
# Multi-Vector RAG: use multiple vector stores to retrieve text and tables from a PDF

from langchain_core.messages import SystemMessage
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode
from typing import List
from langchain_core.documents import Document

# 4.1 Define state for application (modified)
class State(MessagesState):
    context: List[Document] # change 1

# 4.2 load a retriever and construct our prompt:
# Combine_Step_1: use our own MultiVectorRetriever from (Multi_Modal_RAG-v2.ipynb)
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore

store = InMemoryStore()
retriever = MultiVectorRetriever(
    vectorstore=vector_store,
    docstore=store,
    id_key="doc_id",  # Keep track of original full content
)
retriever.search_kwargs["k"] = 4  # number of documents to retrieve

# 4.3 Define the tool
from langchain_core.tools import tool

# Combine_Step_3: Update the Tool to Use Multi-Vector Retrieval and Store Metadata
@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = retriever.invoke(query) # change 3

    # Rebuild full documents from store using doc_id, change 4
    full_docs = []
    for doc in retrieved_docs:
        doc_id = doc.metadata["doc_id"]
        full_text = retriever.docstore.mget([doc_id])[0]
        full_docs.append(Document(page_content=full_text, metadata=doc.metadata))
    
    serialized = "\n\n".join(
        f"Source: {doc.metadata}\nContent: {doc.page_content}"
        for doc in full_docs
    )
    return {
        "content": serialized,
        "artifact": full_docs
    }

# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: State):
    """Generate tool call for retrieval or respond."""
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(state["messages"])
    # MessagesState appends messages to state instead of overwriting
    return {"messages": [response]}


# Step 2: Execute the retrieval.
tools = ToolNode([retrieve])

# 4.4 Combine_Step_2: Summarize Text + Tables and Load into MultiVectorRetriever
# Use your partition_pdf + summary chain:
from unstructured.partition.pdf import partition_pdf
from typing import Any
from pydantic import BaseModel

# Use unstructured to extract
raw_pdf_elements = partition_pdf(
    filename=folder_path + file,
    extract_images_in_pdf=True,
    infer_table_structure=True,
    chunking_strategy="by_title",
)

class Element(BaseModel):
    type: str
    text: Any

# Categorize by type
categorized_elements = []
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        categorized_elements.append(Element(type="table", text=str(element)))
    elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
        categorized_elements.append(Element(type="text", text=str(element)))

# Separate into text and table
text_elements = [e for e in categorized_elements if e.type == "text"]
table_elements = [e for e in categorized_elements if e.type == "table"]

# 4.5 Text and Table summaries
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# Prompt
prompt_text = """You are an assistant tasked with summarizing tables and text. \
Give a concise and essential summary of the table or text. 
Each summary should not longer than 10 sentences. Please keep it as short as possible. \
Table or text chunk: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)

# 4.6 Summary chain
import getpass
import os

if not os.environ.get("GOOGLE_API_KEY"):
  os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ") # use Google Gemini instead of OpenAI

from langchain_google_genai import ChatGoogleGenerativeAI
model = ChatGoogleGenerativeAI(model="gemma-3-27b-it", temperature=0)    # use "gemma-3-27b-it" instead of gemini-2.0-flash or 1.5

summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

# Summarize each
text_summaries = summarize_chain.batch([e.text for e in text_elements], {"max_concurrency": 1})
table_summaries = summarize_chain.batch([e.text for e in table_elements], {"max_concurrency": 1})

# 4.7 Add to retriever
from langchain_core.documents import Document
import uuid

# Store original full text in memory, summaries in vectorstore
# Before adding summaries to the vectorstore, Add Document Title & Page Metadata
text_ids = [str(uuid.uuid4()) for _ in text_elements]

# Build a list of (element, summary, doc_id, metadata)
text_triplets = list(zip(text_elements, text_summaries, text_ids))
for (element, summary, doc_id) in text_triplets:
    idx = text_elements.index(element)
    raw_metadata = raw_pdf_elements[idx].metadata
    retriever.vectorstore.add_documents([
        Document(
            page_content=summary,
            metadata={
                "doc_id": doc_id,
                "source": file,
                "page": getattr(raw_metadata, "page_number", -1)
            }
        )
    ])
retriever.docstore.mset(list(zip(text_ids, [e.text for e in text_elements])))

# Same for tables
table_ids = [str(uuid.uuid4()) for _ in table_elements]

# Build a list of (element, summary, doc_id, metadata)
text_triplets = list(zip(text_elements, text_summaries, text_ids))
for (element, summary, doc_id) in text_triplets:
    idx = text_elements.index(element)
    raw_metadata = raw_pdf_elements[idx].metadata
    retriever.vectorstore.add_documents([
        Document(
            page_content=summary,
            metadata={
                "doc_id": doc_id,
                "source": file,
                "page": getattr(raw_metadata, "page_number", -1)
            }
        )
    ])
retriever.docstore.mset(list(zip(table_ids, [e.text for e in table_elements])))


# 4.8 Step 3: Generate a response using the retrieved content.
def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # In case tool_messages is empty or malformed:
    if not tool_messages or not hasattr(tool_messages[0], "artifact"):
        raise ValueError("No valid tool messages with artifacts found.")
    # Add logging to validate what's being returned
    print("Tool Message Artifact:", tool_messages[0].artifact)

    
    # Format into prompt (customize the prompt)
    docs_content = "\n\n".join(doc.page_content for doc in tool_messages[0].artifact)

    system_message_content = (
        """You are an assistant for question-answering tasks. 
        ONLY Use the following pieces of retrieved context to answer the question. 
        For each fact, cite its source number like [1][2]. 
        At the end of your answer, add a list of sources in the format of [1] <source title>, page <page number> and so on.
        If you don't know the answer, If unsure, say 'I don't know'."""
        "\n\n"
        f"{docs_content}"
    )
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Process and format the answer 
    result = llm.invoke(prompt)

    # Get the content of the AI message
    answer = result.content.strip()

    # Try to get any custom metadata or sources (if your LLM provides it through a custom return)
    sources = tool_messages[0].artifact
    
    # Add formatted citations (with prefered cictation style)
    if sources:
        answer += "\n\nSources:"
        for i, doc in enumerate(sources, start=1):
            source_info = doc.metadata.get('source', 'Unknown document')
            page_info = f", page {doc.metadata['page']}" if 'page' in doc.metadata else ""
            answer += f"\n[{i}] {source_info}{page_info}"
    
    print("Answer:\n", answer)
    
    # Run
    context = []
    for tool_message in tool_messages:                  # change 2
        context.extend(tool_message.artifact)
    return {"messages": [result], "context": context}

# 4.9 compile the application:
from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.prebuilt import tools_condition


graph_builder = StateGraph(MessagesState)

graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

graph = graph_builder.compile()

from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

# 4.10 Invoking our application, the retrieved Document objects are accessible from the application state.
# # about Text
input_message = "What is iPhone net sales in the year of 2020?" # the answer should be with ToolMessage

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

# Question 2: about Table
input_message = "tell me about table, which shows net sales by category for 2022, 2021 and 2020?" # the answer should be with ToolMessage

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

# 5. make a Multi_RAG_Agent (after combining the conversation memory and retriever-multi_vector: text, tables)
from langgraph.prebuilt import create_react_agent
Multi_RAG_Agent = create_react_agent(llm, [retrieve])

# inspect the graph:
display(Image(Multi_RAG_Agent.get_graph().draw_mermaid_png()))

# give a question that would typically require an iterative sequence of retrieval steps to answer:
config = {"configurable": {"thread_id": "def234"}}

input_message = (
    "What is the Total net sales in the Year 2020?\n\n"
    "Once you get the answer, look up Net sales by category, "
    "which products were included and how much of each share was."
)

for event in Multi_RAG_Agent.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    event["messages"][-1].pretty_print()

In [None]:
# "Web_Search_Agent.ipynb" from Hanna

In [None]:
# "Data_Science_Agent.ipynb" from Hanna 

In [None]:
# Supervisor_Agent
from langgraph_supervisor import create_supervisor
from langchain.chat_models import init_chat_model

supervisor = create_supervisor(
    model=init_chat_model("gemini-2.0-flash", model_provider="google_genai"), # use Google Gemini instead of OpenAI
    agents=[Multi_RAG_Agent, Web_Search_Agent, Data_Science_Agent],
    prompt=(
        "You are a supervisor managing two agents:\n"
        "- a research agent. Assign research-related tasks to this agent\n"
        "- a math agent. Assign math-related tasks to this agent\n"
        "Assign work to one agent at a time, do not call agents in parallel.\n"
        "Do not do any work yourself."
    ),
    add_handoff_back_messages=True,
    output_mode="full_history",
).compile()

In [None]:
# create a Gradio chat interface using a LangChain chat model
import gradio as gr
from langchain_core.messages import HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
import os


# Initialize the chat model with explicit API key
model = supervisor

def respond(
    message: str,
    history: list[list[str]],  # Gradio's history format: [[user_msg, ai_msg], ...]
) -> str:
    """
    Respond to user input using the model.
    """
    # Convert Gradio history to LangChain message format
    chat_history = []
    for human_msg, ai_msg in history:
        chat_history.extend([
            HumanMessage(content=human_msg),
            AIMessage(content=ai_msg)
        ])
    
    # Add the new user message
    chat_history.append(HumanMessage(content=message))
    
    # Get the AI's response
    response = model.invoke({'messages': chat_history}, config={"configurable": {"thread_id": "thread_123"}})
    
    return response["messages"][-1].content

demo = gr.ChatInterface(
    fn=respond,
    # examples=["Hello", "What's AI?", "Tell me a joke"],
    title="Gemini Chat",
)

demo.launch()
