# Using SageMaker Multi-Adapter Serving to host LoRA adapters at Scale
In this example we explore how SageMaker Multi-Adapter Inference works with the [Gemma-2b model](https://huggingface.co/google/gemma-2-2b) and an [OSS adapter](https://huggingface.co/Kronu/gemma-2-2b-lean-expert-1760-complete) from HuggingFace. 

We'll build on this and explore more practical applications where we can actually see the value of LORA adapters from an evaluation/Data Science perspective. For this notebook we want to just understand the constructs of Multi-Adapter Inference and how you can set it up, if you have your own adapters and base model already try to plug them into these constructs with appropriate config to take it for a spin yourself!

## Prerequisites
- You also need a HuggingFace token to follow this sample: https://huggingface.co/docs/hub/en/security-tokens
- If new to Inf Components/SM Inference follow this guide: https://www.youtube.com/watch?v=RcUNEeUqpNQ&t=11s

## Additional Resources/Credits
- vLLM LMI Engine Params/Docs: https://docs.djl.ai/v0.29.0/docs/serving/serving/docs/lmi/user_guides/vllm_user_guide.html
- Multi-Adapter Inference Docs: https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-adapt.html

## License
- Model: Gemma-2 2B
- License: Gemma License
- Used under the terms of the Gemma License for research and deployment.

## Setup

In [None]:
%pip install boto3 huggingface_hub sagemaker --upgrade --quiet --no-warn-conflicts

### Configure development environment and boto3 clients

In [None]:
import json
import boto3
import sagemaker
import time
from time import gmtime, strftime
import tarfile
import pathlib
from huggingface_hub import snapshot_download
import os
from botocore.exceptions import ClientError

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
boto_session = boto3.session.Session()

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints
s3_client = boto3.client("s3")
s3 = boto_session.resource('s3')

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

In [None]:
os.environ["AWS_REGION"] = "us-east-1"     # or your region
os.environ["HF_TOKEN"] = "Enter HF token"

## Container & Model Aritfacts Setup

In [None]:
CONTAINER_VERSION = "0.34.0-lmi16.0.0-cu128"
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"
print(f"Using image URI: {inference_image}")

In [None]:
instance_type = "ml.g5.4xlarge"
num_gpu = 1

model_name = "gemma-2b" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

#utilize the vLLM async handler: 
env = {
    "HF_MODEL_ID": "google/gemma-2-2b",
    "HF_TOKEN": os.getenv("HF_TOKEN"),
    "SERVING_FAIL_FAST": "true",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_TENSOR_PARALLEL_DEGREE": json.dumps(num_gpu),
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_TRUST_REMOTE_CODE": "true",
    "OPTION_ENABLE_LORA": "true",
    "OPTION_MAX_LORAS": "1",
    "OPTION_MAX_CPU_LORAS": "2",
    "OPTION_MAX_LORA_RANK": "16",
    "OPTION_DTYPE": "bf16",
}

### Model & Base IC Creation
We first create an Inference Component to represent the base model, we will then associate the adapter component with this base model.

In [None]:
model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image,
        "Environment": env
    },
)
print(json.dumps(model_response, indent=2))

In [None]:
# deployment params
model_data_download_timeout_in_seconds = 900
container_startup_health_check_timeout_in_seconds = 900
initial_instance_count = 1
variant_name = "main"

endpoint_config_name = "gemma-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ExecutionRoleArn = role,
    ProductionVariants = [
        {
            "VariantName": variant_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": initial_instance_count,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ]
)

print(json.dumps(endpoint_config_response, indent=2))

In [None]:
endpoint_name = "gemma-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, 
    EndpointConfigName = endpoint_config_name
)
print(json.dumps(endpoint_response, indent=2))

In [None]:
describe_endpoint_response = sm_client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = sm_client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(60)
print(describe_endpoint_response)
print(f"Created endpoint: {endpoint_name}")

### Base IC Creation

In [None]:
%%time

base_inference_component_name = f"base-{model_name}"
print(f"Base inference component name: {base_inference_component_name}")

# component level params
initial_copy_count = 1
min_memory_required_in_mb = 32000
number_of_accelerator_devices_required = 1

# create base IC
base_create_inference_component_response = sm_client.create_inference_component(
    InferenceComponentName = base_inference_component_name,
    EndpointName = endpoint_name,
    VariantName = variant_name,
    Specification={
        "ModelName": model_name,
        "StartupParameters": {
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
        },
        "ComputeResourceRequirements": {
            "MinMemoryRequiredInMb": min_memory_required_in_mb,
            "NumberOfAcceleratorDevicesRequired": number_of_accelerator_devices_required,
        },
    },
    RuntimeConfig={
        "CopyCount": initial_copy_count,
    },
)

sess.wait_for_inference_component(base_inference_component_name)
print(f"Created Base inference component ARN: {base_create_inference_component_response['InferenceComponentArn']}")

