<h1 style="text-align: center; font-size: 50px;">Multimodal RAG Chatbot with Langchain, Torch, Transformers</h1>

Retrieval-Augmented Generation (RAG) is an architectural approach that can enhance the effectiveness of large language model (LLM) applications using customized data. In this example, we use LangChain, an orchestrator for language pipelines, to build an assistant capable of loading information from a web page and use it for answering user questions. We'll leverage torch and transformers for multimodal model support in Python. We'll also use the MLFlow platform to evaluate and trace the LLM responses (in `register-workflow.ipynb`)

# Notebook Overview
- Configuring the Environment
- Data Loading & Cleaning
- Setup Embeddings & Vector Store
- Retrieval Function
- Model Setup & Chain Creation

## Step 0: Configuring the Environment

In this step, we import all the necessary libraries and internal components required to run the RAG pipeline, including modules for notebook parsing, embedding generation, vector storage, and code generation with LLMs.


By using our Local GenAI workspace image, many of the necessary libraries to work with RAG already come pre-installed - in our case, we just need to extra support for multimodal processes.

In [1]:
import time
import os 
from pathlib import Path
import sys
import logging

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# Create logger
logger = logging.getLogger("multimodal_rag_logger")
logger.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False

In [2]:
start_time = time.time()  

logger.info('Notebook execution started.')

2025-07-28 04:16:25 - INFO - Notebook execution started.


In [3]:
%pip install -r ../requirements.txt --quiet

Note: you may need to restart the kernel to use updated packages.


In [4]:
# === Standard Library Imports ===
import gc
import json
import math
import hashlib
import shutil
import warnings
import numpy as np
from pathlib import Path
from statistics import mean
from typing import Any, Dict, List, Optional, TypedDict
from IPython.display import display, Markdown

# === Third-Party Library Imports ===
import mlflow
import torch
from langchain_core.embeddings import Embeddings
from langchain.schema.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from PIL import Image as PILImage
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer, BitsAndBytesConfig, SiglipModel, SiglipProcessor


# === Project-Specific Imports ===
from src.components import SemanticCache, SiglipEmbeddings
from src.wiki_pages_clone import orchestrate_wiki_clone
from src.utils import (
    configure_hf_cache,
    multimodal_rag_asset_status,
    load_config,
    load_secrets,
    load_mm_docs_clean,
    display_images,
)

2025-07-28 04:16:30.183631: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-28 04:16:30.191634: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753676190.201147    6834 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753676190.203771    6834 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753676190.210974    6834 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [5]:
warnings.filterwarnings("ignore")

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")

Using device: cuda



### Verify Assets

In [7]:
CONFIG_PATH = "../configs/config.yaml"
SECRETS_PATH = "../configs/secrets.yaml"

LOCAL_MODEL: Path = Path("/home/jovyan/datafabric/InternVL3-8B-Instruct")
CONTEXT_DIR: Path = Path("../data/context")             
CHROMA_DIR: Path = Path("../data/chroma_store")     
CACHE_DIR: Path = CHROMA_DIR / "semantic_cache"
MANIFEST_PATH: Path = CHROMA_DIR / "manifest.json"

IMAGE_DIR = CONTEXT_DIR / "images"
WIKI_METADATA_DIR = CONTEXT_DIR / "wiki_flat_structure.json"

DEMO_FOLDER = "../demo"

multimodal_rag_asset_status(
    local_model_path=LOCAL_MODEL,
    config_path=CONFIG_PATH,
    secrets_path=SECRETS_PATH,
    wiki_metadata_dir=WIKI_METADATA_DIR,
    context_dir=CONTEXT_DIR,
    chroma_dir=CHROMA_DIR,
    cache_dir=CACHE_DIR,
    manifest_path=MANIFEST_PATH
)

2025-07-28 04:16:31 - INFO - Local Model is properly configured. 
2025-07-28 04:16:31 - INFO - Config is properly configured. 
2025-07-28 04:16:31 - INFO - Secrets is properly configured. 
2025-07-28 04:16:31 - INFO - wiki_flat_structure.json is properly configured. 
2025-07-28 04:16:31 - INFO - CONTEXT is properly configured. 
2025-07-28 04:16:31 - INFO - CHROMA is properly configured. 
2025-07-28 04:16:31 - INFO - CACHE is properly configured. 
2025-07-28 04:16:31 - INFO - MANIFEST is properly configured. 


