# Financial Compliance Agent Evaluation Pipeline (SageMaker Pipelines + Managed MLflow)

This notebook productionalizes the financial compliance agent evaluation flow into a
three-step Amazon SageMaker Pipeline:

1. **Data Preparation**  
   - Load ground-truth dataset from S3  
   - Basic validation and dataset profiling  
   - Log dataset metadata to **SageMaker managed MLflow**

2. **Agent Inference**  
   - Build the RAG + Web Search agent (Qwen on Amazon Bedrock)  
   - Run inference over all ground-truth prompts  
   - Extract normalized outputs (clean answers, retrieved contexts, tool usage)  
   - Persist an evaluation-ready dataset to S3  
   - Log inference-level metrics to MLflow

3. **Evaluation & Metrics**  
   - Compute semantic similarity (SAS)  
   - Compute tool-selection confusion matrix + accuracy  
   - (Optionally) run LLM-as-a-judge factuality checks  
   - Log all metrics and artifacts to **SageMaker managed MLflow**  
   - Enforce an `AccuracyThreshold` quality gate

## Dependencies

This cell cleans out any old versions of dependencies in our environment, and installs the libraries needed to run this pipeline. Make sure to run this cell below then restart the Kernal before proceeding

In [None]:
%%capture
import sys, subprocess

# Cell 1 - Imports
subprocess.call(
    [
        sys.executable,
        "-m",
        "pip",
        "uninstall",
        "-y",
        "mlflow",
        "haystack",
        "haystack-ai",
        "haystack-experimental",
        "chroma-haystack",
        "amazon-bedrock-haystack",
        "sentence-transformers",
        "protobuf",
    ]
)

# 2) Install from updated requirements.txt
subprocess.check_call(
    [sys.executable, "-m", "pip", "install", "-U", "-r", "requirements.txt"]
)

import haystack
print("haystack-ai version:", getattr(haystack, "__version__", "unknown"))

## Environment 

This cell will setup our envinronment, the two values you must manually enter here are:

1/ Sagemaker role (if not using Sagemaker AI)

2/ MLFlow tracking ARN

In [None]:
#Cell 2 - Setup
import os
import boto3
import sagemaker

from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import (
    ParameterString,
    ParameterFloat,
    ParameterInteger,
)

from sagemaker.workflow.function_step import step  # @step decorator


# Basic AWS / SageMaker setup
session = boto3.Session()
region = session.region_name or os.getenv("AWS_REGION", "us-east-1")

# In SageMaker Studio this will resolve automatically
try:
    role = sagemaker.get_execution_role()
except Exception:
    # Fallback for local dev; replace with your role ARN if needed
    role = os.getenv("SAGEMAKER_ROLE_ARN", "YOUR SAGEMAKER EXECUTION ROLE HERE")

mlflow_arn="YOUR MLFLOW TRACKING SERVER ARN HERE"

pipeline_session = PipelineSession()
default_bucket = pipeline_session.default_bucket()
base_job_prefix = "financial-compliance-agent-eval"

print("Region:", region)
print("Role:", role)
print("Default bucket:", default_bucket)

## Pipeline parameter setup 

here we are setting up parameters for our pipeline, ensure:

1/ The S3 bucket path pointing to the ground truth data exists and the groundtruth file (in the data folder of this repo) is uploaded there.

2/ You select the appropriate Bedrock model that would be powering the Agent under evaluation.

3/ Accuracy threshold required for your usecase 

4/ Rate limit delay in seconds, so you do not encounter Bedrock rate limits

In [None]:
#Cell 3 - Pipeline parameters
from sagemaker.workflow.parameters import ParameterString, ParameterFloat, ParameterInteger

# Pipeline parameters (editable when starting the pipeline)

# Input ground-truth dataset (ground_truth.json) in S3
DataInputS3Uri = ParameterString(
    name="DataInputS3Uri",
    default_value=f"s3://{default_bucket}/{base_job_prefix}/data/ground_truth.json",
)

# Base output prefix for artifacts & intermediate outputs
BaseOutputS3Uri = ParameterString(
    name="BaseOutputS3Uri",
    default_value=f"s3://{default_bucket}/{base_job_prefix}/artifacts",
)

# SageMaker managed MLflow tracking server ARN
MLflowTrackingServerArn = ParameterString(
    name="MLflowTrackingServerArn",
    #default_value="arn:aws:sagemaker:REGION:ACCOUNT_ID:mlflow-tracking-server/YOUR-ID",
    default_value=mlflow_user_arn
)

