In [1]:
# OpenAI Vision Multimodal RAG (Phase 1)
# ------------------------------------------------------------
# A fully working, single-file demo that replaces Gemini with
# OpenAI Vision and uses OpenAI embeddings + ChromaDB.
#
# Setup (recommended Python 3.10+):
#   pip install gradio chromadb langchain pypdf pillow python-docx python-pptx \
#               langchain-openai openai PyPDF2
#
# Required environment variable or pass via UI:
#   OPENAI_API_KEY
# ------------------------------------------------------------

import os
import io
import base64
from datetime import datetime
from typing import List, Dict

import gradio as gr

# Vector DB
import chromadb
from chromadb.config import Settings

# Chunking
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Docs parsing
import PyPDF2
from docx import Document
from pptx import Presentation
from PIL import Image

# OpenAI (LLM + Embeddings)
from langchain_openai import OpenAIEmbeddings
from openai import OpenAI


# -----------------------------
# Multimodal RAG Implementation
# -----------------------------
class MultimodalRAG:
    def __init__(self, openai_api_key: str, embedding_model: str = "text-embedding-3-small"):
        if not openai_api_key:
            raise ValueError("OpenAI API key is required")

        os.environ["OPENAI_API_KEY"] = openai_api_key
        self.client = OpenAI(api_key=openai_api_key)

        # Embeddings
        try:
            self.embeddings = OpenAIEmbeddings(
                api_key=openai_api_key,
                model=embedding_model,
            )
        except Exception as e:
            raise RuntimeError(f"Failed to initialize OpenAIEmbeddings: {e}")

        # ChromaDB
        try:
            self.chroma_client = chromadb.PersistentClient(
                path="./knowledge_base",
                settings=Settings(anonymized_telemetry=False, allow_reset=True),
            )
        except Exception as e:
            raise RuntimeError(f"Failed to initialize ChromaDB: {e}")

        # Collections
        self.text_collection = self.chroma_client.get_or_create_collection(
            name="text_documents", metadata={"description": "Text-based documents"}
        )
        self.image_collection = self.chroma_client.get_or_create_collection(
            name="image_documents", metadata={"description": "Image-derived content (via OpenAI Vision)"}
        )

        # Chunker
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=200, length_function=len
        )

        # Light session memory (in-memory)
        self.session_memory: Dict[str, List[Dict]] = {}

    # ---------
    # Utilities
    # ---------
    @staticmethod
    def _now_iso() -> str:
        return datetime.now().isoformat()

    # ---------------
    # File Extraction
    # ---------------
    def extract_text_from_file(self, file_path: str, file_type: str) -> str:
        try:
            if file_type == "pdf":
                text = ""
                with open(file_path, "rb") as f:
                    pdf_reader = PyPDF2.PdfReader(f)
                    for page in pdf_reader.pages:
                        # Some PDFs return None for empty pages; guard for that
                        text += (page.extract_text() or "") + "\n"
                return text.strip()

            elif file_type == "docx":
                doc = Document(file_path)
                return "\n".join(p.text for p in doc.paragraphs)

            elif file_type == "pptx":
                prs = Presentation(file_path)
                texts = []
                for slide in prs.slides:
                    for shape in slide.shapes:
                        if hasattr(shape, "text"):
                            texts.append(shape.text)
                return "\n".join(texts)

            elif file_type == "txt":
                with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
                    return f.read()

            else:
                return f"Unsupported text file type: {file_type}"
        except Exception as e:
            return f"Error extracting text: {e}"

    # ---------------------------
    # Vision (OpenAI) for images
    # ---------------------------
    def _image_to_data_url(self, image_path: str) -> str:
        with open(image_path, "rb") as f:
            b64 = base64.b64encode(f.read()).decode("utf-8")
        # Guess mime from ext
        ext = os.path.splitext(image_path)[1].lower()
        mime = {
            ".png": "image/png",
            ".jpg": "image/jpeg",
            ".jpeg": "image/jpeg",
            ".gif": "image/gif",
            ".bmp": "image/bmp",
            ".webp": "image/webp",
        }.get(ext, "image/png")
        return f"data:{mime};base64,{b64}"

    def process_image_with_openai(self, image_path: str) -> str:
        """OCR + description + key info using GPT-4o-mini vision."""
        try:
            data_url = self._image_to_data_url(image_path)
            prompt = (
                "Analyze this image and provide:\n"
                "1) Any text visible (OCR)\n"
                "2) A detailed description\n"
                "3) Key information or concepts shown\n"
                "4) Context that might be useful for retrieval/search\n"
                "Return clear sections with headings."
            )

            # Using Chat Completions for broad compatibility
            completion = self.client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": data_url}},
                        ],
                    }
                ],
                temperature=0.2,
            )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            return f"Error processing image with OpenAI Vision: {e}"

    # --------------------
    # Ingestion / Indexing
    # --------------------
    def add_document(self, file_path: str, file_name: str, session_id: str = "default") -> str:
        try:
            ext = file_name.lower().split(".")[-1]

            if ext in ["pdf", "docx", "pptx", "txt"]:
                content = self.extract_text_from_file(file_path, ext)
                if not content or content.startswith("Error"):
                    return f"❌ Failed to extract content from {file_name}"

                chunks = self.text_splitter.split_text(content)
                if not chunks:
                    return f"❌ No text found in {file_name}"

                # Embed in batches for efficiency
                try:
                    vectors = self.embeddings.embed_documents(chunks)
                except Exception as e:
                    return f"❌ Embedding error for {file_name}: {e}"

                ids = []
                metadatas = []
                for i, chunk in enumerate(chunks):
                    ids.append(f"{file_name}_{i}_{int(datetime.now().timestamp())}")
                    metadatas.append(
                        {
                            "file_name": file_name,
                            "file_type": ext,
                            "chunk_index": i,
                            "session_id": session_id,
                            "timestamp": self._now_iso(),
                        }
                    )

                # Persist to Chroma
                self.text_collection.add(
                    embeddings=vectors,
                    documents=chunks,
                    metadatas=metadatas,
                    ids=ids,
                )

                # Session memory
                self.session_memory.setdefault(session_id, []).append(
                    {
                        "file_name": file_name,
                        "file_type": ext,
                        "chunks_count": len(chunks),
                        "timestamp": self._now_iso(),
                    }
                )
                return f"✅ Successfully processed {file_name} ({len(chunks)} chunks)"

            elif ext in ["jpg", "jpeg", "png", "gif", "bmp", "webp"]:
                analysis = self.process_image_with_openai(file_path)
                if not analysis or analysis.startswith("Error"):
                    return f"❌ Failed to process image {file_name}: {analysis}"

                try:
                    emb = self.embeddings.embed_query(analysis)
                except Exception as e:
                    return f"❌ Embedding error for image {file_name}: {e}"

                doc_id = f"{file_name}_{int(datetime.now().timestamp())}"
                self.image_collection.add(
                    embeddings=[emb],
                    documents=[analysis],
                    metadatas=[
                        {
                            "file_name": file_name,
                            "file_type": "image",
                            "session_id": session_id,
                            "timestamp": self._now_iso(),
                        }
                    ],
                    ids=[doc_id],
                )

                self.session_memory.setdefault(session_id, []).append(
                    {
                        "file_name": file_name,
                        "file_type": "image",
                        "timestamp": self._now_iso(),
                    }
                )
                return f"✅ Successfully processed image {file_name}"
            else:
                return f"❌ Unsupported file type: .{ext}"
        except Exception as e:
            return f"❌ Error processing {file_name}: {e}"

    # -------
    # Search
    # -------
    def search_knowledge_base(self, query: str, session_id: str = "default", top_k: int = 5) -> List[Dict]:
        try:
            q_emb = self.embeddings.embed_query(query)

            where_filter = {"session_id": session_id} if session_id in self.session_memory else None

            text_results = self.text_collection.query(
                query_embeddings=[q_emb],
                n_results=max(1, top_k // 2),
                where=where_filter,
                include=["documents", "metadatas", "distances"],
            )
            image_results = self.image_collection.query(
                query_embeddings=[q_emb],
                n_results=max(1, top_k // 2),
                where=where_filter,
                include=["documents", "metadatas", "distances"],
            )

            def pack(res: Dict, typ: str) -> List[Dict]:
                items = []
                if res and res.get("documents") and res["documents"][0]:
                    for i, doc in enumerate(res["documents"][0]):
                        items.append(
                            {
                                "content": doc,
                                "metadata": res["metadatas"][0][i],
                                "distance": res["distances"][0][i],
                                "type": typ,
                            }
                        )
                return items

            all_results = pack(text_results, "text") + pack(image_results, "image")
            all_results.sort(key=lambda x: x.get("distance", 1.0))
            return all_results[:top_k]
        except Exception as e:
            return [
                {
                    "content": f"Search error: {e}",
                    "metadata": {},
                    "distance": 1.0,
                    "type": "error",
                }
            ]

    # ------------------
    # Answer Generation
    # ------------------
    def generate_answer(self, query: str, session_id: str = "default") -> str:
        results = self.search_knowledge_base(query, session_id=session_id, top_k=6)
        if not results or all(r.get("type") == "error" for r in results):
            return "I couldn't find relevant information in your knowledge base. Please upload some documents first."

        context_parts = []
        sources = []
        for i, r in enumerate(results):
            if r.get("type") != "error":
                fn = r.get("metadata", {}).get("file_name", "Unknown")
                snippet = (r.get("content") or "")[:800]
                context_parts.append(f"Document {i+1} ({r['type']}) from {fn}:\n{snippet}\n---")
                sources.append(fn)

        context_block = "\n\n".join(context_parts)
        sys_prompt = (
            "You are a helpful enterprise knowledge assistant. Answer strictly based on the provided context. "
            "If information is missing, clearly say what's missing. Provide concise, actionable answers and cite the source filenames inline like [Source: filename]."
        )

        user_prompt = (
            f"Context:\n{context_block}\n\n"
            f"Question: {query}\n\n"
            "Instructions:\n"
            "1) Use only the context above.\n"
            "2) If context is insufficient, say so and specify what is needed.\n"
            "3) Include short source attributions like [Source: <filename>].\n"
        )

        completion = self.client.chat.completions.create(
            model="gpt-4o-mini",
            temperature=0.2,
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt},
            ],
        )
        answer = completion.choices[0].message.content.strip()

        if sources:
            answer += "\n\n📚 Sources: " + ", ".join(sorted(set(sources)))
        return answer

    # --------------
    # Session status
    # --------------
    def get_session_info(self, session_id: str = "default") -> str:
        if session_id not in self.session_memory:
            return "No documents uploaded in this session."
        docs = self.session_memory[session_id]
        info = f"📁 **Session Documents ({len(docs)} files):**\n\n"
        for d in docs:
            info += f"• {d['file_name']} ({d['file_type']}) - {d['timestamp'][:19]}\n"
            if "chunks_count" in d:
                info += f"  └── {d['chunks_count']} text chunks\n"
        return info