### Config Loading

In this section, we load configuration parameters from the YAML file in the configs folder.

- **config.yaml**: Contains non-sensitive configuration parameters like model sources and URLs

In [8]:
config = load_config(CONFIG_PATH)

### Config HuggingFace Caches

In the next cell, we configure HuggingFace cache, so that all the models downloaded from them are persisted locally, even after the workspace is closed. This is a future desired feature for AI Studio and the GenAI addon.

In [9]:
# Configure HuggingFace cache
configure_hf_cache()

In [10]:
%%time

# Initialize HuggingFace Embeddings
embeddings = HuggingFaceEmbeddings(
    model_name="intfloat/e5-large-v2",
    cache_folder="/tmp/hf_cache"
)

CPU times: user 1.02 s, sys: 193 ms, total: 1.21 s
Wall time: 3.41 s


## Step 1: Data Loading & Cleaning

`wiki_flat_structure.json` is a custom json metadata for ADO Wiki data. It is flatly structured, with keys for filepath, md content, and a list of images. We also have a image folder that contains all the images for every md page. We directly scrape this data from ADO and perform any cleanup if necessary.

- **secrets.yaml**: For Freemium users, use secrets.yaml to store your sensitive data like API Keys. If you are a Premium user, you can use secrets manager.
- **AIS Secrets Manager**: For Paid users, use the secrets manager in the `Project Setup` tab to configure your API key.

In [11]:
%%time

ADO_PAT = os.getenv("AIS_ADO_TOKEN")
if not ADO_PAT:
    logger.info("Environment variable not found... Secrets Manager not properly set. Falling to secrets.yaml.")
    try:
        secrets = load_secrets(SECRETS_PATH)
        ADO_PAT = secrets.get('AIS_ADO_TOKEN')
    except NameError:
        logger.error("The 'secrets' object is not defined or available.")

try:
    orchestrate_wiki_clone(
        pat=ADO_PAT,
        config=config,
        output_dir=CONTEXT_DIR
    )
    logger.info("✅ Wiki data preparation step completed successfully.")

except Exception as e:
    logger.error("Halting notebook execution due to a critical error in the wiki preparation step.")

2025-07-28 04:16:35 - INFO - Starting ADO Wiki clone process...
2025-07-28 04:16:35 - INFO - Cloning wiki 'Phoenix-DS-Platform.wiki' to temporary directory: /tmp/tmp11pb15wy
2025-07-28 04:16:55 - INFO - Scanning for Markdown files...
2025-07-28 04:16:55 - INFO - → Found 567 Markdown pages.
2025-07-28 04:16:55 - INFO - Copying referenced images to ../data/context/images...
2025-07-28 04:17:02 - INFO - → 738 unique images copied.
2025-07-28 04:17:02 - INFO - Assembling flat JSON structure...
2025-07-28 04:17:02 - INFO - ✅ Wiki data successfully cloned to ../data/context
2025-07-28 04:17:02 - INFO - Cleaned up temporary directory: /tmp/tmp11pb15wy
2025-07-28 04:17:02 - INFO - ✅ Wiki data preparation step completed successfully.


CPU times: user 628 ms, sys: 843 ms, total: 1.47 s
Wall time: 27 s


In [12]:
%%time

WIKI_METADATA_DIR   = Path(WIKI_METADATA_DIR)
IMAGE_DIR = Path(IMAGE_DIR)

mm_raw_docs = load_mm_docs_clean(WIKI_METADATA_DIR, Path(IMAGE_DIR))

def log_stage(name: str, docs: List[Document]):
    logger.info(f"{name}: {len(docs)} docs, avg_tokens={sum(len(d.page_content) for d in docs)/len(docs):.0f}")
log_stage("Docs loaded", mm_raw_docs)

2025-07-28 04:17:03 - INFO - Docs loaded: 567 docs, avg_tokens=3097


