# Embedding Neural Network Router Usage

This notebook demonstrates how to use the Embedding Neural Network for intelligent multimodal model routing. The router uses CLIP embeddings and a trained neural network to select the optimal model based on:
- Query content and complexity
- Multi-turn conversation context
- Cost-based optimization with confidence thresholds


> **Note**: The GitHub repository includes a pre-trained neural network and the weights are stored in `llm-router/src/nat_sfc_router/training/router_artifacts`.  The notebook `2_Embedding_NN_Training.ipynb` re-trains the neural network and over-writes those weights. You can run the usage notebook without running the training notebook to use the existing neural network OR you can run the training notebook and then use this notebook with your neural network.

---

## Prerequisites


Clone the repository, or skip this line if you are using the Brev launchable which already has the source code.

In [None]:
!git clone https://github.com/NVIDIA-AI-Blueprints/llm-router.git
!cd llm-router && git checkout experimental # TODO: Remove once ces-dev is merged to main

In [None]:
import os
from pathlib import Path

# Get current working directory
cwd = Path.cwd()

# Check if we're already in llm-router directory
if cwd.name == 'llm-router':
    print(f"✓ Already in llm-router directory: {cwd}")
else:
    # Check if ../llm-router exists (we're in a subdirectory of parent)
    parent_llm_router = cwd.parent / 'llm-router'
    if parent_llm_router.exists() and parent_llm_router.is_dir():
        os.chdir(parent_llm_router)
        print(f"✓ Changed to llm-router directory: {parent_llm_router}")
    # Check if ./llm-router exists (we're in the parent directory)
    elif (cwd / 'llm-router').exists() and (cwd / 'llm-router').is_dir():
        os.chdir(cwd / 'llm-router')
        print(f"✓ Changed to llm-router directory: {cwd / 'llm-router'}")
    else:
        print(f"⚠ Warning: Could not find llm-router directory. Current directory: {cwd}")
        print("  Continuing with current directory...")

# Verify we can find expected files
if not Path('pyproject.toml').exists():
    print("⚠ Warning: pyproject.toml not found in current directory")
    print(f"  Current directory: {Path.cwd()}")
else:
    print(f"✓ Verified project files in: {Path.cwd()}")

---

## 1. Setup and Imports

First, we set up the Python path and import required libraries.

In [None]:
%pip install uv

In [None]:
!uv pip install .


The following sections walk through the Neural Network Objective Function implementation and demonstrate example usage of the multimodal router.

---

## 2. Neural Network Objective Function

This module provides a production-ready objective function that uses a pre-trained neural network router to intelligently route requests to the best model based on:

1. **Embedding generation** from text and multimodal content
2. **Neural network predictions** with configurable confidence thresholds
3. **Multi-turn conversation context** understanding

The router loads on service startup (before request handling) for optimal performance.

The following cells walk through each component of the objective function.

In [None]:
import json
import torch
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import logging
from functools import lru_cache
import time

from pydantic import Field
from nat.builder.builder import Builder
from nat.builder.function_info import FunctionInfo
from nat.cli.register_workflow import register_function
from nat.data_models.function import FunctionBaseConfig

# Import ModelRouter from training package
from nat_sfc_router.training.model_router import _resolve_router_path
from nat_sfc_router.training import ModelRouter

# Import OpenAI schema for type hints
from nat_sfc_router.schema.openai_chat_request import OpenAIChatRequest

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


### 2.1 Model Routing Configuration

Define the mapping between router output model names and actual pipeline model names. Also configure confidence thresholds for model selection.


In [None]:
# Map router output model names to actual pipeline model names
# The router predicts one of these model names based on embeddings
MODEL_ROUTER_TO_TARGET = {
    # GPT models -> openai/gpt-oss-120b
    'gpt-5-chat': 'gpt-5-chat',
    'gpt-5': 'gpt-5-chat',
    
    # Nemotron VL models -> nvidia/nemotron-nano-12b-v2-vl
    'nemotron-vl': 'nvidia/nemotron-nano-12b-v2-vl',
    'nemotron-nano-12b-v2-vl': 'nvidia/nemotron-nano-12b-v2-vl',
    'nvidia/nemotron-nano-12b-v2-vl': 'nvidia/nemotron-nano-12b-v2-vl',

    # Some versions of the router were trained with Qwen, so map Qwen to nano vl for example purposes,
    'Qwen/Qwen3-VL-8B-Instruct': 'nvidia/nemotron-nano-12b-v2-vl',
    
    # Nemotron models -> nvidia/nvidia-nemotron-nano-9b-v2
    'nemotron-nano-12b-v2-vl': 'nvidia/nvidia-nemotron-nano-9b-v2',
    'nemotron-nano': 'nvidia/nvidia-nemotron-nano-9b-v2',
    'nvidia/nvidia-nemotron-nano-9b-v2': 'nvidia/nvidia-nemotron-nano-9b-v2',
}

