<h1 style="text-align: center; font-size: 50px;"> 🤖 MLFlow Registration for Multimodal RAG</h1>

In [None]:
%pip install -r ../requirements.txt --quiet 

# MLFlow Model Service 

In this section, we demonstrate how to deploy a RAG-based chatbot service. This service provides a REST API endpoint that allows users to query the knowledge base with natural language questions, upload new documents to the knowledge base, and manage conversation history, all with built-in safeguards against sensitive information and toxicity. This service encapsulates all the functionality we developed in this notebook, including the document retrieval system, RAG-based question answering capabilities, and Galileo integration for protection, observation and evaluation. It demonstrates how to use our ChatbotService from the src/service directory. 

## Step 0: Imports and Environment Setup

In [None]:
# === Standard Library Imports ===
import gc
import json
import logging
import math
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict

# === Third-Party Library Imports ===
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import torch
from langchain_core.embeddings import Embeddings
from langchain.schema.document import Document
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from mlflow.models.signature import ModelSignature
from mlflow.tracking import MlflowClient
from mlflow.types import ColSpec, DataType, Schema, TensorSpec
from PIL import Image as PILImage
from sentence_transformers import CrossEncoder
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer, BitsAndBytesConfig, SiglipModel, SiglipProcessor


# === Project-Specific Imports ===
# Add the project root to the system path to allow importing from 'src'
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from src.utils import (
    configure_hf_cache,
)

In [None]:
logger = logging.getLogger("multimodal_rag_register_notebook")
logger.setLevel(logging.INFO)
if not logger.handlers:
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
logger.propagate = False

In [None]:
warnings.filterwarnings("ignore")


In [None]:
logger.info("Notebook execution started.")

In [None]:
configure_hf_cache()

