<h1 style="text-align: center; font-size: 50px;"> 🤖 MLFlow Registration for Multimodal RAG</h1>

# MLFlow Model Service 

In this section, we demonstrate how to deploy a RAG-based chatbot service. This service provides a REST API endpoint that allows users to query the knowledge base with natural language questions, upload new documents to the knowledge base, and manage conversation history, all with built-in safeguards against sensitive information and toxicity. This service encapsulates all the functionality we developed in this notebook, including the document retrieval system, RAG-based question answering capabilities, and Galileo integration for protection, observation and evaluation. It demonstrates how to use our ChatbotService from the src/service directory. 

## Step 0: Imports and Environment Setup

In [1]:
import time
import os 
from pathlib import Path
import sys
import logging

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# Create logger
logger = logging.getLogger("multimodal_rag_logger")
logger.setLevel(logging.INFO)
if not logger.handlers:
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
logger.propagate = False

In [2]:
start_time = time.time()  

logger.info("Notebook execution started.")

2025-07-28 08:24:41 - INFO - Notebook execution started.


In [3]:
%pip install -r ../requirements.txt --quiet

Note: you may need to restart the kernel to use updated packages.


In [4]:
# === Standard Library Imports ===
import gc
import json
import math
import os
import tempfile
import shutil
import warnings
from typing import Any, Dict, List, Optional, TypedDict
from statistics import mean
import hashlib

# === Third-Party Library Imports ===
import mlflow
import numpy as np
import pandas as pd
import torch
from langchain_core.embeddings import Embeddings
from langchain.schema.document import Document
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from mlflow.models.signature import ModelSignature
from mlflow.tracking import MlflowClient
from mlflow.types import ColSpec, DataType, Schema, TensorSpec
from PIL import Image as PILImage
from sentence_transformers import CrossEncoder, SentenceTransformer
from transformers import pipeline, AutoImageProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, SiglipModel, SiglipProcessor

# === Project-Specific Imports ===
# Add the project root to the system path to allow importing from 'src'
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from src.components import SemanticCache, SiglipEmbeddings
from src.wiki_pages_clone import orchestrate_wiki_clone
from src.local_genai_judge import LocalGenAIJudge
from src.utils import (
    configure_hf_cache,
    multimodal_rag_asset_status,
    load_config,
    load_secrets,
    load_mm_docs_clean,
)

2025-07-28 08:24:47.372544: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-28 08:24:47.387942: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753691087.399088    2988 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753691087.402106    2988 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753691087.412120    2988 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [5]:
warnings.filterwarnings("ignore")

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")

Using device: cuda


## Step 1: Configurations

### Verify Assets

In [7]:
CONFIG_PATH = "../configs/config.yaml"
SECRETS_PATH = "../configs/secrets.yaml"

LOCAL_MODEL = "/home/jovyan/datafabric/InternVL3-8B-Instruct"
CONTEXT_DIR: Path = Path("../data/context")
CHROMA_DIR: Path = Path("../data/chroma_store")     
CACHE_DIR: Path = CHROMA_DIR / "semantic_cache"
MANIFEST_PATH: Path = CHROMA_DIR / "manifest.json"

IMAGE_DIR = CONTEXT_DIR / "images"
WIKI_METADATA_DIR = CONTEXT_DIR / "wiki_flat_structure.json"

CHROMA_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR.mkdir(parents=True, exist_ok=True)

multimodal_rag_asset_status(
    local_model_path=LOCAL_MODEL,
    config_path=CONFIG_PATH,
    secrets_path=SECRETS_PATH,
    wiki_metadata_dir=WIKI_METADATA_DIR,
    context_dir=CONTEXT_DIR,
    chroma_dir=CHROMA_DIR,
    cache_dir=CACHE_DIR,
    manifest_path=MANIFEST_PATH
)

2025-07-28 08:24:49 - INFO - Local Model is properly configured. 
2025-07-28 08:24:49 - INFO - Config is properly configured. 
2025-07-28 08:24:49 - INFO - Secrets is properly configured. 
2025-07-28 08:24:49 - INFO - wiki_flat_structure.json is properly configured. 
2025-07-28 08:24:49 - INFO - CONTEXT is properly configured. 
2025-07-28 08:24:49 - INFO - CHROMA is properly configured. 
2025-07-28 08:24:49 - INFO - CACHE is properly configured. 
2025-07-28 08:24:49 - INFO - MANIFEST is properly configured. 


In [8]:
config = load_config(CONFIG_PATH)

### Config HuggingFace Caches

In the next cell, we configure HuggingFace cache, so that all the models downloaded from them are persisted locally, even after the workspace is closed. This is a future desired feature for AI Studio and the GenAI addon.

In [9]:
configure_hf_cache()

In [10]:
# Initialize HuggingFace Embeddings
txt_embeddings = HuggingFaceEmbeddings(
    model_name="intfloat/e5-large-v2",
    cache_folder="/tmp/hf_cache"
)

### MLflow Configuration

In [11]:
MODEL_NAME = "AIStudio-Multimodal-Chatbot-Model"
RUN_NAME = f"Register_{MODEL_NAME}"
EXPERIMENT_NAME = "AIStudio-Multimodal-Chatbot-Experiment"

# Set MLflow tracking URI and experiment
# This should be configured for your environment, e.g., a remote server or local file path
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "/phoenix/mlflow"))
mlflow.set_experiment(experiment_name=EXPERIMENT_NAME)

logger.info(f"Using MLflow tracking URI: {mlflow.get_tracking_uri()}")
logger.info(f"Using MLflow experiment: '{EXPERIMENT_NAME}'")