# MLflow experiment name
MLflowExperimentName = ParameterString(
    name="MLflowExperimentName",
    default_value="financial-compliance-agent-eval",
)

# Bedrock model ID (Qwen or any chat model you want)
ModelId = ParameterString(
    name="ModelId",
    default_value="qwen.qwen3-32b-v1:0",
)

# Prompt ID (if you later wire Bedrock Prompt Management; not strictly required here)
PromptId = ParameterString(
    name="PromptId",
    default_value="financial-compliance-base-prompt",
)

# Evaluation + throttling parameters
AccuracyThreshold = ParameterFloat(
    name="AccuracyThreshold",
    default_value=0.8,
)

RateLimitDelaySeconds = ParameterInteger(
    name="RateLimitDelaySeconds",
    default_value=10,
)

print("Pipeline parameters created.")

## Helper functions 

In this cell we have helper functions for the workflow

In [None]:
#Cell 4 - S3 + Mlflow utilities 
import io
import json
from urllib.parse import urlparse

import boto3
import mlflow
import pandas as pd


def _parse_s3_uri(s3_uri: str):
    """Split s3://bucket/key into (bucket, key)."""
    if not s3_uri.startswith("s3://"):
        raise ValueError(f"Invalid S3 URI: {s3_uri}")
    parsed = urlparse(s3_uri)
    bucket = parsed.netloc
    key = parsed.path.lstrip("/")
    return bucket, key


def read_json_records_from_s3(s3_uri: str) -> pd.DataFrame:
    """Read a JSON 'records' file from S3 into a DataFrame."""
    bucket, key = _parse_s3_uri(s3_uri)
    s3 = boto3.client("s3")
    obj = s3.get_object(Bucket=bucket, Key=key)
    body = obj["Body"].read()
    return pd.read_json(io.BytesIO(body), orient="records")


def write_json_records_to_s3(df: pd.DataFrame, s3_uri: str):
    """Write a DataFrame as JSON 'records' to S3."""
    bucket, key = _parse_s3_uri(s3_uri)
    s3 = boto3.client("s3")
    body = df.to_json(orient="records").encode("utf-8")
    s3.put_object(Bucket=bucket, Key=key, Body=body)


def init_mlflow(tracking_server_arn: str, experiment_name: str):
    """
    Connect to SageMaker managed MLflow and select/create the experiment.

    Note: in the console you'll copy the MLflow tracking server ARN and pass
    it via `MLflowTrackingServerArn` pipeline parameter.
    """
    mlflow.set_tracking_uri(tracking_server_arn)
    mlflow.set_experiment(experiment_name)

## Step 1 - Data prep

Here is where the actual first pipeline step is defined. Here we specify the compute that this step will run on, the requirements.txt file, the logic in the step, and the metrics that will be logged in MLFlow

In [None]:
#Cell 5 - Step 1 Data prep
from typing import Optional

@step(
    name="data-preparation",
    instance_type="ml.m5.xlarge",
    keep_alive_period_in_seconds=3600,
    dependencies="requirements.txt",
)
def data_preparation_step(
    data_input_s3_uri: str,
    base_output_s3_uri: str,
    tracking_server_arn: str,
    experiment_name: str,
    pipeline_run_id: str,
) -> str:
    """
    Step 1: Load ground-truth dataset, validate, log to MLflow, and
    write a normalized dataset to S3 for downstream steps.

    Returns:
        normalized_dataset_s3_uri (str): S3 URI of normalized dataset.
    """
    import numpy as np

    init_mlflow(tracking_server_arn, experiment_name)

    run_name = f"data-prep-{pipeline_run_id}"
    with mlflow.start_run(run_name=run_name):
        df = read_json_records_from_s3(data_input_s3_uri)

        # Basic sanity checks / stats
        n_rows = len(df)
        n_rag = int((df.get("tool_label") == "rag").sum()) if "tool_label" in df else 0
        n_web = int((df.get("tool_label") == "web_search").sum()) if "tool_label" in df else 0

        mlflow.log_param("data_input_s3_uri", data_input_s3_uri)
        mlflow.log_metric("num_rows", n_rows)
        mlflow.log_metric("num_rag_rows", n_rag)
        mlflow.log_metric("num_web_search_rows", n_web)

        # Normalize/ensure expected columns exist
        expected_cols = ["prompt", "output", "tool_label", "context", "page"]
        for col in expected_cols:
            if col not in df.columns:
                df[col] = np.nan

        # Output location for normalized dataset
        normalized_dataset_s3_uri = (
            f"{base_output_s3_uri.rstrip('/')}/data/ground_truth_normalized.json"
        )

        write_json_records_to_s3(df, normalized_dataset_s3_uri)
        mlflow.log_param("normalized_dataset_s3_uri", normalized_dataset_s3_uri)

    return normalized_dataset_s3_uri

