# Serve multiple LoRA adapters efficiently on SageMaker

In this tutorial, we will learn how to serve many Low-Rank Adapters (LoRA) on top of the same base model efficiently on the same GPU. In order to do this, we'll deploy the LoRA Exchange ([LoRAX](https://github.com/predibase/lorax/tree/main)) inference server to SageMaker Hosting. 

These are the steps we will take:

1. [Setup our environment](#setup)
2. [Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR](#container)
3. [Download adapters from the HuggingFace Hub and upload them to S3](#download_adapter)
4. [Deploy the extended LoRAX container to SageMaker](#deploy)
5. [Compare outputs of the base model and the adapter model](#compare)
6. [Benchmark our deployed endpoint under different traffic patterns - same adapter, and random access to many adapters](#benchmark)


## What is LoRAX? 

LoRAX is a production-ready framework specialized in multi-adapter serving that  efficiently share the same GPU resources, which dramatically reduces the cost of serving without compromising on throughput or latency. Some of the features that enable this are: 

* Dynamic Adapter Loading - fine-tuned LoRA weights are loaded from storage (local or remote) just-in-time as requests come in at runtime
* Tiered Weight Caching - fast exchanging of LoRA adapters between requests, and offloading of adapter weights to CPU and disk as they are not needed to avoid out-of-memory errors.
* Continuous Multi-Adapter Batching - a fair scheduling policy that continuously batches requests targeted at different LoRA adapters so they can be processed in paralle, optimizing aggregate throughput.
* Optimized Inference - high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming.

You can read more about LoRAX [here](https://predibase.com/blog/lora-exchange-lorax-serve-100s-of-fine-tuned-llms-for-the-cost-of-one).

#### Can we teach a new language to the Large Model ?

<a id="setup"></a>
## 1. Setup our environment 

In [None]:
!pip install -U boto3 sagemaker huggingface_hub --quiet

In [1]:
import sagemaker
import boto3
sess = sagemaker.Session()

# sagemaker session bucket -> used for uploading data, models and logs
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
region = sess._region_name 

print(f"sagemaker role arn: {role}")
print(f"sagemaker session region: {region}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker role arn: arn:aws:iam::622343165275:role/SagemakerEMRNoAuthProductWi-SageMakerExecutionRole-405QXR1USJDE
sagemaker session region: us-east-1


<a id="container"></a>
## 2. Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR

This example includes a `Dockerfile` and `sagemaker_entrypoint.sh` in the `sagemaker_lorax` directory. Building this new container image makes LoRAX compatible with SageMaker Hosting, namely launching the server on port 8080 via the container's `ENTRYPOINT` instruction. [Here](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-run-image) you can find the basic interfaces required to adapt any container for deployment on Sagemaker Hosting.

In [2]:
!cat sagemaker_lorax/sagemaker_entrypoint.sh

#!/bin/bash

if [[ -z "${HF_MODEL_ID}" ]]; then
  echo "HF_MODEL_ID must be set"
  exit 1
fi
export MODEL_ID="${HF_MODEL_ID}"

if [[ -n "${HF_MODEL_REVISION}" ]]; then
  export REVISION="${HF_MODEL_REVISION}"
fi

if [[ -n "${SM_NUM_GPUS}" ]]; then
  export NUM_SHARD="${SM_NUM_GPUS}"
fi

if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then
  export QUANTIZE="${HF_MODEL_QUANTIZE}"
fi

if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then
  export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}"
fi

if [[ -z "${ADAPTER_BUCKET}" ]]; then
else
  export PREDIBASE_MODEL_BUCKET="${ADAPTER_BUCKET}"
fi

lorax-launcher --port 8080


In [3]:
!cat sagemaker_lorax/Dockerfile

FROM ghcr.io/predibase/lorax:0.7.0

RUN apt-get install wget
RUN wget \
    https://raw.githubusercontent.com/predibase/lorax/v0.8.1/server/lorax_server/utils/sources/__init__.py \
    https://raw.githubusercontent.com/predibase/lorax/v0.8.1/server/lorax_server/utils/sources/s3.py \
    && mv -t /usr/src/server/lorax_server/utils/sources/ __init__.py s3.py

COPY sagemaker_entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]


We build the new container image and push it to a new ECR repository. Note SageMaker [supports private Docker registries](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-containers-inference-private.html) as well.

In [4]:
%%bash -s {region}
algorithm_name="sm-lorax"  # name of your algorithm
tag="latest"
region='us-east-1'

account=$(aws sts get-caller-identity --query Account --output text)

image_uri="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:${tag}"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" --region $region > /dev/null
fi

cd sagemaker_lorax/ && docker build -t ${algorithm_name}:${tag} .

# Authenticate Docker to an Amazon ECR registry
aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin ${account}.dkr.ecr.${region}.amazonaws.com

# Tag the image
docker tag ${algorithm_name}:${tag} ${image_uri}

# Push the image to the repository
docker push ${image_uri}

# Save image name to tmp file to use when deploying endpoint
echo $image_uri > /tmp/image_uri

Sending build context to Docker daemon  3.584kB
Step 1/6 : FROM ghcr.io/predibase/lorax:0.7.0
0.7.0: Pulling from predibase/lorax
96d54c3075c9: Pulling fs layer
09d415c238d7: Pulling fs layer
9fe6e2e61518: Pulling fs layer
41f16248e682: Pulling fs layer
95d7b7817039: Pulling fs layer
4f4fb700ef54: Pulling fs layer
959f82742943: Pulling fs layer
9025bd344ce7: Pulling fs layer
adfa41a6373b: Pulling fs layer
c4446226bc69: Pulling fs layer
e8967d84bc78: Pulling fs layer
058d42d7b817: Pulling fs layer
c229ba7b2d98: Pulling fs layer
8e1085530a30: Pulling fs layer
4e3516d65361: Pulling fs layer
36fbc80e3462: Pulling fs layer
734c9cff5cc4: Pulling fs layer
520ce0040a14: Pulling fs layer
c01e26680578: Pulling fs layer
c7df9566c432: Pulling fs layer
899130d4026e: Pulling fs layer
15bb6f25c9d7: Pulling fs layer
80097af03388: Pulling fs layer
b6e1027e7d59: Pulling fs layer
f4a90201c6e3: Pulling fs layer
a77df8ee08a0: Pulling fs layer
324dcf76c071: Pulling fs layer
604daeae5aff: Pulling fs layer
59

https://docs.docker.com/engine/reference/commandline/login/#credentials-store



Login Succeeded
The push refers to repository [622343165275.dkr.ecr.us-east-1.amazonaws.com/sm-lorax]
e59fb7ef57a0: Preparing
88156e44280f: Preparing
6f0bc75cafcf: Preparing
587d555f2c9d: Preparing
21b57c16a62f: Preparing
5f70bf18a086: Preparing
5cb39f79cd22: Preparing
181eb05384b2: Preparing
3657817dc3f0: Preparing
f14d86fa3017: Preparing
400ac87f472d: Preparing
d619dfa00c73: Preparing
8c8243578a2b: Preparing
261c3774c596: Preparing
5f70bf18a086: Preparing
3323ffa36a6b: Preparing
65411ef2fe39: Preparing
96f5d88e1c6e: Preparing
8421b2fc221a: Preparing
6db0e47080dc: Preparing
1dbfef547c88: Preparing
4dd4b78d72b9: Preparing
01e7f8ae12e2: Preparing
0c2888617ca7: Preparing
12748023f7f7: Preparing
405aa7f9a1db: Preparing
30b51295ff33: Preparing
da5e3a75a750: Preparing
1538db58ab64: Preparing
fb1883f1cf7b: Preparing
c8a655e8e8cd: Preparing
83d1bc0501bc: Preparing
e35acfadf1d1: Preparing
5f70bf18a086: Preparing
f344b08ff6c5: Preparing
86f0cc586e78: Preparing
33e57ea5b30a: Preparing
851dfeb181

In [6]:
print(f"image is will be 622343165275.dkr.ecr.us-east-1.amazonaws.com/sm-lorax:latest")

image is will be 622343165275.dkr.ecr.us-east-1.amazonaws.com/sm-lorax:latest


<a id="download_adapter"></a>
## 3. Download adapter from HuggingFace Hub and push it to S3

We are going to simulate storing our adapter weights on S3, and having LoRAX load them dynamically as we invoke them. This enables most scenarios, including deployment after you’ve finetuned your own adapter and pushed it to S3, as well as securing deployments with no internet access inside your VPC, as detailed in this [blog post](https://www.philschmid.de/sagemaker-llm-vpc#2-upload-the-model-to-amazon-s3).

We first download an adapter trained with Mistral Instruct v0.1 as the base model to a local directory. This particular adapter was trained on GSM8K, a grade school math dataset.

In [7]:
from pathlib import Path
from huggingface_hub import snapshot_download

HF_MODEL_ID = "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k"
# create model dir
model_dir = Path('mistral-adapter')
model_dir.mkdir(exist_ok=True)

# Download model from Hugging Face into model_dir
snapshot_download(
    HF_MODEL_ID,
    local_dir=str(model_dir), # download to model dir
    local_dir_use_symlinks=False, # use no symlinks to save disk space
    revision="main", # use a specific revision, e.g. refs/pr/21
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/1.31k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/501 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/4.09k [00:00<?, ?B/s]

adapter_model.bin:   0%|          | 0.00/13.7M [00:00<?, ?B/s]

'/home/ec2-user/SageMaker/optimized-llm-deployment-workshop/03_multi_adapter_inference/mistral-adapter'

We copy this same adapter `n_adapters` times to different S3 prefixes in our SageMaker session bucket, simulating a large number of adapters we want to serve on the same endpoint and underlying GPU.

In [8]:
import os

s3 = boto3.client('s3')

def upload_folder_to_s3(local_path, s3_bucket, s3_prefix):
    for root, dirs, files in os.walk(local_path):
        for file in files:
            local_file_path = os.path.join(root, file)
            s3_object_key = os.path.join(s3_prefix, os.path.relpath(local_file_path, local_path))
            s3.upload_file(local_file_path, s3_bucket, s3_object_key)

# Upload the folder n_adapters times under different prefixes
n_adapters=50
base_prefix = 'lorax/mistral-adapters'
for i in range(1, n_adapters+1):
    prefix = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'

    upload_folder_to_s3(model_dir, sagemaker_session_bucket, prefix)
    print(f'Uploaded folder to S3 with prefix: {prefix}')

Uploaded folder to S3 with prefix: lorax/mistral-adapters/01
Uploaded folder to S3 with prefix: lorax/mistral-adapters/02
Uploaded folder to S3 with prefix: lorax/mistral-adapters/03
Uploaded folder to S3 with prefix: lorax/mistral-adapters/04
Uploaded folder to S3 with prefix: lorax/mistral-adapters/05
Uploaded folder to S3 with prefix: lorax/mistral-adapters/06
Uploaded folder to S3 with prefix: lorax/mistral-adapters/07
Uploaded folder to S3 with prefix: lorax/mistral-adapters/08
Uploaded folder to S3 with prefix: lorax/mistral-adapters/09
Uploaded folder to S3 with prefix: lorax/mistral-adapters/10
Uploaded folder to S3 with prefix: lorax/mistral-adapters/11
Uploaded folder to S3 with prefix: lorax/mistral-adapters/12
Uploaded folder to S3 with prefix: lorax/mistral-adapters/13
Uploaded folder to S3 with prefix: lorax/mistral-adapters/14
Uploaded folder to S3 with prefix: lorax/mistral-adapters/15
Uploaded folder to S3 with prefix: lorax/mistral-adapters/16
Uploaded folder to S3 wi

<a id="deploy"></a>
## 4. Deploy SageMaker endpoint


Now we deploy a SageMaker endpoint, pointing to our SageMaker session bucket as the ADAPTER_BUCKET env variable, which enables downloading adapters from S3.

If you have any problems deploying on g5.xlarge, you can change the instance type to g5.2xlarge or g5.4xlarge.

In [11]:
import json
import datetime

from sagemaker import Model
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

print(f"image is will be 622343165275.dkr.ecr.us-east-1.amazonaws.com/sm-lorax:latest")
# Retrieve image_uri from tmp file
image_uri = !cat /tmp/image_uri
image_uri

image is will be 622343165275.dkr.ecr.us-east-1.amazonaws.com/sm-lorax:latest


['622343165275.dkr.ecr.us-east-1.amazonaws.com/sm-lorax:latest']

In [12]:
import json
import datetime

from sagemaker import Model
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# Retrieve image_uri from tmp file
image_uri = !cat /tmp/image_uri
# Increased health check timeout to give time for model download
health_check_timeout = 800
number_of_gpu = 1
instance_type = "ml.g5.xlarge"
endpoint_name = sagemaker.utils.name_from_base("sm-lorax")


# Model and Endpoint configuration parameters
env = {
  'HF_MODEL_ID': "mistralai/Mistral-7B-Instruct-v0.1", # model_id from hf.co/models
  'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(1024),  # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(4096),  # Max length of the generation (including input text)
  'ADAPTER_BUCKET': sagemaker_session_bucket,
}

lorax_model = Model(
    image_uri=image_uri[0],
    role=role,
    env=env
)

lorax_predictor = lorax_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout, 
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

--------------------!

In [13]:
endpoint_name

'sm-lorax-2024-03-18-19-14-32-114'

In [14]:
# You can reinstantiate the Predictor object if you restart the notebook or Predictor is None
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
endpoint_name = endpoint_name

lorax_predictor = Predictor(
    endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

<a id="compare"></a>
## 5. Invoke base model and adapter, compare outputs

We can invoke the base Mistral model, as well as any of the adapters in our bucket! LoRAX will take care of downloading them, continuously batch requests for different adapters, and manage DRAM and RAM by loading/offloading adapters.

Let’s inspect the difference between the base model’s response and the adapter’s response:

<div class="alert alert-block alert-info">
⚠️ I observed a weird error that I haven't debugged yet, where S3 download failed for adapters ID 1 through 5, but worked as expected for all other adapters. Something with the S3 prefix. Added 0 before adapter id if id < 10 as a workaround.
</div>

In [15]:
prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'

payload_base = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
    }
}

payload_adapter = {
"inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
        "adapter_id": f'{base_prefix}/01',
        # "adapter_id" : adapter_uri, 
        "adapter_source": "s3"
    }
}

response_base = lorax_predictor.predict(payload_base)
response_adapter = lorax_predictor.predict(payload_adapter)

print(f'Base model output:\n-------------\n {response_base[0]["generated_text"]}')
print(f'Adapter output:\n-------------\n {response_adapter[0]["generated_text"]}')

Base model output:
-------------
 Let's break down the problem:

1. In April, Natalia sold clips to 48 of her friends.
2. In May, she sold half as many clips as in April, which means she sold 48/2 = 24 clips in May.

Adapter output:
-------------
 Natalia sold 48/2 = <<48/2=24>>24 clips in May.
In total, Natalia sold 48 + 24 = <<48+24=72>>72 clips in April and May.
#### 72


<a id="benchmark"></a>
## 6. Benchmark single adapter vs. random access to adapters



First, we individually call each of the adapters in sequence, to make sure they are previously downloaded to the endpoint instance’s disk. We want to exclude S3 download latency from the benchmark metrics.

In [16]:
from tqdm import tqdm

for i in tqdm(range(1,n_adapters+1)):
    adapter_id = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
    payload_adapter = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
        "adapter_id": adapter_id,
        "adapter_source": "s3"
        }
    }
    lorax_predictor.predict(payload_adapter)

100%|██████████| 50/50 [02:39<00:00,  3.18s/it]


Now we are ready to benchmark. For the single adapter case, we invoke the adapter `total_requests` times from `num_threads` concurrent clients.

For the multi-adapter case, we invoke a random adapter from any of the clients, until all adapters have been invoked `total_requests//num_adapters` times.

In [None]:
# Adjust if you run into connection pool errors
# import botocore

# Configure botocore to use a larger connection pool
# config = botocore.config.Config(max_pool_connections=100)

In [17]:
print("hello")

hello


In [22]:
import threading
import time
import random


# Configuration
total_requests = 300
num_adapters = 50
num_threads = 20  # Adjust based on your system capabilities


# Shared lock and counters for # invocations of each adapter 
adapter_counters = [total_requests // num_adapters] * num_adapters
counters_lock = threading.Lock()

def invoke_adapter(aggregate_latency, single_adapter=False):
    global total_requests
    latencies = []
    while True:
        with counters_lock:
            if single_adapter:
                adapter_id = 1
                if total_requests > 0:
                    total_requests -= 1
                else:
                    break
            else:
                # Find an adapter that still needs to be called
                remaining_adapters = [i for i, count in enumerate(adapter_counters) if count > 0]
                if not remaining_adapters:
                    break
                adapter_id = random.choice(remaining_adapters) + 1
                adapter_counters[adapter_id - 1] -= 1

        prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'
        invoke_adapter_id = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
        payload_adapter = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 64,
                "adapter_id": invoke_adapter_id,
                "adapter_source": "s3"
            }
        }
        start_time = time.time()
        response_adapter = lorax_predictor.predict(payload_adapter)
        latency = time.time() - start_time
        latencies.append(latency)

    aggregate_latency.extend(latencies)

def benchmark_scenario(single_adapter=False, num_threads=20):
    threads = []
    all_latencies = []
    start_time = time.time()

    for _ in range(num_threads):
        thread_latencies = []
        all_latencies.append(thread_latencies)
        thread = threading.Thread(target=invoke_adapter, args=(thread_latencies, single_adapter))
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    total_latency = sum([sum(latencies) for latencies in all_latencies])
    total_requests_made = sum([len(latencies) for latencies in all_latencies])
    if total_requests_made <=0 :
        total_requests_made = 1
    average_latency = total_latency / total_requests_made
    throughput = total_requests_made / (time.time() - start_time)

    print(f"Total Time: {time.time() - start_time}s")
    print(f"Average Latency: {average_latency} s")
    print(f"Throughput: {throughput} requests/s")




In [134]:
# Run benchmarks
print("Benchmarking: Single Adapter Multiple Times")
benchmark_scenario(single_adapter=True)

print("\nBenchmarking: Multiple Adapters with Random Access")
benchmark_scenario()

Benchmarking: Single Adapter Multiple Times
Total Time: 44.282410860061646s
Average Latency: 2.9100885208447775 s
Throughput: 6.774699079147043 requests/s

Benchmarking: Multiple Adapters with Random Access


ZeroDivisionError: division by zero

In [23]:
print("\nBenchmarking: Multiple Adapters with Random Access")
benchmark_scenario()


Benchmarking: Multiple Adapters with Random Access
Total Time: 43.7987117767334s
Average Latency: 2.8808842070897422 s
Throughput: 6.849516821263714 requests/s


### Bench mark thread pool

In [135]:
import numpy as np
import time
import traceback
from collections import defaultdict
import json
import sagemaker
from sagemaker.model import Model
from sagemaker import serializers, deserializers
from sagemaker import image_uris
import boto3
import os
import time
import json
import jinja2
from pathlib import Path

import boto3
import sagemaker
from sagemaker import get_execution_role


import time
from PIL import Image
import numpy as np

# variables
s3_client = boto3.client("s3")
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
model_bucket = sess.default_bucket()  # bucket to house artifacts

region = sess._region_name
account_id = sess.account_id()

sm_client = boto3.client("sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")


jinja_env =  jinja2.Environment()  # jinja environment to generate model configuration templates



def run_worker_endpoint(worker_id, adaptor_id):

    # make a inference request to load model into memory
    runtime_sm_client = boto3.client("sagemaker-runtime")

    print(f"Starting invocation for model:: worker_id={worker_id}:::  adaptor_id={adaptor_id}::....please wait ...")
    
    start_time = time.time()
    results = [0]
    total_count=0
    error_count=0
    #base_prefix = 'lorax/mistral-adapters'
    total_run_time = 40 #10  #120: #3600: #400:  # -- 300 sec  -- 1 hour 3600    2 hour 7200 is 4 is 14400
    
    while (time.time() - start_time) < total_run_time: #120: #3600: #400:  # -- 300 sec  -- 1 hour 3600    2 hour 7200 is 4 is 14400
        start_run = time.time()
        try:
            prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'
            invoke_adapter_id = f'{base_prefix}/0{adaptor_id}' if adaptor_id < 10 else f'{base_prefix}/{adaptor_id}'
            payload_adapter = {
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": 64,
                    "adapter_id": invoke_adapter_id,
                    "adapter_source": "s3"
                }
            }
            response = runtime_sm_client.invoke_endpoint(
                EndpointName=endpoint_name,
                ContentType="application/json",
                Body=json.dumps(payload_adapter),
            )
            results.append((time.time() - start_run) * 1000)
            total_count = total_count + 1

        except:
            print(traceback.format_exc())
            error_count = error_count + 1
            time.sleep(0.005)
     
    if total_count <=0 :
        total_count = 1
    print(f"worker_id={worker_id}::adaptor_id={adaptor_id}:: p95::{np.percentile(results, 95)} ms:throughput={total_run_time/total_count}::total_success_count={total_count}::error_count={error_count}::")

    return (np.percentile(results, 90), (total_count/total_run_time)*60, total_count, error_count) # - latency, throughput in TPM, total run , error count


In [136]:
# create a process pool
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing import Pool
from multiprocessing import cpu_count
import threading
import time
import random

max_workers_cpu = 40 #3 #4 #5  # cpu_count() # -*2
print(f"Max_A-Sync:processes={max_workers_cpu}")

Max_A-Sync:processes=40


In [139]:
max_adaptors_to_run = 1 #40
max_adaptors = 50
remaining_adaptors = [i+1 for i in range(max_adaptors)] # - adaptors id is 1 base
#remaining_adaptors

In [140]:
result_pool_list = []
print(max_workers_cpu)


with ThreadPoolExecutor(max_workers=max_workers_cpu) as pool:
    # call a function

    for worker in range(max_adaptors_to_run): # run certain adaptors
        adaptor_id = random.choice(remaining_adaptors) 
        result_p = pool.submit(run_worker_endpoint, worker,adaptor_id ) # -- making worker id 1 based to test 20 models in 1 instance
        print(result_p)
        result_pool_list.append(result_p)
        
        remaining_adaptors.remove(adaptor_id) # so a new adaptor get assigned



40
<Future at 0x7f16be573af0 state=running>
Starting invocation for model:: worker_id=0:::  adaptor_id=46::....please wait ...
worker_id=0::adaptor_id=46:: p95::2416.24196767807 ms:throughput=2.3529411764705883::total_success_count=17::error_count=0::


In [141]:
for result_p in result_pool_list:
    print(result_p.result()) # blocks

(2414.013719558716, 25.5, 17, 0)


In [142]:
time_in_ms_list = []
throughput_run_list = []
total_run = 0

for result_p in result_pool_list:
    time_in_ms, throughput_run , total_count, error_count = result_p.result()
    time_in_ms_list.append(time_in_ms)
    total_run = total_run + total_count
    throughput_run_list.append(throughput_run)
    
p90_avg = np.percentile(time_in_ms_list, 90)
p90_throughput = np.percentile(throughput_run_list, 90)

print(f"max_adaptors_to_run={max_adaptors_to_run}::P90 Average:latency--->{p90_avg}:ms: and P90 throughput --- >{p90_throughput}:TPM:   total_count={total_count}")
    

max_adaptors_to_run=1::P90 Average:latency--->2414.013719558716:ms: and P90 throughput --- >25.5:TPM:   total_count=17


<a id="cleanup"></a>
## 7. Cleanup endpoint resources

In [None]:
lorax_predictor.delete_endpoint()