2025-07-28 08:24:52 - INFO - Using MLflow tracking URI: /phoenix/mlflow
2025-07-28 08:24:52 - INFO - Using MLflow experiment: 'AIStudio-Multimodal-Chatbot-Experiment'


## Step 2: Data loading, Cleaning, and Setup

`wiki_flat_structure.json` is a custom json metadata for ADO Wiki data. It is flatly structured, with keys for filepath, md content, and a list of images. We also have a image folder that contains all the images for every md page. We directly scrape this data from ADO and perform any cleanup if necessary.

- **secrets.yaml**: For Freemium users, use secrets.yaml to store your sensitive data like API Keys. If you are a Premium user, you can use secrets manager.
- **AIS Secrets Manager**: For Paid users, use the secrets manager in the `Project Setup` tab to configure your API key.

### Parse from ADO Wiki

In [12]:
%%time

ADO_PAT = os.getenv("AIS_ADO_TOKEN")
if not ADO_PAT:
    logger.info("Environment variable not found... Secrets Manager not properly set. Falling to secrets.yaml.")
    try:
        secrets = load_secrets(SECRETS_PATH)
        ADO_PAT = secrets.get('AIS_ADO_TOKEN')
    except NameError:
        logger.error("The 'secrets' object is not defined or available.")

try:
    orchestrate_wiki_clone(
        pat=ADO_PAT,
        config=config,
        output_dir=CONTEXT_DIR
    )
    logger.info("✅ Wiki data preparation step completed successfully.")

except Exception as e:
    logger.error("Halting notebook execution due to a critical error in the wiki preparation step.")

2025-07-28 08:24:52 - INFO - Starting ADO Wiki clone process...
2025-07-28 08:24:52 - INFO - Cloning wiki 'Phoenix-DS-Platform.wiki' to temporary directory: /tmp/tmph7fl3bmj
2025-07-28 08:25:21 - INFO - Scanning for Markdown files...
2025-07-28 08:25:21 - INFO - → Found 567 Markdown pages.
2025-07-28 08:25:21 - INFO - Copying referenced images to ../data/context/images...
2025-07-28 08:25:27 - INFO - → 738 unique images copied.
2025-07-28 08:25:27 - INFO - Assembling flat JSON structure...
2025-07-28 08:25:27 - INFO - ✅ Wiki data successfully cloned to ../data/context
2025-07-28 08:25:27 - INFO - Cleaned up temporary directory: /tmp/tmph7fl3bmj
2025-07-28 08:25:27 - INFO - ✅ Wiki data preparation step completed successfully.
CPU times: user 509 ms, sys: 868 ms, total: 1.38 s
Wall time: 35.3 s


In [13]:
%%time

WIKI_METADATA_DIR   = Path(WIKI_METADATA_DIR)
IMAGE_DIR = Path(IMAGE_DIR)

mm_raw_docs = load_mm_docs_clean(WIKI_METADATA_DIR, Path(IMAGE_DIR))

def log_stage(name: str, docs: List[Document]):
    logger.info(f"{name}: {len(docs)} docs, avg_tokens={sum(len(d.page_content) for d in docs)/len(docs):.0f}")
log_stage("Docs loaded", mm_raw_docs)

2025-07-28 08:25:28 - INFO - Docs loaded: 567 docs, avg_tokens=3098
CPU times: user 20.2 ms, sys: 59.3 ms, total: 79.4 ms
Wall time: 726 ms


### Creation of Chunks

In [14]:
%%time

def chunk_documents(
    docs,
    chunk_size: int = 1200,
    overlap: int = 200,
) -> list[Document]:
    """
    1) Split each wiki page on Markdown headers (#, ## …) to keep logical
       sections together.
    2) Recursively break long sections to <= `chunk_size` chars with `overlap`.
    3) Prefix every chunk with its page-title and store the title in metadata.
    """
    header_splitter = MarkdownHeaderTextSplitter(
        headers_to_split_on=[("#", "title"), ("##", "section")]
    )
    recursive_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=overlap,
    )

    all_chunks: list[Document] = []
    for doc in docs:
        page_title = Path(doc.metadata["source"]).stem.replace("-", " ")

        # 1️. section‑level split (returns list[Document])
        section_docs = header_splitter.split_text(doc.page_content)

        for section in section_docs:
            # 2. size‑based split inside each section
            tiny_texts = recursive_splitter.split_text(section.page_content)

            for idx, tiny in enumerate(tiny_texts):
                all_chunks.append(
                    Document(
                        page_content=f"{page_title}\n\n{tiny.strip()}",
                        metadata={
                            "title": page_title,
                            "source": doc.metadata["source"],
                            "section_header": section.metadata.get("header", ""),
                            "chunk_id": idx,
                        },
                    )
                )
    if all_chunks:
        avg_len = int(mean(len(c.page_content) for c in all_chunks))
        logger.info(
            "Chunking complete: %d docs → %d chunks (avg %d chars)",
            len(docs),
            len(all_chunks),
            avg_len,
        )
    else:
        logger.warning("Chunking produced zero chunks for %d docs", len(docs))

    return all_chunks

splits = chunk_documents(mm_raw_docs)

2025-07-28 08:25:28 - INFO - Chunking complete: 567 docs → 2615 chunks (avg 717 chars)
CPU times: user 67.2 ms, sys: 0 ns, total: 67.2 ms
Wall time: 63 ms


### Setup Text ChromaDB

In [15]:
%%time