## Agent + RAG helpers 

In this cell we are laying out the code & logic that will be used in the inference step

In [None]:
# Cell 6 â€“ Agent + RAG helpers 

import json
import ast
from typing import List, Any, Dict, Optional

from haystack.core.pipeline import Pipeline as HaystackPipeline
from haystack.components.converters import PyPDFToDocument
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
from haystack.components.writers import DocumentWriter

from haystack_integrations.document_stores.chroma import ChromaDocumentStore
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever

from haystack.components.agents import Agent
from haystack.dataclasses import ChatMessage
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
from haystack_integrations.components.generators.amazon_bedrock import (
    AmazonBedrockChatGenerator,
)
from haystack.tools import tool

from ddgs import DDGS
from ddgs.exceptions import DDGSException, RatelimitException

from haystack import Document

# Global retriever used by rag_tool (call init_chroma_retriever first)
_retriever: Optional[ChromaQueryTextRetriever] = None

def init_chroma_retriever(
    pdf_paths: Optional[List[str]] = None,
    persist_path: str = "../data/10k-vec-db",
    recreate: bool = False,
    split_length: int = 150,
) -> ChromaQueryTextRetriever:
    """
    Initialize (and optionally rebuild) a ChromaDocumentStore + retriever.

    Args:
        pdf_paths: List of local PDF paths used to (re)build the store.
        persist_path: Folder where Chroma DB will be persisted.
        recreate: If True, rebuild the store from the provided PDFs.
        split_length: Word-level split length for DocumentSplitter.

    Returns:
        ChromaQueryTextRetriever instance.
    """
    global _retriever

    document_store = ChromaDocumentStore(persist_path=persist_path)

    if recreate and pdf_paths:
        pipe = HaystackPipeline()
        pipe.add_component("converter", PyPDFToDocument())
        pipe.add_component("cleaner", DocumentCleaner())
        pipe.add_component(
            "splitter",
            DocumentSplitter(split_by="word", split_length=split_length),
        )
        pipe.add_component("writer", DocumentWriter(document_store=document_store))

        pipe.connect("converter", "cleaner")
        pipe.connect("cleaner", "splitter")
        pipe.connect("splitter", "writer")

        pipe.run({"converter": {"sources": pdf_paths}})

    _retriever = ChromaQueryTextRetriever(document_store=document_store)
    return _retriever

# Tools: web_search + rag_tool (uses global _retriever)
@tool
def web_search(keywords: str, region: str = "us-en", max_results: int = 3) -> Any:
    """Search the web for updated information.

    Args:
        keywords: The search query keywords.
        region: The search region: wt-wt, us-en, uk-en, ru-ru, etc.
        max_results: The maximum number of results to return.

    Returns:
        List of dictionaries with search results, or an error string.
    """
    try:
        results = DDGS().text(keywords, region=region, max_results=max_results)
        return results if results else "No results found."
    except RatelimitException:
        return "Rate limit reached. Please try again later."
    except DDGSException as e:
        return f"Search error: {e}"
    except Exception as e:
        return f"Search error: {str(e)}"


@tool
def rag_tool(query: str) -> List[str]:
    """Use this tool to get grounded information for answering queries
    about Amazon (10-K data through 2023).

    Returns a list of text chunks.
    """
    if _retriever is None:
        raise RuntimeError(
            "Chroma retriever is not initialized. "
            "Call init_chroma_retriever(...) before using rag_tool."
        )

    docs = _retriever.run(query=query)["documents"]
    return [doc.content for doc in docs]