# Custom confidence thresholds for model selection
# These thresholds ensure the router only selects a model if confidence is above the threshold
# If the top choice doesn't meet its threshold, the router falls back to the next best model
CUSTOM_THRESHOLDS = {
    'gpt-5-chat': 0.70,
    'nvidia/nemotron-nano-12b-v2-vl': 0.55,
    'nvidia/nvidia-nemotron-nano-9b-v2': 0.50,
}

# Default fallback model (used on errors or unknown routing)
DEFAULT_FALLBACK_MODEL = 'nvidia/nvidia-nemotron-nano-9b-v2'

print(f"Configured {len(MODEL_ROUTER_TO_TARGET)} model mappings")
print(f"Configured thresholds for {len(CUSTOM_THRESHOLDS)} models")
print(f"Default fallback: {DEFAULT_FALLBACK_MODEL}")


### 2.2 Router Loading

The router is loaded once on service startup for optimal performance. This function initializes the ModelRouter with the trained neural network router model, custom confidence thresholds, and device selection (CUDA if available).


In [None]:
# Global router instance (loaded on startup)
_router = None


def _load_router():
    """
    Load the ModelRouter on service startup (before _response_fn).
    This ensures the router is initialized once when the service starts,
    not on every request.
    """
    global _router
    if _router is None:
        logger.info("Loading ModelRouter on service startup...")
        try:
            _router = ModelRouter(
                model_thresholds=CUSTOM_THRESHOLDS,
                device='cuda' if torch.cuda.is_available() else 'cpu',
                verbose=True
            )
            logger.info("✓ ModelRouter loaded successfully on startup")
        except Exception as e:
            logger.error(f"Failed to load ModelRouter: {e}", exc_info=True)
            raise
    return _router


### 2.3 Message Parsing Utilities

This function extracts and formats text and images from multi-turn chat messages. It handles OpenAI API format messages with support for:
- Multiple turns (user, assistant, system)
- Text content (strings and lists)
- Images (data URIs and base64 encoded)
- Both dict and Pydantic object types