# -----------------------
# Gradio App / UI Wiring
# -----------------------
rag_system: MultimodalRAG | None = None


def initialize_system(openai_key: str) -> str:
    global rag_system
    if not openai_key and not os.getenv("OPENAI_API_KEY"):
        return "❌ Please provide an OpenAI API key"
    key = openai_key or os.getenv("OPENAI_API_KEY")
    try:
        rag_system = MultimodalRAG(key)
        return "✅ System initialized successfully!"
    except Exception as e:
        return f"❌ Error initializing system: {e}"


def upload_document(files, session_id: str = "default") -> str:
    if rag_system is None:
        return "❌ Please initialize the system with your API key first"
    if not files:
        return "❌ No files uploaded"

    results = []
    for f in files:
        # gr.Files gives tempfile objects with .name that points to a real path
        results.append(rag_system.add_document(f.name, os.path.basename(f.name), session_id))
    return "\n".join(results)


def ask_question(question: str, session_id: str = "default") -> str:
    if rag_system is None:
        return "❌ Please initialize the system with your API key first"
    if not question:
        return "❌ Please ask a question"
    return rag_system.generate_answer(question, session_id)


def get_session_status(session_id: str = "default") -> str:
    if rag_system is None:
        return "❌ System not initialized"
    return rag_system.get_session_info(session_id)


