In [None]:
from document_ingestion import ingest_documents
from audio_processing import transcribe_audio
from video_processing import extract_frames
from vector_search import search_text

# 1️⃣ Ingest documents
ingest_documents()

# 2️⃣ Text query
results = search_text("Explain self-attention in transformers")
print(results)

# 3️⃣ Audio → Text → Search
transcript = transcribe_audio("data/audio/sample.wav")
print(search_text(transcript))

# 4️⃣ Video → Frames
extract_frames("data/video/sample.mp4")


In [None]:
import os
import uuid
import re
from enum import Enum
from typing import List
from neo4j import GraphDatabase

from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.schema import Document, BaseRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings  # ✅ Updated import

from config import VECTOR_DB_PATH, EMBEDDING_MODEL, NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD

__all__ = ["llm", "multimodal_search", "rerank_with_llm", "save_to_neo4j", "run_multimodal_qa"]

# Set API Keys
if os.getenv("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
if os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

# Embeddings
from transformers.utils import logging
logging.set_verbosity_error()  # hides transformers warnings

embeddings = OpenAIEmbeddings(model=EMBEDDING_MODEL)
image_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/clip-ViT-B-32")

# Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

# Optional verbose flag
VERBOSE = False

def load_chroma_if_exists(path, embedding_func):
    """Load Chroma DB only if it exists & has data."""
    if os.path.exists(path) and os.listdir(path):
        return Chroma(persist_directory=path, embedding_function=embedding_func)
    if VERBOSE:
        print(f"⚠️ Skipping vector store at {path} (not found or empty)")
    return None

# Vector stores
text_video_store = load_chroma_if_exists(VECTOR_DB_PATH, embeddings)
audio_store = load_chroma_if_exists(os.path.join(VECTOR_DB_PATH, "audio_db"), embeddings)
image_store = load_chroma_if_exists(os.path.join(VECTOR_DB_PATH, "image_db"), image_embeddings)

def reciprocal_rank_fusion(tv_results, audio_results, image_results, k: int = 10) -> List[Document]:
    combined = {}
    for source_name, results in [
        ("Document", tv_results),
        ("Audio", audio_results),
        ("Image", image_results)
    ]:
        for rank, (doc, score) in enumerate(results, start=1):
            rr = 1 / (rank + 60)
            key = doc.page_content
            if key not in combined:
                combined[key] = {"doc": doc, "score": 0}
            combined[key]["score"] += rr
            combined[key]["doc"].metadata["source_type"] = source_name
    sorted_docs = sorted(combined.values(), key=lambda x: x["score"], reverse=True)
    return [item["doc"] for item in sorted_docs[:k]]

def multimodal_search(query: str, k: int = 10, threshold: float = 0.3) -> List[Document]:
    tv = [(d, s) for d, s in (text_video_store.similarity_search_with_score(query, k=k) if text_video_store else []) if s >= threshold]
    aud = [(d, s) for d, s in (audio_store.similarity_search_with_score(query, k=k) if audio_store else []) if s >= threshold]
    img = [(d, s) for d, s in (image_store.similarity_search_with_score(query, k=k) if image_store else []) if s >= threshold]
    return reciprocal_rank_fusion(tv, aud, img, k)

def rerank_with_llm(query: str, docs: List[Document], top_n: int = 3, llm=None) -> List[Document]:
    if not docs or not llm:
        return docs[:top_n]
    prompt = f"You are a reranker. Query: {query}\nDocuments:\n"
    for i, doc in enumerate(docs, start=1):
        snippet = doc.page_content.replace("\n", " ")[:400]
        prompt += f"{i}. {snippet}\n"
    prompt += "\nReturn top 3 document numbers in order."
    resp = llm.invoke(prompt).content
    try:
        indices = [int(x) for x in re.findall(r"\d+", resp)]
        return [docs[i-1] for i in indices if 1 <= i <= len(docs)]
    except:
        return docs[:top_n]

class MultimodalRetriever(BaseRetriever):
    def __init__(self, rerank_llm):
        super().__init__()
        self._rerank_llm = rerank_llm

    def _get_relevant_documents(self, query: str) -> List[Document]:
        docs = multimodal_search(query)
        return rerank_with_llm(query, docs, llm=self._rerank_llm)

def summarize_history(history_msgs, llm) -> str:
    if not history_msgs:
        return "We haven’t had any conversation yet in this session."
    history_text = "\n".join(
        (f"User: {m.content}" if getattr(m, 'type', '') == 'human' else f"Bot: {m.content or m.page_content}")
        for m in history_msgs[-10:]
    )
    prompt = f"Summarize the key topics from this conversation history:\n{history_text}\nProvide a concise summary."
    return llm.invoke(prompt).content.strip()

def create_chain(llm):
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        input_key="question",
        output_key="answer",
        return_messages=True
    )
    chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=MultimodalRetriever(rerank_llm=llm),
        memory=memory,
        condense_question_llm=llm,
        return_source_documents=False
    )
    return chain, memory

