# Triton Ensembles
In this example we will explore how we can stitch together multiple models using Triton Ensembles. With Triton Ensembles you can orchestrate complex workflows involving various models from the same or different backend engines. We recommend locally testing your logic before building out this pipeline. Here's a bit of a visual for how the architecture looks:

![Arch diagram](triton-ensemble.jpg)

## Directory Setup
Triton also expects a certain folder structure to properly pick up on different artifacts, here's a high-level visual of how this should look for ensembles:

```
- hf_pipeline
  - preprocess
      - 1
          - model.py
      - config.pbtxt
  - torch_classifier
      - 1
          - model.pt
      - config.pbtxt
  - postprocess
      - 1
          - model.py
      - config.pbtxt
  - text_ensemble 
      - 1
          - stubbed file (need for ensemble to be picked upon)
      - config.pbtxt
```

## Notebook Setup
Using a SageMaker g5.4xlarge Classic Notebook Instance, you can use any environment just ensure it comes with Docker installed/setup.

## Generate Torch Model Artifact
This is how we generated the model.pt in the torch_classifier directory, we do not recommend uploading the artifact like we did here, please run with your own model and don't share on Git especially if it's a custom one you are keeping private.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"

class HFWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        out = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return out.logits  # <-- return a Tensor, not a dict


def main():
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
    model.eval()

    wrapped = HFWrapper(model)

    example_input_ids = torch.randint(0, 100, (1, 128))
    example_attention_mask = torch.ones((1, 128), dtype=torch.long)

    traced = torch.jit.trace(
        wrapped,
        (example_input_ids, example_attention_mask)
    )
    traced.save("model.pt")
    print("saved model.pt")
main()

## Build Custom Docker Image
Traditionally you can just use the provided Triton Docker container, but we install some custom dependencies for the Python backend. We extend the base Triton image and install transformers:

```
FROM nvcr.io/nvidia/tritonserver:25.10-py3

# install what your python model needs
RUN pip install --no-cache-dir transformers torch
```

Build Docker image:
```
docker build -t custom-triton .
```

## Start Docker Container
Point towards the path of your model repository artifacts and the image built:

```
docker run --gpus=all --shm-size=4G --rm -p8000:8000 -p8001:8001 -p8002:8002 -v/home/ec2-user/SageMaker/hf_pipeline:/model_repository custom-triton:latest tritonserver --model-repository=/model_repository --exit-on-error=false --log-verbose=1
```

## Sample Model Inference
Can use the Triton Python Client library: https://github.com/triton-inference-server/client/tree/main

In [None]:
import numpy as np
import tritonclient.http as httpclient

# Connect to Triton (adjust host:port if running remotely)
client = httpclient.InferenceServerClient("localhost:8000")

# Prepare your text inputs
texts = [
    "I am SO UPSET!",
    "I am super happy!!",
    "Life is great!"
]

# Triton expects BYTES tensors for string inputs
input_data = np.array([t.encode("utf-8") for t in texts], dtype=object)

# Create an InferInput matching the ensembleâ€™s config.pbtxt
infer_input = httpclient.InferInput("TEXT", [len(texts)], "BYTES")
infer_input.set_data_from_numpy(input_data)

# Perform inference on your ensemble
response = client.infer(
    model_name="text_ensemble",
    inputs=[infer_input]
)

# Retrieve outputs
#   - If your ensemble outputs LABEL (postprocess stage)
#   - Or LOGITS if you skipped postprocess
try:
    labels = response.as_numpy("LABEL")
    print([l.decode("utf-8") for l in labels])
except KeyError:
    logits = response.as_numpy("LOGITS")
    print("Raw logits:\n", logits)