# Optimizing custom model using inference optimization toolkit
In this notebook we will download Llama-3-8B model from the HuggingFace, quantize the model using inference optimization toolkit and then deploy it to the Amazon SageMaker Endpoint

In [None]:
%pip install sagemaker --upgrade --quiet --no-warn-conflicts

In [None]:
import json
import boto3
import sagemaker
import huggingface_hub
from pathlib import Path

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

## 2. Quantize (using inference optimization tollkit) and Deploy a model to Amazon SageMake Endpoint

### 2.1 Run optimization job

In [None]:
model_id="meta-llama/Meta-Llama-3-8B"

hf_local_download_dir = Path.cwd() / "model_repo"
hf_local_download_dir.mkdir(exist_ok=True)

huggingface_hub.snapshot_download(
    repo_id=model_id,
    revision="main",
    local_dir=hf_local_download_dir,
)

In [None]:
!rm -rf model_repo/.ipynb_checkpoints
!rm -rf model_repo/.cache
!rm -rf model_repo/.gitattributes
!rm -rf model_repo/original

In [None]:
model_uri = sess.upload_data(
    path=hf_local_download_dir.as_posix(),
    bucket=bucket,
    key_prefix="inference-model",
)

In [None]:
model_uri = model_uri + "/" #need to point towards the uncompressed model artifacts
model_uri

In [None]:
!aws s3 ls {model_uri} #verify model artifacts

In [None]:
LMI_VERSION = "0.29.0"
LMI_FRAMEWORK = 'djl-lmi'

serving_image = sagemaker.image_uris.retrieve(framework=LMI_FRAMEWORK, region=region, version=LMI_VERSION)

print(f"Inference Image: {serving_image}")

In [None]:
prefix = "inference-model-awq"
model_name = sagemaker.utils.name_from_base(prefix)
output_location = f"s3://{bucket}/{prefix}/"
instance_type = "ml.g5.12xlarge"

In [None]:
job_name = model_name
job_timeout = 7200

response = sm_client.create_optimization_job(
    OptimizationJobName=job_name,
    RoleArn=role,
    ModelSource={
        'S3': {
            'S3Uri': model_uri,
            'ModelAccessConfig': {
                'AcceptEula': True  # Change it to True
            }
        }
    },
    DeploymentInstanceType=instance_type,
    OptimizationEnvironment={},
    OptimizationConfigs=[
        {
            'ModelQuantizationConfig': {
                'Image': serving_image,
                'OverrideEnvironment': {
                    "OPTION_QUANTIZE": "awq"
                }
            }
        },
    ],
    OutputConfig={
        'S3OutputLocation': output_location
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': job_timeout,
        'MaxWaitTimeInSeconds': job_timeout,
        'MaxPendingTimeInSeconds': job_timeout
    },
)
response

In [None]:
sess.wait_for_optimization_job(job_name)

### 2.2. Endpoint Deployment

In [None]:
env = {
    "HF_MODEL_ID": output_location,
    "OPTION_ROLLING_BATCH": "lmi-dist",
    "OPTION_MAX_MODEL_LEN": "2048",
    "OPTION_MAX_ROLLING_BATCH_PREFILL_TOKENS": "2048",
    "OPTION_QUANTIZE": "awq",
}

In [None]:
create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": serving_image,
        "Environment": env,
    }
)
model_arn = create_model_response["ModelArn"]
print(f"Created Model: {model_arn}")

Start-up of LLM inference containers can last longer than smaller models, mainly due to longer model downloading and loading times. Timeout values need to be increased accordingly from their default values. Each endpoint deployment takes a few minutes.

In [None]:
endpoint_config_name = model_name
health_check_timeout = 900

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": health_check_timeout,
            "RoutingConfig": {
                'RoutingStrategy': 'LEAST_OUTSTANDING_REQUESTS'
            },
        },
    ],
)
endpoint_config_response

In [None]:
#
# Create endpoint
#
endpoint_name = model_name

create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

### 2.3 Endpoint invocation

Let's invoke our endpoint and get a sample response.

In [None]:
#
# define payload
#
prompt = """You are an helpful Assistant, called Jarvis. Knowing everything about AWS.
User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    Body = json.dumps(payload),
    ContentType = "application/json",
)

assistant = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]
print(assistant)

### 2.4 Clean Up Endpoint

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)