CPU times: user 26.4 ms, sys: 60.3 ms, total: 86.7 ms
Wall time: 808 ms


## Step 2: Creation of Chunks

Here, we split the loaded documents into chunks, so we have smaller and more specific texts to add to our vector database. 

We chunk based on header style, and then within each header style we futher chunk based on the provided chunk size. Each chunk retains the page name, which preserves the relevance of each chunk. 

In [13]:
%%time

def chunk_documents(
    docs,
    chunk_size: int = 1200,
    overlap: int = 200,
) -> list[Document]:
    """
    1) Split each wiki page on Markdown headers (#, ## …) to keep logical
       sections together.
    2) Recursively break long sections to <= `chunk_size` chars with `overlap`.
    3) Prefix every chunk with its page-title and store the title in metadata.
    """
    header_splitter = MarkdownHeaderTextSplitter(
        headers_to_split_on=[("#", "title"), ("##", "section")]
    )
    recursive_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=overlap,
    )

    all_chunks: list[Document] = []
    for doc in docs:
        page_title = Path(doc.metadata["source"]).stem.replace("-", " ")

        # 1️. section‑level split (returns list[Document])
        section_docs = header_splitter.split_text(doc.page_content)

        for section in section_docs:
            # 2. size‑based split inside each section
            tiny_texts = recursive_splitter.split_text(section.page_content)

            for idx, tiny in enumerate(tiny_texts):
                all_chunks.append(
                    Document(
                        page_content=f"{page_title}\n\n{tiny.strip()}",
                        metadata={
                            "title": page_title,
                            "source": doc.metadata["source"],
                            "section_header": section.metadata.get("header", ""),
                            "chunk_id": idx,
                        },
                    )
                )
    if all_chunks:
        avg_len = int(mean(len(c.page_content) for c in all_chunks))
        logger.info(
            "Chunking complete: %d docs → %d chunks (avg %d chars)",
            len(docs),
            len(all_chunks),
            avg_len,
        )
    else:
        logger.warning("Chunking produced zero chunks for %d docs", len(docs))

    return all_chunks

splits = chunk_documents(mm_raw_docs)

2025-07-28 04:17:03 - INFO - Chunking complete: 567 docs → 2614 chunks (avg 717 chars)


CPU times: user 69.7 ms, sys: 0 ns, total: 69.7 ms
Wall time: 66.8 ms


## Step 3: Setup Embeddings & Vector Store
Here we setup Siglip for Image embeddings, and also transform our cleaned text chunks into embeddings to be stored in Chroma. We store the chroma data locally on the disk to reduce memory usage. 

### Setup Text ChromaDB

In [14]:
%%time

# 1) TEXT store
def _current_manifest() -> Dict[str, str]:
    """
    Returns a dictionary mapping every context JSON file to its SHA256 content hash.
    This allows detecting changes in file content, not just filenames.
    """
    manifest = {}
    json_files = sorted(CONTEXT_DIR.rglob("*.json"))

    for file_path in json_files:
        try:
            with open(file_path, "rb") as f:
                file_bytes = f.read()
                file_hash = hashlib.sha256(file_bytes).hexdigest()
                manifest[str(file_path.resolve())] = file_hash
        except IOError as e:
            logger.error(f"Could not read file {file_path} for hashing: {e}")
    return manifest

def _needs_rebuild() -> bool:
    """
    Determines if the ChromaDB needs to be rebuilt.
    A rebuild is needed if:
    1. The Chroma directory or manifest file doesn't exist.
    2. The manifest is unreadable.
    3. The stored file hashes in the manifest do not match the current file hashes.
    """
    if not CHROMA_DIR.exists() or not MANIFEST_PATH.exists():
        logger.info("Chroma directory or manifest not found. A rebuild is required.")
        return True
    try:
        old_manifest = json.loads(MANIFEST_PATH.read_text())
    except Exception as e:
        logger.warning(f"Could not read manifest file. A rebuild is required. Error: {e}")
        return True

    current_manifest = _current_manifest()
    if old_manifest != current_manifest:
        logger.info("Data content has changed. A rebuild is required.")
        return True

    return False

