<h1 style=\"text-align: center; font-size: 50px;\"> 🤖 MLFlow Registration for Agentic RAG </h1>

# Imports

In [1]:
%pip install -r ../requirements.txt --quiet 

[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
from __future__ import annotations

import json
import logging
import os
import sys
from collections import namedtuple
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, TypedDict

import pandas as pd
import tensorrt_llm

import mlflow.pyfunc
from mlflow.models.signature import ModelSignature
from mlflow.tracking import MlflowClient
from mlflow.types import ColSpec, DataType, Schema

from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langgraph.graph import StateGraph, START, END

# Ensure the project root is on the Python path so we can import local modules
project_root = Path('.').resolve().parent
sys.path.insert(0, str(project_root))

from src.trt_llm_langchain import TensorRTLangchain

  from .autonotebook import tqdm as notebook_tqdm
2025-06-13 21:21:44,000 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend


[TensorRT-LLM] TensorRT-LLM version: 0.18.0


In [3]:
LOG_FORMAT = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"

logging.basicConfig(
    level=logging.INFO,
    format=LOG_FORMAT,
    datefmt=LOG_DATEFMT,
)

# Named logger for the Agentic RAG notebook
logger = logging.getLogger("agentic_rag_notebook")

In [4]:
logger.info('Notebook execution started.')

2025-06-13 21:21:46 [INFO] agentic_rag_notebook: Notebook execution started.


In [5]:
# --------------------------------------------------------------------------------------------------
# RagAgenticModel
# --------------------------------------------------------------------------------------------------
class RagAgenticModel(mlflow.pyfunc.PythonModel):
    """
    rag_mlflow_model.py
    
    This module defines an MLflow PythonModel subclass (`RagAgenticModel`) that faithfully reproduces
    the Agentic RAG workflow defined in the Jupyter notebook. It uses LangGraph to replicate the exact
    state graph, decision logic, and node functions (ingest_query, check_relevance, rewrite_query,
    check_memory, retrieve_chunks, generate_answer, update_memory, output_answer).
    
    Artifacts expected when registering/logging:
      - "chroma_dir": Persisted Chroma vectorstore directory
      - "memory_path": Path to a JSON file (SimpleKVMemory)
      
    Usage:
      RagAgenticModel.log_model(model_name="Agentic_RAG_Model")
      
    After logging, you can load the model via:
      mlflow.pyfunc.load_model("models:/Agentic_RAG_Model/Production")
      
    and then call .predict({"query": "<user question>"}) to get a dict with keys:
      - "answer": str
      - "retrieved_chunks": List[str]
      - "messages": List[Dict[str, Any]]
    """

    TOPIC: str = "AI Studio"
    CONTEXT_DIR: Path = Path("../data/context")             
    CHROMA_DIR: Path = Path("../data/chroma_store")     
    MEMORY_PATH: Path = Path("../data/memory/memory.json")     
    MANIFEST_PATH: Path = CHROMA_DIR / "manifest.json"

    class SimpleKVMemory:
        """Very small persistent key-value store (JSON on disk)."""
    
        def __init__(self, file_path: Path) -> None:
            self.file_path: Path = file_path
            self._store: Dict[str, str] = self._load()
    
        # ---------- public ----------------------------------------------------
        def get(self, key: str) -> Optional[str]:
            """Return answer if present, else None."""
            return self._store.get(key)
    
        def set(self, key: str, value: str) -> None:
            """Save answer and flush to disk."""
            self._store[key] = value
            self._dump()
    
        # ---------- private ---------------------------------------------------
        def _load(self) -> Dict[str, str]:
            if self.file_path.exists():
                try:
                    with self.file_path.open("r", encoding="utf-8") as f:
                        return json.load(f)
                except Exception as exc:  
                    logger.warning("Failed to load memory (%s). Starting fresh.", exc)
            return {}
    
        def _dump(self) -> None:
            self.file_path.parent.mkdir(parents=True, exist_ok=True)
            with self.file_path.open("w", encoding="utf-8") as f:
                json.dump(self._store, f, ensure_ascii=False, indent=2)


    class RAGState(TypedDict, total=False):
        topic: str
        query: str
        is_relevant: Optional[bool]
        rewritten_query: Optional[str]
        retrieved_chunks: List[str]
        answer: Optional[str]
        from_memory: Optional[bool]
        messages: List[Dict[str, Any]]  # full conversation with LLM

    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """
        Load artifacts and initialize all components:
          - Embedding model (HuggingFaceEmbeddings)
          - Chroma vectorstore from artifact "chroma_dir"
          - SimpleKVMemory from artifact "memory_path"
          - Namedtuple Response (for LLM outputs)
          - Build and compile the LangGraph state graph to self._compiled_graph
        """
        self.TOPIC = RagAgenticModel.TOPIC
        self._logger = logging.getLogger("RagAgenticModel")
        if not self._logger.handlers:
            handler = logging.StreamHandler()
            handler.setFormatter(
                logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
            )
            self._logger.addHandler(handler)
            self._logger.setLevel(logging.INFO)

        # 1. Load embedding model
        try:
            self._embed_model = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-mpnet-base-v2",
            encode_kwargs={"normalize_embeddings": True},
            )
        except:
            self._embed_model = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-mpnet-base-v2",
            encode_kwargs={"normalize_embeddings": True},
            )
            
        # 2. Load persisted Chroma vectorstore
        chroma_dir = Path(context.artifacts["chroma_dir"])
        self._vectorstore = Chroma(
            collection_name="-".join(self.TOPIC.split()),
            persist_directory=str(chroma_dir),
            embedding_function=self._embed_model,
        )

        # 3. Load LLM via TensorRTLangchain
        sampling_params = tensorrt_llm.SamplingParams(
            temperature=0.0,
            top_k=1,
            repetition_penalty=1.2,
            stop_token_ids=[128009],
        )
        self._llm = TensorRTLangchain(model_path="nvidia/Llama-3.1-Nemotron-Nano-8B-v1", sampling_params=sampling_params)

        # 4. Initialize persistent memory
        memory_path = Path(context.artifacts["memory_path"])
        memory_path.parent.mkdir(parents=True, exist_ok=True)
        if not memory_path.exists():
            memory_path.write_text("{}", encoding="utf-8")
        self._memory = RagAgenticModel.SimpleKVMemory(memory_path)

        # 5. Define a simple Response namedtuple (mirrors notebook)
        self._LLMResponse = namedtuple("Response", ["content"])

        # 6. Build and compile the LangGraph state graph
        self._build_state_graph()

    # ----------------------------------------
    # Node Functions (each mirrors the notebook)
    # ----------------------------------------
    def ingest_query(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Log the incoming user query and record it in the message history.
        """
        user_query = state["query"]
        self._logger.info("Received user query: %s", user_query)
        previous_messages = state.get("messages", [])
        new_messages = previous_messages + [{"role": "user", "content": user_query}]
        return {"messages": new_messages}

    def check_relevance(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Ask the LLM whether the query relates to our topic.
        If not relevant, include a default apology answer.
        """
        topic = state["topic"]
        user_query = state["query"]

        system_prompt = (
            "You are a strict classifier. Only respond with either \"yes\" or \"no\". "
            "Do not include any additional words, explanations, or punctuation. "
            "Answer based solely on whether the user's query is about the specified topic."
        )
        user_prompt = (
            f"The topic is: \"{topic}\"\n\n"
            f"User query: \"{user_query}\"\n\n"
            "Is this query related to the topic above? Respond with only 'yes' or 'no'."
            "Answer: "
        )

        resp = self._get_response_from_llm(system_prompt, user_prompt)
        is_relevant = "yes" in resp.strip().lower()
        self._logger.info("Relevance check result: %s", is_relevant)

        messages = state.get("messages", []) + [
            {"role": "developer", "content": "Relevance check result:"},
            {"role": "assistant", "content": resp},
        ]
        result: Dict[str, Any] = {"is_relevant": is_relevant, "messages": messages}
        if not is_relevant:
            result["answer"] = f"Sorry, I can only answer questions related to {topic}."
        return result

    def check_memory(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Look up the exact user query in memory and return the cached answer if found.
        """
        raw_query = state["query"]
        key = raw_query.strip().lower()
        cached_answer = self._memory.get(key)
        if cached_answer is not None:
            self._logger.info("Cache hit for query: %s", raw_query)
            return {"answer": cached_answer, "from_memory": True}
        self._logger.info("Cache miss for query: %s", raw_query)
        return {"from_memory": False}

    def rewrite_query(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Correct any grammar in the question and rewrite it as a clear statement
        without altering its meaning, to improve retrieval.
        """
        original = state["query"]
        system_prompt = (
            "You are a rewriting assistant. Your only task is to convert a question into a "
            "grammatically correct statement. Do not change its meaning. "
            "Output only the corrected statement—no explanations or extra text."
        )
        user_prompt = (
            "Convert the following question into a grammatically correct statement "
            "that preserves the original meaning exactly:\n\n"
            "Note: Output only the corrected statement—no explanations or extra text.\n"
            f"Question: \"{original}\"\n\n"
            "Corrected Statement:"
        )

        resp = self._get_response_from_llm(system_prompt, user_prompt).strip()
        self._logger.info("Rewritten query: %s", resp)

        messages = state.get("messages", []) + [
            {"role": "developer", "content": "Rewritten query:"},
            {"role": "assistant", "content": resp},
        ]
        return {"rewritten_query": resp, "messages": messages}

    def retrieve_chunks(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Fetch the top-k most relevant chunks for the rewritten query.
        """
        statement = state["rewritten_query"]
        docs = self._vectorstore.similarity_search(statement, k=5)
        chunks = [doc.page_content for doc in docs]
        self._logger.info("Retrieved %d chunks for query.", len(chunks))
        return {"retrieved_chunks": chunks}

    def generate_answer(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Use the LLM to generate an answer based solely on retrieved context.
        """
        topic = state["topic"]
        user_query = state["query"]
        context = "\n\n---\n\n".join(state["retrieved_chunks"])

        system_prompt = (
            f"You are a knowledgeable assistant specialized in {topic}. Your task is to answer "
            "the user query using only the information found within the <context> block. "
            "Ignore any external knowledge. If the context does not contain the answer, reply exactly with: \"I don't know.\" "
            "Do not assume, infer, or add any extra information. "
            "Respond with only the answer—do not include any introductory or explanatory text."
        )
        user_prompt = (
            f"<context>\n{context}\n</context>\n\n"
            f"User query: \"{user_query}\"\n\n"
            "Based only on the context above, provide the exact answer to the query. "
            "If the context does not contain the answer, respond exactly with: \"I don't know.\" "
            "Give only the answer—do not include any intro phrases such as 'The answer is' or 'Here it is'."
            "Answer: "
        )

        resp = self._get_response_from_llm(system_prompt, user_prompt).strip()
        self._logger.info("Generated answer (%d chars)", len(resp))

        messages = state.get("messages", []) + [
            {"role": "developer", "content": "Generated answer:"},
            {"role": "assistant", "content": resp},
        ]
        return {"answer": resp, "messages": messages}

    def update_memory(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        Store new query-answer pairs in memory for faster future lookup.
        """
        if state.get("from_memory"):
            return {}
        raw_query = state["query"]
        key = raw_query.strip().lower()
        answer = state["answer"]
        if answer is not None:
            self._memory.set(key, answer)
            self._logger.info("Stored query-answer in memory for key: %s", key)
        return {}

    def output_answer(self, state: RagAgenticModel.RAGState) -> Dict[str, Any]:
        """
        The final node. We do not print to STDOUT when serving via MLflow.
        Just return an empty dict as this node does not add new state.
        """
        return {}

    # ----------------------------------------
    # Helper Methods
    # ----------------------------------------
    def _get_response_from_llm(self, system_prompt: str, user_prompt: str) -> str:
        """
        Wrap the LLM call into the meta-prompt format and return the .content string.
        """
        meta_llama_prompt = (
            f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
            f"{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
            f"{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        )
        raw = self._llm(meta_llama_prompt)
        # TensorRTLangchain returns a raw string; we can wrap into Response if needed
        return raw

    # def _route_relevance(self, state: RagAgenticModel.RAGState) -> Literal["irrelevant", "relevant"]:
    def _route_relevance(self, state: RagAgenticModel.RAGState) -> str:
        return "relevant" if state["is_relevant"] else "irrelevant"

    # def _route_memory(self, state: RagAgenticModel.RAGState) -> Literal["cached", "not_cached"]:
    def _route_memory(self, state: RagAgenticModel.RAGState) -> str:
        return "cached" if state.get("from_memory") else "not_cached"

    def _build_state_graph(self) -> None:
        """
        Construct and compile the LangGraph state graph exactly as in the notebook.
        """
        rag_graph = StateGraph(RagAgenticModel.RAGState)

        # Add nodes
        rag_graph.add_node("ingest_query", self.ingest_query)
        rag_graph.add_node("check_relevance", self.check_relevance)
        rag_graph.add_node("rewrite_query", self.rewrite_query)
        rag_graph.add_node("check_memory", self.check_memory)
        rag_graph.add_node("retrieve_chunks", self.retrieve_chunks)
        rag_graph.add_node("generate_answer", self.generate_answer)
        rag_graph.add_node("update_memory", self.update_memory)
        rag_graph.add_node("output_answer", self.output_answer)

        # Add edges
        rag_graph.add_edge(START, "ingest_query")
        rag_graph.add_edge("ingest_query", "check_relevance")

        rag_graph.add_conditional_edges(
            "check_relevance",
            self._route_relevance,
            {
                "irrelevant": "output_answer",
                "relevant": "check_memory",
            },
        )

        rag_graph.add_conditional_edges(
            "check_memory",
            self._route_memory,
            {
                "cached": "output_answer",
                "not_cached": "rewrite_query",
            },
        )

        rag_graph.add_edge("rewrite_query", "retrieve_chunks")
        rag_graph.add_edge("retrieve_chunks", "generate_answer")
        rag_graph.add_edge("generate_answer", "update_memory")
        rag_graph.add_edge("update_memory", "output_answer")
        rag_graph.add_edge("output_answer", END)

        # Compile graph
        self._compiled_graph = rag_graph.compile()

    # ----------------------------------------
    # MLflow PythonModel Interface
    # ----------------------------------------
    def predict(self, context: mlflow.pyfunc.PythonModelContext, model_input):
        """
        The MLflow inference entrypoint. Expects model_input = {"query": "<user question>"}.
        Returns a dict with:
          - "answer": str
          - "retrieved_chunks": List[str]
          - "messages": List[Dict[str, Any]]
        """
        print('MODEL INPUT')
        print(type(model_input))
        print(model_input)
        # If MLflow gave us a pandas DataFrame, extract the first row
        if isinstance(model_input, pd.DataFrame):
            if "query" not in model_input.columns:
                raise Exception("DataFrame must contain a 'query' column.")
            # Take the first record in that column
            raw_query = model_input["query"].iloc[0]
        else:
            # Could be a plain dict or something else
            if not isinstance(model_input, dict):
                raise Exception(
                    f"Unexpected input type: {type(model_input)}. "
                    "Expected pandas.DataFrame or dict with 'query'."
                )
            # If it's a dict, accept either string or single-element list
            if "query" not in model_input:
                raise Exception("Input dict must contain key 'query'.")
            raw_query = model_input["query"]

        # Initialize state with topic, query, and empty messages
        initial_state: RagAgenticModel.RAGState = {
            "topic": self.TOPIC,
            "query": raw_query.strip(),
            "messages": [],
        }

        # Invoke the compiled LangGraph
        final_state = self._compiled_graph.invoke(input=initial_state)

        # Extract elements to return
        answer = final_state.get("answer", "")
        retrieved_chunks = final_state.get("retrieved_chunks", [])
        messages = final_state.get("messages", [])

        return {
            "answer": answer,
            "retrieved_chunks": retrieved_chunks,
            "messages": messages,
        }

    @classmethod
    def log_model(cls, model_name: str) -> None:
        """
        Logs RagAgenticModel to MLflow and registers it in the Model Registry.

        1. Assumes the following local directories exist relative to this file:
             - data/chroma_store/       (persisted Chroma index)
             - llm_weights/Llama-3.1-Nemotron-Nano-8B-v1/ (LLM weights folder)
             - data/memory/memory.json  (initial memory file; created if missing)
        2. Creates an MLflow run, logs the PyFunc model with the three artifacts, and registers it.

        Args:
          model_name (str): Name to register under in the MLflow Model Registry.
        """
        # 1. Configure MLflow experiment & logging
        logger = logging.getLogger("RagAgenticModel.log_model")
        if not logger.handlers:
            h = logging.StreamHandler()
            h.setFormatter(
                logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
            )
            logger.addHandler(h)
            logger.setLevel(logging.INFO)

        # 2. Define local artifact paths (adjust if your folder structure differs)
        project_root = Path.cwd().parent.resolve()
        chroma_dir_local = project_root / "data" / "chroma_store"
        memory_path_local = project_root / "data" / "memory" / "memory.json"

        # 3. Validate local artifacts
        if not chroma_dir_local.exists():
            raise FileNotFoundError(f"Chroma directory not found at {chroma_dir_local}")
        memory_path_local.parent.mkdir(parents=True, exist_ok=True)
        if not memory_path_local.exists():
            memory_path_local.write_text("{}", encoding="utf-8")


        # 4.a. Define input schema: a single column "query" of type string
        input_schema = Schema([ColSpec(DataType.string, "query")])
        # We omit output_schema (PyFunc can return arbitrary JSON), but we supply the input signature
        signature = ModelSignature(inputs=input_schema)

        # 4.b. Collect artifacts
        artifacts: Dict[str, str] = {
            "chroma_dir": str(RagAgenticModel.CHROMA_DIR),
            "memory_path": str(RagAgenticModel.MEMORY_PATH),
        }


        # 4.c. Log the PyFunc model
        mlflow.pyfunc.log_model(
            artifact_path=model_name,
            python_model=cls(),
            artifacts=artifacts,
            signature=signature,
            pip_requirements="../requirements.txt",
            code_paths=["../src"],
        )
        logger.info(f"Logged RagAgenticModel under artifact_path '{model_name}'")



In [6]:
# 1. Configuration: experiment, model name, run name
MODEL_NAME = "Agentic_RAG_Model"
RUN_NAME = f"Register_{MODEL_NAME}"
EXPERIMENT_NAME = "Agentic_RAG_Experiment"

In [7]:
# 2. Set MLflow tracking URI and experiment
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "/phoenix/mlflow"))
mlflow.set_experiment(experiment_name=EXPERIMENT_NAME)
print(f"Using MLflow tracking URI: {mlflow.get_tracking_uri()}")
print(f"Experiment: {EXPERIMENT_NAME}")

Using MLflow tracking URI: /phoenix/mlflow
Experiment: Agentic_RAG_Experiment


In [8]:
%%time

# 3. Start an MLflow run and log + register the model
with mlflow.start_run(run_name=RUN_NAME) as run:
    print(f"Started MLflow run: {run.info.run_id}")

    # Log RagAgenticModel using the class method
    RagAgenticModel.log_model(model_name=MODEL_NAME)

    model_uri = f"runs:/{run.info.run_id}/{MODEL_NAME}"
    mlflow.register_model(model_uri=model_uri, name=MODEL_NAME)



Started MLflow run: 4d6479d2e5614326ae847b856e1da63d


Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 144.15it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 120.48it/s]
2025-06-13 21:21:48,912 [INFO] RagAgenticModel.log_model: Logged RagAgenticModel under artifact_path 'Agentic_RAG_Model'
2025-06-13 21:21:48 [INFO] RagAgenticModel.log_model: Logged RagAgenticModel under artifact_path 'Agentic_RAG_Model'
Registered model 'Agentic_RAG_Model' already exists. Creating a new version of this model...


CPU times: user 1.1 s, sys: 264 ms, total: 1.37 s
Wall time: 3.34 s


Created version '7' of model 'Agentic_RAG_Model'.


In [9]:
# 4. Retrieve the latest version from the Model Registry
client = MlflowClient()
versions = client.get_latest_versions(MODEL_NAME, stages=["None"])
if not versions:
    raise RuntimeError(f"No registered versions found for model '{MODEL_NAME}'.")
latest_version = versions[0].version

model_info = mlflow.models.get_model_info(f"models:/{MODEL_NAME}/{latest_version}")
print(f"Latest registered version of '{MODEL_NAME}': {latest_version}")
print(f"Signature: {model_info.signature}")

  versions = client.get_latest_versions(MODEL_NAME, stages=["None"])


Latest registered version of 'Agentic_RAG_Model': 7
Signature: inputs: 
  ['query': string (required)]
outputs: 
  None
params: 
  None



In [10]:
%%time

# 5. Load the model from the Model Registry
loaded_model = mlflow.pyfunc.load_model(model_uri=f"models:/{MODEL_NAME}/{latest_version}")
print(f"Successfully loaded model '{MODEL_NAME}' version {latest_version} for inference.")

  self._embed_model = HuggingFaceEmbeddings(
2025-06-13 21:21:51 [INFO] sentence_transformers.SentenceTransformer: Use pytorch device_name: cuda:0
2025-06-13 21:21:51 [INFO] sentence_transformers.SentenceTransformer: Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
  self._vectorstore = Chroma(
2025-06-13 21:21:54 [INFO] chromadb.telemetry.product.posthog: Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.
Loading Model: [1;32m[1/3]	[0mDownloading HF model
[38;20mDownloaded model to /root/.cache/huggingface/hub/models--nvidia--Llama-3.1-Nemotron-Nano-8B-v1/snapshots/a22e1c57330633cd3522903f9bb82480bf3192a6
[0m[38;20mTime: 53.574s
[0mLoading Model: [1;32m[2/3]	[0mLoading HF model to memory
230it [00:00, 974.27it/s]
[38;20mTime: 0.800s
[0mLoading Model: [1;32m[3/3]	[0mBuilding TRT-LLM engine
[38;20mTime: 63.005s
[0m[1;32mLoading model done.
[0m[38;20mTotal latency: 117.380s
[0m2025

[TensorRT-LLM] TensorRT-LLM version: 0.18.0
[TensorRT-LLM][INFO] Engine version 0.18.0 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] Refreshed the MPI local session
[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
[TensorRT-LLM][INFO] Rank 0 is using GPU 0
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 2048
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 2048
[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1
[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 131072
[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (131072) * 32
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 0
[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 8192
[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 8192 = min(maxSequenceLen - 1, maxNumTokens) since context FMHA and usePackedInput are enabled
[TensorRT-LLM][INFO] TRTGptModel If model type is enc

In [11]:
# 6. Run a sample inference using the loaded model
sample_query = "What is the hardware requirement for AI Studio?"
input_payload = {"query": sample_query}

print("\n=== Running Sample Inference ===")
result = loaded_model.predict(input_payload)

2025-06-13 21:24:39,411 [INFO] RagAgenticModel: Received user query: What is the hardware requirement for AI Studio?
2025-06-13 21:24:39 [INFO] RagAgenticModel: Received user query: What is the hardware requirement for AI Studio?
  raw = self._llm(meta_llama_prompt)



=== Running Sample Inference ===
MODEL INPUT
<class 'pandas.core.frame.DataFrame'>
                                             query
0  What is the hardware requirement for AI Studio?


Processed requests: 100%|██████████| 1/1 [00:02<00:00,  2.51s/it]
2025-06-13 21:24:41,934 [INFO] RagAgenticModel: Relevance check result: True
2025-06-13 21:24:41 [INFO] RagAgenticModel: Relevance check result: True
2025-06-13 21:24:41,938 [INFO] RagAgenticModel: Cache miss for query: What is the hardware requirement for AI Studio?
2025-06-13 21:24:41 [INFO] RagAgenticModel: Cache miss for query: What is the hardware requirement for AI Studio?
Processed requests: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it]
2025-06-13 21:24:42,962 [INFO] RagAgenticModel: Rewritten query: The required hardware for AI Studio must have at least X GB of RAM and a multi-core processor. (Assuming specific technical details were missing in your note.)
2025-06-13 21:24:42 [INFO] RagAgenticModel: Rewritten query: The required hardware for AI Studio must have at least X GB of RAM and a multi-core processor. (Assuming specific technical details were missing in your note.)
2025-06-13 21:24:43,674 [INFO] RagAgenti

In [12]:
# 7. Print results
print(f"Query:")
print("{sample_query}\n")
print("\n==============\n")

print("Answer:")
print(result.get("answer", "<no answer>"), "\n")
print("\n==============\n")

print("Retrieved Chunks:")
for idx, chunk in enumerate(result.get("retrieved_chunks", []), start=1):
    print(f"  {idx}. {chunk[:100]}{'...' if len(chunk)>100 else ''}")

print("\n==============\n")
print("\nMessage History:")
for msg in result.get("messages", []):
    role = msg.get("role", "<unknown>")
    content = msg.get("content", "")
    print(f"  [{role}]: {content}")

Query:
{sample_query}



Answer:
AMD Ryzen™ 9 processor, Intel Core™ i5 12th generation processor, or higher 



Retrieved Chunks:
  1. Technical Requirements

Hardware:

Windows 10 or 11 or Linux Ubuntu 22.04 LTS on a workstation

GPU ...
  2. Software:

Windows 10 or 11 or Linux Ubuntu 22.04 LTS

Windows OS requires Windows Subsystem for Lin...
  3. title: 'System Requirements' sidebar_position: 1

System Requirements

Z by HP AI Studio currently r...
  4. Distro selection modal

:::tip

If git is not already installed on your machine, the app will guide ...
  5. title: 'Troubleshooting AI Studio' sidebar_position: 6

AI Studio Troubleshooting Guide

Find quick ...



Message History:
  [user]: What is the hardware requirement for AI Studio?
  [developer]: Relevance check result:
  [assistant]: Yes
  [developer]: Rewritten query:
  [assistant]: The required hardware for AI Studio must have at least X GB of RAM and a multi-core processor. (Assuming specific technical details were miss