# 1) TEXT store
def _current_manifest() -> Dict[str, str]:
    """
    Returns a dictionary mapping every context JSON file to its SHA256 content hash.
    This allows detecting changes in file content, not just filenames.
    """
    manifest = {}
    json_files = sorted(CONTEXT_DIR.rglob("*.json"))

    for file_path in json_files:
        try:
            with open(file_path, "rb") as f:
                file_bytes = f.read()
                file_hash = hashlib.sha256(file_bytes).hexdigest()
                manifest[str(file_path.resolve())] = file_hash
        except IOError as e:
            logger.error(f"Could not read file {file_path} for hashing: {e}")
    return manifest

def _needs_rebuild() -> bool:
    """
    Determines if the ChromaDB needs to be rebuilt.
    A rebuild is needed if:
    1. The Chroma directory or manifest file doesn't exist.
    2. The manifest is unreadable.
    3. The stored file hashes in the manifest do not match the current file hashes.
    """
    if not CHROMA_DIR.exists() or not MANIFEST_PATH.exists():
        logger.info("Chroma directory or manifest not found. A rebuild is required.")
        return True
    try:
        old_manifest = json.loads(MANIFEST_PATH.read_text())
    except Exception as e:
        logger.warning(f"Could not read manifest file. A rebuild is required. Error: {e}")
        return True

    current_manifest = _current_manifest()
    if old_manifest != current_manifest:
        logger.info("Data content has changed. A rebuild is required.")
        return True

    return False

def _save_manifest(manifest: Dict[str, str]) -> None:
    """Saves the current data manifest (mapping file paths to hashes) to disk."""
    CHROMA_DIR.mkdir(parents=True, exist_ok=True)
    MANIFEST_PATH.write_text(json.dumps(manifest, indent=2))

def _build_text_db() -> Chroma:
    collection = "mm_text"
    # The rebuild check is now done outside this function.
    # We check if the directory exists. If not, we build.
    if not CHROMA_DIR.exists() or not (CHROMA_DIR / "chroma.sqlite3").exists():
        logger.info("Creating new text context index in %s ...", CHROMA_DIR)
        chroma = Chroma.from_documents(
            documents          = splits,
            embedding          = txt_embeddings,
            collection_name    = collection,
            persist_directory  = str(CHROMA_DIR),
        )
        return chroma

    logger.info("Loading existing Chroma index from %s", CHROMA_DIR)
    return Chroma(
        collection_name   = collection,
        persist_directory = str(CHROMA_DIR),
        embedding_function= txt_embeddings,
    )
    
# Check if a rebuild is needed and wipe the old DB if so.
# This ensures both the text and image databases are rebuilt from scratch.
if _needs_rebuild():
    logger.warning("REBUILDING: Wiping old ChromaDB store at %s", CHROMA_DIR)
    if CHROMA_DIR.exists():
        shutil.rmtree(CHROMA_DIR)
    # Save the new manifest immediately after deciding to rebuild
    _save_manifest(_current_manifest())

# Now, initialize your databases. They will be created fresh if they were just deleted.
text_db = _build_text_db()
CACHE_DIR.mkdir(parents=True, exist_ok=True)

2025-07-28 08:25:28 - INFO - Loading existing Chroma index from ../data/chroma_store
CPU times: user 172 ms, sys: 20.9 ms, total: 193 ms
Wall time: 276 ms


### Setup Image ChromaDB

In [16]:
%%time

#  Helper: walk all docs once and gather *unique* image vectors + metadata
def _collect_image_vectors():
    """
    Scans every wiki page for image references and returns three parallel lists:
        img_paths : list[str]   → full file-system paths (for SigLIP)
        img_ids   : list[str]   → unique key per (page, image) pair
        img_meta  : list[dict]  → {"source": wiki_page, "image": file_name}
    Runs in < 1s even for thousands of docs.
    """
    img_paths, img_ids, img_meta = [], [], []
    seen = set()

    for doc in mm_raw_docs:                         # raw wiki pages
        src = doc.metadata["source"]
        for name in doc.metadata.get("images", []): # list[str]
            img_id = f"{src}::{name}"
            if img_id in seen:
                continue                            # de‑dupe
            seen.add(img_id)

            img_paths.append(str(IMAGE_DIR / name))
            img_ids.append(img_id)
            img_meta.append({"source": src, "image": name})

    return img_paths, img_ids, img_meta

siglip_embeddings = SiglipEmbeddings("google/siglip2-base-patch16-224", DEVICE)

# 2) IMAGE store
image_db = Chroma(
    collection_name    = "mm_image",
    persist_directory  = str(CHROMA_DIR),   # SAME dir as text db
    embedding_function = siglip_embeddings, # <-- class you kept
)

# Populate vectors *only* if it is empty
if not image_db._collection.count():
    img_paths, img_ids, img_meta = _collect_image_vectors()
    image_db.add_texts(texts=img_paths, metadatas=img_meta, ids=img_ids)
    image_db.persist()
    logger.info("Indexed %d unique images.", len(img_paths))
else:
    logger.info("Loaded existing image index (%d vectors).",
                image_db._collection.count())


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


2025-07-28 08:25:32 - INFO - Loaded existing image index (752 vectors).
CPU times: user 1.74 s, sys: 714 ms, total: 2.45 s
Wall time: 3.73 s


## Step 3: MLflow Model Setup