# Agent builder (Qwen on Bedrock) with your system prompt
_SYSTEM_PROMPT = """
You are a professional Amazon research agent with access to two tools:
1. RAG context retrieval tool (`rag_tool`): Contains Amazon 10-K filings data through 2023.
2. Web search tool (`web_search`): For current information beyond 2023.

TOOL SELECTION RULES:
- Use ONLY `rag_tool` for questions about Amazon data from 2023 or earlier.
- Use ONLY `web_search` for questions about Amazon data from 2024 or later.
- NEVER use both tools for a single query.
- You must call the single tool you selected based on the criteria ONCE AND ONLY ONCE.

EXAMPLES FOR RAG TOOL (2023 and earlier data):
- "What was Amazon's revenue in 2022?" â†’ rag_tool
- "Who was Amazon's CFO in 2023?" â†’ rag_tool
- "What were Amazon's operating expenses in 2021?" â†’ rag_tool
- "Who served on Amazon's board of directors in 2023?" â†’ rag_tool

EXAMPLES FOR WEB SEARCH TOOL (2024 and later data):
- "What is Amazon's current stock price?" â†’ web_search
- "What are Amazon's 2024 earnings?" â†’ web_search
- "Who is Amazon's current CEO?" â†’ web_search
- "What new products did Amazon launch in 2024?" â†’ web_search

DECISION LOGIC:
- If the question asks about historical data (2023 or earlier) â†’ rag_tool.
- If the question asks about current/recent data (2024 or later) â†’ web_search.
- If the question doesn't specify a time period but asks for "current" information â†’ web_search.

Give concise, factual answers without preamble. Always use exactly one tool per response.
""".strip()

def build_financial_agent(model_id: str) -> Agent:
    """
    Build and warm up the Haystack Agent using a Bedrock-hosted model.

    Args:
        model_id: Bedrock model ID (e.g., "qwen.qwen3-32b-v1:0").

    Returns:
        Warmed-up haystack Agent instance.
    """
    chat_generator = AmazonBedrockChatGenerator(
        model=model_id,
        generation_kwargs={"temperature": 0.1},
    )

    agent = Agent(
        chat_generator=chat_generator,
        tools=[web_search, rag_tool],
        system_prompt=_SYSTEM_PROMPT,
        exit_conditions=["text"],
        max_agent_steps=2,  # one tool call + one final answer
        raise_on_tool_invocation_failure=False,
    )

    agent.warm_up()
    return agent

# Prompt builder (ChatPromptBuilder)
def format_prompt(query: str) -> List[ChatMessage]:
    """
    Build a prompt for the Agent enforcing ONE tool call and time-based
    tool-selection rules via the user message.
    """
    template = [
        ChatMessage.from_user(
            "Using only ONE of the available tools, accurately answer the "
            "following question:\n\n{{question}}\n\n"
            "CRITICAL INSTRUCTIONS:\n"
            "- Select EXACTLY ONE tool based on the time period criteria in your system prompt\n"
            "- Make ONLY ONE tool call - do not break down or modify the query\n"
            "- If the question is about 2023 or earlier Amazon data â†’ use rag_tool\n"
            "- If the question is about 2024+ or current Amazon data â†’ use web_search\n"
            "- Answer directly after your single tool call"
        )
    ]
    builder = ChatPromptBuilder(template=template, required_variables=["question"])
    result = builder.run(question=query)
    return result["prompt"]  # List[ChatMessage]

# Tool-context extraction (get_clean_docs) â€“ adapted to rag_tool
def _parse_result_payload(payload: Any) -> Any:
    """Return a Python object from payload (str|list|dict)."""
    if isinstance(payload, (list, dict)):
        return payload
    if isinstance(payload, str):
        s = payload.strip()
        # try JSON first
        try:
            return json.loads(s)
        except Exception:
            pass
        # then ast as a fallback
        try:
            return ast.literal_eval(s)
        except Exception:
            # last resort: treat the raw string as a single doc
            return [{"content": s}]
    # Unknown type -> wrap as string
    return [{"content": str(payload)}]

def _coerce_to_documents(obj: Any) -> List[Document]:
    """Normalize various shapes to List[haystack.Document]."""
    # If it's a dict, look for common keys
    if isinstance(obj, dict):
        for key in ("documents", "docs", "results", "retrieved_documents", "retrieved_docs"):
            if key in obj and isinstance(obj[key], list):
                return _coerce_to_documents(obj[key])
        # maybe itâ€™s a single doc-like dict
        obj = [obj]

    docs_out: List[Document] = []
    if isinstance(obj, list):
        for item in obj:
            if isinstance(item, Document):
                docs_out.append(item)
                continue
            if isinstance(item, dict):
                content = (
                    item.get("content")
                    or item.get("text")
                    or item.get("page_content")
                    or ""
                )
                meta = item.get("meta") or item.get("metadata") or {}
                if content is None:
                    content = ""
                docs_out.append(Document(content=content, meta=meta))
                continue
            # anything else -> coerce to string content
            docs_out.append(Document(content=str(item), meta={}))
    return docs_out