def _save_manifest(manifest: Dict[str, str]) -> None:
    """Saves the current data manifest (mapping file paths to hashes) to disk."""
    CHROMA_DIR.mkdir(parents=True, exist_ok=True)
    MANIFEST_PATH.write_text(json.dumps(manifest, indent=2))

def _build_text_db() -> Chroma:
    collection = "mm_text"
    # The rebuild check is now done outside this function.
    # We check if the directory exists. If not, we build.
    if not CHROMA_DIR.exists() or not (CHROMA_DIR / "chroma.sqlite3").exists():
        logger.info("Creating new text context index in %s ...", CHROMA_DIR)
        chroma = Chroma.from_documents(
            documents          = splits,
            embedding          = embeddings,
            collection_name    = collection,
            persist_directory  = str(CHROMA_DIR),
        )
        return chroma

    logger.info("Loading existing Chroma index from %s", CHROMA_DIR)
    return Chroma(
        collection_name   = collection,
        persist_directory = str(CHROMA_DIR),
        embedding_function= embeddings,
    )
    
# Check if a rebuild is needed and wipe the old DB if so.
# This ensures both the text and image databases are rebuilt from scratch.
if _needs_rebuild():
    logger.warning("REBUILDING: Wiping old ChromaDB store at %s", CHROMA_DIR)
    if CHROMA_DIR.exists():
        shutil.rmtree(CHROMA_DIR)
    # Save the new manifest immediately after deciding to rebuild
    _save_manifest(_current_manifest())

# Now, initialize your databases. They will be created fresh if they were just deleted.
text_db = _build_text_db()

2025-07-28 04:17:03 - INFO - Loading existing Chroma index from ../data/chroma_store


CPU times: user 186 ms, sys: 20.7 ms, total: 206 ms
Wall time: 299 ms


### Setup Image ChromaDB

In [15]:
%%time

#  Helper: walk all docs once and gather *unique* image vectors + metadata
def _collect_image_vectors():
    """
    Scans every wiki page for image references and returns three parallel lists:
        img_paths : list[str]   → full file-system paths (for SigLIP)
        img_ids   : list[str]   → unique key per (page, image) pair
        img_meta  : list[dict]  → {"source": wiki_page, "image": file_name}
    Runs in < 1s even for thousands of docs.
    """
    img_paths, img_ids, img_meta = [], [], []
    seen = set()

    for doc in mm_raw_docs:                         # raw wiki pages
        src = doc.metadata["source"]
        for name in doc.metadata.get("images", []): # list[str]
            img_id = f"{src}::{name}"
            if img_id in seen:
                continue                            # de‑dupe
            seen.add(img_id)

            img_paths.append(str(IMAGE_DIR / name))
            img_ids.append(img_id)
            img_meta.append({"source": src, "image": name})

    return img_paths, img_ids, img_meta

siglip_embeddings = SiglipEmbeddings("google/siglip2-base-patch16-224", DEVICE)

# 2) IMAGE store
image_db = Chroma(
    collection_name    = "mm_image",
    persist_directory  = str(CHROMA_DIR),   # SAME dir as text db
    embedding_function = siglip_embeddings, # <-- class you kept
)

# Populate vectors *only* if it is empty
if not image_db._collection.count():
    img_paths, img_ids, img_meta = _collect_image_vectors()
    image_db.add_texts(texts=img_paths, metadatas=img_meta, ids=img_ids)
    image_db.persist()
    logger.info("Indexed %d unique images.", len(img_paths))
else:
    logger.info("Loaded existing image index (%d vectors).",
                image_db._collection.count())


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-07-28 04:17:07 - INFO - Loaded existing image index (752 vectors).


CPU times: user 1.78 s, sys: 794 ms, total: 2.58 s
Wall time: 3.8 s


### Setup Memory Store

In [16]:
# Initialize the semantic cache
semantic_cache = SemanticCache(persist_directory=CACHE_DIR, embedding_function=embeddings)

## Step 4: Retrieval Function

This code implements a multi-stage retrieval process that combines vector similarity search, cross-encoder reranking, and a hybrid scoring mechanism to select the most relevant text documents and associated images.