In [17]:
class MultimodalRagModel(mlflow.pyfunc.PythonModel):
    """
    An MLflow PythonModel that encapsulates the entire Multimodal RAG pipeline.

    This class faithfully reproduces the workflow from the `run-notebook.ipynb`, including
    data loading, multi-stage retrieval (vector search + reranking), and multimodal
    generation with the InternVL model.

    Expected Artifacts during logging/loading:
      - "chroma_dir": Path to the persisted Chroma vectorstore directory.
      - "context_dir": Path to the root data directory, containing the `images` subdirectory.
      - "cache_dir": Path to the directory for the semantic cache.
    """

    # --------------------------------------------------------------------------
    # Helper Classes (Encapsulated from the original notebook)
    # --------------------------------------------------------------------------

    class InternVLMM:
        """Minimal, self-contained multimodal QA wrapper around InternVL-Chat-V1-5."""
        def __init__(self, model_path: str, device: str, cache: Any):
            self.device = device
            self.model = None
            self.model_path = model_path
            self.tok = None
            self.image_processor = None
            self.cache = cache
            self._load()

        def generate(self, query: str, context: Dict[str, Any], force_regenerate: bool = False) -> Dict[str, Any]:
            start_gen_time = time.time()
            
            if not force_regenerate:
                cached_result = self.cache.get(query, threshold=0.92)
                if cached_result:
                    logger.info(f"SEMANTIC CACHE HIT for query: '{query}'")
                    return cached_result

            if force_regenerate:
                logger.info(f"Forced regeneration for query: '{query}'. Clearing old cache entry.")
                self.cache.delete(query)

            logger.info(f"CACHE MISS for query: '{query}'. Running full pipeline.")
            if self.model is None or self.tok is None:
                return {"reply": "Error: model not initialised.", "used_images": []}

            hits = self._retrieve_mm(query, **context)
            docs, images = hits["docs"], hits["images"]

            if not docs and not images:
                return {"reply": "I don't know based on the provided context.", "used_images": []}

            # Build prompt
            context_str = "\n\n".join(
                f"<source_document name=\"{d.metadata.get('source', 'unknown')}\">\n{d.page_content}\n</source_document>"
                for d in docs
            )

            visual_analysis_prompt = ""
            if images:
                visual_analysis_prompt = """
                ## **Visual Analysis**
                #### Answer Here
                First, provide a detailed description of what is shown in the provided image(s), based only on what you can see.
                """
            
            # Construct the final prompt
            user_content = f"""
                <task_instructions>
                Your response must follow this exact structure:
                {visual_analysis_prompt}
                ## **Synthesized Answer**
                #### Answer Here
                Next, answer the user's original query. Your answer must be synthesized from the provided text `<context>`. Use the visual analysis only if it is relevant. If the image is not relevant, rely solely on the text context to formulate your answer.
        
                ## **Source Documents**
                #### Answer Here
                At the very end of your response, cite the source from the context in brackets and backticks, like this: [`source-file-name.md`].
                </task_instructions>
        
                <context>
                    {context_str}
                </context>
        
                <user_query>
                     {query}
                </user_query>
        
                Now, generate the response following all instructions.
                """
            SYSTEM_PROMPT = """
                You are AI Studio DevOps Assistant. Your function is to analyze images and text, then answer questions based ONLY on the provided materials.
                
                **PERMANENT INSTRUCTIONS:**
                1.  **Analyze and Answer from Context**: Your entire response MUST be derived thoroughly from the provided `<context>` block or the user's image(s).
                2.  **Follow Output Structure**: You MUST follow the multi-part response structure outlined in the user's message. Completing all sections is mandatory.
                3.  **No External Knowledge**: You MUST NOT use any information outside the provided materials.
                4.  **No Hallucination**: Do not invent or assume any details. If information is not present, it does not exist.
                5.  **Handle Missing Information**: If the provided context or image(s) do not contain the answer, your ONLY response will be: "Based on the provided context, I cannot answer this question." Do not add any other words or explanation.
                """
            conversation = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content}
            ]

            prompt = self.tok.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=True
            )

            # Generate
            try:
                self._clear_cuda()
                pixel_values = self._process_images(images) if images else None
                reply = self.model.chat(
                    self.tok, pixel_values, prompt,
                    generation_config=dict(
                        max_new_tokens=4096, 
                        pad_token_id=self.tok.pad_token_id, 
                        eos_token_id=self.tok.eos_token_id,
                        repetition_penalty=1.25,
                    ),
                )
                self._clear_cuda()
                end_gen_time = time.time()
                elapsed_time = end_gen_time - start_gen_time
                
                result_dict = {
                    "reply": reply,
                    "used_images": images,
                    "generation_time_seconds": elapsed_time
                }

                self.cache.set(query, result_dict)
                return result_dict
            except RuntimeError as e:
                logger.error("InternVL generation failed: %s", e)
                return {"reply": f"Error during generation: {e}", "used_images": images}

        def _retrieve_mm(self, query: str, text_db: Any, image_db: Any, siglip_embeds: Any, k_txt: int = 4, k_img: int = 8, fetch_k: int = 20, boost_slug: float = 0.1) -> Dict[str, Any]:
            """Performs retrieval for text and associated images without a cross-encoder."""
            # 1. Coarse recall (text)
            docs_and_scores = text_db.similarity_search_with_score(query, k=fetch_k)
            if not docs_and_scores:
                return {"docs": [], "images": []}

            # 2. Score adjustment
            slug = query.lower().replace(" ", "-")
            scored_docs = []
            for doc, initial_score in docs_and_scores:
                score = initial_score
                if slug in doc.metadata.get("source", "").lower():
                    score += boost_slug
                scored_docs.append((doc, score))

            # 3. Select top-k text
            scored_docs.sort(key=lambda x: x[1], reverse=True)
            top_docs_and_scores = scored_docs[:k_txt]

            if not top_docs_and_scores:
                return {"docs": [], "images": []}

            selected_docs, _ = zip(*top_docs_and_scores)

            # 4. Image retrieval
            sources = [d.metadata["source"] for d in selected_docs]
            q_emb = siglip_embeds.embed_query(query)
            img_hits = image_db.similarity_search_by_vector(q_emb, k=k_img * 2, filter={"source": {"$in": sources}})
            images = [img.page_content for img in img_hits[:k_img]]

            return {"docs": list(selected_docs), "images": images}

        def _load(self):
            logger.info("Loading %s...", self.model_path)
            self._clear_cuda()

            self.tok = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
            if self.tok.pad_token is None: self.tok.pad_token = self.tok.eos_token

            self.image_processor = AutoImageProcessor.from_pretrained(self.model_path, trust_remote_code=True, use_fast=True)

            q_cfg = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            ) if self.device == "cuda" else None

            self.model = AutoModel.from_pretrained(
                self.model_path,
                quantization_config=q_cfg,
                torch_dtype=(torch.bfloat16 if self.device == "cuda" else torch.float32),
                low_cpu_mem_usage=True,
                use_flash_attn=False,
                trust_remote_code=True,
                device_map="auto" if self.device == "cuda" else None,
            ).eval()
            logger.info("Model loaded on %s.", self.device)

        def _process_images(self, image_paths: List[str]):
            if not image_paths: return None
            try:
                pil_images = [PILImage.open(p).convert("RGB") for p in image_paths]
                processed_data = self.image_processor(images=pil_images, return_tensors="pt")
                # Ensure pixel values are on the same device and dtype as the model
                pixel_values = processed_data['pixel_values'].to(device=self.device, dtype=next(self.model.parameters()).dtype)
                return pixel_values
            except Exception as e:
                logger.error("Image processing failed: %s", e)
                return None

        def _clear_cuda(self):
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

    # --------------------------------------------------------------------------
    # MLflow pyfunc Methods
    # --------------------------------------------------------------------------

    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """
        This method is called when loading an MLflow model. It initializes all
        necessary components using the artifacts logged with the model.
        """
        logger.info("--- Initializing MultimodalRagModel context ---")
        
        # This part remains the same
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Running on device: {self.device}")
    
        # Get the path to the bundled model artifacts that MLflow provides
        # The key "local_model_dir" must match the key used in the `artifacts` dict during logging
        model_artifact_path = Path(context.artifacts["local_model_dir"])
    
        self.model_path = model_artifact_path.resolve()
        logger.info(f"Resolved local model path to: {self.model_path}")
        
        # The rest of the method now uses this resolved path
        self.chroma_dir = Path(context.artifacts["chroma_dir"])
        self.context_dir = Path(context.artifacts["context_dir"])
        self.cache_dir = Path(context.artifacts["cache_dir"])
        logger.info(f"Artifacts loaded: chroma_dir='{self.chroma_dir}', context_dir='{self.context_dir}', cache_dir='{self.cache_dir}'")
    
        # This loads from the local artifacts bundled with your MLflow model
        logger.info("Loading embedding models and cross-encoder from local artifacts...")
        
        # 1. Get the local paths from the MLflow context
        e5_model_path = context.artifacts["e5_model_dir"]
        siglip_model_path = context.artifacts["siglip_model_dir"]
        
        # 2. Initialize models using the local paths
        self.text_embed_model = HuggingFaceEmbeddings(
            model_name=e5_model_path, 
            model_kwargs={"device": self.device}
        )
        self.siglip_embed_model = SiglipEmbeddings(
            model_id=siglip_model_path, 
            device=self.device
        )
        
        logger.info("✅ Models loaded successfully from artifacts.")
    
        logger.info("Loading ChromaDB vector stores...")
        self.text_db = Chroma(
            collection_name="mm_text",
            persist_directory=str(self.chroma_dir),
            embedding_function=self.text_embed_model,
        )
        self.image_db = Chroma(
            collection_name="mm_image",
            persist_directory=str(self.chroma_dir),
            embedding_function=self.siglip_embed_model,
        )
        logger.info(f"Text DB count: {self.text_db._collection.count()}, Image DB count: {self.image_db._collection.count()}")
    
        self.cache = SemanticCache(persist_directory=self.cache_dir, embedding_function=self.text_embed_model)
    
        # The InternVLMM will now be initialized with the safe, absolute path
        self.mm_llm = self.InternVLMM(model_path=self.model_path, device=self.device, cache=self.cache)

        logger.info("Initializing evaluation judge by sharing the main model...")
        try:
            # Instantiate the MODIFIED judge with the model, not a pipeline
            self.judge = LocalGenAIJudge(
                model=self.mm_llm.model,      # Pass the model directly
                tokenizer=self.mm_llm.tok   # Pass the tokenizer
            )
            logger.info("✅ Evaluation judge initialized successfully.")
            
        except Exception as e:
            logger.error(f"Failed to initialize evaluation judge: {e}")
            self.judge = None

        logger.info("--- Context initialization complete ---")


    def predict(self, context: mlflow.pyfunc.PythonModelContext, model_input: pd.DataFrame) -> pd.DataFrame:
        """
        MLflow inference entrypoint.
        Expects a pandas DataFrame with a "query" column.
        Returns a DataFrame with "reply" and "used_images" columns.
        """
        logger.info("Received prediction request.")
        queries = model_input["query"].tolist()
        force_regenerate = model_input.get("force_regenerate", pd.Series([False] * len(queries))).tolist()
        results = []

        retrieval_context = {
            "text_db": self.text_db,
            "image_db": self.image_db,
            "siglip_embeds": self.siglip_embed_model,
        }

        for i, query in enumerate(queries):
            logger.info(f"Processing query: '{query}'")
            
            # 1. Generate the answer
            response_dict = self.mm_llm.generate(query, retrieval_context, force_regenerate=force_regenerate[i])

            # Re-run retrieval to get the context string for evaluation
            retrieved_info = self.mm_llm._retrieve_mm(query, **retrieval_context)
            context_str = "\n\n".join(d.page_content for d in retrieved_info["docs"])

            # 2. Run evaluation if the judge was loaded successfully
            if self.judge:
                # Create a single-row DataFrame for the judge
                eval_df = pd.DataFrame([{
                    "questions": query,
                    "result": response_dict["reply"],
                    "source_documents": context_str
                }])
                
                # Get scores and add them to the response dictionary
                response_dict["faithfulness"] = self.judge.evaluate_faithfulness(eval_df).iloc[0]
                response_dict["relevance"] = self.judge.evaluate_relevance(eval_df).iloc[0]
            else:
                # Provide default null values if the judge isn't available
                response_dict["faithfulness"] = None
                response_dict["relevance"] = None
            
            results.append(response_dict)

        return pd.DataFrame(results)

    @classmethod
    def log_model(cls, model_name: str, local_model: str) -> None:
        """
        Helper class method to log the MultimodalRagModel to MLflow.
        This version downloads supporting models to a temporary directory that is
        automatically cleaned up after logging.
        """
        logger.info(f"--- Logging '{model_name}' to MLflow ---")
        
        # Use a temporary directory that gets automatically deleted
        with tempfile.TemporaryDirectory() as temp_dir:
            logger.info(f"Created temporary directory for models: {temp_dir}")
            temp_path = Path(temp_dir)
    
            # --- 1. Download models into the temporary directory ---
            # e5 model
            e5_path = temp_path / "e5-large-v2"
            e5_model = SentenceTransformer("intfloat/e5-large-v2")
            e5_model.save(str(e5_path))
            logger.info(f"✅ Temporarily saved e5-large-v2 to {e5_path}")
    
            # SigLIP model
            siglip_path = temp_path / "siglip2-base-patch16-224"
            SiglipModel.from_pretrained("google/siglip2-base-patch16-224").save_pretrained(siglip_path)
            SiglipProcessor.from_pretrained("google/siglip2-base-patch16-224").save_pretrained(siglip_path)
            logger.info(f"✅ Temporarily saved SigLIP to {siglip_path}")
    
            # --- 2. Define artifacts using paths from the temporary directory ---
            project_root = Path.cwd().parent.resolve()
            artifacts = {
                "local_model_dir": local_model,
                "e5_model_dir": str(e5_path),
                "siglip_model_dir": str(siglip_path),
                "chroma_dir": str(project_root / "data" / "chroma_store"),
                "context_dir": str(project_root / "data" / "context"),
                "cache_dir": str(project_root / "data" / "chroma_store" / "semantic_cache"),
            }
    
            # --- 3. Log the model (MLflow will copy from the temp dir) ---
            input_schema = Schema([
                ColSpec(DataType.string, "query"),
                ColSpec(DataType.boolean, "force_regenerate")
            ])
            output_schema = Schema([
                ColSpec(DataType.string, "reply"),
                ColSpec(DataType.string, "used_images"),
                ColSpec(DataType.double, "generation_time_seconds"),
                ColSpec(DataType.double, "faithfulness"),
                ColSpec(DataType.double, "relevance"),
            ])
            signature = ModelSignature(inputs=input_schema, outputs=output_schema)
    
            mlflow.pyfunc.log_model(
                artifact_path=model_name,
                python_model=cls(),
                artifacts=artifacts,
                pip_requirements="../requirements.txt",
                signature=signature,
                code_paths=["../src"],
            )
    
        # The temporary directory and all its contents are automatically deleted here
        logger.info(f"✅ Successfully logged '{model_name}' and cleaned up temporary files.")