def get_clean_docs(answer: Dict[str, Any], target_tool_name: str = "rag_tool") -> List[Document]:
    """
    Walks tool messages in `answer['messages']` and extracts documents
    robustly for the given tool name (defaults to 'rag_tool').
    """
    try:
        candidates = []
        for msg in answer.get("messages", []):
            role = getattr(msg, "_role", None)
            role_val = getattr(role, "value", None) or getattr(msg, "role", None)
            if str(role_val).lower() != "tool":
                continue

            content = getattr(msg, "_content", None) or getattr(msg, "content", None) or []
            if isinstance(content, list):
                for part in content:
                    result = getattr(part, "result", None)
                    origin = getattr(part, "origin", None)
                    tool_name = getattr(origin, "tool_name", None)
                    if result is None:
                        continue
                    if target_tool_name and tool_name == target_tool_name:
                        candidates.append(result)
                    else:
                        candidates.append(result)

        if not candidates:
            return []

        payload = candidates[0]
        parsed = _parse_result_payload(payload)
        return _coerce_to_documents(parsed)

    except Exception as e:
        print(f"Error parsing documents: {e}")
        return []

# Tool usage extraction â€“ normalized to 'rag' vs 'web_search'
def extract_combined_tools(raw_answer: Dict[str, Any]) -> str:
    """Extract all tools used in one interaction and join with |."""
    tools_used: List[str] = []

    if not raw_answer or "messages" not in raw_answer:
        return "none"

    messages = raw_answer.get("messages", [])

    for message in messages:
        content = getattr(message, "_content", []) or []

        for item in content:
            if hasattr(item, "tool_name"):
                tool_name = item.tool_name

                # Normalize tool names
                if "context_retrieval" in tool_name or "rag_tool" in tool_name:
                    tool_name = "rag"
                elif "web_search" in tool_name:
                    tool_name = "web_search"

                if tool_name not in tools_used:
                    tools_used.append(tool_name)

    return " | ".join(tools_used) if tools_used else "none"


print("In-notebook agent + RAG helpers initialized.")

## Step 2 - Agent Inference

In this step we will conduct agent inference against the ground dataset just like in the first notebook. We are writing the results to a dataframe and logging metrics to MLFlow 

In [None]:
# Cell 7 - Inference
@step(
    name="agent-inference",
    instance_type="ml.m5.xlarge",   # or whatever instance type you want
    keep_alive_period_in_seconds=3600,
    dependencies="requirements.txt",
)
def agent_inference_step(
    dataset_s3_uri: str,
    model_id: str,
    prompt_id: str,  # kept for consistency, even if unused
    base_output_s3_uri: str,
    tracking_server_arn: str,
    experiment_name: str,
    pipeline_run_id: str,
    rate_limit_delay: int = 10,
) -> str:
    """
    Step 2: Run agent inference on all prompts, build an evaluation-ready
    dataset, and write it to S3.

    Returns:
        eval_dataset_s3_uri (str): S3 URI of results with columns:
          [prompt, output, tool_label, context, page,
           clean_answers, extracted_contexts, tool_used, raw_answers]
    """
    import time
    import numpy as np
    import pandas as pd
    from tqdm import tqdm

    # MLflow setup
    init_mlflow(tracking_server_arn, experiment_name)

    run_name = f"inference-{pipeline_run_id}"
    with mlflow.start_run(run_name=run_name):
        df = read_json_records_from_s3(dataset_s3_uri)

        mlflow.log_param("dataset_s3_uri", dataset_s3_uri)
        mlflow.log_param("model_id", model_id)
        mlflow.log_param("prompt_id", prompt_id)
        mlflow.log_param("rate_limit_delay_seconds", rate_limit_delay)

        # Initialize Chroma RAG index (no rebuild by default)
        init_chroma_retriever(
            pdf_paths=None,      # or ["../data/AMZN-2023-10k.pdf"] if you want to rebuild
            recreate=False,
        )

        # Build Bedrock-backed financial agent (Qwen + RAG + web_search)
        agent = build_financial_agent(model_id=model_id)

        # Run inference over all prompts
        prompts = df["prompt"].tolist()
        answers = []

        for p in tqdm(prompts, desc="Running agent inference"):
            try:
                prompt_msg_list = format_prompt(p)  # List[ChatMessage]
                res = agent.run(prompt_msg_list)
            except Exception as e:
                # Keep alignment even if something fails
                res = {"messages": [], "error": str(e)}
            answers.append(res)
            time.sleep(rate_limit_delay)

        # Attach raw answers
        df["raw_answers"] = answers

        # Clean final answers (mirror your original notebook logic)
        def get_final_text(answer):
            try:
                msgs = answer["messages"]
                last = msgs[-1]
                return getattr(last, "text", None) or getattr(last, "content", None) or "I don't know"
            except Exception:
                return "I don't know"

        df["clean_answers"] = [get_final_text(a) for a in answers]

        # Extract retrieved contexts for RAG
        df["extracted_contexts"] = [get_clean_docs(a) for a in answers]

        # Extract tool usage
        df["tool_used"] = [extract_combined_tools(a) for a in answers]

        # Basic counts for logging
        tool_counts = df["tool_used"].value_counts().to_dict()
        for k, v in tool_counts.items():
            mlflow.log_metric(f"tool_used_{k}", float(v))

        # Write evaluation-ready dataset to S3
        eval_dataset_s3_uri = (
            f"{base_output_s3_uri.rstrip('/')}/eval/agent_eval_dataset.json"
        )
        write_json_records_to_s3(df, eval_dataset_s3_uri)
        mlflow.log_param("eval_dataset_s3_uri", eval_dataset_s3_uri)

    return eval_dataset_s3_uri