In [None]:
class MultimodalRagModel(mlflow.pyfunc.PythonModel):
    """
    An MLflow PythonModel that encapsulates the entire Multimodal RAG pipeline.

    This class faithfully reproduces the workflow from the `run-notebook.ipynb`, including
    data loading, multi-stage retrieval (vector search + reranking), and multimodal
    generation with the InternVL model.

    Expected Artifacts during logging/loading:
      - "chroma_dir": Path to the persisted Chroma vectorstore directory (containing both text and image collections).
      - "context_dir": Path to the root data directory, which must contain the `images` subdirectory.
      - "cache_dir": Path to the directory for the semantic cache.
    """

    # --------------------------------------------------------------------------
    # Helper Classes (Encapsulated from the original notebook)
    # --------------------------------------------------------------------------

    class SemanticCache:
        """
        A semantic cache using a Chroma vector store to find similar queries.
        """
        def __init__(self, persist_directory: Path, embedding_function: Embeddings, collection_name: str = "multimodal_cache"):
            self.embedding_function = embedding_function
            self._vectorstore = Chroma(
                collection_name=collection_name,
                persist_directory=str(persist_directory),
                embedding_function=self.embedding_function
            )

        def get(self, query: str, threshold: float = 0.90) -> Optional[Dict[str, Any]]:
            """
            Searches for a semantically similar query in the cache.
            """
            if self._vectorstore._collection.count() == 0:
                return None # Cache is empty

            results = self._vectorstore.similarity_search_with_score(query, k=1)
            if not results:
                return None

            most_similar_doc, score = results[0]
            similarity = 1.0 - score

            logger.info(f"Most similar cached query: '{most_similar_doc.page_content}' (Similarity: {similarity:.4f})")

            if similarity >= threshold:
                cached_result = most_similar_doc.metadata
                cached_result['used_images'] = json.loads(cached_result.get('used_images', '[]'))
                return cached_result

            return None

        def set(self, query: str, result_dict: Dict[str, Any]) -> None:
            """
            Adds a new query and its result to the cache.
            """
            metadata_to_store = {
                'reply': result_dict.get('reply', ''),
                'used_images': json.dumps(result_dict.get('used_images', []))
            }

            doc = Document(page_content=query, metadata=metadata_to_store)
            self._vectorstore.add_documents([doc])
            logger.info(f"Added query to semantic cache: '{query}'")

        def delete(self, query: str) -> None:
            """
            Finds and deletes a query and its cached response from the vector store.
            """
            results = self._vectorstore.get(where={"page_content": query})
            if results and 'ids' in results and results['ids']:
                doc_id_to_delete = results['ids'][0]
                self._vectorstore._collection.delete(ids=[doc_id_to_delete])
                logger.info(f"Cleared old cache for query: '{query}'")

    class SiglipEmbeddings(Embeddings):
        """LangChain compatible wrapper for SigLIP image/text embeddings."""
        def __init__(self, model_id: str, device: str):
            self.device = device
            self.model = SiglipModel.from_pretrained(model_id).to(self.device)
            self.processor = SiglipProcessor.from_pretrained(model_id)

        def _embed_text(self, txts: List[str]) -> np.ndarray:
            inp = self.processor(text=txts, return_tensors="pt", padding=True, truncation=True).to(self.device)
            with torch.no_grad():
                return self.model.get_text_features(**inp).cpu().numpy()

        def _embed_imgs(self, paths: List[str]) -> np.ndarray:
            imgs = [PILImage.open(p).convert("RGB") for p in paths]
            inp = self.processor(images=imgs, return_tensors="pt").to(self.device)
            with torch.no_grad():
                return self.model.get_image_features(**inp).cpu().numpy()

        def embed_documents(self, docs: List[str]) -> List[List[float]]:
            return self._embed_imgs(docs).tolist()

        def embed_query(self, txt: str) -> List[float]:
            return self._embed_text([txt])[0].tolist()

    class InternVLMM:
        """Minimal, self-contained multimodal QA wrapper around InternVL3-8B-Instruct."""
        MODEL_NAME = "OpenGVLab/InternVL3-8B-Instruct"

        def __init__(self, device: str, cache: Any):
            self.device = device
            self.model = None
            self.tok = None
            self.image_processor = None
            self.cache = cache
            self._load()

        def generate(self, query: str, context: Dict[str, Any], force_regenerate: bool = False) -> Dict[str, Any]:
            """Run retrieval, prompt assembly, and model generation."""
            if not force_regenerate:
                cached_result = self.cache.get(query, threshold=0.98)
                if cached_result:
                    logger.info(f"SEMANTIC CACHE HIT for query: '{query}'")
                    return cached_result

            if force_regenerate:
                logger.info(f"Forced regeneration for query: '{query}'. Clearing old cache entry.")
                self.cache.delete(query)

            logger.info(f"CACHE MISS for query: '{query}'. Running full pipeline.")
            if self.model is None or self.tok is None:
                return {"reply": "Error: model not initialised.", "used_images": []}

            hits = self._retrieve_mm(query, **context)
            docs, images = hits["docs"], hits["images"]

            if not docs and not images:
                return {"reply": "I don't know based on the provided context.", "used_images": []}

            # Build prompt
            context_str = "\n\n".join(
                f"<source_document name=\"{d.metadata.get('source', 'unknown')}\">\n{d.page_content}\n</source_document>"
                for d in docs
            )

            user_content = f"""
                <task_instructions>
                Your response must follow this exact structure:

                ##**1. Visual Analysis**
                First, provide a detailed description of what is shown in the provided image(s), based only on what you can see.

                ##**2. Synthesized Answer**
                Next, answer the user's original query. Your answer must synthesize information STRICTLY from your **Visual Analysis** and the provided text `<context>`.

                ##**3. Source Documents**
                At the very end of your response, cite the source from the context in brackets and backticks, like this: [`source-file-name.md`].
                </task_instructions>

                <context>
                {context_str}
                </context>

                <user_query>
                {query}
                </user_query>

                Now, generate the response following all instructions.
                """
            SYSTEM_PROMPT = """
You are AI Studio DevOps Assistant. Your function is to analyze images and text, then answer questions based ONLY on the provided materials.

**PERMANENT INSTRUCTIONS:**
1.  **Analyze and Answer from Context**: Your entire response MUST be derived exclusively from the provided `<context>` block and the user's image(s).
2.  **No External Knowledge**: You MUST NOT use any of your pre-trained knowledge or information outside the provided materials.
3.  **No Hallucination**: Do not invent or assume any details. If information is not present, it does not exist.
4.  **Handle Missing Information**: If the provided context and image(s) do not contain the answer, your ONLY response will be: "Based on the provided context, I cannot answer this question." Do not add any other words or explanation.
"""
            conversation = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content}
            ]

            prompt = self.tok.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=True
            )

            if images:
                prompt += f"\n\n[{len(images)} image(s) are provided for analysis]"

            # Generate
            try:
                self._clear_cuda()
                pixel_values = self._process_images(images) if images else None
                reply = self.model.chat(
                    self.tok, pixel_values, prompt,
                    generation_config=dict(max_new_tokens=8916, pad_token_id=self.tok.pad_token_id, eos_token_id=self.tok.eos_token_id),
                )
                self._clear_cuda()
                result_dict = {"reply": reply, "used_images": images}

                self.cache.set(query, result_dict)

                return result_dict
            except RuntimeError as e:
                logger.error("InternVL generation failed: %s", e)
                return {"reply": f"Error during generation: {e}", "used_images": images}

        def _retrieve_mm(self, query: str, text_db: Chroma, image_db: Chroma, siglip_embeds: Any, cross_encoder: Any, k_txt: int = 8, k_img: int = 8, fetch_k: int = 20) -> Dict[str, Any]:
            """Performs hybrid retrieval for text and associated images."""
            # 1. Coarse recall (text)
            docs_and_init = text_db.similarity_search_with_score(query, k=fetch_k)
            if not docs_and_init:
                return {"docs": [], "images": []}
            docs, init_scores = zip(*docs_and_init)

            # 2. Rerank (text)
            rerank_scores = cross_encoder.predict([(query, d.page_content) for d in docs])

            # 3. Hybrid scoring
            hybrid_scores = [0.6 * init + 0.4 * rerank for init, rerank in zip(init_scores, rerank_scores)]

            # 4. Select top-k text
            scored_docs = sorted(zip(docs, hybrid_scores), key=lambda x: x[1], reverse=True)
            selected_docs = [doc for doc, score in scored_docs[:k_txt]]

            # 5. Image retrieval
            sources = [d.metadata["source"] for d in selected_docs]
            q_emb = siglip_embeds.embed_query(query)
            img_hits = image_db.similarity_search_by_vector(q_emb, k=k_img * 2, filter={"source": {"$in": sources}})
            images = [img.page_content for img in img_hits[:k_img]]

            return {"docs": selected_docs, "images": images}

        def _load(self):
            """Load tokenizer, image_processor, & model."""
            logger.info("Loading %s...", self.MODEL_NAME)
            self._clear_cuda()

            self.tok = AutoTokenizer.from_pretrained(self.MODEL_NAME, trust_remote_code=True)
            if self.tok.pad_token is None: self.tok.pad_token = self.tok.eos_token

            self.image_processor = AutoImageProcessor.from_pretrained(self.MODEL_NAME, trust_remote_code=True)

            q_cfg = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            )

            self.model = AutoModel.from_pretrained(
                self.MODEL_NAME,
                quantization_config=q_cfg,
                torch_dtype=(torch.bfloat16 if self.device == "cuda" else torch.float32),
                low_cpu_mem_usage=True,
                trust_remote_code=True,
                device_map="auto" if self.device == "cuda" else None,
            ).eval()
            logger.info("Model loaded on %s.", self.device)

        def _process_images(self, image_paths: List[str]):
            """Convert a list of image filepaths to a single batched tensor."""
            if not image_paths: return None
            try:
                pil_images = [PILImage.open(p).convert("RGB") for p in image_paths]
                processed_data = self.image_processor(images=pil_images, return_tensors="pt")
                pixel_values = processed_data['pixel_values'].to(device=self.device, dtype=next(self.model.parameters()).dtype)
                return pixel_values
            except Exception as e:
                logger.error("Image processing failed: %s", e)
                return None

        def _clear_cuda(self):
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()


    # --------------------------------------------------------------------------
    # MLflow pyfunc Methods
    # --------------------------------------------------------------------------

    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """
        This method is called when loading an MLflow model. It initializes all
        necessary components using the artifacts logged with the model.
        """
        logger.info("--- Initializing MultimodalRagModel context ---")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # 1. Define paths from artifacts
        self.chroma_dir = Path(context.artifacts["chroma_dir"])
        self.context_dir = Path(context.artifacts["context_dir"])
        self.image_dir = self.context_dir / "images"
        self.cache_dir = Path(context.artifacts["cache_dir"])

        logger.info(f"Artifacts loaded: chroma_dir='{self.chroma_dir}', context_dir='{self.context_dir}', cache_dir='{self.cache_dir}'")

        # 2. Initialize embedding models and reranker
        logger.info("Loading embedding models and cross-encoder...")
        self.text_embed_model = HuggingFaceEmbeddings(model_name="intfloat/e5-large-v2")
        self.siglip_embed_model = self.SiglipEmbeddings(model_id="google/siglip2-base-patch16-224", device=self.device)
        self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

        # 3. Load persisted Chroma vector stores
        logger.info("Loading ChromaDB vector stores...")
        self.text_db = Chroma(
            collection_name="mm_text",
            persist_directory=str(self.chroma_dir),
            embedding_function=self.text_embed_model,
        )
        self.image_db = Chroma(
            collection_name="mm_image",
            persist_directory=str(self.chroma_dir),
            embedding_function=self.siglip_embed_model,
        )
        logger.info(f"Text DB count: {self.text_db._collection.count()}, Image DB count: {self.image_db._collection.count()}")

        # 4. Initialize semantic cache
        self.cache = self.SemanticCache(persist_directory=self.cache_dir, embedding_function=self.text_embed_model)

        # 5. Load the main multimodal LLM
        self.mm_llm = self.InternVLMM(device=self.device, cache=self.cache)
        logger.info("--- Context initialization complete ---")


    def predict(self, context: mlflow.pyfunc.PythonModelContext, model_input: pd.DataFrame) -> pd.DataFrame:
        """
        MLflow inference entrypoint.
        Expects a pandas DataFrame with a "query" column.
        Returns a DataFrame with "reply" and "used_images" columns.
        """
        logger.info("Received prediction request.")
        queries = model_input["query"].tolist()
        results = []

        # The context passed to the generate function needs all the retrieved components
        retrieval_context = {
            "text_db": self.text_db,
            "image_db": self.image_db,
            "siglip_embeds": self.siglip_embed_model,
            "cross_encoder": self.cross_encoder
        }

        for query in queries:
            logger.info(f"Processing query: '{query}'")
            response_dict = self.mm_llm.generate(query, retrieval_context)
            results.append(response_dict)

        return pd.DataFrame(results)

    @classmethod
    def log_model(cls, model_name: str, local_model_path: str) -> None:
        """
        Helper class method to log the MultimodalRagModel to MLflow.
        """
        logger.info(f"--- Logging '{model_name}' to MLflow ---")

        # 1. Define local artifact paths, including the local model directory
        project_root = Path.cwd().parent.resolve()
        artifacts = {
            "local_model_dir": local_model_path,
            "chroma_dir": str(project_root / "data" / "chroma_store"),
            "context_dir": str(project_root / "data" / "context"),
            "cache_dir": str(project_root / "data" / "chroma_store" / "semantic_cache"),
        }

        # 2. Validate artifact paths
        for key, path in artifacts.items():
            if not Path(path).exists():
                raise FileNotFoundError(f"Required artifact not found: {key} at '{path}'")

        # 3. Define model signature
        input_schema = Schema([ColSpec(DataType.string, "query")])
        output_schema = Schema([
            ColSpec(DataType.string, "reply"),
            ColSpec(DataType.string, "used_images"),
        ])
        signature = ModelSignature(inputs=input_schema, outputs=output_schema)

        # 4. Log the model
        mlflow.pyfunc.log_model(
            artifact_path=model_name,
            python_model=cls(),
            artifacts=artifacts,
            signature=signature,
            pip_requirements="../requirements.txt",
            code_paths=["../src"],
        )
        logger.info(f"✅ Successfully logged '{model_name}'")