def save_to_neo4j(user_id: str, session_id: str, query: str, answer: str, sources: list[str]):
    with driver.session() as session:
        session.run(
            """
            MERGE (u:User {id: $user_id})
            MERGE (s:Session {id: $session_id, date: date()})
            MERGE (u)-[:HAS_SESSION]->(s)
            MERGE (q:Query {text: $query, timestamp: datetime()})
            MERGE (a:Answer {text: $answer})
            MERGE (s)-[:ASKED]->(q)-[:ANSWERED_BY]->(a)
            """, parameters={"user_id": user_id, "session_id": session_id, "query": query, "answer": answer}
        )
        for src in sources:
            session.run(
                """
                MERGE (src:Source {type: $type})
                MERGE (a:Answer {text: $answer})
                MERGE (a)-[:USED_SOURCE]->(src)
                """, parameters={"type": src, "answer": answer}
            )

class QueryType(Enum):
    HISTORY = "history"
    FOLLOW_UP = "follow_up"
    DOCUMENT = "document"

def detect_query_type(query: str, history_msgs) -> QueryType:
    text = query.lower().strip()
    history_patterns = [
        r"what.*we.*\b(discuss|talk|cover|speak|spoke)\b",
        r"on what topic.*we.*\b(speak|spoke|discuss|covered)\b",
        r"what.*previous.*\b(conversation|discussion)\b",
        r"tell me.*\b(earlier|so far|before)\b",
        r"remind me", r"i forgot",
        r"show me our \b(chat|conversation|history)\b",
        r"what topics.*we.*(spoke|discussed)"
    ]
    if any(re.search(p, text) for p in history_patterns):
        return QueryType.HISTORY
    follow_up_patterns = [r"\b(this|it|that|these|those)\b", r"what about", r"what is the use of", r"how does it work"]
    if history_msgs and any(re.search(p, text) for p in follow_up_patterns):
        return QueryType.FOLLOW_UP
    return QueryType.DOCUMENT

def rewrite_query_with_history(query: str, history_msgs, question_rewriter) -> str:
    hist = []
    for m in history_msgs[-6:]:
        if getattr(m, 'type', '') == 'human':
            hist.append(f"User: {m.content}")
        else:
            hist.append(f"Bot: {m.content or m.page_content}")
    history_text = "\n".join(hist)
    prompt = f"Conversation history:\n{history_text}\nRewrite the question '{query}' into a standalone question."
    return question_rewriter.invoke(prompt).content.strip()