## Step 3 - Evaluations 

In this cell we will be setting up the 3rd pipeline step for the evaluations. Here we will be evaluating the metrics such as Semantic similarity, tool selection accuracy, and LLM-as-Judge.

In [None]:
# Cell 8 â€“ Evaluation step (SAS, tool selection, LLM-as-judge)

from typing import List, Any, Dict


@step(
    name="agent-evaluation",
    instance_type="ml.m5.large",
    keep_alive_period_in_seconds=3600,
    dependencies="requirements.txt",
)
def agent_evaluation_step(
    eval_dataset_s3_uri: str,
    tracking_server_arn: str,
    experiment_name: str,
    pipeline_run_id: str,
    accuracy_threshold: float,
) -> str:
    """
    Step 3: Evaluate agent performance (SAS, tool selection, LLM-as-judge, nDCG)
    and log all metrics to MLflow.

    Inputs:
        eval_dataset_s3_uri : S3 URI produced by the inference step
        tracking_server_arn : SageMaker MLflow tracking server ARN
        experiment_name     : MLflow experiment name
        pipeline_run_id     : SageMaker Pipeline run id
        accuracy_threshold  : Min acceptable accuracy (via LLM-as-judge)

    Returns:
        eval_dataset_s3_uri (str) â€“ same dataset location (for chaining / inspection)
    """
    import os
    import json
    import numpy as np
    import pandas as pd
    import mlflow

    # ðŸ”‘ Disable TensorFlow so transformers / sentence-transformers
    #     don't try to import TF (which conflicts with Studio protobuf)
    os.environ["TRANSFORMERS_NO_TF"] = "1"
    os.environ["USE_TF"] = "0"

    # Now it's safe to import Haystack evaluators (which pull in sentence-transformers)
    from haystack.components.evaluators import (
        SASEvaluator,
        DocumentNDCGEvaluator,
        LLMEvaluator,
    )
    from haystack_integrations.components.generators.amazon_bedrock import (
        AmazonBedrockChatGenerator,
    )
    from haystack import Document  # used for nDCG ground-truth docs

    # ---------- MLflow setup ----------
    init_mlflow(tracking_server_arn, experiment_name)
    run_name = f"evaluation-{pipeline_run_id}"

    with mlflow.start_run(run_name=run_name):
        mlflow.log_param("eval_dataset_s3_uri", eval_dataset_s3_uri)
        mlflow.log_param("accuracy_threshold", accuracy_threshold)

        # ---------- Load evaluation dataset ----------
        df = read_json_records_from_s3(eval_dataset_s3_uri)

        # Ensure required columns exist
        required_cols = ["prompt", "output", "clean_answers", "tool_label", "tool_used"]
        for c in required_cols:
            if c not in df.columns:
                raise ValueError(f"Expected column '{c}' not found in eval dataset")

        # 1) Semantic Answer Similarity (SAS)
        gt_answers = df["output"].tolist()
        pred_answers = df["clean_answers"].tolist()

        sas_evaluator = SASEvaluator()
        sas_evaluator.warm_up()

        sas_result = sas_evaluator.run(
            ground_truth_answers=gt_answers,
            predicted_answers=pred_answers,
        )

        df["sas_score"] = sas_result["individual_scores"]
        avg_sas = float(sas_result["score"])

        mlflow.log_metric("sas_avg", avg_sas)
        mlflow.log_metric("sas_min", float(np.min(df["sas_score"])))
        mlflow.log_metric("sas_max", float(np.max(df["sas_score"])))
        mlflow.log_metric("sas_std", float(np.std(df["sas_score"])))

        # 2) Tool selection accuracy (rag vs web_search)
        tool_correct = (df["tool_label"].astype(str) == df["tool_used"].astype(str))
        tool_accuracy = float(tool_correct.mean())
        mlflow.log_metric("tool_selection_accuracy", tool_accuracy)

        # 3) LLM-as-a-Judge (factual accuracy)
        # Configure Bedrock-backed judge (Llama 3 or similar model)
        judge_generator = AmazonBedrockChatGenerator(
            model="us.meta.llama3-3-70b-instruct-v1:0",
            generation_kwargs={
                "temperature": 0.0,
            },
        )

        judge_instructions = """
        Evaluate whether the model's answer is factually correct given the ground-truth answer.
        Score:
        - 1 if the predicted answer is factually consistent with the ground truth.
        - 0 if it contradicts, omits key facts, or is clearly incorrect.
        Return ONLY a JSON object: {"score": 0 or 1}
        """.strip()

        llm_judge = LLMEvaluator(
            instructions=judge_instructions,
            chat_generator=judge_generator,
            inputs=[("predicted_answers", List[str])],
            outputs=["score"],
            examples=[],          # ðŸ‘ˆ required positional arg in haystack-ai==2.13.0
            raise_on_failure=False,
            progress_bar=False,
        )

        judge_scores: List[Any] = []
        for _, row in df.iterrows():
            question = str(row["prompt"])
            gt = str(row["output"])
            ans = str(row["clean_answers"])

            item: Dict[str, Any] = {
                "inputs": {
                    "questions": question,
                    # we use ground truth as a reference "context" here;
                    # if you prefer, you can also add row["context"]
                    "contexts": gt,
                },
                "outputs": {
                    "statements": [ans],
                },
            }

            try:
                res = llm_judge.run(predicted_answers=[item])
                s = res["results"][0].get("score", None)
            except Exception:
                s = None
            judge_scores.append(s)

        df["llm_judge_score"] = judge_scores
        valid_scores = [s for s in judge_scores if s is not None]

        if valid_scores:
            avg_judge = float(np.mean(valid_scores))
            ones = sum(1 for s in valid_scores if s == 1)
            zeros = sum(1 for s in valid_scores if s == 0)

            mlflow.log_metric("llm_judge_avg", avg_judge)
            mlflow.log_metric("llm_judge_supported_count", float(ones))
            mlflow.log_metric("llm_judge_unsupported_count", float(zeros))
        else:
            avg_judge = 0.0
            mlflow.log_metric("llm_judge_avg", 0.0)

        # 5) Accuracy gate based on LLM-as-judge
        if valid_scores and avg_judge < accuracy_threshold:
            # Just log a flag instead of failing the pipeline
            mlflow.log_param("accuracy_gate_failed", True)
        else:
            mlflow.log_param("accuracy_gate_failed", False)

        # For now we just return the eval_dataset_s3_uri for chaining.
        return eval_dataset_s3_uri

