# Serve BAAI Bge Reranker Large using Triton Inference Server on AWS Neuron

This notebook shows how to serve [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) model using [Triton Inference Server](https://github.com/triton-inference-server) on [AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html) with [torch-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html).

## Setup and Imports

In [None]:
! pip install kubernetes
! pip install boto3

In [None]:
import os
import subprocess
import sys

# Set working directory
os.chdir(os.path.expanduser('~/amazon-eks-machine-learning-with-terraform-and-kubeflow'))
print(f"Working directory: {os.getcwd()}")

# Get the src directory
src_dir = os.path.join(os.getcwd(), "src")
sys.path.insert(0, src_dir)

from k8s.utils import (
    wait_for_helm_release_pods,
    wait_for_triton_server,
    find_matching_helm_services
)

# Get notebook directory
notebook_dir = os.path.join(os.getcwd(), 'examples', 'inference', 
    'triton-inference-server', 'python_backend', 'baai-bge-reranker-large-neuron')
print(f"Notebook directory: {notebook_dir}")

# initialize key variables
release_name = 'triton-server-baai-bge-reranker-large-neuron'
namespace = 'kubeflow-user-example-com'
hf_model_id = 'BAAI/bge-reranker-large'

## Step 1: Build and Push Docker Container

Build and push Docker container image to your current AWS region.

In [None]:
import sys
import boto3

# Create a Boto3 session
session = boto3.session.Session()

# Access the region_name attribute to get the current region
current_region = session.region_name

cmd = ['./containers/tritonserver-neuronx/build_tools/build_and_push.sh', current_region]

# Start the subprocess with streaming output
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 
                          text=True, bufsize=1, universal_newlines=True)

# Stream output line by line
for line in process.stdout:
    print(line, end='')  # end='' prevents double newlines
    sys.stdout.flush()   # Force immediate output

# Wait for the process to complete and get the return code
return_code = process.wait()

if return_code != 0:
    print(f"\nProcess exited with return code: {return_code}")
else:
    print("\nProcess completed successfully")

## Step 2: Download Hugging Face BAAI Bge Reranker Large Model Weights

Below we download the Hugging Face model.

In [None]:

cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/model-prep/hf-snapshot',
    '--set-json', f'env=[{{"name":"HF_MODEL_ID","value":"{hf_model_id}"}}]',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
# Wait for model download to complete
wait_for_helm_release_pods(release_name, namespace)

In [None]:
# Uninstall the model download job
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

## Step 3: Launch Triton Server

In [None]:
cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/serving/triton-inference-server',
    '-f', f'{notebook_dir}/triton_server.yaml',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
wait_for_triton_server(release_name, namespace)

## Step 4: Check Service Status

In [None]:
# Check service status
services = find_matching_helm_services(release_name, namespace)
for service in services:
    print(f"Service {service.metadata.name} is available.")
    print(f"Service type: {service.spec.type}")
    print(f"Service ports: {service.spec.ports} ")
    print(f"Run  'kubectl port-forward svc/{release_name} 8000:8000 -n {namespace}' in a separate terminal")

## Step 5: Test the Deployed Reranker Model

**Prerequisites:**
- Run `kubectl port-forward svc/YOUR_SERVICE_NAME 8000:8000 -n YOUR_NAMESPACE` in a separate terminal
- Install required packages: `pip install requests numpy`

In [None]:
# Install additional packages for testing
! pip install requests numpy

In [None]:
import json
import requests
import numpy as np
from typing import List, Dict, Any, Tuple

# Configuration for testing
BASE_URL = "http://localhost:8000"
MODEL_NAME = "baai-bge-reranker-large"  # Update this based on your model deployment

### Check Model is Ready

Below we check Triton Inference server is healthy, and the reranker model is successfully deployed within the server, and is ready.

In [None]:
def check_server_health(base_url: str = BASE_URL) -> bool:
    """Check if the Triton server is healthy and responsive"""
    try:
        health_url = f"{base_url}/v2/health/ready"
        response = requests.get(health_url, timeout=10)
        
        if response.status_code == 200:
            print("‚úì Triton server is healthy and ready")
            return True
        else:
            print(f"‚úó Triton server health check failed: {response.status_code}")
            return False
            
    except requests.exceptions.RequestException as e:
        print(f"‚úó Cannot connect to Triton server: {e}")
        print("\nPlease ensure kubectl port-forward is running:")
        print(f"kubectl port-forward svc/{release_name} 8000:8000 -n {namespace}")
        return False