# ----------------
# Gradio Interface
# ----------------

def create_gradio_interface():
    with gr.Blocks(title="Enterprise Knowledge Assistant — OpenAI Vision", theme=gr.themes.Soft()) as demo:
        gr.Markdown(
            """
            # 🧠 Enterprise Knowledge Assistant (OpenAI Vision)

            Upload documents (PDF, DOCX, PPTX, TXT, Images) and ask questions about their content.\
            This app uses **OpenAI Vision (GPT-4o-mini)** for images, **OpenAI Embeddings** for retrieval, and **ChromaDB** for vector storage.

            **Install:**
            ```bash
            pip install gradio chromadb langchain pypdf pillow python-docx python-pptx \\
                        langchain-openai openai PyPDF2
            ```
            """
        )

        with gr.Tab("🔧 Setup"):
            gr.Markdown("### Initialize the System")
            openai_key_input = gr.Textbox(
                label="OpenAI API Key",
                placeholder="Enter your OpenAI API key here...",
                type="password",
            )
            init_btn = gr.Button("Initialize System", variant="primary")
            init_status = gr.Textbox(label="Status", interactive=False)
            init_btn.click(fn=initialize_system, inputs=[openai_key_input], outputs=[init_status])

        with gr.Tab("📁 Upload Documents"):
            gr.Markdown("### Upload Your Documents")
            session_input = gr.Textbox(label="Session ID", value="default", placeholder="Enter session ID (optional)")
            file_upload = gr.Files(
                label="Upload Documents",
                file_count="multiple",
                file_types=[
                    ".pdf",
                    ".docx",
                    ".pptx",
                    ".txt",
                    ".jpg",
                    ".jpeg",
                    ".png",
                    ".gif",
                    ".bmp",
                    ".webp",
                ],
            )
            upload_btn = gr.Button("Process Documents", variant="primary")
            upload_status = gr.Textbox(label="Upload Status", interactive=False, lines=6)
            upload_btn.click(fn=upload_document, inputs=[file_upload, session_input], outputs=[upload_status])

        with gr.Tab("🤖 Ask Questions"):
            gr.Markdown("### Query Your Knowledge Base")
            session_query = gr.Textbox(label="Session ID", value="default")
            question_input = gr.Textbox(
                label="Your Question", placeholder="Ask anything about your uploaded documents...", lines=3
            )
            ask_btn = gr.Button("Get Answer", variant="primary")
            answer_output = gr.Textbox(label="Answer", interactive=False, lines=12)
            ask_btn.click(fn=ask_question, inputs=[question_input, session_query], outputs=[answer_output])

        with gr.Tab("📊 Session Info"):
            gr.Markdown("### Session Status")
            session_status_input = gr.Textbox(label="Session ID", value="default")
            status_btn = gr.Button("Check Status")
            status_output = gr.Textbox(label="Session Information", interactive=False, lines=10)
            status_btn.click(fn=get_session_status, inputs=[session_status_input], outputs=[status_output])

        gr.Markdown(
            """
            ### 🔧 Troubleshooting
            1. Make sure you're in your virtual environment.
            2. Install deps again if needed:
               ```bash
               pip install --upgrade chromadb openai langchain-openai gradio
               ```
            3. Set `OPENAI_API_KEY` in your environment or paste it in **Setup**.
            """
        )

    return demo


if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0", server_port=7861, share=True)


  from .autonotebook import tqdm as notebook_tqdm


* Running on local URL:  http://0.0.0.0:7861
* Running on public URL: https://7be2a27fd768e33463.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