## Wiring cell 

In this cell we will wire the 3 steps together, setting up the inputs of each step, and creating the Sagemaker pipeline object.  

In [None]:
# Cell 9 - Wiring 
from sagemaker.workflow.pipeline import Pipeline

pipeline_name = "financial-compliance-agent-eval-pipeline"

# Wire steps together using their DelayedReturn outputs

# Step 1 â€“ Data Prep
data_prep = data_preparation_step(
    data_input_s3_uri=DataInputS3Uri,
    base_output_s3_uri=BaseOutputS3Uri,
    tracking_server_arn=MLflowTrackingServerArn,
    experiment_name=MLflowExperimentName,
    pipeline_run_id=ExecutionVariables.PIPELINE_EXECUTION_ID,
)

# Step 2 â€“ Inference (takes the *return value* of step 1 as input)
agent_infer = agent_inference_step(
    dataset_s3_uri=data_prep,    # <-- pass DelayedReturn directly, no .properties
    model_id=ModelId,
    prompt_id=PromptId,
    base_output_s3_uri=BaseOutputS3Uri,
    tracking_server_arn=MLflowTrackingServerArn,
    experiment_name=MLflowExperimentName,
    pipeline_run_id=ExecutionVariables.PIPELINE_EXECUTION_ID,
    rate_limit_delay=RateLimitDelaySeconds,
)

