# Using SageMaker Efficient Multi-Adapter Serving to host LoRA adapters at Scale

Multi-Adapter serving allows for multiple fine-tuned models to be hosted in a cost efficient manner on a singular endpoint. Via a multi-adapter approach we can tackle multiple different tasks with a singular base LLM. In this example you will use a pre-trained LoRA adapter that was fine tuned from Llama 3.1 8B Instruct on the [ECTSum dataset](https://huggingface.co/datasets/mrSoul7766/ECTSum).

You will also see how to dynamically load these adapters using [SageMaker Inference Components](https://aws.amazon.com/blogs/aws/amazon-sagemaker-adds-new-inference-capabilities-to-help-reduce-foundation-model-deployment-costs-and-latency/), in this example we specifically explore the Inference Component Adapter feature which will allow for us to load hundreds of adapters on a SageMaker real-time endpoint.

![](./images/ic-adapter-architecture.png)

## Step 1: Setup

### Fetch and import dependencies

In [None]:
!pip install boto3==1.35.68 --quiet --upgrade
!pip install sagemaker==2.235.2 --quiet --upgrade

## Restart kernel before continuing

In [None]:
import sagemaker
import boto3
import json
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

### Configure development environment and boto3 clients

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name

sm_client = boto3.client(service_name="sagemaker")
sm_runtime = boto3.client(service_name="sagemaker-runtime")

## Step 2: Create a model, endpoint configuration and endpoint

### Download your model

Download the base model from the HuggingFace model hub. Since [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) is a gated model, you will need a HuggingFace access token and submit a request for model access on the model page. Follow the [HuggingFace instructions on accessing gated models for details](https://huggingface.co/docs/transformers.js/en/guides/private).

In [None]:
HF_TOKEN = "<<YOUR HF TOKEN HERE>>"
model_id = "meta-llama/Llama-3.1-8B-Instruct"
model_id_pathsafe = model_id.replace("/","-")
local_model_path = f"./models/{model_id_pathsafe}"
s3_model_path = f"s3://{bucket}/models/{model_id_pathsafe}"

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id=model_id, use_auth_token=HF_TOKEN, local_dir=local_model_path, allow_patterns=["*.json", "*.safetensors"])

Copy your model artifact to S3 to improve model load time during deployment

In [None]:
!aws s3 cp --recursive {local_model_path} {s3_model_path}

### Select a Large Model Inference (LMI) container image

Select one of the [available Large Model Inference (LMI) container images for hosting](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers). Efficient adapter inference capability is available in `0.31.0-lmi13.0.0` and higher. Ensure that you are using the image URI for the region that corresponds with your deployment region.

In [None]:
inference_image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"

print(f"Inference container image:: {inference_image_uri}")

### Configure model container environment

Create an container environment for the hosting container. LMI container parameters can be found in the [LMI User Guides](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/index.html).

By using the `OPTION_MAX_LORAS` and `OPTION_MAX_CPU_LORAS` parameters, you can control how adapters are loaded and unloaded into GPU/CPU memory. The `OPTION_MAX_LORAS` parameter defines the number of adapters that will be held in GPU memory, and any additional adapters will be offloaded to CPU memory. The `OPTION_MAX_CPU_LORAS` parameter controls the number of adapters that will be held in CPU memory, with any additional adapters being offloaded to local SSD. In the following example, the container will hold 30 adapters in GPU memory, and 70 adapters in CPU memory.


In [None]:
env = {
    "HF_MODEL_ID": f"{s3_model_path}",
    "OPTION_ROLLING_BATCH": "lmi-dist",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "16",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_ENABLE_LORA": "true",
    "OPTION_MAX_LORAS": "30",
    "OPTION_MAX_CPU_LORAS": "70",
    "OPTION_DTYPE": "fp16",
    "OPTION_MAX_MODEL_LEN": "6000"
}

env

### Create a model object

With your container image and environment defined, you can create a SageMaker model object that you will use to create an inference component later.

In [None]:
model_name = sagemaker.utils.name_from_base("llama-3-1-8b-instruct")
print(f'Model name: {model_name}')

In [None]:
create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image_uri, 
        "Environment": env,
    },
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model ARN: {model_arn}")

### Create an endpoint configuration

To create a SageMaker endpoint, you need an endpoint configuration. When using Inference Components, you do not specify a model in the endpoint configuration. You will load the model as a component later on.

In [None]:
# Set variant name and instance type for hosting
endpoint_config_name = f"{model_name}"
variant_name = "AllTraffic"
instance_type = "ml.g5.2xlarge"
model_data_download_timeout_in_seconds = 900
container_startup_health_check_timeout_in_seconds = 900

initial_instance_count = 1

In [None]:
create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ExecutionRoleArn = role,
    ProductionVariants = [
        {
            "VariantName": variant_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": initial_instance_count,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ]
)

print(f"Created Endpoint Config ARN: {create_endpoint_config_response['EndpointConfigArn']}")

### Create inference endpoint

Create your empty SageMaker endpoint. You will use this to host your base model and adapter inference components later.

In [None]:
endpoint_name = f"{model_name}"

print(f'Endpoint name: {endpoint_name}')

#### This step can take around 5 minutes

In [None]:
%%time

create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)

sess.wait_for_endpoint(endpoint_name)

print(f"Created Endpoint ARN: {create_endpoint_response['EndpointArn']}")

### Create base model inference component

With your endpoint created, you can now create the IC for the base model. This will be the base component that the adapter components you create later will depend on. 

Notable parameters here are [`ComputeResourceRequirements`](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_InferenceComponentComputeResourceRequirements.html). These are a component level configuration that determine the amount of resources that the component needs (Memory, vCPUs, Accelerators). The adapters will share these resources with the base component.


#### This step can take around 7 minutes

In [None]:
%%time

base_inference_component_name = f"base-{model_name}"
print(f"Base inference component name: {base_inference_component_name}")

variant_name = "AllTraffic"

initial_copy_count = 1
min_memory_required_in_mb = 32000
number_of_accelerator_devices_required = 4

base_create_inference_component_response = sm_client.create_inference_component(
    InferenceComponentName = base_inference_component_name,
    EndpointName = endpoint_name,
    VariantName = variant_name,
    Specification={
        "ModelName": model_name,
        "StartupParameters": {
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
        },
        "ComputeResourceRequirements": {
            "MinMemoryRequiredInMb": min_memory_required_in_mb,
            "NumberOfAcceleratorDevicesRequired": number_of_accelerator_devices_required,
        },
    },
    RuntimeConfig={
        "CopyCount": initial_copy_count,
    },
)

sess.wait_for_inference_component(base_inference_component_name)

print(f"Created Base inference component ARN: {base_create_inference_component_response['InferenceComponentArn']}")

### View logs for the base inference component (and adapters after they're loaded)

In [None]:
import urllib

cw_path = urllib.parse.quote_plus(f'/aws/sagemaker/InferenceComponents/{base_inference_component_name}', safe='', encoding=None, errors=None)

print(f'You can view your inference component logs here:\n\n https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}#logsV2:log-groups/log-group/{cw_path}')

### Create the Inference Components (ICs) for the adapters

In this example you’ll create a single adapter, but you could host up to hundreds of them per endpoint. They will need to be compressed and uploaded to S3.

The adapter package has the following files at the root of the archive with no sub-folders:

![](./images/adapter_files.png)

For this example, an adapter was fine tuned using QLoRA and [Fully Sharded Data Parallel (FSDP)](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-v2.html) on the training split of the [ECTSum dataset](https://huggingface.co/datasets/mrSoul7766/ECTSum). Training took 21 minutes on a ml.p4d.24xlarge and cost ~$13 using current [on-demand pricing](https://aws.amazon.com/sagemaker/pricing/).

#### Copy locally downloaded adapters to S3

In [None]:
ectsum_adapter_filename = "ectsum-adapter.tar.gz"
ectsum_adapter_s3_uri = f"s3://{bucket}/adapters/{ectsum_adapter_filename}"

!aws s3 cp ./adapters/{ectsum_adapter_filename} {ectsum_adapter_s3_uri}

### Create ECTSum adapter inference component

For each adapter you are going to deploy, you need to specify an `InferenceComponentName`, an `ArtifactUrl` with the S3 location of the adapter archive, and a `BaseInferenceComponentName` to create the connection between the base model IC and the new adapter ICs. You will repeat this process for each additional adapter.

#### This step can take around 2 minutes

In [None]:
ic_ectsum_name = f"ic-ectsum-{base_inference_component_name}"

sm_client.create_inference_component(
    InferenceComponentName = ic_ectsum_name,
    EndpointName = endpoint_name,
    Specification={
        "BaseInferenceComponentName": base_inference_component_name,
        "Container": {
            "ArtifactUrl": ectsum_adapter_s3_uri
        },
    },
)

sess.wait_for_inference_component(ic_ectsum_name)

Look at base inference component logs again.

It should show a line that looks like:

`Registered adapter <ADAPTER_NAME> from /opt/ml/models/ ... successfully`.

## Step 3: Invoking the Endpoint

First you will pull a random datapoint form the ECTSum test split. You'll use the `text` field to invoke the model and the `summary` filed to compare with ground truth later.

In [None]:
from datasets import load_dataset
dataset_name = "mrSoul7766/ECTSum"

test_dataset = load_dataset(dataset_name, trust_remote_code=True, split="test")

test_item = test_dataset.shuffle().select(range(1))

ground_truth_response = test_item["summary"]

Next you will build a prompt to invoke the model for earnings summarization, filling in the source text with a random item from the ECTSum dataset. 

In [None]:
prompt =f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
                You are an AI assistant trained to summarize earnings calls. Provide a concise summary of the call, capturing the key points and overall context. Focus on quarter over quarter revenue, earnings per share, changes in debt, highlighted risks, and growth opportunities.
                <|eot_id|><|start_header_id|>user<|end_header_id|>
                Summarize the following earnings call:

                {test_item["text"]}
                <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

### Plain base model with no adapters

To test the base model, specify the `EndpointName` for the endpoint you created earlier and the name of the base inference component as `InferenceComponentName` along with your prompt and other inference parameters in the `Body` parameter.

In [None]:
%%time

component_to_invoke = base_inference_component_name

response_model = sm_runtime.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = component_to_invoke,
    Body = json.dumps(
        {
            "inputs": prompt,
            "parameters": {"do_sample": True, "top_p": 0.9, "temperature": 0.9, "max_new_tokens": 125, "temperature":0.9}
        }
    ),
    ContentType = "application/json",
)

base_response = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]

