# Serve Mistral 8x22B Instruct v0.1 using Triton Inference Server with TensorRT-LLM

This notebook shows how to serve [mistralai/Mixtral-8x22B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) model in a multi-GPU, multi-node deployment, using [Triton Inference Server](https://github.com/triton-inference-server) with [TensorRT-LLM backend](https://github.com/triton-inference-server/tensorrtllm_backend/tree/main).

## Setup and Imports

In [None]:
! pip install kubernetes
! pip install boto3

In [None]:
import os
import subprocess
import sys

# Set working directory
os.chdir(os.path.expanduser('~/amazon-eks-machine-learning-with-terraform-and-kubeflow'))
print(f"Working directory: {os.getcwd()}")

# Get the src directory
src_dir = os.path.join(os.getcwd(), "src")
sys.path.insert(0, src_dir)

from k8s.utils import (
    wait_for_helm_release_pods,
    wait_for_triton_server,
    find_matching_helm_services
)

# Get notebook directory
notebook_dir = os.path.join(os.getcwd(), 'examples', 'inference', 'triton-inference-server', 
                            'tensorrtllm_backend', 'mistral-8x22b-instruct-v01')
print(f"Notebook directory: {notebook_dir}")

# initialize key variables
release_name = 'triton-server-mistral-8x22b-instruct-v01-trtllm'
namespace = 'kubeflow-user-example-com'
hf_model_id = 'mistralai/Mixtral-8x22B-Instruct-v0.1'

## Step 1: Build and Push Docker Container

In [None]:
import sys
import boto3

# Create a Boto3 session
session = boto3.session.Session()

# Access the region_name attribute to get the current region
current_region = session.region_name

cmd = ['./containers/tritonserver-trtllm/build_tools/build_and_push.sh', current_region]

# Start the subprocess with streaming output
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 
                          text=True, bufsize=1, universal_newlines=True)

# Stream output line by line
for line in process.stdout:
    print(line, end='')  # end='' prevents double newlines
    sys.stdout.flush()   # Force immediate output

# Wait for the process to complete and get the return code
return_code = process.wait()

if return_code != 0:
    print(f"\nProcess exited with return code: {return_code}")
else:
    print("\nProcess completed successfully")

## Step 2: Download Hugging Face Mistral 8x22B Instruct v0.1 Model Weights

**Note:** Set your Hugging Face token below before running cell.

In [None]:
# Replace with your actual Hugging Face token
HF_TOKEN = None
assert HF_TOKEN, "Please set HF_TOKEN"

cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/model-prep/hf-snapshot',
    '--set-json', f'env=[{"name":"HF_MODEL_ID","value":"{hf_model_id}"},{"name":"HF_TOKEN","value":"{HF_TOKEN}"}]',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
# Wait for model download to complete
wait_for_helm_release_pods(release_name, namespace)

In [None]:
# Uninstall the model download job
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

## Step 3: Convert HuggingFace Checkpoint to TensorRT-LLM Checkpoint

In [None]:
cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/data-prep/data-process',
    '-f', f'{notebook_dir}/hf_to_trtllm.yaml',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
# Wait for checkpoint conversion to complete
wait_for_helm_release_pods(release_name, namespace, timeout=7200)

In [None]:
# Uninstall the checkpoint conversion job
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

## Step 4: Build TensorRT-LLM Engine

In [None]:
cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/data-prep/data-process',
    '-f', f'{notebook_dir}/trtllm_engine.yaml',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
# Wait for engine build to complete
wait_for_helm_release_pods(release_name, namespace, timeout=7200)

In [None]:
# Uninstall the engine build job
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

## Step 5: Build Triton Model

In [None]:
cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/data-prep/data-process',
    '-f', f'{notebook_dir}/triton_model.yaml',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
# Wait for Triton model build to complete
wait_for_helm_release_pods(release_name, namespace)

In [None]:
# Uninstall the Triton model build job
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

## Step 6: Launch Triton Server

In [None]:
cmd = [
    'helm', 'install', '--debug', release_name,
    'charts/machine-learning/serving/triton-inference-server-lws',
    '-f', f'{notebook_dir}/triton_server.yaml',
    '-n', namespace
]

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

In [None]:
# Wait for Triton server to be ready
wait_for_triton_server(release_name, namespace)

## Step 7: Check Service Status

In [None]:
# Check service status
services = find_matching_helm_services(release_name, namespace)
for service in services:
    print(f"Service {service.metadata.name} is available.")
    print(f"Service type: {service.spec.type}")
    print(f"Service ports: {service.spec.ports} ")

## Step 8: Stop Service

When you're done with the service, run this cell to clean up resources.

In [None]:
cmd = ['helm', 'uninstall', release_name, '-n', namespace]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)