def check_model_ready(base_url: str = BASE_URL, model_name: str = MODEL_NAME) -> List[str]:
    """Check model is ready"""
    try:
        model_url = f"{base_url}/v2/models/{model_name}/ready"
        response = requests.get(model_url, timeout=10)
        
        if response.status_code == 200:
            print(f"Available model: {model_name}")
            return [model_name]
        else:
            print(f"Failed to check model readiness: {response.status_code}")
            return []
            
    except requests.exceptions.RequestException as e:
        print(f"Cannot check model readiness: {e}")
        return []

# Check server health and model is ready
server_healthy = check_server_health()
if server_healthy:
    available_models = check_model_ready()
else:
    available_models = []

### Define Tests for Reranker Model

Below we define the tests for the reranker model. The reranker takes a query and multiple documents, then returns relevance scores for ranking the documents.

In [None]:
def test_reranker(query: str, documents: List[str], model_name: str = MODEL_NAME, base_url: str = BASE_URL) -> Dict[str, Any]:
    """Test reranker with a query and list of documents"""
    
    # Triton inference endpoint
    url = f"{base_url}/v2/models/{model_name}/infer"
    
    # Prepare the request payload for Triton reranker
    payload = {
        "inputs": [
            {
                "name": "query",
                "shape": [1, 1],  # batch_size=1, single query
                "datatype": "BYTES",
                "data": [query]
            },
            {
                "name": "texts",
                "shape": [1, len(documents)],  # batch_size=1, number of documents
                "datatype": "BYTES",
                "data": documents
            }
        ],
        "outputs": [
            {
                "name": "scores"
            }
        ]
    }
    
    try:
        # Send request to Triton server
        response = requests.post(url, json=payload, timeout=30)
        
        result = {
            "query": query,
            "num_documents": len(documents),
            "documents": documents,
            "status_code": response.status_code,
            "success": response.status_code == 200
        }
        
        if response.status_code == 200:
            response_data = response.json()
            
            # Extract scores from response
            if "outputs" in response_data and len(response_data["outputs"]) > 0:
                scores_output = response_data["outputs"][0]
                scores_data = scores_output["data"]
                scores_shape = scores_output["shape"]
                
                # Convert to numpy array for analysis
                scores_array = np.array(scores_data)
                
                result.update({
                    "scores_shape": scores_shape,
                    "scores_datatype": scores_output.get("datatype", "Unknown"),
                    "scores": scores_array.tolist(),
                    "scores_stats": {
                        "min": float(scores_array.min()),
                        "max": float(scores_array.max()),
                        "mean": float(scores_array.mean()),
                        "std": float(scores_array.std()) if len(scores_array) > 1 else 0.0
                    }
                })
                
                # Create ranked results
                doc_scores = list(zip(documents, scores_array))
                ranked_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)
                
                result["ranked_documents"] = [
                    {
                        "rank": i + 1,
                        "document": doc,
                        "score": float(score),
                        "document_preview": doc[:100] + "..." if len(doc) > 100 else doc
                    }
                    for i, (doc, score) in enumerate(ranked_docs)
                ]
            else:
                result["error"] = "No outputs found in response"
                result["raw_response"] = response_data
        else:
            result["error"] = response.text
            
    except requests.exceptions.RequestException as e:
        result["error"] = f"Request failed: {e}"
    except Exception as e:
        result["error"] = f"Unexpected error: {e}"
    
    return result

# Define test cases for reranker
test_cases = [
    {
        "query": "What is machine learning?",
        "documents": [
            "Machine learning is a subset of artificial intelligence that enables computers to learn and make decisions from data without being explicitly programmed.",
            "The weather today is sunny with a temperature of 75 degrees Fahrenheit.",
            "Machine learning algorithms can be supervised, unsupervised, or reinforcement learning based on the type of data and learning approach.",
            "Cooking pasta requires boiling water and adding salt for flavor.",
            "Deep learning is a specialized branch of machine learning that uses neural networks with multiple layers to process complex patterns in data."
        ]
    },
    {
        "query": "How do pandas eat bamboo?",
        "documents": [
            "Giant pandas have a specialized thumb-like structure that helps them grasp bamboo stalks while eating.",
            "Pandas spend up to 14 hours a day eating bamboo, consuming up to 40 pounds daily to meet their nutritional needs.",
            "The stock market experienced significant volatility last week due to economic uncertainty.",
            "Pandas have strong jaw muscles and flat molars that are perfectly adapted for crushing and grinding tough bamboo fibers.",
            "Python programming language is popular for data science and machine learning applications."
        ]
    },
    {
        "query": "Benefits of renewable energy",
        "documents": [
            "Solar and wind energy are clean, sustainable sources that reduce greenhouse gas emissions and combat climate change.",
            "Fast food restaurants often use processed ingredients that may not be the healthiest option for regular consumption.",
            "Renewable energy sources like hydroelectric, solar, and wind power provide energy independence and reduce reliance on fossil fuels.",
            "The latest smartphone features include improved camera quality and longer battery life for enhanced user experience."
        ]
    }
]