In [None]:
# --- MLflow Configuration ---
MODEL_NAME = "AIStudio-Multimodal-Chatbot-Model"
RUN_NAME = f"Register_{MODEL_NAME}"
EXPERIMENT_NAME = "AIStudio-Multimodal-Chatbot-Experiment"

LOCAL_MODEL_PATH = "/home/jovyan/datafabric/InternVL3-8B-Instruct-1"
# Set MLflow tracking URI and experiment
# This should be configured for your environment, e.g., a remote server or local file path
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "/phoenix/mlflow"))
mlflow.set_experiment(experiment_name=EXPERIMENT_NAME)

logger.info(f"Using MLflow tracking URI: {mlflow.get_tracking_uri()}")
logger.info(f"Using MLflow experiment: '{EXPERIMENT_NAME}'")

In [None]:
%%time

# --- Start MLflow Run and Log the Model ---
try:
    with mlflow.start_run(run_name=RUN_NAME) as run:
        run_id = run.info.run_id
        logger.info(f"Started MLflow run: {run_id}")

        # Use the class method to log the model and its artifacts
        MultimodalRagModel.log_model(model_name=MODEL_NAME, local_model_path=LOCAL_MODEL_PATH)


        model_uri = f"runs:/{run_id}/{MODEL_NAME}"
        logger.info(f"Registering model from URI: {model_uri}")
        
        # Register the model in the MLflow Model Registry
        mlflow.register_model(model_uri=model_uri, name=MODEL_NAME)
        logger.info(f"✅ Successfully registered model '{MODEL_NAME}'")