In [None]:
def extract_text_and_images_from_messages(
    messages: List[Dict[str, Any]]
) -> Tuple[str, List[str]]:
    """
    Extract and format text and images from multi-turn chat messages.
    
    Args:
        messages: List of message dicts or objects with 'role' and 'content'
    
    Returns:
        Tuple of (combined_text, images_list)
        - combined_text: Full conversation context as a single string
        - images_list: List of image URIs/base64 strings suitable for embedding model
    """
    full_text_parts = []
    images = []
    
    # Convert messages to list in case it's a ValidatorIterator or other iterable
    messages_list = list(messages) if not isinstance(messages, list) else messages
    
    for msg_idx, message in enumerate(messages_list):
        # Handle both dict and object types
        if isinstance(message, dict):
            role = message.get("role", "").upper()
            content = message.get("content", "")
        else:
            # Try to access as object attributes (Pydantic models, etc.)
            role = getattr(message, "role", "").upper()
            content = getattr(message, "content", "")
        
        # Add role prefix for conversation context
        if role:
            full_text_parts.append(f"[{role}]:")
        
        # Handle string content
        if isinstance(content, str):
            if content.strip():
                full_text_parts.append(content)
        
        # Handle list content (multimodal)
        elif isinstance(content, list):
            for part_idx, part in enumerate(content):
                if isinstance(part, dict):
                    # Text items (dict)
                    if part.get("type") == "text":
                        text = part.get("text", "").strip()
                        if text:
                            full_text_parts.append(text)
                    
                    # Image items (dict)
                    elif part.get("type") == "image_url":
                        img_url_obj = part.get("image_url", {})
                        
                        # Handle both dict and string formats
                        if isinstance(img_url_obj, dict):
                            url = img_url_obj.get("url", "").strip()
                        else:
                            url = str(img_url_obj).strip()
                        
                        # Only include valid image URLs
                        if url and isinstance(url, str) and len(url) > 0:
                            images.append(url)
                else:
                    # Handle object types (from Pydantic models)
                    part_type = getattr(part, "type", None)
                    
                    if part_type == "text":
                        text_content = getattr(part, "text", "")
                        if text_content:
                            full_text_parts.append(text_content)
                    
                    elif part_type == "image_url":
                        img_url_obj = getattr(part, "image_url", None)
                        if img_url_obj:
                            # Handle as dict or object
                            if isinstance(img_url_obj, dict):
                                url = img_url_obj.get("url", "").strip()
                            else:
                                url = getattr(img_url_obj, "url", "")
                                if url:
                                    url = url.strip()
                            
                            # Only include valid image URLs
                            if url and isinstance(url, str) and len(url) > 0:
                                images.append(url)
        else:
            # Try to iterate over content in case it's a ValidatorIterator
            try:
                content_list = list(content)
                for part_idx, part in enumerate(content_list):
                    if isinstance(part, dict):
                        # Text items (dict)
                        if part.get("type") == "text":
                            text = part.get("text", "").strip()
                            if text:
                                full_text_parts.append(text)
                        # Image items (dict)
                        elif part.get("type") == "image_url":
                            img_url_obj = part.get("image_url", {})
                            if isinstance(img_url_obj, dict):
                                url = img_url_obj.get("url", "").strip()
                            else:
                                url = str(img_url_obj).strip()
                            if url and isinstance(url, str) and len(url) > 0:
                                images.append(url)
                    else:
                        # Handle object types
                        part_type = getattr(part, "type", None)
                        if part_type == "text":
                            text_content = getattr(part, "text", "")
                            if text_content:
                                full_text_parts.append(text_content)
                        elif part_type == "image_url":
                            img_url_obj = getattr(part, "image_url", None)
                            if img_url_obj:
                                if isinstance(img_url_obj, dict):
                                    url = img_url_obj.get("url", "").strip()
                                else:
                                    url = getattr(img_url_obj, "url", "")
                                    if url:
                                        url = url.strip()
                                if url and isinstance(url, str) and len(url) > 0:
                                    images.append(url)
            except TypeError:
                # Not iterable, log and skip
                logger.debug(
                    f"extract_text_and_images_from_messages - Message {msg_idx}: "
                    f"content is not iterable, skipping"
                )
    
    # Combine all text parts
    full_text = " ".join(full_text_parts).strip()
    
    # Fallback text if empty
    if not full_text:
        logger.warning("No text extracted from messages, using default prompt")
        full_text = "routing query"
    
    logger.info(f"extract_text_and_images_from_messages - Final: text_len={len(full_text)}, images={len(images)}")
    
    return full_text, images


### 2.4 Model Name Mapping

Maps router output model names to target pipeline model names. Includes direct lookup and case-insensitive matching for flexibility.


In [None]:
def map_router_model_to_target(router_model: str) -> str:
    """
    Map router output model name to target pipeline model name.
    
    Args:
        router_model: Model name predicted by router (e.g., 'gpt-5-chat')
    
    Returns:
        Target model name for routing (e.g., 'openai/gpt-oss-120b')
    """
    # Direct lookup
    if router_model in MODEL_ROUTER_TO_TARGET:
        return MODEL_ROUTER_TO_TARGET[router_model]
    
    # Case-insensitive lookup
    lower_model = router_model.lower()
    for key, value in MODEL_ROUTER_TO_TARGET.items():
        if key.lower() == lower_model:
            return value
            
    # Default fallback
    logger.warning(
        f"Unknown router model '{router_model}', using default fallback: {DEFAULT_FALLBACK_MODEL}"
    )
    return DEFAULT_FALLBACK_MODEL


### 2.5 Cost-Based Model Selection

Selects the best model using cost-based routing strategy:
1. Filter models that meet their confidence threshold
2. Among qualified models, select the one with lowest cost factor
3. Falls back to highest probability if no thresholds specified