Here, the system performs an initial similarity search against a `text_db` (likely a vector store like ChromaDB, given your imports). It uses the `query` to find the top `fetch_k` most similar text documents based on their initial embedding similarity. This step acts as a broad filter, quickly identifying a larger set of potentially relevant documents. The result includes both the documents `(docs)` and their initial similarity scores (init_scores). After the intial recall, we use a cross-encoder to rerank these `fetch_k` documents. Unlike the initial embedding similarity, a cross-encoder takes the query and each document content as a pair and provides a more nuanced relevance score by considering their interaction. We also implement Hybrid scoring and select the `top-k` documents at the end.

Using the `top-k` documents, we retrieve images associated with those documents.



In [None]:
def retrieve_mm(
    query: str,
    k_txt: int = 4,
    k_img: int = 8,
    fetch_k: int = 20,
    boost_slug: float = 0.1,
) -> Dict[str, Any]:
    """
    Performs multi-modal retrieval without a cross-encoder.

    1) Coarse recall: Fetches the top `fetch_k` documents based on similarity score.
    2) Score adjustment: Applies an optional score boost if the query slug
       matches the document source.
    3) Top-k selection: Sorts documents by the adjusted score and selects the top `k_txt`.
    4) Image retrieval: Fetches relevant images for the selected top documents.
    """
    # 1) Coarse recall: Fetch top `fetch_k` docs and their initial scores
    docs_and_scores = text_db.similarity_search_with_score(query, k=fetch_k)

    if not docs_and_scores:
        return {"docs": [], "images": [], "scores": []}

    # 2) Compute adjusted scores (+ slug boost)
    slug = query.lower().replace(" ", "-")
    scored_docs = []
    for doc, initial_score in docs_and_scores:
        score = initial_score
        # Apply boost if the formatted query slug appears in the source URL/path
        if slug in doc.metadata.get("source", "").lower():
            score += boost_slug
        scored_docs.append((doc, score))

    # 3) Sort by the new score and select top-k_txt
    scored_docs.sort(key=lambda x: x[1], reverse=True)
    top_docs_and_scores = scored_docs[:k_txt]

    if not top_docs_and_scores:
        return {"docs": [], "images": [], "scores": []}

    selected_docs, final_scores = zip(*top_docs_and_scores)

    # 4) Image retrieval using the sources of the top text documents
    sources = [d.metadata["source"] for d in selected_docs]
    q_emb = siglip_embeddings.embed_query(query)
    img_hits = image_db.similarity_search_by_vector(
        q_emb,
        k=k_img * 2,  # Fetch more to allow for some filtering/variety
        filter={"source": {"$in": sources}},
    )
    images = [img.page_content for img in img_hits[:k_img]]

    return {
        "docs": list(selected_docs),
        "images": images,
        "scores": list(final_scores),
    }

## Step 5: Model Setup & Chain Creation

In this section, we set up our local Large Language Model (LLM) and integrate it into a Question Answering (QA) pipeline. We're using `internvl3-8b-instruct` as our multimodal model, which can process both text and images. This setup is encapsulated within the InternVLMM class, designed for efficient and robust multimodal interactions.

### System Prompt

In [18]:
SYSTEM_PROMPT = """
    You are AI Studio DevOps Assistant. Your function is to analyze images and text, then answer questions based ONLY on the provided materials.
    
    **PERMANENT INSTRUCTIONS:**
    1.  **Analyze and Answer from Context**: Your entire response MUST be derived thoroughly from the provided `<context>` block or the user's image(s).
    2.  **Follow Output Structure**: You MUST follow the multi-part response structure outlined in the user's message. Completing all sections is mandatory.
    3.  **No External Knowledge**: You MUST NOT use any information outside the provided materials.
    4.  **No Hallucination**: Do not invent or assume any details. If information is not present, it does not exist.
    5.  **Handle Missing Information**: If the provided context or image(s) do not contain the answer, your ONLY response will be: "Based on the provided context, I cannot answer this question." Do not add any other words or explanation.
    """

### InternVLMM QA Wrapper

In [19]:
%%time


