<a href="https://colab.research.google.com/github/SuccessSoham/Gen-AI/blob/main/Multimodal_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# 🚀 Multimodal RAG Chatbot (Text Docs + Image Captions from Docs + Direct Image Uploads)
# Using LangChain, CTransformers (GPU for LLM), Sentence Transformers (Embeddings),
# BLIP (Image Captioning), FAISS (Vector Store), and Gradio.

# This notebook demonstrates how to build a RAG chatbot that can:
# 1. Process text from uploaded documents (.pdf, .docx, .csv).
# 2. Extract images from .pdf and .docx files, generate text captions, and add them to the context.
# 3. Accept direct image uploads (.jpg, .jpeg, .png), generate captions, and add them to the context.
# The combined text (original text + all image captions) is used as context for a text-only LLM.

In [None]:
## ⚙️ Step 1: Environment Setup

**Context:**
*   Jupyter Lab launched from `/mnt/c/Users/HP` after `conda activate chatbot_gpu`.
*   Files will be uploaded from paths relative to `C:/Users/HP`.
*   **Conda Environment:** `chatbot_gpu` (Python 3.10.x)
*   **CUDA via `nvcc` in Conda:** 12.9 (or similar, ensure this from previous steps)
*   **NVIDIA Driver:** e.g., 576.52 (compatible with your CUDA)
*   **GPU:** NVIDIA GeForce RTX 4060 Laptop GPU (Compute Capability 8.9)

In [None]:
## 📚 Step 2: Install Required Libraries

Ensure all necessary libraries are present in your `chatbot_gpu` Conda environment.
Run these commands in your Anaconda Prompt/Terminal if not already installed.

In [11]:
# This cell is for guidance. Run these in your Anaconda Prompt with chatbot_gpu activated
# if you haven't already, or if you create a fresh environment.
# If run in notebook, RESTART KERNEL afterwards.

print("Guidance: Ensure these are installed in your Conda environment.")
!conda install -c nvidia cuda-toolkit=12.1 -y # Or your preferred recent CUDA, e.g., 11.8") # Ensure a full toolkit for nvcc
!conda install -c conda-forge gcc_linux-64 gxx_linux-64 cmake make ninja git --yes")

print("Installing Python packages...")
!pip install ctransformers[cuda] --no-cache-dir # For LLM with GPU
!pip install langchain langchain-community gradio huggingface_hub torch torchvision torchaudio --quiet # Core
!pip install pypdf python-docx pandas --quiet # Document loaders
!pip install faiss-cpu sentence-transformers --quiet # Vector store & embeddings
!pip install Pillow PyMuPDF transformers --quiet # Image processing & captioning
# Install PyMuPDF
!pip install PyMuPDF
# Install python-docx
!pip install python-docx

# --- For this session, let's assume they are installed. If errors, uncomment and run above, then RESTART KERNEL ---
print("Assuming necessary packages are installed in 'chatbot_gpu' environment.")
print("If you encounter import errors, run the pip/conda install commands above in an Anaconda Prompt,")
print("then RESTART THE JUPYTER KERNEL from the 'Kernel' menu above.")