In [None]:
def select_best_model_by_cost(
    probabilities: Dict[str, float],
    model_thresholds: Optional[Dict[str, float]] = None,
    model_costs: Optional[Dict[str, float]] = None
) -> Tuple[str, str]:
    """
    Select the best model using cost-based routing.
    
    Args:
        probabilities: Dict of model_name -> confidence score
        model_thresholds: Dict of model_name -> minimum confidence threshold
        model_costs: Dict of model_name -> cost factor (lower is better)
    
    Returns:
        Tuple of (selected_model, selection_reason)
    """
    if not probabilities:
        raise ValueError("No probabilities provided")
    
    # Filter models that meet their thresholds
    qualified_models = []
    
    if model_thresholds:
        for model_name, prob in probabilities.items():
            threshold = model_thresholds.get(model_name, 0.0)
            if prob >= threshold:
                qualified_models.append((model_name, prob, threshold))
        
        if qualified_models:
            logger.debug(
                f"Models meeting thresholds: "
                f"{[(m, f'{p:.3f}') for m, p, t in qualified_models]}"
            )
        else:
            logger.warning(
                "No models met their thresholds, falling back to highest probability"
            )
            # Fallback: use highest probability if nothing meets threshold
            best_model = max(probabilities.items(), key=lambda x: x[1])
            return best_model[0], "threshold_fallback"
    else:
        # No thresholds - all models are qualified
        qualified_models = [(m, p, 0.0) for m, p in probabilities.items()]
    
    # If cost factors provided, select by lowest cost among qualified models
    if model_costs:
        # Find the qualified model with lowest cost
        best_model_name = None
        best_cost = float('inf')
        best_prob = 0.0
        
        for model_name, prob, threshold in qualified_models:
            cost = model_costs.get(model_name, float('inf'))
            
            # Prefer lower cost, but use probability as tiebreaker
            if cost < best_cost or (cost == best_cost and prob > best_prob):
                best_model_name = model_name
                best_cost = cost
                best_prob = prob
        
        if best_model_name:
            logger.debug(
                f"Cost-based selection: {best_model_name} "
                f"(cost={best_cost:.3f}, prob={best_prob:.3f})"
            )
            return best_model_name, "cost_optimized"
    
    # Default: select highest probability among qualified models
    best_model = max(qualified_models, key=lambda x: x[1])
    return best_model[0], "highest_probability"


### 2.6 Objective Function Registration

The main objective function that ties everything together. This function:
1. Configures thresholds and costs from the config
2. Loads the router on startup
3. Provides an async response function for routing requests


In [None]:
class NNObjectiveConfig(FunctionBaseConfig, name="nn_objective_fn"):
    """Neural network objective function configuration for model routing.
    
    Attributes:
        model_thresholds: Dict of model_name -> minimum confidence threshold
                         Only routes to models meeting their threshold
        model_costs: Dict of model_name -> cost factor for cost-based selection
                    Among models meeting thresholds, selects lowest cost model
    """
    model_thresholds: Optional[Dict[str, float]] = None
    model_costs: Optional[Dict[str, float]] = None