def run_multimodal_qa(user_id: str, query: str, input_type: str, file_path: str = None, llm=None):
    session_id = f"{user_id}_{uuid.uuid4().hex[:6]}"
    qa_chain, memory = create_chain(llm)
    history_msgs = memory.chat_memory.messages[-6:]

    qtype = detect_query_type(query, history_msgs)

    if qtype == QueryType.HISTORY:
        summary = summarize_history(history_msgs, llm)
        memory.chat_memory.add_user_message(query)
        memory.chat_memory.add_ai_message(summary)
        save_to_neo4j(user_id, session_id, query, summary, ["ChatHistory"])
        return {"answer": summary, "source": "summary"}

    final_query = rewrite_query_with_history(query, history_msgs, llm) if qtype == QueryType.FOLLOW_UP else query
    docs = MultimodalRetriever(rerank_llm=llm)._get_relevant_documents(final_query)

    if not docs:
        answer = llm.invoke(final_query).content.strip()
        memory.chat_memory.add_user_message(query)
        memory.chat_memory.add_ai_message(answer)
        save_to_neo4j(user_id, session_id, query, answer, ["LLMOnly"])
        return {"answer": answer, "source": "llm"}

    result = qa_chain({"question": final_query})
    final_answer = result.get("answer", "").strip()

    if re.match(r"(?i)i\s+don'?t\s+know", final_answer):
        final_answer = llm.invoke(final_query).content.strip()
        memory.chat_memory.add_user_message(query)
        memory.chat_memory.add_ai_message(final_answer)
        save_to_neo4j(user_id, session_id, query, final_answer, ["LLMOnly"])
        return {"answer": final_answer, "source": "llm"}

    memory.chat_memory.add_user_message(query)
    memory.chat_memory.add_ai_message(final_answer)
    save_to_neo4j(user_id, session_id, query, final_answer, [])
    return {"answer": final_answer, "source": "retriever"}


if __name__ == "__main__":
    # Load an LLM for reranking
    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

    # Ask for user input
    query = input("Enter your search query: ").strip()

    # Run multimodal search
    results = multimodal_search(query)

    if not results:
        print("❌ No results found in any modality.")
    else:
        print(f"\n✅ Found {len(results)} results:")
        for i, doc in enumerate(results, start=1):
            print(f"\n--- Result {i} ---")
            print(f"Source Type: {doc.metadata.get('source_type', 'Unknown')}")
            print(f"Content Preview: {doc.page_content[:200]}...")




In [None]:
import streamlit as st
from PIL import Image
import tempfile
import os
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from dotenv import load_dotenv
load_dotenv()

from vector_search import (
    create_chain,
    save_to_neo4j,
    detect_query_type,
    summarize_history,
    rewrite_query_with_history,
    MultimodalRetriever,
    run_multimodal_qa
)
from neo4j import GraphDatabase
from config import NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD

# ------------------ Neo4j Load History ------------------
def load_user_history_from_neo4j(user_id: str):
    query = """
    MATCH (u:User {id: $user_id})-[:HAS_SESSION]->(:Session)-[:ASKED]->(q:Query)-[:ANSWERED_BY]->(a:Answer)
    RETURN q.text AS question, a.text AS answer
    ORDER BY q.timestamp
    """
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    with driver.session() as session:
        results = session.run(query, parameters={"user_id": user_id})
        return [(record["question"], record["answer"]) for record in results]

# ------------------ Session State Setup ------------------
if "user_id" not in st.session_state:
    st.session_state.user_id = None
