# Serve multiple LoRA adapters efficiently on SageMaker w/ LoRAX

In this tutorial, we will learn how to serve many Low-Rank Adapters (LoRA) on top of the same base model efficiently on the same GPU. In order to do this, we'll deploy the LoRA Exchange ([LoRAX](https://github.com/predibase/lorax/tree/main)) inference server to SageMaker Hosting. 

These are the steps we will take:

1. [Setup our environment](#setup)
2. [Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR](#container)
3. [Download adapters from the HuggingFace Hub and upload them to S3](#download_adapter)
4. [Deploy the extended LoRAX container to SageMaker](#deploy)
5. [Compare outputs of the base model and the adapter model](#compare)
6. [Benchmark our deployed endpoint under different traffic patterns - same adapter, and random access to many adapters](#benchmark)


## What is LoRAX? 

LoRAX is a production-ready framework specialized in multi-adapter serving that  efficiently share the same GPU resources, which dramatically reduces the cost of serving without compromising on throughput or latency. Some of the features that enable this are: 

* Dynamic Adapter Loading - fine-tuned LoRA weights are loaded from storage (local or remote) just-in-time as requests come in at runtime
* Tiered Weight Caching - fast exchanging of LoRA adapters between requests, and offloading of adapter weights to CPU and disk as they are not needed to avoid out-of-memory errors.
* Continuous Multi-Adapter Batching - a fair scheduling policy that continuously batches requests targeted at different LoRA adapters so they can be processed in paralle, optimizing aggregate throughput.
* Optimized Inference - high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming.

You can read more about LoRAX [here](https://predibase.com/blog/lora-exchange-lorax-serve-100s-of-fine-tuned-llms-for-the-cost-of-one).

<a id="setup"></a>
## Setup our environment 

In [2]:
!pip install -U boto3 sagemaker huggingface_hub --quiet

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
aiobotocore 2.12.2 requires botocore<1.34.52,>=1.34.41, but you have botocore 1.34.139 which is incompatible.
autogluon-common 0.8.3 requires pandas<1.6,>=1.4.1, but you have pandas 2.1.4 which is incompatible.
autogluon-core 0.8.3 requires pandas<1.6,>=1.4.1, but you have pandas 2.1.4 which is incompatible.
autogluon-core 0.8.3 requires scikit-learn<1.4.1,>=1.1, but you have scikit-learn 1.4.2 which is incompatible.
autogluon-features 0.8.3 requires pandas<1.6,>=1.4.1, but you have pandas 2.1.4 which is incompatible.
autogluon-features 0.8.3 requires scikit-learn<1.4.1,>=1.1, but you have scikit-learn 1.4.2 which is incompatible.
autogluon-multimodal 0.8.3 requires pandas<1.6,>=1.4.1, but you have pandas 2.1.4 which is incompatible.
autogluon-multimodal 0.8.3 requires pytorch-lightning<1.10.0,>=1.9.0, but yo

In [3]:
import sagemaker
import boto3
sess = sagemaker.Session()

# sagemaker session bucket -> used for uploading data, models and logs
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
region = sess.boto_region_name

print(f"sagemaker default S3 bucket: {sagemaker_session_bucket}")
print(f"sagemaker role arn: {role}")
print(f"sagemaker session region: {region}")

sagemaker default S3 bucket: sagemaker-us-east-1-626723862963
sagemaker role arn: arn:aws:iam::626723862963:role/service-role/AmazonSageMaker-ExecutionRole-20231214T145077
sagemaker session region: us-east-1


## Activating Docker in Jupyterlab

Make sure to install docker in Sagemaker Studio Jupyterlab. This can also be run in terminal (File - New - Terminal)

In [14]:
%%bash
bash setup_docker.sh

Reading package lists...
Building dependency tree...
Reading state information...
ca-certificates is already the newest version (20230311ubuntu0.22.04.1).
curl is already the newest version (7.81.0-1ubuntu1.16).
gnupg is already the newest version (2.2.27-3ubuntu2.1).
0 upgraded, 0 newly installed, 0 to remove and 12 not upgraded.


gpg: cannot open '/dev/tty': No such device or address
curl: (23) Failure writing output to destination


Hit:1 https://download.docker.com/linux/ubuntu jammy InRelease
Get:2 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Fetched 257 kB in 1s (414 kB/s)
Reading package lists...
Reading package lists...
Building dependency tree...
Reading state information...
docker-ce-cli is already the newest version (5:20.10.24~3-0~ubuntu-jammy).
docker-compose-plugin is already the newest version (2.28.1-1~ubuntu.22.04~jammy).
0 upgraded, 0 newly installed, 0 to remove and 11 not upgraded.
Client: Docker Engine - Community
 Version:           20.10.24
 API version:       1.41
 Go version:        go1.19.7
 Git commit:        297e128
 Built:             Tue Apr  4 18:21:03 2023
 OS/Arch:           linux/amd64
 Context:           default
 Experimental:      true

Server:
 Engine:
  Version:       

<a id="container"></a>
## Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR

This example includes a `Dockerfile` and `sagemaker_entrypoint.sh` in the `sagemaker_lorax` directory. Building this new container image makes LoRAX compatible with SageMaker Hosting, namely launching the server on port 8080 via the container's `ENTRYPOINT` instruction. [Here](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-run-image) you can find the basic interfaces required to adapt any container for deployment on Sagemaker Hosting.

In [15]:
!cat sagemaker_lorax/sagemaker_entrypoint.sh

#!/bin/bash

if [[ -z "${HF_MODEL_ID}" ]]; then
  echo "HF_MODEL_ID must be set"
  exit 1
fi
export MODEL_ID="${HF_MODEL_ID}"

if [[ -n "${HF_MODEL_REVISION}" ]]; then
  export REVISION="${HF_MODEL_REVISION}"
fi

if [[ -n "${SM_NUM_GPUS}" ]]; then
  export NUM_SHARD="${SM_NUM_GPUS}"
fi

if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then
  export QUANTIZE="${HF_MODEL_QUANTIZE}"
fi

if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then
  export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}"
fi

if [[ -z "${ADAPTER_BUCKET}" ]]; then
else
  export PREDIBASE_MODEL_BUCKET="${ADAPTER_BUCKET}"
fi

lorax-launcher --port 8080


In [16]:
!cat sagemaker_lorax/Dockerfile

ARG VERSION
FROM ghcr.io/predibase/lorax:$VERSION

COPY sagemaker_entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]


We build the new container image and push it to a new ECR repository. Note SageMaker [supports private Docker registries](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-containers-inference-private.html) as well.

In [17]:
%%bash -s {region}
algorithm_name="sagemaker-lorax"  # name of your algorithm
tag="0.8.0"
region=$1

account=$(aws sts get-caller-identity --query Account --output text)

image_uri="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:${tag}"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" --region $region > /dev/null
fi

cd sagemaker_lorax/ && docker build --network=sagemaker --build-arg VERSION=$tag -t ${algorithm_name}:${tag} .

# Authenticate Docker to an Amazon ECR registry
aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin ${account}.dkr.ecr.${region}.amazonaws.com

# Tag the image
docker tag ${algorithm_name}:${tag} ${image_uri}

# Push the image to the repository
docker push ${image_uri}

# Save image name to tmp file to use when deploying endpoint
echo $image_uri > /tmp/image_uri

Sending build context to Docker daemon  6.656kB
Step 1/6 : ARG VERSION
Step 2/6 : FROM ghcr.io/predibase/lorax:$VERSION
 ---> 12cbb858e917
Step 3/6 : COPY sagemaker_entrypoint.sh entrypoint.sh
 ---> Using cache
 ---> d022857f458a
Step 4/6 : RUN chmod +x entrypoint.sh
 ---> Using cache
 ---> 5eeca0aab381
Step 5/6 : ENTRYPOINT ["./entrypoint.sh"]
 ---> Using cache
 ---> 6b18cfc543a7
Step 6/6 : LABEL com.amazon.studio.user.resources=true
 ---> Using cache
 ---> 22f4f5d50a04
Successfully built 22f4f5d50a04
Successfully tagged sagemaker-lorax:0.8.0


https://docs.docker.com/engine/reference/commandline/login/#credentials-store



Login Succeeded
The push refers to repository [626723862963.dkr.ecr.us-east-1.amazonaws.com/sagemaker-lorax]
03630b8fa156: Preparing
cb150628e6cc: Preparing
1d0b41ccda6f: Preparing
5f70bf18a086: Preparing
6f8902c56d3a: Preparing
8d781d80cb61: Preparing
a789d8b3e5c9: Preparing
4266187e4a78: Preparing
1d8eb7c54d2b: Preparing
f0ae60d5fc16: Preparing
9ea406769079: Preparing
dd129f32cbce: Preparing
5f70bf18a086: Preparing
4bf8ae4ca4b4: Preparing
dd2ba8e514ec: Preparing
56de09d53165: Preparing
0ed8bcd23002: Preparing
3e9d873bbf45: Preparing
4de6133ecfe4: Preparing
c6ea34fb28c0: Preparing
ea2f174ab3cb: Preparing
0c395af0c0dc: Preparing
62ef9850694f: Preparing
3eab3f5945eb: Preparing
a80e63203a09: Preparing
e629e1f35af9: Preparing
75b355150569: Preparing
06a3bfc021d0: Preparing
4194045cdf23: Preparing
2952ab0c67d1: Preparing
b68332006944: Preparing
5f70bf18a086: Preparing
498bbcc60d01: Preparing
c0e21dcee623: Preparing
d6b19a46b795: Preparing
e6c05e83c163: Preparing
256d88da4185: Preparing
426

<a id="download_adapter"></a>
## Download adapter from HuggingFace Hub and push it to S3

We are going to simulate storing our adapter weights on S3, and having LoRAX load them dynamically as we invoke them. This enables most scenarios, including deployment after you’ve finetuned your own adapter and pushed it to S3, as well as securing deployments with no internet access inside your VPC, as detailed in this [blog post](https://www.philschmid.de/sagemaker-llm-vpc#2-upload-the-model-to-amazon-s3).

We first download an adapter trained with Mistral Instruct v0.1 as the base model to a local directory. This particular adapter was trained on GSM8K, a grade school math dataset.

In [18]:
from pathlib import Path
from huggingface_hub import snapshot_download

HF_MODEL_ID = "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k"
# create model dir
model_dir = Path('mistral-adapter')
model_dir.mkdir(exist_ok=True)

# Download model from Hugging Face into model_dir
snapshot_download(
    HF_MODEL_ID,
    local_dir=str(model_dir), # download to model dir
    local_dir_use_symlinks=False, # use no symlinks to save disk space
    revision="main", # use a specific revision, e.g. refs/pr/21
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

'/home/sagemaker-user/notebooks/sagemaker/01_multi_adapter_hosting_sagemaker_lorax/mistral-adapter'

We copy this same adapter `n_adapters` times to different S3 prefixes in our SageMaker session bucket, simulating a large number of adapters we want to serve on the same endpoint and underlying GPU.

In [19]:
import os

s3 = boto3.client('s3')

def upload_folder_to_s3(local_path, s3_bucket, s3_prefix):
    for root, dirs, files in os.walk(local_path):
        for file in files:
            local_file_path = os.path.join(root, file)
            s3_object_key = os.path.join(s3_prefix, os.path.relpath(local_file_path, local_path))
            s3.upload_file(local_file_path, s3_bucket, s3_object_key)

# Upload the folder n_adapters times under different prefixes
n_adapters=50
base_prefix = 'lorax/mistral-adapters'
for i in range(1, n_adapters+1):
    prefix = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
    upload_folder_to_s3(model_dir, sagemaker_session_bucket, prefix)
    print(f'Uploaded folder to S3 with prefix: {prefix}, bucket: {sagemaker_session_bucket}')

Uploaded folder to S3 with prefix: lorax/mistral-adapters/01, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/02, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/03, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/04, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/05, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/06, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/07, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/08, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/09, bucket: sagemaker-us-east-1-626723862963
Uploaded folder to S3 with prefix: lorax/mistral-adapters/10, bucket: sag

<a id="deploy"></a>
## Deploy SageMaker endpoint


Now we deploy a SageMaker endpoint, pointing to our SageMaker session bucket as the ADAPTER_BUCKET env variable, which enables downloading adapters from S3.

In [20]:
pip install --upgrade huggingface_hub

Note: you may need to restart the kernel to use updated packages.


In [21]:
from huggingface_hub import login
login()
from pathlib import Path
hf_token = Path("/home/sagemaker-user/.cache/huggingface/token").read_text()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [22]:
import json
import datetime

from sagemaker import Model
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# Retrieve image_uri from tmp file
image_uri = !cat /tmp/image_uri
# Increased health check timeout to give time for model download
health_check_timeout = 800
number_of_gpu = 1
instance_type = "ml.g5.xlarge"
endpoint_name = sagemaker.utils.name_from_base("sm-lorax")

# Model and Endpoint configuration parameters
config = {
  'HF_MODEL_ID': "mistralai/Mistral-7B-Instruct-v0.1", # model_id from hf.co/models
  'HUGGING_FACE_HUB_TOKEN': hf_token,
  'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(1024),  # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(4096),  # Max length of the generation (including input text)
  'ADAPTER_BUCKET': sagemaker_session_bucket,
}

lorax_model = Model(
    image_uri=image_uri[0],
    role=role,
    env=config
)


-------------!

In [None]:

lorax_predictor = lorax_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout, 
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

## Start Inference

In [12]:
# You can reinstantiate the Predictor object if you restart the notebook or Predictor is None
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
n_adapters=50
base_prefix = 'lorax/mistral-adapters'
endpoint_name = "sm-lorax-2024-06-30-14-57-07-207"


lorax_predictor = Predictor(
    endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

<a id="compare"></a>
## Invoke base model and adapter, compare outputs

We can invoke the base Mistral model, as well as any of the adapters in our bucket! LoRAX will take care of downloading them, continuously batch requests for different adapters, and manage DRAM and RAM by loading/offloading adapters.

Let’s inspect the difference between the base model’s response and the adapter’s response:

<div class="alert alert-block alert-info">
⚠️ I observed a weird error that I haven't debugged yet, where S3 download failed for adapters ID 1 through 5, but worked as expected for all other adapters. Something with the S3 prefix. Added 0 before adapter id if id < 10 as a workaround.
</div>

In [13]:
prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'

payload_base = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
    }
}

payload_adapter = {
"inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
        "adapter_id": f'{base_prefix}/01',
        "adapter_source": "s3"
    }
}

response_base = lorax_predictor.predict(payload_base)
response_adapter = lorax_predictor.predict(payload_adapter)

print(f'Base model output:\n-------------\n {response_base[0]["generated_text"]}')
print(f'Adapter output:\n-------------\n {response_adapter[0]["generated_text"]}')

Base model output:
-------------
 Let's break down the problem:

1. In April, Natalia sold clips to 48 of her friends.
2. In May, she sold half as many clips as in April, which means she sold 48/2 = 24 clips in May.

Adapter output:
-------------
 Natalia sold 48/2 = <<48/2=24>>24 clips in May.
In total, Natalia sold 48 + 24 = <<48+24=72>>72 clips in April and May.
#### 72


<a id="benchmark"></a>
## Benchmark single adapter vs. random access to adapters



First, we individually call each of the adapters in sequence, to make sure they are previously downloaded to the endpoint instance’s disk. We want to exclude S3 download latency from the benchmark metrics.

In [26]:
from tqdm import tqdm

for i in tqdm(range(1,n_adapters+1)):
    adapter_id = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
    payload_adapter = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
        "adapter_id": adapter_id,
        "adapter_source": "s3"
        }
    }
    lorax_predictor.predict(payload_adapter)

100%|██████████| 50/50 [01:57<00:00,  2.35s/it]


Now we are ready to benchmark. For the single adapter case, we invoke the adapter `total_requests` times from `num_threads` concurrent clients.

For the multi-adapter case, we invoke a random adapter from any of the clients, until all adapters have been invoked `total_requests//num_adapters` times.

In [27]:
# Adjust if you run into connection pool errors
# import botocore

# Configure botocore to use a larger connection pool
# config = botocore.config.Config(max_pool_connections=100)

In [28]:
import threading
import time
import random


# Configuration
total_requests = 500
num_adapters = 50
num_threads = 20  # Adjust based on your system capabilities


# Shared lock and counters for # invocations of each adapter 
adapter_counters = [total_requests // num_adapters] * num_adapters
counters_lock = threading.Lock()

def invoke_adapter(aggregate_latency, single_adapter=False):
    global total_requests
    latencies = []
    while True:
        with counters_lock:
            if single_adapter:
                adapter_id = 1
                if total_requests > 0:
                    total_requests -= 1
                else:
                    break
            else:
                # Find an adapter that still needs to be called
                remaining_adapters = [i for i, count in enumerate(adapter_counters) if count > 0]
                if not remaining_adapters:
                    break
                adapter_id = random.choice(remaining_adapters) + 1
                adapter_counters[adapter_id - 1] -= 1

        prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'
        invoke_adapter_id = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
        payload_adapter = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 64,
                "adapter_id": invoke_adapter_id,
                "adapter_source": "s3"
            }
        }
        start_time = time.time()
        response_adapter = lorax_predictor.predict(payload_adapter)
        latency = time.time() - start_time
        latencies.append(latency)

    aggregate_latency.extend(latencies)

def benchmark_scenario(single_adapter=False):
    threads = []
    all_latencies = []
    start_time = time.time()

    for _ in range(num_threads):
        thread_latencies = []
        all_latencies.append(thread_latencies)
        thread = threading.Thread(target=invoke_adapter, args=(thread_latencies, single_adapter))
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    total_latency = sum([sum(latencies) for latencies in all_latencies])
    total_requests_made = sum([len(latencies) for latencies in all_latencies])
    average_latency = total_latency / total_requests_made
    throughput = total_requests_made / (time.time() - start_time)

    print(f"Total Time: {time.time() - start_time}s")
    print(f"Average Latency: {average_latency}s")
    print(f"Throughput: {throughput} requests/s")

# Run benchmarks
print("Benchmarking: Single Adapter Multiple Times")
benchmark_scenario(single_adapter=True)

print("\nBenchmarking: Multiple Adapters with Random Access")
benchmark_scenario()


Benchmarking: Single Adapter Multiple Times
Total Time: 281.5893428325653s
Average Latency: 2.8096318435668945s
Throughput: 7.102541547400547 requests/s

Benchmarking: Multiple Adapters with Random Access
Total Time: 279.85398745536804s
Average Latency: 2.7977559057474135s
Throughput: 7.146583920425574 requests/s


<a id="cleanup"></a>
## Cleanup endpoint resources

In [None]:
lorax_predictor.delete_endpoint()