In [None]:
@register_function(config_type=NNObjectiveConfig)
async def nn_objective_fn(config: NNObjectiveConfig, _builder: Builder):
    """
    Neural network objective function for intelligent cost-optimized model routing.
    
    Uses a pre-trained neural network router with:
    - HuggingFace embedding model (Nemotron VL) for multimodal content encoding
    - Trained router network to predict best model
    - Confidence thresholds (from config) for quality gates
    - Cost-based selection among qualified models
    - Multi-turn conversation context support
    
    Routing Strategy:
    1. Generate embedding for the query
    2. Get confidence scores from neural router
    3. Filter models meeting their confidence threshold
    4. Select lowest-cost model among qualified options
    """
    
    # Extract configuration
    model_thresholds = config.model_thresholds or {}
    model_costs = config.model_costs or {}
    
    logger.info("nn_objective_fn: Configuration loaded")
    logger.info(f"  Thresholds: {model_thresholds}")
    logger.info(f"  Costs: {model_costs}")
    
    # Load router on startup - this happens BEFORE _response_fn
    logger.info("nn_objective_fn: Initializing router...")
    router = _load_router()
    logger.info("nn_objective_fn: Router ready")
    
    async def _response_fn(chat_request: OpenAIChatRequest) -> Tuple[str, Dict[str, float]]:
        """
        Route a chat request to the best model using the neural network router.
        """
        response_start = time.perf_counter()
        
        try:
            # ===== EXTRACT MESSAGES =====
            extract_start = time.perf_counter()
            messages = chat_request.messages

            logger.info(f"Routing with messages: {messages}")
            
            if not messages:
                logger.warning("No messages received, using default fallback model")
                return DEFAULT_FALLBACK_MODEL
            
            # Convert messages to dicts if needed (handle Pydantic models)
            messages_dict = []
            for msg in messages:
                if hasattr(msg, 'model_dump'):
                    messages_dict.append(msg.model_dump())
                elif hasattr(msg, '__dict__'):
                    messages_dict.append(vars(msg))
                elif isinstance(msg, dict):
                    messages_dict.append(msg)
                else:
                    messages_dict.append(dict(msg))
            
            extract_time = time.perf_counter() - extract_start
            logger.debug(f"Extracted {len(messages_dict)} messages in {extract_time*1000:.2f}ms")
            
            # ===== PARSE TEXT AND IMAGES =====
            parse_start = time.perf_counter()
            full_text, images = extract_text_and_images_from_messages(messages_dict)
            parse_time = time.perf_counter() - parse_start
            
            logger.debug(
                f"Parsed {len(images)} images, "
                f"text: {len(full_text)} chars, "
                f"time: {parse_time*1000:.2f}ms"
            )

            logger.info(f"Routing with full text: {full_text} and number of images: {len(images)}")
            
            # ===== ROUTE USING NEURAL NETWORK =====
            route_start = time.perf_counter()
            
            # Get probabilities from router (without thresholds - raw scores)
            # Use async version to avoid event loop conflicts with CLIP client
            embedding = await router.generate_embedding_async(full_text, images)
            embedding_2d = embedding.reshape(1, -1)
            
            # Get raw probabilities from router
            router.router_model.eval()
            with torch.no_grad():
                proba = router.router_model(torch.FloatTensor(embedding_2d).to(router.device))
                proba = proba.cpu().numpy()
            
            # Build probability dict
            probabilities = {
                model: float(proba[0, i])
                for i, model in enumerate(router.model_names)
            }
            
            route_time = time.perf_counter() - route_start
            
            # ===== COST-BASED SELECTION =====
            cost_select_start = time.perf_counter()
            
            router_model, selection_reason = select_best_model_by_cost(
                probabilities=probabilities,
                model_thresholds=model_thresholds,
                model_costs=model_costs
            )
            
            cost_select_time = time.perf_counter() - cost_select_start
            confidence = probabilities.get(router_model, 0.0)
            
            logger.info(
                f"Routing decision | "
                f"Model: {router_model} | "
                f"Confidence: {confidence:.3f} | "
                f"Selection: {selection_reason} | "
                f"Probabilities: {{{', '.join(f'{m}: {p:.3f}' for m, p in probabilities.items())}}} | "
                f"Route+Select time: {route_time*1000 + cost_select_time*1000:.2f}ms"
            )
            
            # ===== MAP TO TARGET MODEL =====
            map_start = time.perf_counter()
            target_model = map_router_model_to_target(router_model)
            map_time = time.perf_counter() - map_start
            
            total_time = time.perf_counter() - response_start
            
            logger.info(
                f"Final routing | "
                f"Router model: {router_model} -> Target: {target_model} | "
                f"Total time: {total_time*1000:.2f}ms"
            )
            
            return target_model, probabilities
        
        except Exception as e:
            logger.error(f"Error in nn_objective_fn routing: {e}", exc_info=True)
            logger.warning(f"Using default fallback model: {DEFAULT_FALLBACK_MODEL}")
            return DEFAULT_FALLBACK_MODEL
    
    yield FunctionInfo.from_fn(
        _response_fn,
        description="Neural network objective function for intelligent model routing using embeddings and trained router."
    )


---

## 3. Example Usage

The following cells demonstrate how to use the ModelRouter directly for multimodal routing. This shows the core functionality without the full objective function integration.