except FileNotFoundError as e:
    logger.error(f"Error: A required file or directory was not found. Please ensure the project structure is correct.")
    logger.error(f"Details: {e}")
except Exception as e:
    logger.error(f"An unexpected error occurred during the MLflow run: {e}", exc_info=True)

In [None]:
# --- Retrieve the latest version from the Model Registry ---
try:
    client = MlflowClient()
    versions = client.get_latest_versions(MODEL_NAME, stages=["None"])
    if not versions:
        raise RuntimeError(f"No registered versions found for model '{MODEL_NAME}'.")
    
    latest_version = versions[0]
    logger.info(f"Found latest version '{latest_version.version}' for model '{MODEL_NAME}' in stage '{latest_version.current_stage}'.")
    model_uri_registry = latest_version.source

except Exception as e:
    logger.error(f"Failed to retrieve model from registry: {e}", exc_info=True)
    model_uri_registry = None # Ensure variable exists


In [None]:
if model_uri_registry:
    try:
        logger.info(f"Loading model from: {model_uri_registry}")
        loaded_model = mlflow.pyfunc.load_model(model_uri=model_uri_registry)
        logger.info("✅ Successfully loaded model from registry.")
    except Exception as e:
        logger.error(f"Failed to load model from registry URI: {e}", exc_info=True)
        loaded_model = None