class InternVLMM:
    """
    Minimal, self-contained multimodal QA wrapper around InternVL3-8B-Instruct.
    This class:
      • loads / resets the model
      • builds the prompt (<context>…)
      • returns the model's answer (based on img and text) and also the top retrieved images
    """

    def __init__(self, cache: SemanticCache):
        self.tok   = None
        self.image_processor = None
        self.cache = cache
        self._load()

    # ---------- public function ----------
    def generate(self, query: str, force_regenerate: bool = False, **retrieval_kwargs) -> Dict[str, Any]:
        """
        Run retrieval, prompt assembly, and model generation.
        Returns a dictionary with the text reply and a list of used image paths.
        """
        
        # === 1. CHECK SEMANTIC CACHE (or bypass if forced) ===
        if not force_regenerate:
            cached_result = self.cache.get(query, threshold=0.92)
            if cached_result:
                logger.info(f"SEMANTIC CACHE HIT for query: '{query}'")
                return cached_result
        
        if force_regenerate:
            logger.info(f"Forced regeneration for query: '{query}'. Clearing old cache entry.")
            self.cache.delete(query)

        logger.info(f"CACHE MISS for query: '{query}'. Running full pipeline.")
        if self.model is None or self.tok is None:
            return {"reply": "Error: model not initialised.", "used_images": []}
    
        # === 2. RETRIEVE (if not cached) ===
        hits = retrieve_mm(query, **retrieval_kwargs)
        docs: List[Any]   = hits["docs"]
        images: List[str] = hits["images"] # This is the list of paths you want to return
    
        if not docs and not images:
            return {"reply": "I don't know based on the provided context.", "used_images": []}

        # === 3. BUILD PROMPT & GENERATE ===
        context_str = "\n\n".join(
            f"<source_document name=\"{d.metadata.get('source', 'unknown')}\">\n{d.page_content}\n</source_document>"
            for d in docs
        )
        
        # Combine the response structure and query into the user message.
        visual_analysis_prompt = ""
        if images:
            visual_analysis_prompt = """
            ## **Visual Analysis**\n
            """
        
        # Construct the final prompt
        user_content = f"""
            <task_instructions>
            Your response must follow this exact structure:
            {visual_analysis_prompt}
            ## **Synthesized Answer**\n
            Next, answer the user's original query. Your answer must be synthesized from the provided text `<context>`. Use the visual analysis only if it is relevant. If the image is not relevant, rely solely on the text context to formulate your answer.\n
            
            ## **Source Documents**\n
            At the very end of your response, cite the source from the context in brackets and backticks, like this: [`source-file-name.md`].\n
            
            </task_instructions>
    
            <context>
                {context_str}
            </context>
    
            <user_query>
                 {query}
            </user_query>
    
            Now, generate the response following all instructions.
            """

        # Construct the conversation history as a list of dictionaries
        conversation = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content}
        ]
        
        # Apply the chat template
        prompt = self.tok.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=True
        )
                
        if images:
            prompt += f"\n\n[{len(images)} image(s) are provided for analysis]"

        # generate
        try:
            self._clear_cuda()
            pixel_values = self._process_images(images) if images else None
            reply = self.model.chat(
                self.tok, pixel_values, prompt,
                generation_config=dict(
                    max_new_tokens=4096, 
                    pad_token_id=self.tok.pad_token_id, 
                    eos_token_id=self.tok.eos_token_id,
                    repetition_penalty=1.25,
                ),
            )
            self._clear_cuda()

            result_dict = {"reply": reply, "used_images": images}
            
            # === 4. UPDATE CACHE ===
            self.cache.set(query, result_dict)
            
            return result_dict

        except RuntimeError as e:
            msg = str(e).lower()
            if "cuda" in msg or "out of memory" in msg:
                logger.warning("CUDA error – resetting model: %s", e)
                self._reset()
                error_reply = "I ran into a GPU memory error – please try again."
            else:
                logger.error("Runtime error: %s", e)
                error_reply = f"Error: {e}"
            return {"reply": error_reply, "used_images": images}
            
    # ---------- internal helpers ----------
    
    def _load(self):
        """Load tokenizer, image_processor, & model. Handles 4-bit quant on GPUs, fp32 on CPU."""
        logger.info("Loading %s ...", LOCAL_MODEL)
        gc.collect()
        self._clear_cuda()

        self.tok = AutoTokenizer.from_pretrained(
            LOCAL_MODEL, trust_remote_code=True
        )
        if self.tok.pad_token is None:
            self.tok.pad_token = self.tok.eos_token

        self.image_processor = AutoImageProcessor.from_pretrained(
            LOCAL_MODEL, trust_remote_code=True, use_fast=True
        )

        q_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4", # Use the modern "Normal Float 4"
            bnb_4bit_compute_dtype=torch.bfloat16, # Speeds up computation
            bnb_4bit_use_double_quant=True, # Minor memory improvement
        )

        self.model = AutoModel.from_pretrained(
            LOCAL_MODEL,
            quantization_config=q_cfg,
            torch_dtype=(torch.bfloat16 if DEVICE == "cuda" else torch.float32),
            low_cpu_mem_usage=True,
            use_flash_attn=False,
            trust_remote_code=True,
            device_map="auto" if DEVICE == "cuda" else None,
        ).eval()
        logger.info("Model loaded on %s.", DEVICE)

    def _reset(self):
        """Free everything and reload (called after persistent CUDA errors)."""
        logger.warning("Resetting InternVL model …")
        del self.model, self.tok, self.image_processor
        self.model = self.tok = self.image_processor = None
        gc.collect()
        self._clear_cuda()
        time.sleep(1)
        self._load()

    def _process_images(self, image_paths: List[str]):
        """
        Convert a list of image filepaths to a single batched tensor.
        """
        if not image_paths:
            return None
        try:
            # Open all images from their file paths
            pil_images = [PILImage.open(p).convert("RGB") for p in image_paths]
            
            # The processor naturally handles a list of PIL images
            processed_data = self.image_processor(images=pil_images, return_tensors="pt")
            pixel_values = processed_data['pixel_values']

            # Match model device/dtype
            target_dtype = next(self.model.parameters()).dtype if self.model else torch.float32
            pixel_values = pixel_values.to(device=DEVICE, dtype=target_dtype)
            
            return pixel_values
            
        except Exception as e:
            logger.error("Image processing failed for one or more images: %s", e)
            return None

    @staticmethod
    def _clear_cuda():
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