### 3.1 Load the Router

Initialize the ModelRouter with the trained neural network. This requires a running CLIP server for embedding generation.


In [None]:
# =====================================================
# PREREQUISITE: Deploy CLIP server with Docker (Linux)
# =====================================================
# Run this command in your terminal BEFORE proceeding with the notebook.

!docker run -d --rm \
    --name clip_server \
    --gpus all \
    -p 51000:51000 \
    jinaai/clip-as-service:latest

In [None]:
# Load the router (requires CLIP server to be running)
# To start a CLIP server: docker run -p 51000:51000 jinaai/clip-as-service:latest
router = _load_router()
print(f"Router loaded with models: {router.model_names}")


### 3.2 Text-Only Routing Example

Route a simple text query to determine the best model.


In [None]:
# Example: Text-only routing
text_query = "What is the capital of France?"
selected_model = router.route(text_query)
print(f"Query: {text_query}")
print(f"Selected model: {selected_model}")


### 3.3 Multi-Turn Conversation Example

Demonstrate routing with a multi-turn conversation using the message parsing utilities.


In [None]:
# Example: Multi-turn conversation routing
messages = [
    {"role": "user", "content": "I'm working on a machine learning project."},
    {"role": "assistant", "content": "That sounds interesting! What kind of ML project are you working on?"},
    {"role": "user", "content": "I'm building a neural network for image classification. Can you help me understand backpropagation?"}
]

# Extract text from messages
full_text, images = extract_text_and_images_from_messages(messages)
print(f"Extracted text: {full_text[:200]}...")
print(f"Number of images: {len(images)}")

# Route the conversation
selected_model = router.route(full_text, images=images if images else None)
print(f"Selected model: {selected_model}")


### 3.4 Cost-Based Selection Example

Demonstrate how cost-based selection works with custom thresholds and cost factors.


In [None]:
# Example: Cost-based selection with simulated probabilities
simulated_probabilities = {
    'gpt-5-chat': 0.75,
    'nvidia/nemotron-nano-12b-v2-vl': 0.60,
    'nvidia/nvidia-nemotron-nano-9b-v2': 0.55,
}

# Define cost factors (lower = cheaper)
model_costs = {
    'gpt-5-chat': 1.0,  # Most expensive
    'nvidia/nemotron-nano-12b-v2-vl': 0.3,  # Mid-tier
    'nvidia/nvidia-nemotron-nano-9b-v2': 0.1,  # Cheapest
}

# Select best model with cost optimization
selected_model, reason = select_best_model_by_cost(
    probabilities=simulated_probabilities,
    model_thresholds=CUSTOM_THRESHOLDS,
    model_costs=model_costs
)

print(f"Probabilities: {simulated_probabilities}")
print(f"Thresholds: {CUSTOM_THRESHOLDS}")
print(f"Cost factors: {model_costs}")
print(f"Selected model: {selected_model}")
print(f"Selection reason: {reason}")


---

## 4. Async Usage with `nn_objective_fn`

The `nn_objective_fn` is designed for **async environments** with the NAT framework. This section demonstrates how to use it directly for testing and development.

Key differences from the direct `ModelRouter` usage:
- **Async interface**: Uses `async/await` for non-blocking operation
- **OpenAI-compatible**: Accepts `OpenAIChatRequest` objects (OpenAI chat completion format)
- **Framework integration**: Designed to work with NAT's `Builder` and `FunctionInfo` system
- **Configuration-driven**: Uses `NNObjectiveConfig` for thresholds and costs


### 4.1 Create a Helper to Run the Objective Function

Since `nn_objective_fn` is an async generator designed for the NAT framework, we need a helper to extract and use the response function directly.


In [None]:
import asyncio

async def get_objective_response_fn(config: NNObjectiveConfig):
    """
    Helper to extract the response function from nn_objective_fn.
    
    In production, the NAT framework handles this automatically.
    This helper allows us to test the objective function directly.
    """
    # The @register_function decorator uses asynccontextmanager
    # which yields the FunctionInfo directly when entering the context
    async with nn_objective_fn(config, _builder=None) as function_info:
        # FunctionInfo has single_fn (regular async) and stream_fn (async generator)
        # Our _response_fn is a regular async function, so use single_fn
        return function_info.single_fn