print(f'Ground Truth:\n\n {test_item["summary"]}\n\n')
print(f'Base Model Response:\n\n {base_response}')

### Invoke ECTSum adapter

To invoke the adapter, use the adapter inference component name in your `invoke_endpoint` call.

In [None]:
%%time

component_to_invoke = ic_ectsum_name

response_model = sm_runtime.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = component_to_invoke,
    Body = json.dumps(
        {
            "inputs": prompt,
            "parameters": {"do_sample": True, "top_p": 0.9, "temperature": 0.9, "max_new_tokens": 125, "temperature":0.9}
        }
    ),
    ContentType = "application/json",
)

adapter_response = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]

print(f'Ground Truth:\n\n {test_item["summary"]}\n\n')
print(f'Adapter Model Response:\n\n {adapter_response}')

### Compare outputs

Compare the outputs of the base model and adapter to ground truth. In this test, notice that while the base model looks subjectively more visually attractive, the adapter response is significantly closer to ground truth; which is what you are looking for. This will be proven with metrics in the next section.

In [None]:
print(f'Ground Truth:\n\n {test_item["summary"][0]}\n\n')
print("\n----------------------------------\n")
print(f'Base Model Response:\n\n {base_response}')
print("\n----------------------------------\n")
print(f'Adapter Model Response:\n\n {adapter_response}')