# Step 3 â€“ Evaluation (takes the *return value* of step 2 as input)
agent_eval = agent_evaluation_step(
    eval_dataset_s3_uri=agent_infer,   # <-- pass DelayedReturn directly
    tracking_server_arn=MLflowTrackingServerArn,
    experiment_name=MLflowExperimentName,
    pipeline_run_id=ExecutionVariables.PIPELINE_EXECUTION_ID,
    accuracy_threshold=AccuracyThreshold,
)

# Create SageMaker Pipeline object
#   - only the leaf node (agent_eval) is strictly required
#   - SageMaker infers upstream dependencies from the DelayedReturn graph
pipeline = Pipeline(
    name=pipeline_name,
    parameters=[
        DataInputS3Uri,
        BaseOutputS3Uri,
        MLflowTrackingServerArn,
        MLflowExperimentName,
        ModelId,
        PromptId,
        AccuracyThreshold,
        RateLimitDelaySeconds,
    ],
    steps=[agent_eval],   # leaf node; data_prep & agent_infer inferred automatically
    sagemaker_session=pipeline_session,
)

print("Pipeline object created:", pipeline_name)

## Upsert cell 

In this cell we will Upsert the pipeline. This step takes all of the pipeline level parameters and starts the pipeline based on the wiring we did in the previous cell. From here, you may see the execution & Diagram in the Sagemaker Studio UI. 

In [None]:
# Cell 10 - Upsert 

# Register / update the pipeline definition in SageMaker
pipeline_upsert_response = pipeline.upsert(role_arn=role)
print("Upsert response:", pipeline_upsert_response)

# Example: start an execution (you can also start from the SageMaker console)
execution = pipeline.start(
    parameters={
        "DataInputS3Uri": f"s3://{default_bucket}/{base_job_prefix}/data/ground_truth.json",
        "BaseOutputS3Uri": f"s3://{default_bucket}/{base_job_prefix}/artifacts",
        "MLflowTrackingServerArn": mlflow_user_arn,
        "MLflowExperimentName": "financial-compliance-agent-eval",
        "ModelId": "qwen.qwen3-32b-v1:0",
        "PromptId": "financial-compliance-base-prompt",
        "AccuracyThreshold": 0.8,
        "RateLimitDelaySeconds": 10,
    }
)

print("Started execution:", execution.arn)


## Viewing Sagemaker pipeline & Metrics in MLFlow 

While the pipeline is running you can check the progress by going to your Sagemaker AI Studio domain and selecting "pipelines"
![studio](../images/1-studio-ui.png "studio")

Then select your pipeline name:
![pipeline](../images/2-pipeline-ui.png "pipeline")

Every pipeline can have multiple runs, known as "executions". Select the one you would like to view, in this case the latest one. All of the metadata associated can be viewed here as well. 
![execution](../images/3-execution-ui.png "execution")

Here we can view the pipeline steps in action, you will be able to see if a step has failed, succeeded, and view the logs either in this pane or in Cloudwatch.
![execution](../images/4-pipeline-dag.png "execution")

Going back to the main Sagemaker AI Studio domain page, you can make your way over to MLFlow by clicking the MLFLow app on the top left, then identifying your MLfLow tracking server. Click the 3 dots, and click "Open MLflow"
![execution](../images/5-mlflow-open.png "execution")

Under the correct pipeline & execution identifiers you will be able to see all of the metrics we have captured in MLflow:
![execution](../images/6-mlflow-metrics.png "execution")

## Notebook end

This brings us to the end of the notebook. We have learned how to take the original agent evaluation workflow, and convert it to use Sagemaker pipelines for automation, and MLflow for metric tracking. 