### Invoke the Base Model Inference Component

In [None]:
import json
payload = "Who is Rafael Nadal?"
response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=base_inference_component_name, #specify IC name
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(
        {
            "inputs": payload,
            "parameters": {
                "max_new_tokens": 200  # Adjust this value as needed
                },
        }
    ),
)
result = json.loads(response["Body"].read().decode())['generated_text']
result

## Create Adapter IC

We have a few utility functions to upload the model data for the base adapter to an S3 bucket. This is expected in a model.tar.gz format, you can upload your own artifacts here or pull from HF if working with another OSS adapter.

In [None]:
def ensure_bucket(bucket_name: str, region: str = "us-east-1"):
    """Check if an S3 bucket exists; if not, create it."""
    s3 = boto3.client("s3", region_name=region)
    try:
        s3.head_bucket(Bucket=bucket_name)
        print(f"✅ Bucket exists: s3://{bucket_name}")
    except ClientError as e:
        code = e.response["Error"]["Code"]
        if code in ("404", "NoSuchBucket", "NotFound"):
            print(f"🪣 Creating bucket: {bucket_name}")
            params = {"Bucket": bucket_name}
            if region != "us-east-1":
                params["CreateBucketConfiguration"] = {"LocationConstraint": region}
            s3.create_bucket(**params)
            print(f"✅ Created new bucket: s3://{bucket_name}")
        else:
            raise


def pull_from_hf(model_id: str, out_dir: str = "./_download") -> str:
    """
    Download model/artifacts from Hugging Face and bundle into model.tar.gz.
    Tarball contains only the artifact files at the root (SageMaker-compliant).
    """
    pathlib.Path(out_dir).mkdir(parents=True, exist_ok=True)
    print(f"📥 Downloading {model_id} from Hugging Face...")
    local_dir = snapshot_download(
        repo_id=model_id,
        local_dir=out_dir,
        local_dir_use_symlinks=False,
        token=os.getenv("HF_TOKEN"),
    )

    tar_path = os.path.join(out_dir, "model.tar.gz")
    with tarfile.open(tar_path, "w:gz") as tar:
        # Add *contents* of the folder, not the folder itself
        for item in pathlib.Path(local_dir).iterdir():
            tar.add(item, arcname=item.name)

    print(f"✅ Created SageMaker-compliant tarball: {tar_path}")
    return tar_path


def push_s3(bucket_name: str, file_path: str):
    """Upload model.tar.gz to the specified S3 bucket (auto-creates bucket)."""
    ensure_bucket(bucket_name)
    s3 = boto3.resource("s3", region_name="us-east-1")
    key = os.path.basename(file_path)
    print(f"🚀 Uploading {file_path} → s3://{bucket_name}/{key}")
    s3.meta.client.upload_file(file_path, bucket_name, key)
    print(f"✅ Uploaded successfully: s3://{bucket_name}/{key}")
    return f"s3://{bucket_name}/{key}"

### Upload Adapter Artifacts

In [None]:
model_id = "Kronu/gemma-2-2b-lean-expert-1760-complete"
bucket_name = "gemma2-adapter-artifacts-rv"

# pull adapter from HF
tarball = pull_from_hf(model_id)

# push tarball to S3
adapter_artifacts = push_s3(bucket_name, tarball)
print(f"Adapter artifacts location: {adapter_artifacts}")

### Create & Invoke Adapter IC

In [None]:
adapter_ic_name = f"ic-adapter-{base_inference_component_name}"

sm_client.create_inference_component(
    InferenceComponentName = adapter_ic_name,
    EndpointName = endpoint_name,
    # associate with the base IC we created
    Specification={
        "BaseInferenceComponentName": base_inference_component_name,
        "Container": {
            "ArtifactUrl": adapter_artifacts
        },
    },
)

sess.wait_for_inference_component(adapter_ic_name)

In [None]:
import json
payload = "Who is Rafael Nadal?"
response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=adapter_ic_name, #specify adapterIC name
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(
        {
            "inputs": payload,
            "parameters": {
                "max_new_tokens": 200  # Adjust this value as needed
                },
        }
    ),
)
result = json.loads(response["Body"].read().decode())['generated_text']
result

## Cleanup

Deleting the base model IC will automatically delete the base IC and any associated adapter ICs.

In [None]:
sess.delete_inference_component(adapter_ic_name, wait = True)

print(f'Adapter Component {adapter_ic_name} deleted.')

In [None]:
sess.delete_inference_component(base_inference_component_name, wait = True)

print(f'Base Component {base_inference_component_name} deleted.')

Clean up the running endpoint and its configuration.

In [None]:
sess.delete_endpoint(endpoint_name)
print(f'Endpoint {endpoint_name} deleted.')

sess.delete_endpoint_config(endpoint_name)
print(f'Endpoint Configuration {endpoint_name} deleted.')

sess.delete_model(model_name)
print(f'Model {model_name} deleted.')