print(f"Testing {len(test_cases)} reranker examples with model: {MODEL_NAME}")
print("=" * 80)

### Run Reranker Tests

Now we run the defined reranker tests to verify the model correctly ranks documents by relevance to the query.

In [None]:
# Test all examples
if server_healthy and available_models:
    results = []
    
    for i, test_case in enumerate(test_cases, 1):
        query = test_case["query"]
        documents = test_case["documents"]
        
        print(f"\nTest {i}/{len(test_cases)}: {query}")
        print(f"Documents to rank: {len(documents)}")
        print("-" * 70)
        
        result = test_reranker(query, documents)
        results.append(result)
        
        if result['success']:
            print(f"‚úì Success - Ranked {result['num_documents']} documents")
            print(f"  Scores shape: {result['scores_shape']}")
            print(f"  Score stats: min={result['scores_stats']['min']:.4f}, max={result['scores_stats']['max']:.4f}, mean={result['scores_stats']['mean']:.4f}")
            
            # Show top 3 ranked documents
            print("\n  Top 3 ranked documents:")
            for rank_info in result['ranked_documents'][:3]:
                print(f"    {rank_info['rank']}. Score: {rank_info['score']:.4f}")
                print(f"       {rank_info['document_preview']}")
                print()
        else:
            print(f"‚úó Failed: {result.get('error', 'Unknown error')}")
    
    # Summary
    successful_tests = sum(1 for r in results if r['success'])
    print(f"\n{'='*80}")
    print(f"Test Summary: {successful_tests}/{len(results)} tests passed")
    
    if successful_tests == len(results):
        print("üéâ All tests passed! Your reranker model is working correctly.")
        print("\nThe model successfully:")
        print("- Processed queries with multiple documents")
        print("- Generated relevance scores for ranking")
        print("- Ranked documents by relevance to the query")
    elif successful_tests > 0:
        print("‚ö†Ô∏è  Some tests passed, but there were failures. Check the errors above.")
    else:
        print("‚ùå All tests failed. Please check your model deployment and configuration.")
        
else:
    print("Cannot run tests - server not healthy or no models available")
    print("\nTroubleshooting:")
    print("1. Ensure your reranker model is deployed via Helm")
    print("2. Check that kubectl port-forward is running:")
    print(f"   kubectl port-forward svc/{release_name} 8000:8000 -n {namespace}")
    print("3. Verify the service is running:")
    print(f"   kubectl get pods -n {namespace}")
    print("4. Check model name matches your deployment configuration")

### Additional Reranker Analysis

Let's run some additional analysis to better understand the reranker's behavior.

In [None]:
# Additional analysis if tests were successful
if server_healthy and available_models and results and all(r['success'] for r in results):
    print("\n" + "="*80)
    print("RERANKER ANALYSIS")
    print("="*80)
    
    for i, (test_case, result) in enumerate(zip(test_cases, results), 1):
        print(f"\nAnalysis {i}: {result['query']}")
        print("-" * 50)
        
        # Show score distribution
        scores = result['scores']
        print(f"Score distribution: {scores}")
        
        # Identify most and least relevant documents
        ranked_docs = result['ranked_documents']
        most_relevant = ranked_docs[0]
        least_relevant = ranked_docs[-1]
        
        print(f"\nMost relevant (Score: {most_relevant['score']:.4f}):")
        print(f"  {most_relevant['document'][:150]}...")
        
        print(f"\nLeast relevant (Score: {least_relevant['score']:.4f}):")
        print(f"  {least_relevant['document'][:150]}...")
        
        # Calculate score spread
        score_spread = most_relevant['score'] - least_relevant['score']
        print(f"\nScore spread: {score_spread:.4f} (higher spread indicates better discrimination)")
    
    print(f"\n{'='*80}")
    print("Reranker model analysis complete!")
    print("The model demonstrates ability to:")
    print("- Distinguish between relevant and irrelevant documents")
    print("- Provide meaningful relevance scores")
    print("- Rank documents in order of relevance to queries")

## Step 6: Stop Service

When you're done with the service, run this cell to clean up resources.

In [None]:
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)