if "page" not in st.session_state:
    st.session_state.page = "start"
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# ------------------ Page 1: Landing ------------------
if st.session_state.page == "start":
    st.title("Multimodal Retrieval-Augmented Generation (RAG) Assistant")
    st.markdown("""
    This system integrates document, audio, image, and video data into a unified RAG framework for enhanced contextual understanding.
    It supports LLM-based reasoning with configurable settings and personalized chat history via Neo4j.
    """)

    # Project Info Button
    with st.expander("Project Info", expanded=False):
        st.markdown("""
        ### Multimodal RAG Assistant – Project Description

        This assistant integrates **text, audio, image, and video data** into a unified **Retrieval-Augmented Generation (RAG)** system. It enhances user queries using document search, LLM reasoning, and multimodal embeddings to provide accurate and context-aware answers.

        #### Concepts and Tools Used:
        - **LangChain**: For chaining LLMs and retrievers in a modular pipeline.
        - **ChromaDB**: Vector database to store text, audio, image, and video embeddings.
        - **OpenAI Whisper**: For transcribing audio/video files.
        - **CLIP**: To generate image embeddings.
        - **OpenAI / Groq LLMs**: For answering queries and reranking results.
        - **Neo4j**: To store and retrieve personalized chat history.
        - **RRF (Reciprocal Rank Fusion)**: Combines results from different modalities (text, audio, image, video).
        - **LLM-based Reranking**: Ensures only high-quality relevant results are shown.

        #### Workflow Summary:
        1. **Input**: User provides a query (text) or uploads an image/audio/video file.
        2. **Classification**: Determines whether the query is new, a follow-up, or requires summarization.
        3. **Multimodal Retrieval**: Fetches relevant documents from ChromaDB using embeddings.
        4. **Reranking**: Uses LLMs to rerank retrieved results based on relevance.
        5. **Answer Generation**: Uses QA chains or fallback LLM generation.
        6. **History Storage**: Neo4j logs user-specific query-answer pairs.
        """, unsafe_allow_html=True)

    user_id = st.text_input("Enter your User ID", key="user_id_input")
    st.markdown("---")
    st.markdown("### Model Settings")

    model_choice = st.selectbox("Choose LLM Model", ["OpenAI (gpt-4o-mini)", "Groq (llama3-70b-8192)"])
    temperature = st.slider("Set Temperature", 0.0, 1.0, 0.2, 0.05)
    max_tokens = st.slider("Max Tokens", 128, 4096, 1024, 64)

    if "OpenAI" in model_choice:
        st.markdown("*`GPT-4o-mini` supports ~128k tokens context, recommended max generation: ~2048–4096 tokens.*")
    else:
        st.markdown("*`LLaMA3-70B` supports ~8k tokens context, recommended max generation: ~2048–4096 tokens.*")

    if st.button("Continue"):
        if user_id.strip():
            st.session_state.user_id = user_id.strip()
            st.session_state.page = "chat"

            if "OpenAI" in model_choice:
                llm = ChatOpenAI(model="gpt-4o-mini", temperature=temperature, max_tokens=max_tokens)
            else:
                llm = ChatGroq(model="llama3-70b-8192", temperature=temperature, max_tokens=max_tokens)

            st.session_state.llm = llm
            st.session_state.qa_chain, st.session_state.memory = create_chain(llm)

            past_chats = load_user_history_from_neo4j(user_id.strip())
            for q, a in past_chats:
                st.session_state.memory.chat_memory.add_user_message(q)
                st.session_state.memory.chat_memory.add_ai_message(a)

            st.rerun()
        else:
            st.warning("User ID cannot be empty.")