else:
    logger.warning("Skipping model loading due to previous errors.")
    loaded_model = None

In [None]:
# --- Helper Function to Display Test Results ---
def display_results(query: str, result_df: pd.DataFrame):
    """Helper to neatly print the query, reply, and display images."""
    if result_df.empty:
        print("Received an empty result.")
        return

    reply = result_df["reply"].iloc[0]
    # Images are stored as a string representation of a list, so we need to evaluate it
    image_paths_str = result_df["used_images"].iloc[0]
    image_paths = eval(image_paths_str) if isinstance(image_paths_str, str) and image_paths_str.startswith('[') else []
    
    print("---" * 20)
    print(f"❓ Query:\n{query}\n")
    print(f"🤖 Reply:\n{reply}\n")
    
    if image_paths:
        print(f"🖼️ Displaying {len(image_paths)} retrieved image(s):")
        # You can integrate the display_images function from the original notebook here
        # For simplicity, we'll just print the paths.
        for path in image_paths:
            print(f"  - {path}")
    else:
        print("▶ No images were retrieved for this query.")
    print("---" * 20 + "\n")

In [None]:
if loaded_model:
    logger.info("Running sample inference with the loaded model...")
    
    sample_queries = [
        "How do i run blueprints locally?",
        "What are some feature flags that i can enable in AIStudio?",
        "How do i manually clean my environment without hooh?",
    ]

    for query in sample_queries:
        try:
            input_payload = pd.DataFrame([{"query": query}])
            result = loaded_model.predict(input_payload)
            display_results(query, result)
        except Exception as e:
            logger.error(f"Prediction failed for query '{query}': {e}", exc_info=True)

else:
    logger.warning("Skipping sample inference because the model was not loaded.")


In [None]:
logger.info("✅ Notebook execution completed.")


Built with ❤️ using Z by HP AI Studio.

# Evaluate Hallucinations & Relevance

In [None]:
# model_source = config["model_source"]

In [None]:
# %%time

# llm = initialize_llm(model_source, secrets)

In [None]:

# def model(batch_df: pd.DataFrame) -> pd.DataFrame:
#     preds, contexts = [], []
#     for q in batch_df["questions"]:
#         answer = mm_chain.invoke(q)
#         preds.append(answer)

#         docs = retriever.get_relevant_documents(q)
#         contexts.append(" ".join(d.page_content for d in docs))

#     # keep the incoming index so every batch’s rows stay unique
#     return pd.DataFrame(
#         {
#             "result": preds,
#             "source_documents": contexts,
#         },
#         index=batch_df.index,      #  ← key line
#     )

# # --- 3)  Evaluation dataset
# eval_df = pd.DataFrame({"questions": [
#     "What naming convention should I use for a new blueprint project folder?",
#     "What is the first step in the standard blueprint testing workflow?",
#     "How do I fetch logs from a running Kubernetes pod?",
# ]})

# judge = LocalGenAIJudge(
#     llm=llm
# )

# faithfulness_metric = judge.to_mlflow_metric("faithfulness")
# relevance_metric = judge.to_mlflow_metric("relevance")

# results = mlflow.evaluate(
#     model,
#     eval_df,
#     predictions="result",
#     evaluators="default",
#     extra_metrics=[faithfulness_metric, relevance_metric],
#     evaluator_config={
#         "col_mapping": {
#             "inputs": "questions",
#             "context": "source_documents"
#         }
#     },
# )