## Step 4: Start Run, Log & Register Model

In [18]:
# %%
# --- !!! ADD THIS NEW CELL BEFORE LOADING THE MODEL !!! ---
# This cell explicitly deletes the large model objects and cleans up GPU memory.

logger.info("--- Cleaning up VRAM before loading model from registry ---")

# Delete the objects holding GPU memory
try:
    del text_db
    del image_db
    del txt_embeddings
    del siglip_embeddings
    del splits # Also good to clean up large CPU objects
    logger.info("Deleted model and data objects from memory.")
except NameError:
    logger.warning("Some objects for cleanup were not found, they may have already been deleted.")

# Force Python's garbage collector to run
import gc
gc.collect()

# Now, clear the CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    logger.info("✅ VRAM cleaned and cache emptied.")

2025-07-28 08:25:32 - INFO - --- Cleaning up VRAM before loading model from registry ---
2025-07-28 08:25:32 - INFO - Deleted model and data objects from memory.
2025-07-28 08:25:32 - INFO - ✅ VRAM cleaned and cache emptied.


In [19]:
%%time

# --- Start MLflow Run and Log the Model ---
try:
    with mlflow.start_run(run_name=RUN_NAME) as run:
        run_id = run.info.run_id
        logger.info(f"Started MLflow run: {run_id}")

        # Use the class method to log the model and its artifacts
        MultimodalRagModel.log_model(model_name=MODEL_NAME, local_model=LOCAL_MODEL)

        model_uri = f"runs:/{run_id}/{MODEL_NAME}"
        logger.info(f"Registering model from URI: {model_uri}")
        
        # Register the model in the MLflow Model Registry
        mlflow.register_model(model_uri=model_uri, name=MODEL_NAME)
        logger.info(f"✅ Successfully registered model '{MODEL_NAME}'")