# Initalize mm llm
mm = InternVLMM(semantic_cache)

2025-07-28 04:17:07 - INFO - Loading /home/jovyan/datafabric/InternVL3-8B-Instruct ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2025-07-28 04:18:24 - INFO - Model loaded on cuda.


CPU times: user 9.5 s, sys: 13.6 s, total: 23.1 s
Wall time: 1min 17s


## Step 6: Test Generation and Outputs

In [None]:
%%time

question = "How do i run blueprints locally?"
results = mm.generate(question, force_regenerate=False)

print("--- MODEL RESPONSE ---")
display(Markdown(results["reply"]))
print("----------------------\n")

display_images(results["used_images"])

In [None]:
%%time

question2 = "What are some feature flags in AIStudio?"
results = mm.generate(question2, force_regenerate=False)

print("--- MODEL RESPONSE ---")
display(Markdown(results["reply"]))
print("----------------------\n")

display_images(results["used_images"])

In [None]:
%%time

question3 = "How do i manually clean my environment without hooh?"
results = mm.generate(question3, force_regenerate=True)

print("--- MODEL RESPONSE ---")
display(Markdown(results["reply"]))
print("----------------------\n")

display_images(results["used_images"])

In [None]:
%%time

question4 = "How do i sign a config file?"
results = mm.generate(question4, force_regenerate=True)

print("--- MODEL RESPONSE ---")
display(Markdown(results["reply"]))
print("----------------------\n")

display_images(results["used_images"])

In [24]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed successfully.")

2025-07-28 04:18:50 - INFO - ⏱️ Total execution time: 2m 25.10s
2025-07-28 04:18:50 - INFO - ✅ Notebook execution completed successfully.


Built with ❤️ using Z by HP AI Studio.