print("Helper function created for testing nn_objective_fn")


### 4.2 Route Using OpenAI Chat Request Format

Create an `OpenAIChatRequest` and route it using the objective function. This demonstrates the production API format.


In [None]:
async def demo_async_routing():
    """Demonstrate async routing with nn_objective_fn"""
    
    # 1. Create configuration with custom thresholds and costs
    config = NNObjectiveConfig(
        model_thresholds={
            'gpt-5-chat': 0.70,
            'nvidia/nemotron-nano-12b-v2-vl': 0.55,
            'nvidia/nvidia-nemotron-nano-9b-v2': 0.50,
        },
        model_costs={
            'gpt-5-chat': 1.0,              # Most expensive
            'nvidia/nemotron-nano-12b-v2-vl': 0.3,  # Mid-tier
            'nvidia/nvidia-nemotron-nano-9b-v2': 0.1,  # Cheapest
        }
    )
    
    print("Configuration:")
    print(f"  Thresholds: {config.model_thresholds}")
    print(f"  Costs: {config.model_costs}")
    print()
    
    # 2. Get the response function from the objective
    print("Initializing objective function...")
    response_fn = await get_objective_response_fn(config)
    print("✓ Objective function ready\n")
    
    # 3. Create OpenAI-format chat requests
    test_cases = [
        {
            "name": "Simple question",
            "messages": [
                {"role": "user", "content": "What is the capital of France?"}
            ]
        },
        {
            "name": "Multi-turn conversation",
            "messages": [
                {"role": "system", "content": "You are a helpful AI assistant."},
                {"role": "user", "content": "I need help with machine learning."},
                {"role": "assistant", "content": "I'd be happy to help! What specific aspect of ML?"},
                {"role": "user", "content": "Explain backpropagation in neural networks."}
            ]
        },
        {
            "name": "Complex reasoning",
            "messages": [
                {"role": "user", "content": "Analyze the philosophical implications of artificial general intelligence on human society, considering economic, ethical, and existential perspectives."}
            ]
        }
    ]
    
    # 4. Route each request
    print("=" * 70)
    print("Routing Results")
    print("=" * 70)
    
    for test in test_cases:
        # Create OpenAIChatRequest
        chat_request = OpenAIChatRequest(
            model="router",  # Placeholder - router will select actual model
            messages=test["messages"]
        )
        
        # Route the request
        result = await response_fn(chat_request)
        
        # Handle result (can be tuple of (model, probs) or just model string)
        if isinstance(result, tuple):
            target_model, probabilities = result
            print(f"\n{test['name']}:")
            print(f"  → Selected model: {target_model}")
            print(f"  → Probabilities: {', '.join(f'{m}: {p:.3f}' for m, p in probabilities.items())}")
        else:
            print(f"\n{test['name']}:")
            print(f"  → Selected model: {result}")
    
    print("\n" + "=" * 70)
    print("Async routing demo complete!")
    print("=" * 70)

# Run the async demo
await demo_async_routing()


### 4.3 Multimodal Routing with Images

The objective function also supports multimodal requests with images using the OpenAI vision format.


In [None]:
async def demo_multimodal_routing():
    """Demonstrate multimodal routing with images"""
    
    # Create configuration
    config = NNObjectiveConfig(
        model_thresholds=CUSTOM_THRESHOLDS
    )
    
    # Get response function (reuses loaded router)
    response_fn = await get_objective_response_fn(config)
    
    # Create a multimodal request with image (OpenAI vision format)
    # In production, this would be a real base64 image
    multimodal_request = OpenAIChatRequest(
        model="router",
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "What objects do you see in this image? Describe them in detail."
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
                        }
                    }
                ]
            }
        ]
    )
    
    print("Multimodal Request (with image):")
    print("  Text: 'What objects do you see in this image?'")
    print("  Image: [base64 encoded image]")
    print()
    
    # Route the multimodal request
    result = await response_fn(multimodal_request)
    
    if isinstance(result, tuple):
        target_model, probabilities = result
        print(f"Selected model: {target_model}")
        print(f"Probabilities:")
        for model, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
            print(f"  {model}: {prob:.3f}")
    else:
        print(f"Selected model: {result}")

# Run multimodal demo
await demo_multimodal_routing()