Guidance: Ensure these are installed in your Conda environment.
/bin/bash: line 1: conda: command not found
/bin/bash: -c: line 1: unexpected EOF while looking for matching `"'
/bin/bash: -c: line 2: syntax error: unexpected end of file
Installing Python packages...
Collecting ctransformers[cuda]
  Downloading ctransformers-0.2.27-py3-none-any.whl.metadata (17 kB)
Downloading ctransformers-0.2.27-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m256.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ctransformers
Successfully installed ctransformers-0.2.27
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m97.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m128.0 MB/s[0m eta [36m0:00:00[0m
[2K

In [12]:
import os
import io
import shutil # For copying uploaded files if needed
import tempfile # For handling Gradio's temporary file uploads
import gradio as gr
from huggingface_hub import hf_hub_download
import fitz  # PyMuPDF
from docx import Document as DocxDocument # python-docx
from PIL import Image
import pandas as pd # Keep this import here for use in functions

# LangChain & CTransformers
from langchain_community.llms import CTransformers
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain.memory import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyPDFLoader, CSVLoader, Docx2txtLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document as LangchainDocument

# Image Captioning
from transformers import pipeline as transformers_pipeline

import torch

# --- Basic Health Checks ---
print(f"PyTorch version: {torch.__version__}")
PYTORCH_CUDA_AVAILABLE = torch.cuda.is_available()
print(f"PyTorch CUDA available: {PYTORCH_CUDA_AVAILABLE}")
if PYTORCH_CUDA_AVAILABLE:
    print(f"PyTorch CUDA version detected by PyTorch: {torch.version.cuda}")
    try:
        print(f"Current GPU: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    except Exception as e:
        print(f"Could not get GPU name from PyTorch: {e}")
else:
    print("WARNING: PyTorch does not see CUDA. LLM/Captioning GPU offload will fail.")

print("\nAll libraries imported.")

PyTorch version: 2.6.0+cu124
PyTorch CUDA available: True
PyTorch CUDA version detected by PyTorch: 12.4
Current GPU: Tesla T4

All libraries imported.


In [None]:
## 📦 Step 3: Download GGUF Model for LLM

In [13]:
llm_model_name_repo = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
llm_model_basename = "mistral-7b-instruct-v0.2.Q4_K_M.gguf"

print(f"Downloading LLM: {llm_model_basename} from {llm_model_name_repo}...")
llm_model_path = hf_hub_download(
    repo_id=llm_model_name_repo,
    filename=llm_model_basename
)
print(f"LLM downloaded to: {llm_model_path}")

Downloading LLM: mistral-7b-instruct-v0.2.Q4_K_M.gguf from TheBloke/Mistral-7B-Instruct-v0.2-GGUF...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


mistral-7b-instruct-v0.2.Q4_K_M.gguf:   0%|          | 0.00/4.37G [00:00<?, ?B/s]

LLM downloaded to: /root/.cache/huggingface/hub/models--TheBloke--Mistral-7B-Instruct-v0.2-GGUF/snapshots/3a6fbf4a41a1d52e415a4958cde6856d34b2db93/mistral-7b-instruct-v0.2.Q4_K_M.gguf


In [None]:
## 🧠 Step 4: Initialize Models (LLM, Embeddings, Image Captioning)

In [14]:
# --- Initialize LLM (CTransformers) ---
llm = None
print("Initializing CTransformers LLM with GPU offloading...")
llm_config = {
    'gpu_layers': 50,
    'max_new_tokens': 1024,
    'temperature': 0.7, # Default was causing issues, trying explicit
    'context_length': 2048,
    'top_p': 0.95,
    'repetition_penalty': 1.1
}
if PYTORCH_CUDA_AVAILABLE:
    try:
        llm = CTransformers(
            model=llm_model_path,
            config=llm_config,
            # model_type='mistral' # Usually auto-detected
        )
        print("CTransformers LLM loaded successfully.")
        if llm_config.get('gpu_layers', 0) > 0:
            print("LLM GPU offloading is ENABLED.")
    except Exception as e:
        print(f"Error loading CTransformers LLM with GPU: {e}")
        print("Attempting to load LLM on CPU as fallback...")
        try:
            llm_config_cpu = llm_config.copy()
            llm_config_cpu['gpu_layers'] = 0
            llm = CTransformers(model=llm_model_path, config=llm_config_cpu)
            print("CTransformers LLM loaded successfully on CPU.")
        except Exception as e_cpu:
            print(f"Error loading CTransformers LLM on CPU: {e_cpu}")
            llm = None
else:
    print("PyTorch CUDA not available, attempting to load LLM on CPU...")
    try:
        llm_config_cpu = llm_config.copy()
        llm_config_cpu['gpu_layers'] = 0
        llm = CTransformers(model=llm_model_path, config=llm_config_cpu)
        print("CTransformers LLM loaded successfully on CPU.")
    except Exception as e_cpu:
        print(f"Error loading CTransformers LLM on CPU: {e_cpu}")
        llm = None


# --- Initialize Text Embeddings Model ---
embeddings = None
print("\nInitializing sentence-transformer embeddings model...")
try:
    embeddings_model_name = "sentence-transformers/all-MiniLM-L6-v2"
    embedding_device = 'cuda' if PYTORCH_CUDA_AVAILABLE else 'cpu'
    embeddings = HuggingFaceEmbeddings(
        model_name=embeddings_model_name,
        model_kwargs={'device': embedding_device}
    )
    print(f"Text embeddings model '{embeddings_model_name}' loaded successfully on '{embedding_device}'.")
except Exception as e:
    print(f"Error loading text embeddings model: {e}")

# --- Initialize Image Captioning Model ---
image_captioner = None
print("\nInitializing image captioning model (e.g., BLIP)... This may download model weights.")
try:
    captioning_device = 0 if PYTORCH_CUDA_AVAILABLE else -1 # device=0 for first GPU, -1 for CPU
    image_captioner = transformers_pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base", # Using base for potentially faster loading/less VRAM
        # model="Salesforce/blip-image-captioning-large", # Use large for better captions if resources allow
        device=captioning_device
    )
    print(f"Image captioning model loaded successfully on device index: {captioning_device}.")
except Exception as e:
    print(f"Error loading image captioning model: {e}. Image captioning will be disabled.")


# --- Initialize Text Splitter ---
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
print("\nText splitter initialized.")

# --- Global state variables ---
vector_store = None # Will hold the FAISS index
processed_document_content = [] # List of LangchainDocument objects (text + captions)
processed_images_for_display = [] # List of {"image": PIL.Image, "caption": str, "source": str}

Initializing CTransformers LLM with GPU offloading...
CTransformers LLM loaded successfully.
LLM GPU offloading is ENABLED.

Initializing sentence-transformer embeddings model...


  embeddings = HuggingFaceEmbeddings(


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Text embeddings model 'sentence-transformers/all-MiniLM-L6-v2' loaded successfully on 'cuda'.

Initializing image captioning model (e.g., BLIP)... This may download model weights.


config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


Image captioning model loaded successfully on device index: 0.

Text splitter initialized.


In [None]:
## 📄 Step 5: Document & Image Processing Logic

In [15]:
# CELL 11: Document & Image Processing Functions

# Ensure pd is available if not imported globally earlier and used here
# import pandas as pd # Already imported in Cell 5 if you followed the full notebook

def get_image_caption(pil_image, source_info="image"):
    """Generates a caption for a given PIL Image object."""
    global image_captioner # Use the globally loaded captioner
    if not image_captioner:
        return "Image captioning model not available."
    try:
        # Ensure image is RGB for BLIP and other captioning models
        rgb_image = pil_image.convert("RGB")
        # Some captioners expect a list of images, even if it's just one
        caption_result = image_captioner(rgb_image) # For BLIP, passing single image is fine

        # Output format can vary; for Salesforce/blip it's a list of dicts
        caption_text = caption_result[0]['generated_text'] if caption_result and isinstance(caption_result, list) and len(caption_result) > 0 and 'generated_text' in caption_result[0] else "No caption generated."
        print(f"Caption for {source_info}: {caption_text}")
        return caption_text
    except Exception as e:
        print(f"Error generating caption for {source_info}: {e}")
        return f"Error generating caption for {source_info}."


def extract_content_from_file(file_path, original_filename, progress_callback):
    """
    Extracts text and image information (PIL image, caption) from a single document file.
    Image information includes the PIL image object, its generated caption, and source.
    """
    # Ensure global image_captioner is accessible if needed by get_image_caption
    # global image_captioner # Not strictly needed here if get_image_caption handles it

    extracted_texts_lc_docs = [] # List of LangchainDocument for text
    extracted_image_infos_list = [] # List of {"image": PIL.Image, "caption": str, "source": str}

    # --- Text Extraction ---
    try:
        progress_callback(0.05, f"Extracting text from {original_filename}...") # Progress for this sub-task
        current_text_docs = []
        if original_filename.lower().endswith(".pdf"):
            loader = PyPDFLoader(file_path)
            current_text_docs = loader.load()
        elif original_filename.lower().endswith(".docx") or original_filename.lower().endswith(".doc"):
            loader = Docx2txtLoader(file_path)
            current_text_docs = loader.load()
        elif original_filename.lower().endswith(".csv"):
            df = pd.read_csv(file_path) # pd should be imported in Cell 5 or at top of this cell
            csv_texts = [f"Row {i+1}: {', '.join(f'{col}: {str(row[col])}' for col in df.columns if pd.notna(row[col]))}" for i, row in df.iterrows()]
            current_text_docs = [LangchainDocument(page_content=text, metadata={"source": original_filename, "row": i+1}) for i, text in enumerate(csv_texts)]

        if current_text_docs:
            for doc in current_text_docs:
                if "source" not in doc.metadata: # Ensure source metadata
                    doc.metadata["source"] = original_filename
            extracted_texts_lc_docs.extend(current_text_docs)
        progress_callback(0.25, "Text extraction complete.") # Max 25% of this function's progress for text
    except Exception as e:
        print(f"Error extracting text from {original_filename}: {e}")
        progress_callback(0.25, f"Text extraction error for {original_filename}: {e}")

    # --- Image Extraction & Captioning (from PDF/DOCX) ---
    # This part takes more time, so allocate more progress percentage (e.g., 25% to 75%)
    if image_captioner and (original_filename.lower().endswith(".pdf") or original_filename.lower().endswith(".docx")):
        progress_callback(0.26, f"Processing images in {original_filename}...")
        img_counter = 0
        try:
            if original_filename.lower().endswith(".pdf"):
                pdf_doc = fitz.open(file_path)
                num_pdf_pages = len(pdf_doc)
                for page_num in range(num_pdf_pages):
                    # Update progress based on PDF page processing
                    current_progress = 0.26 + (0.49 * ((page_num + 1) / num_pdf_pages)) # Scale 0.49 progress over pages
                    progress_callback(current_progress, f"PDF {original_filename} - Page {page_num+1}/{num_pdf_pages}")

                    page_image_list = pdf_doc.get_page_images(page_num)
                    if not page_image_list: continue

                    for img_idx, img_info_fitz in enumerate(page_image_list):
                        img_counter += 1
                        xref = img_info_fitz[0]
                        base_image = pdf_doc.extract_image(xref)
                        if not base_image or not base_image.get("image"): continue # Skip if image data is bad

                        pil_image = Image.open(io.BytesIO(base_image["image"]))
                        source_desc = f"{original_filename} (page {page_num+1}, image {img_idx+1})"
                        caption = get_image_caption(pil_image, source_desc)
                        extracted_image_infos_list.append({"image": pil_image, "caption": caption, "source": source_desc})
                pdf_doc.close()
            elif original_filename.lower().endswith(".docx"):
                doc = DocxDocument(file_path) # python-docx Document
                # Iterate through inline shapes and other potential image containers if needed
                # This example uses rels, which is common for primary images
                rels = [doc.part.rels[rel_id] for rel_id in doc.part.rels if "image" in doc.part.rels[rel_id].target_ref]
                num_rels = len(rels)
                for rel_idx, rel_obj in enumerate(rels):
                    current_progress = 0.26 + (0.49 * ((rel_idx + 1) / num_rels if num_rels > 0 else 1))
                    progress_callback(current_progress, f"DOCX {original_filename} - Image {rel_idx+1}/{num_rels}")
                    img_counter += 1
                    pil_image = Image.open(io.BytesIO(rel_obj.target_part.blob))
                    source_desc = f"{original_filename} (embedded image {img_counter})"
                    caption = get_image_caption(pil_image, source_desc)
                    extracted_image_infos_list.append({"image": pil_image, "caption": caption, "source": source_desc})
            progress_callback(0.75, f"Image processing complete for {original_filename}.")
        except Exception as e:
            print(f"Error during image processing for {original_filename}: {e}")
            progress_callback(0.75, f"Image processing error for {original_filename}: {e}")

    progress_callback(1.0, f"Content extraction finished for {original_filename}.") # Mark this file as 100% done for its stage
    return extracted_texts_lc_docs, extracted_image_infos_list


def process_uploaded_files(files, progress=gr.Progress(track_tqdm=True)):
    global vector_store, processed_document_content, processed_images_for_display
    # Ensure pandas is imported if not done globally or in Cell 5
    # import pandas as pd # Already imported in Cell 5

    if not files:
        vector_store = None
        processed_document_content = []
        processed_images_for_display = []
        return "No files uploaded or files removed.", [], None

    all_content_for_vectorstore = [] # Holds LangchainDocument objects
    current_images_for_display_batch = [] # Holds {"image": PIL, "caption": str, "source": str} for this batch

    total_files = len(files)
    for i, uploaded_file_obj in enumerate(files):
        if uploaded_file_obj is None: continue

        actual_file_path = None # Define to ensure it exists in scope for finally
        original_display_name = "unknown_file"
        try:
            # For Gradio gr.Files, each item is a tempfile._TemporaryFileWrapper
            # Its .name attribute gives the path to the temporary file on disk.
            actual_file_path = uploaded_file_obj.name
            # Try to get a more meaningful name for display if possible; Gradio might not always provide original.
            # os.path.basename on a temp name might just be the temp name.
            # We'll use the temp name for processing, but try to get a better display name.
            # Let's assume uploaded_file_obj.orig_name might exist (depends on Gradio version/usage)
            # or default to basename of the temp path.
            original_display_name = getattr(uploaded_file_obj, 'orig_name', os.path.basename(actual_file_path))


            if not os.path.exists(actual_file_path):
                print(f"ERROR: Gradio temporary file path does not exist: {actual_file_path}")
                continue

            # Calculate progress range for this specific file
            file_progress_start_overall = i / total_files
            file_progress_end_overall = (i + 1) / total_files

            def file_specific_progress_wrapper(p_stage_for_file, desc_stage_for_file):
                # p_stage_for_file is progress within the sub-function (0 to 1)
                overall_p = file_progress_start_overall + p_stage_for_file * (file_progress_end_overall - file_progress_start_overall)
                progress(overall_p, desc=f"{original_display_name}: {desc_stage_for_file}")

            file_specific_progress_wrapper(0, "Starting...")

            # --- Process File Content ---
            if original_display_name.lower().endswith((".jpg", ".jpeg", ".png")):
                file_specific_progress_wrapper(0.1, "Direct image upload...")
                if image_captioner:
                    pil_image = Image.open(actual_file_path)
                    caption = get_image_caption(pil_image, original_display_name)
                    caption_doc = LangchainDocument(page_content=f"Image Caption: {caption}", metadata={"source": original_display_name, "type": "direct_image_upload"})
                    all_content_for_vectorstore.append(caption_doc)
                    current_images_for_display_batch.append({"image": pil_image, "caption": caption, "source": original_display_name})
                    file_specific_progress_wrapper(1.0, "Image captioned.")
                else:
                    file_specific_progress_wrapper(1.0, "Skipped (no captioner).")
            else: # Assumed to be a document file (PDF, DOCX, CSV)
                texts, image_infos = extract_content_from_file(actual_file_path, original_display_name, file_specific_progress_wrapper)
                all_content_for_vectorstore.extend(texts)
                for img_info in image_infos: # These are already {"image": PIL, "caption": str, "source": str}
                    caption_doc = LangchainDocument(page_content=f"Image Caption: {img_info['caption']}", metadata={"source": img_info["source"], "type": "embedded_image_caption"})
                    all_content_for_vectorstore.append(caption_doc)
                    current_images_for_display_batch.append(img_info)

        except Exception as e_file_proc:
            print(f"Error processing file object {original_display_name}: {e_file_proc}")
            # Update progress to show this file had an issue but we move on
            file_specific_progress_wrapper(1.0, f"Error: {e_file_proc}")
            # No finally block here, Gradio manages its temp files from gr.Files

    # Update global states after processing all files in the batch
    processed_document_content = all_content_for_vectorstore
    processed_images_for_display = current_images_for_display_batch # Overwrite with current batch's images

    if not processed_document_content:
        vector_store = None
        return "No processable content found in uploaded file(s).", [], None

    try:
        # Progress for final steps (splitting, vector store creation)
        # These steps operate on the accumulated content from all files
        progress(0.90, desc="Splitting all extracted content...") # Assuming overall progress is near end
        split_docs = text_splitter.split_documents(processed_document_content)
        if not split_docs:
             vector_store = None
             return "Aggregated content from file(s) could not be split.", [], None

        progress(0.95, desc="Creating vector store (FAISS)...")
        vector_store = FAISS.from_documents(split_docs, embeddings) # embeddings from global scope
        progress(1.0, desc="All file(s) processed successfully!")

        gallery_items = [(img_info['image'], img_info['caption']) for img_info in processed_images_for_display if img_info.get('image')]

        return f"{len(files)} file object(s) received, {len(processed_document_content)} content pieces indexed. Ready to chat!", [], gallery_items
    except Exception as e_vs:
        vector_store = None
        processed_images_for_display = [] # Clear display images on error
        print(f"Error creating vector store from aggregated content: {e_vs}")
        return f"Error creating vector store: {e_vs}", [], None

print("Document and image processing functions (get_image_caption, extract_content_from_file, process_uploaded_files) defined.")

Document and image processing functions (get_image_caption, extract_content_from_file, process_uploaded_files) defined.


In [None]:
## 💬 Step 6: Define RAG Chain and Chat Logic

In [16]:
# Prompt for RAG
rag_prompt_template_str = """
You are a helpful AI assistant. Use the following pieces of retrieved context (which may include text excerpts and image captions) to answer the user's question.
If the question seems to be about an image, pay close attention to any context labeled 'Image Caption'.
If you don't know the answer from the context, just say that you don't know. Don't try to make up an answer.

Context:
{context}

Question: {question}

Helpful Answer:
"""
rag_prompt = PromptTemplate(template=rag_prompt_template_str, input_variables=["context", "question"])

# Prompt for general conversation (no document context)
no_doc_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful AI assistant. Engage in conversation. Be friendly and informative."),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{input}")
])

def format_docs_for_rag(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# In-memory store for chat histories
message_history_store = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
    if session_id not in message_history_store:
        message_history_store[session_id] = ChatMessageHistory()
    return message_history_store[session_id]

print("RAG and conversational prompt templates and history management defined.")

RAG and conversational prompt templates and history management defined.


In [None]:
## 🖥️ Step 7: Create Gradio Chat Interface (Multimodal)

In [18]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Helper function to clear vector store when file upload is cleared in Gradio
def handle_file_clear_mm_gradio():
    global vector_store, processed_document_content, processed_images_for_display
    vector_store = None
    processed_document_content = []
    processed_images_for_display = []
    print("File input cleared, vector_store and processed content reset.")
    return "File input cleared. Chatting without document context.", None, None # status, clear chatbot, clear image gallery


if llm and embeddings: # Image captioner is optional but used if available
    def chat_interface_mm_rag_gradio(user_input: str, chat_history_tuples: list, session_id_str: str):
        global vector_store # Access the global vector store

        if not user_input or not user_input.strip():
            return chat_history_tuples, "" # No change to session_id_str needed here

        current_chat_history_object = get_session_history(session_id_str)
        ai_response = "Sorry, I encountered an issue during generation." # Default error

        try:
       # In Cell 15, inside chat_interface_mm_rag_gradio, RAG `if vector_store:` block

            if vector_store:
                print("RAG mode: Using document context (text & image captions).")
                retriever = vector_store.as_retriever(search_kwargs={"k": 4})

                # 1. Prepare initial input for history and question
                def prepare_initial_input(invoke_input_dict):
                    # invoke_input_dict will be {"question": user_input, "session_id": session_id_str}
                    session_id = invoke_input_dict["session_id"]
                    history_messages = get_session_history(session_id).messages
                    return {
                        "question": invoke_input_dict["question"],
                        "history": history_messages
                    }

                # 2. Retrieve context and combine with question and history
                def retrieve_and_combine(input_from_step1):
                    # input_from_step1 will be {"question": str, "history": List[BaseMessage]}
                    question = input_from_step1["question"]
                    retrieved_docs = retriever.invoke(question) # retriever expects a string
                    formatted_context = format_docs_for_rag(retrieved_docs)

                    # --- DEBUG ---
                    print("\n--- DEBUG: retrieve_and_combine output ---")
                    print(f"  Type of question: {type(question)}, Value (start): {str(question)[:100]}")
                    print(f"  Type of formatted_context: {type(formatted_context)}, Value (start): {str(formatted_context)[:200]}")
                    print(f"  Type of history: {type(input_from_step1['history'])}")
                    if input_from_step1['history']:
                        print(f"  Type of first history message: {type(input_from_step1['history'][0])}")
                    print("--- END DEBUG ---\n")
                    # --- END DEBUG ---

                    # Ensure all parts are strings or appropriate types for the prompt
                    if not isinstance(question, str):
                        raise TypeError(f"RAG chain: 'question' must be a string, got {type(question)}")
                    if not isinstance(formatted_context, str):
                        raise TypeError(f"RAG chain: 'context' (from format_docs_for_rag) must be a string, got {type(formatted_context)}")

                    return {
                        "context": formatted_context,
                        "question": question,
                        "history": input_from_step1["history"] # Pass history through
                    }

                # Define the prompt for RAG that includes history
                rag_prompt_with_history = ChatPromptTemplate.from_messages([
                    ("system", rag_prompt_template_str), # rag_prompt_template_str uses {context} and {question}
                    MessagesPlaceholder(variable_name="history"),
                    ("human", "{question}") # Uses {question}
                ])

                # Construct the full RAG chain
                full_rag_chain = (
                    RunnableLambda(prepare_initial_input)       # Output: {"question": str, "history": List[BaseMessage]}
                    | RunnableLambda(retrieve_and_combine)       # Output: {"context": str, "question": str, "history": List[BaseMessage]}
                    | rag_prompt_with_history                    # Output: PromptValue (typically list of BaseMessages for Chat models)
                    | llm                                        # LLM processes this
                    | StrOutputParser()
                )

                ai_response = full_rag_chain.invoke({
                    "question": user_input,
                    "session_id": session_id_str
                })

                # Manually add to history for this RAG path
                current_chat_history_object.add_user_message(user_input)
                current_chat_history_object.add_ai_message(ai_response)

            else: # Standard Conversational Mode
                print("Conversational mode: No document context.")
                standard_conversational_chain_with_history = RunnableWithMessageHistory(
                    runnable=no_doc_prompt | llm | StrOutputParser(),
                    get_session_history=get_session_history,
                    input_messages_key="input",
                    history_messages_key="history",
                )
                ai_response = standard_conversational_chain_with_history.invoke(
                    {"input": user_input},
                    config={"configurable": {"session_id": session_id_str}}
                )
                # RunnableWithMessageHistory handles adding to history

        except Exception as e:
            print(f"Error during generation for session {session_id_str}: {e}")
            error_message_for_chat = f"Sorry, an error occurred: {str(e)}"
            if not any(msg.content == user_input and msg.type == "human" for msg in current_chat_history_object.messages if hasattr(msg, 'type')):
                 current_chat_history_object.add_user_message(user_input)
            current_chat_history_object.add_ai_message(error_message_for_chat)
            ai_response = error_message_for_chat

        chat_history_tuples.append([user_input, ai_response])
        return chat_history_tuples, ""


    with gr.Blocks(theme=gr.themes.Soft(primary_hue="cyan", secondary_hue="sky")) as demo_mm_rag_final:
        gr.Markdown("## 🖼️📄 Multimodal RAG Chatbot (Text Docs, Doc Images, Direct Images)")

        session_id_state = gr.State(value=lambda: os.urandom(16).hex())

        with gr.Row():
            with gr.Column(scale=2):
                chatbot_display = gr.Chatbot(
                    label="Chat with Your Content",
                    bubble_full_width=False,
                    height=500,
                    avatar_images=(None, "https://huggingface.co/avatars/2edb9158f4a690851541ce0f35732988.svg")
                )
                msg_textbox = gr.Textbox(show_label=False, placeholder="Enter message or ask about uploaded content...", container=False)
                with gr.Row():
                    clear_chat_btn = gr.Button("Clear Chat Only")
                    clear_all_btn = gr.Button("Clear Chat & All Content")
            with gr.Column(scale=1):
                gr.Markdown("### Upload Content")
                # Use file_count="multiple" for gr.Files to allow multiple uploads
                file_uploader_component = gr.Files(
                    label="Upload Documents (.pdf, .docx, .csv) or Images (.jpg, .jpeg, .png)",
                    file_types=[".pdf", ".docx", ".doc", ".csv", ".jpg", ".jpeg", ".png"],
                    file_count="multiple" # Allow uploading multiple files
                )
                upload_status_md = gr.Markdown("No content uploaded yet.")
                gr.Markdown("### Extracted/Uploaded Images & Captions")
                image_gallery_display = gr.Gallery(label="Processed Images", height=400, object_fit="contain", columns=2, preview=True)

        # Event Handlers
        file_uploader_component.upload(
            fn=process_uploaded_files, # Use the new function for multiple files
            inputs=[file_uploader_component],
            outputs=[upload_status_md, chatbot_display, image_gallery_display]
        )
        file_uploader_component.clear(
            fn=handle_file_clear_mm_gradio,
            inputs=[],
            outputs=[upload_status_md, chatbot_display, image_gallery_display]
        )

        def clear_chat_only_fn_mm(current_session_id_str):
            if current_session_id_str in message_history_store:
                del message_history_store[current_session_id_str]
            new_session_id = os.urandom(16).hex()
            return [], new_session_id
        clear_chat_btn.click(fn=clear_chat_only_fn_mm, inputs=[session_id_state], outputs=[chatbot_display, session_id_state], queue=False)

        def clear_all_fn_mm(current_session_id_str):
            global vector_store, processed_document_content, processed_images_for_display
            vector_store = None
            processed_document_content = []
            processed_images_for_display = []
            if current_session_id_str in message_history_store:
                del message_history_store[current_session_id_str]
            new_session_id = os.urandom(16).hex()
            return "All content and chat cleared.", [], None, new_session_id
        clear_all_btn.click(fn=clear_all_fn_mm, inputs=[session_id_state], outputs=[upload_status_md, chatbot_display, image_gallery_display, session_id_state], queue=False)

        msg_textbox.submit(fn=chat_interface_mm_rag_gradio, inputs=[msg_textbox, chatbot_display, session_id_state], outputs=[chatbot_display, msg_textbox])

    print("Gradio Multimodal RAG interface defined. Launching...")
    demo_mm_rag_final.queue().launch(debug=True, share=False)
else:
    print("Gradio interface cannot be launched because LLM or Embeddings (or Captioner, if essential) failed to initialize.")