except FileNotFoundError as e:
    logger.error(f"Error: A required file or directory was not found. Please ensure the project structure is correct.")
    logger.error(f"Details: {e}")
except Exception as e:
    logger.error(f"An unexpected error occurred during the MLflow run: {e}", exc_info=True)

2025-07-28 08:25:32 - INFO - Started MLflow run: 41ce0430678c4e12990f1f576eb41851
2025-07-28 08:25:32 - INFO - --- Logging 'AIStudio-Multimodal-Chatbot-Model' to MLflow ---
2025-07-28 08:25:32 - INFO - Created temporary directory for models: /tmp/tmpk00kt338
2025-07-28 08:25:36 - INFO - ✅ Temporarily saved e5-large-v2 to /tmp/tmpk00kt338/e5-large-v2
2025-07-28 08:25:43 - INFO - ✅ Temporarily saved SigLIP to /tmp/tmpk00kt338/siglip2-base-patch16-224


Downloading artifacts:   0%|          | 0/49 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/743 [00:00<?, ?it/s]

Downloading artifacts: 0it [00:00, ?it/s]

2025-07-28 08:29:48 - INFO - ✅ Successfully logged 'AIStudio-Multimodal-Chatbot-Model' and cleaned up temporary files.
2025-07-28 08:29:48 - INFO - Registering model from URI: runs:/41ce0430678c4e12990f1f576eb41851/AIStudio-Multimodal-Chatbot-Model
2025-07-28 08:29:48 - INFO - ✅ Successfully registered model 'AIStudio-Multimodal-Chatbot-Model'
CPU times: user 4.54 s, sys: 45 s, total: 49.5 s
Wall time: 4min 16s