# ------------------ Page 2: Chat UI ------------------
elif st.session_state.page == "chat":
    st.title("Ask Your Question")
    st.markdown(f"**User ID:** `{st.session_state.user_id}`")

    if st.button("Sign Out"):
        st.session_state.page = "start"
        st.session_state.user_id = None
        st.session_state.chat_history = []
        del st.session_state.qa_chain
        del st.session_state.memory
        del st.session_state.llm
        st.rerun()

    col1, col2 = st.columns(2)
    with col1:
        input_type = st.radio("Choose input type", ["Text", "Image", "Audio", "Video"], key="input_type_radio")
    with col2:
        if st.button("Show History"):
            with st.expander("Conversation History", expanded=True):
                for item in st.session_state.chat_history:
                    st.markdown(f"**You:** {item['user']}")
                    label = item['source']
                    label_text = (
                        "**Bot (Document):** " if label == "retriever" else
                        "**LLM-only:** " if label == "llm" else
                        "**Bot (summary):** " if label == "summary" else
                        "**Bot:** "
                    )
                    st.markdown(f"{label_text} {item['bot']}")
            st.stop()

    query = None
    file_path = None

    if input_type == "Text":
        query = st.text_area("Type your question here", height=100)
    elif input_type == "Image":
        uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
        if uploaded_file:
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_column_width=True)
            query = st.text_area("Ask a question about this image", height=100)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
                image.save(tmp.name)
                file_path = tmp.name
    elif input_type == "Audio":
        uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav"])
        if uploaded_file:
            st.audio(uploaded_file)
            # No text box for audio; user only uploads audio
            with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp:
                tmp.write(uploaded_file.getbuffer())
                file_path = tmp.name
    elif input_type == "Video":
        uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "mov", "avi"])
        if uploaded_file:
            st.video(uploaded_file)
            # No text box for video; user only uploads video
            with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp:
                tmp.write(uploaded_file.getbuffer())
                file_path = tmp.name

    if st.button("Submit"):
        if input_type == "Text":
            if not query or not query.strip():
                st.warning("Please enter a question.")
                st.stop()
        else:
            # For media inputs, text query is not mandatory
            if not file_path:
                st.warning(f"Please upload a {input_type.lower()} file.")
                st.stop()

        if input_type == "Text" and query.strip().lower() == "exit":
            st.session_state.page = "start"
            st.session_state.user_id = None
            st.session_state.chat_history = []
            del st.session_state.qa_chain
            del st.session_state.memory
            del st.session_state.llm
            st.success("You have been logged out.")
            st.rerun()

        with st.spinner("Processing..."):
            memory = st.session_state.memory
            llm = st.session_state.llm
            history_msgs = memory.chat_memory.messages[-6:]
            # For media input, we transcribe file to text query
            if input_type in ["Audio", "Video"]:
                from openai import OpenAI
                client = OpenAI()
                with open(file_path, "rb") as f:
                    transcript = client.audio.transcriptions.create(model="whisper-1", file=f)
                query = transcript.text.strip()
                if not query:
                    st.warning(f"Could not transcribe the {input_type.lower()}. Please try again.")
                    st.stop()

            qtype = detect_query_type(query if input_type == "Text" else query, history_msgs)

            if qtype.name == "HISTORY":
                summary = summarize_history(history_msgs, llm)
                memory.chat_memory.add_user_message(query)
                memory.chat_memory.add_ai_message(summary)
                source = "summary"
                answer = summary
            else:
                final_query = rewrite_query_with_history(query, history_msgs, llm) if qtype.name == "FOLLOW_UP" else query
                retriever = MultimodalRetriever(rerank_llm=llm)
                docs = retriever._get_relevant_documents(final_query)

                if not docs:
                    answer = llm.invoke(final_query).content.strip()
                    memory.chat_memory.add_user_message(query)
                    memory.chat_memory.add_ai_message(answer)
                    save_to_neo4j(st.session_state.user_id, f"{st.session_state.user_id}_streamlit", query, answer, ["LLMOnly"])
                    source = "llm"
                else:
                    result = st.session_state.qa_chain({"question": final_query})
                    answer = result.get("answer", "").strip()

                    if "i don't know" in answer.lower():
                        answer = llm.invoke(final_query).content.strip()
                        memory.chat_memory.add_user_message(query)
                        memory.chat_memory.add_ai_message(answer)
                        save_to_neo4j(st.session_state.user_id, f"{st.session_state.user_id}_streamlit", query, answer, ["LLMOnly"])
                        source = "llm"
                    else:
                        memory.chat_memory.add_user_message(query)
                        memory.chat_memory.add_ai_message(answer)
                        save_to_neo4j(st.session_state.user_id, f"{st.session_state.user_id}_streamlit", query, answer, [])
                        source = "retriever"

            st.markdown(f"**You:** {query if input_type == 'Text' else f'[{input_type} file uploaded]'}")
            label_text = (
                "**Bot (Document):** " if source == "retriever" else
                "**LLM-only:** " if source == "llm" else
                "**Bot (summary):** " if source == "summary" else
                "**Bot:** "
            )
            st.markdown(f"{label_text} {answer}")

            st.session_state.chat_history.append({
                "user": query if input_type == "Text" else f"[{input_type} file uploaded]",
                "bot": answer,
                "source": source
            })

        # Clean up temp file if used
        if file_path and os.path.exists(file_path):
            os.remove(file_path)