To validate the true adapter performance, you can use a tool like [fmeval](https://github.com/aws/fmeval) to run an evaluation of summarization accuracy. This will calculate the METEOR, ROUGE, and BertScore metrics for the adapter versus the base model. Doing so against the test split of ECTSum yields the following results:

![](./images/fmeval-overall.png)

The fine-tuned adapter shows a 59% increase in METEOR score, 159% increase in ROUGE score, and 8.6% in BertScore. The following diagram shows the frequency distribution of scores for the different metrics, with the adapter consistently scoring better more often in all metrics. 

Model latency is largely unaffected, with only a difference of 2% between direct base model invocation and the adapter.

![](./images/fmeval-histogram.png)

### Upload a new ECTSum adapter artifact and update the live adapter inference component

Since adapters are managed as Inference Components, you can update them on a running endpoint. SageMaker handles the unloading/deregistering of the old adapter and loading/registering of the new adapter onto every base ICs on all of the instances that it is running on for this endpoint. To update an adapter IC, use the  update_inference_component  API and supply the existing IC name and the S3 path to the new compressed adapter archive. 

You can train a new adapter, or re-upload the existing adapter artifact to test this functionality.

In [None]:
new_ectsum_adapter_s3_uri = f"s3://{bucket}/lora-adapters/new-ectsum-adapter.tar.gz"

!aws s3 cp ./adapters/ectsum-adapter.tar.gz {new_prt_adapter_s3_uri}

#### This step can take around 5 minutes

In [None]:
%%time

update_inference_component_response = sm_client.update_inference_component(
    InferenceComponentName = ic_ectsum_name,
    Specification={
        "Container": {
            "ArtifactUrl": new_ectsum_adapter_s3_uri
        },
    },
)

sess.wait_for_inference_component(ic_ectsum_name)

print(f'Updated inference component adapter ARN: {update_inference_component_response["InferenceComponentArn"]}')

If you view your inference component logs (link below), you will see log entries for the deregistration of the old adapter and the registration of the new one.

You should see something similar to:
`[INFO ] PyProcess - W-200-0d1e4741a42db26-stdout: [1,0]<stdout>:INFO::Unregistered adapter ic-ectsum-base-llama-3-1-8b-instruct-2024-11-25-20-41-07-401 successfully`

`[INFO ] PyProcess - W-200-0d1e4741a42db26-stdout: [1,0]<stdout>:INFO::Registered adapter ic-ectsum-base-llama-3-1-8b-instruct-2024-11-25-20-41-07-401 from /opt/ml/models/container_340043819279-ic-ectsum-base-llama-3-1-8b-instruct-2024-11-25-20-41-07-401-1732570150851-MaeveWestworldService-1.0.9353.0 successfully`

In [None]:
print(f'You can view your inference component logs here:\n\n https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}#logsV2:log-groups/log-group/{cw_path}')

### Retest with updated adapter

In [None]:
%%time

component_to_invoke = ic_ectsum_name

response_model = sm_runtime.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = component_to_invoke,
    Body = json.dumps(
        {
            "inputs": prompt,
            "parameters": {"do_sample": True, "top_p": 0.9, "temperature": 0.9, "max_new_tokens": 125, "temperature":0.9}
        }
    ),
    ContentType = "application/json",
)

response_model["Body"].read().decode("utf8")

## Step 4: Clean up the environment

If you need to delete an adapter, call the `delete_inference_component` API with the IC name to remove it. 

In [None]:
sess.delete_inference_component(ic_ectsum_name, wait = True)

Deleting the base model IC will automatically delete the base IC and any associated adapter ICs.

In [None]:
sess.delete_inference_component(base_inference_component_name, wait = True)

Clean up the running endpoint and its configuration.

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)