Registered model 'AIStudio-Multimodal-Chatbot-Model' already exists. Creating a new version of this model...
Created version '3' of model 'AIStudio-Multimodal-Chatbot-Model'.


In [20]:
# --- Retrieve the latest version from the Model Registry ---
try:
    client = MlflowClient()
    versions = client.get_latest_versions(MODEL_NAME, stages=["None"])
    if not versions:
        raise RuntimeError(f"No registered versions found for model '{MODEL_NAME}'.")
    
    latest_version = versions[0]
    logger.info(f"Found latest version '{latest_version.version}' for model '{MODEL_NAME}' in stage '{latest_version.current_stage}'.")
    model_uri_registry = latest_version.source

except Exception as e:
    logger.error(f"Failed to retrieve model from registry: {e}", exc_info=True)
    model_uri_registry = None # Ensure variable exists


2025-07-28 08:29:49 - INFO - Found latest version '3' for model 'AIStudio-Multimodal-Chatbot-Model' in stage 'None'.


In [21]:
if model_uri_registry:
    try:
        logger.info(f"Loading model from: {model_uri_registry}")
        loaded_model = mlflow.pyfunc.load_model(model_uri=model_uri_registry)
        logger.info("✅ Successfully loaded model from registry.")
    except Exception as e:
        logger.error(f"Failed to load model from registry URI: {e}", exc_info=True)
        loaded_model = None
else:
    logger.warning("Skipping model loading due to previous errors.")
    loaded_model = None

2025-07-28 08:29:49 - INFO - Loading model from: /phoenix/mlflow/747860200993380382/41ce0430678c4e12990f1f576eb41851/artifacts/AIStudio-Multimodal-Chatbot-Model
2025-07-28 08:29:49 - INFO - --- Initializing MultimodalRagModel context ---
2025-07-28 08:29:49 - INFO - Running on device: cuda
2025-07-28 08:29:49 - INFO - Resolved local model path to: /phoenix/mlflow/747860200993380382/41ce0430678c4e12990f1f576eb41851/artifacts/AIStudio-Multimodal-Chatbot-Model/artifacts/InternVL3-8B-Instruct
2025-07-28 08:29:49 - INFO - Artifacts loaded: chroma_dir='/phoenix/mlflow/747860200993380382/41ce0430678c4e12990f1f576eb41851/artifacts/AIStudio-Multimodal-Chatbot-Model/artifacts/chroma_store', context_dir='/phoenix/mlflow/747860200993380382/41ce0430678c4e12990f1f576eb41851/artifacts/AIStudio-Multimodal-Chatbot-Model/artifacts/context', cache_dir='/phoenix/mlflow/747860200993380382/41ce0430678c4e12990f1f576eb41851/artifacts/AIStudio-Multimodal-Chatbot-Model/artifacts/semantic_cache'
2025-07-28 08:29

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2025-07-28 08:31:12 - INFO - Model loaded on cuda.
2025-07-28 08:31:12 - INFO - Initializing evaluation judge by sharing the main model...
2025-07-28 08:31:12 - INFO - ✅ Evaluation judge initialized successfully.
2025-07-28 08:31:12 - INFO - --- Context initialization complete ---
2025-07-28 08:31:12 - INFO - ✅ Successfully loaded model from registry.


## Step 5: Display Results

In [22]:
# --- Helper Function to Display Test Results ---
def display_results(query: str, result_df: pd.DataFrame):
    """Helper to neatly print the query, reply, and display images."""
    if result_df.empty:
        print("Received an empty result.")
        return

    reply = result_df["reply"].iloc[0]
    image_paths_str = result_df["used_images"].iloc[0]
    image_paths = eval(image_paths_str) if isinstance(image_paths_str, str) and image_paths_str.startswith('[') else []
    
    gen_time = result_df["generation_time_seconds"].iloc[0]
    faithfulness = result_df["faithfulness"].iloc[0]
    relevance = result_df["relevance"].iloc[0]

    print("---" * 20)
    print(f"❓ Query:\n{query}\n")
    print(f"🤖 Reply:\n{reply}\n")
    
    print(f"📊 Faithfulness: {faithfulness:.4f} | Relevance: {relevance:.4f}")
    print(f"⏱️ Generation Time: {gen_time:.2f}s\n")

    if image_paths:
        print(f"🖼️ Displaying {len(image_paths)} retrieved image(s):")
        for path in image_paths:
            print(f" - {path}")
    else:
        print("▶ No images were retrieved for this query.")
    print("---" * 20 + "\n")

In [23]:
if loaded_model:
    logger.info("Running sample inference with the loaded model...")
    
    sample_queries = [
        "What are the AI Blueprints Repository best practices?",
        "What are some feature flags that i can enable in AIStudio?",
        "How do i manually clean my environment without hooh?",
    ]

    for query in sample_queries:
        try:
            # --- MODIFIED LINE ---
            # Add 'force_regenerate': False to the dictionary when creating the DataFrame
            input_payload = pd.DataFrame([{"query": query, "force_regenerate": False}])
            
            result = loaded_model.predict(input_payload)

            display_results(query, result)
        except Exception as e:
            logger.error(f"Prediction failed for query '{query}': {e}", exc_info=True)

else:
    logger.warning("Skipping sample inference because the model was not loaded.")

2025-07-28 08:31:12 - INFO - Running sample inference with the loaded model...
2025-07-28 08:31:12 - INFO - Received prediction request.
2025-07-28 08:31:12 - INFO - Processing query: 'What are the AI Blueprints Repository best practices?'
2025-07-28 08:31:12 - INFO - CACHE MISS for query: 'What are the AI Blueprints Repository best practices?'. Running full pipeline.


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


faithfulness: 0.85 relevance:  0.6
------------------------------------------------------------
❓ Query:
What are the AI Blueprints Repository best practices?

🤖 Reply:
## **Visual Analysis**
The image shows is a flowchart titled "Blueprint & Blueprint Delivery" with several components and arrows indicating the process of blueprint delivery. It includes sections like `Create Project Plan`, `Project Creation in AI Studio Workspace Startup Swimlane`, etc., showing steps such as selecting blueprints, creating projects from them, setting up workspaces for notebooks or models.

## **Synthesized Answer**

### Best Practices for HP AI Studio Blueprints Repository

1. **GPU Usage**:
   - Always detect an NVIDIA GPU when available.
     ```python
     import torch
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     print(f"Using device: {device}")
     ```
2 Ensure all layers run on the GPU to speed inference significantly.
- Add comments noting that you can reduce l

## Step 6: Log Hallucinations & Relevance Evaluations to MlFlow

In [24]:
# In register_model.ipynb, add this as the final cell

# Check if the model was loaded and the original run_id is available
if loaded_model and 'run_id' in locals():
    logger.info(f"--- Reopening original run ({run_id}) to log pre-computed evaluations ---")

    # 1. Define your evaluation dataset
    evaluation_payload = pd.DataFrame([
        {"query": "What are the AI Blueprints Repository best practices?", "force_regenerate": True},
        {"query": "What are some feature flags that i can enable in AIStudio?", "force_regenerate": True},
        {"query": "How do i manually clean my environment without hooh?", "force_regenerate": True},
    ])

    # 2. Run predict() to get results with the embedded scores
    results_df = loaded_model.predict(evaluation_payload)
    
    # Add the original query to the results for clarity in the logged table
    results_df['query'] = evaluation_payload['query']

    # 3. Reopen the existing run using its ID
    with mlflow.start_run(run_id=run_id) as run:
        logger.info("Successfully reopened existing run. Logging metrics and artifacts...")

        # 4. Calculate average scores from the DataFrame
        avg_faithfulness = results_df["faithfulness"].mean()
        avg_relevance = results_df["relevance"].mean()

        # 5. Log the average scores as metrics to the original run
        mlflow.log_metrics({
            "avg_faithfulness": avg_faithfulness,
            "avg_relevance": avg_relevance
        })

        # 6. Log the full results DataFrame as a table artifact to the original run
        mlflow.log_table(data=results_df, artifact_file="inline_evaluation_results.json")
        
        logger.info("✅ Successfully logged metrics and artifacts to the original model run.")

else:
    logger.warning("Skipping logging because the model was not loaded or run_id was not found.")

2025-07-28 08:31:59 - INFO - --- Reopening original run (41ce0430678c4e12990f1f576eb41851) to log pre-computed evaluations ---
2025-07-28 08:31:59 - INFO - Received prediction request.
2025-07-28 08:31:59 - INFO - Processing query: 'What are the AI Blueprints Repository best practices?'
2025-07-28 08:31:59 - INFO - Forced regeneration for query: 'What are the AI Blueprints Repository best practices?'. Clearing old cache entry.
2025-07-28 08:31:59 - INFO - CACHE MISS for query: 'What are the AI Blueprints Repository best practices?'. Running full pipeline.
2025-07-28 08:32:13 - INFO - Processing query: 'What are some feature flags that i can enable in AIStudio?'
2025-07-28 08:32:13 - INFO - Forced regeneration for query: 'What are some feature flags that i can enable in AIStudio?'. Clearing old cache entry.
2025-07-28 08:32:13 - INFO - CACHE MISS for query: 'What are some feature flags that i can enable in AIStudio?'. Running full pipeline.
2025-07-28 08:32:15 - INFO - Processing query:

In [25]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed.")

2025-07-28 08:32:45 - INFO - ⏱️ Total execution time: 8m 4.61s
2025-07-28 08:32:45 - INFO - ✅ Notebook execution completed.


Built with ❤️ using Z by HP AI Studio.