# Direct Preference Optimization (DPO) Training with SageMaker

## Lab 4 - LLM Deployment

In this notebook, we are going to deploy the fine-tuned LLM using SageMaker Real-time endpoint

***

### Prerequistes

#### Setup and dependencies

In [None]:
import boto3
import os
from rich.pretty import pprint
from sagemaker.core.helper.session_helper import Session, get_execution_role

sess = Session()
sagemaker_session_bucket = None

if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

s3_client = boto3.client("s3")
sess = Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker", region_name=sess.boto_region_name)
bucket_name = sess.default_bucket()
default_prefix = sess.default_bucket_prefix

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Edit model package group name and model package version if needed

In [None]:
import random
from sagemaker.core.resources import ModelPackage, ModelPackageGroup

base_model_id = "meta-textgeneration-llama-3-2-1b-instruct"
model_name = f"{base_model_id}-dpo-{random.randint(100, 100000)}"

model_package_group_name = f"{base_model_id}-dpo"
model_package_version = "1"

endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"
ic_name = f"{model_name}-ic"

In [None]:
from sagemaker.core import s3

model_package_group = ModelPackageGroup.get(model_package_group_name)

fine_tuned_model_package_group_arn = model_package_group.model_package_group_arn
print(f"Fine-tuned Model Package Group ARN: {fine_tuned_model_package_group_arn}")

fine_tuned_model_package_arn = f"{model_package_group.model_package_group_arn.replace("model-package-group", "model-package", 1)}/{model_package_version}"
print(f"Fine-tuned Model Package ARN: {fine_tuned_model_package_arn}")

model_package = ModelPackage.get(fine_tuned_model_package_arn)

# get the merged model artifact and deploy it
merged_model_s3_uri = s3.s3_path_join(model_package.inference_specification.containers[0].model_data_source.s3_data_source.s3_uri, "checkpoints", "hf_merged")+ "/"
print(merged_model_s3_uri)

***

### Utility functions

Utility functions to check the creation status of endpoints and inference components

***

### Create Endpoint Configuration

Define inference configuration

In [None]:
instance_count = 1
instance_type = "ml.g5.12xlarge"
number_of_gpu = 4
health_check_timeout = 700

In [None]:
import random
from sagemaker.core.resources import Endpoint, EndpointConfig
from sagemaker.core.shapes import ProductionVariant

print(f"Creating EndpointConfig: {endpoint_config_name}")
endpoint_config=EndpointConfig.create(
    endpoint_config_name=endpoint_config_name,
    execution_role_arn=role,
    production_variants=[
        ProductionVariant(
            variant_name="AllTraffic",
            instance_type=instance_type,
            initial_instance_count=1,
            model_data_download_timeout_in_seconds=health_check_timeout,
            routing_config={"routing_strategy": "LEAST_OUTSTANDING_REQUESTS"}
        )
    ]
)

### Create Endpoint

A SageMaker Endpoint is a fully managed, always-on HTTPS API that hosts your deployed model and serves real-time inference requests.

In [None]:
print(f"Creating Endpoint: {endpoint_name}")
endpoint = Endpoint.create(
    endpoint_name=endpoint_name,
    endpoint_config_name=endpoint_config_name
)
endpoint.wait_for_status("InService")
print(f"Endpoint {endpoint_name} is InService")

### Create Model from Model Package

Get the image URI

In [None]:
region = sess.boto_region_name
CONTAINER_VERSION = "0.36.0-lmi18.0.0-cu128"

inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"

In [None]:
lmi_env = {
    "SERVING_FAIL_FAST": "true",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_MAX_MODEL_LEN": "16384",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_TRUST_REMOTE_CODE": "true",
}

In [None]:
from sagemaker.core.resources import Model
from sagemaker.core.resources import TrainingJob
from sagemaker.core.shapes import ContainerDefinition, ModelDataSource, S3ModelDataSource

fine_tuned_model = Model.create(
    model_name=model_name,
    primary_container=ContainerDefinition(
        image=inference_image,
        model_data_source=ModelDataSource(
            s3_data_source=S3ModelDataSource(
                s3_uri=merged_model_s3_uri ,
                s3_data_type="S3Prefix",
                compression_type="None"
            )
        ),
        environment=lmi_env
    ),
    execution_role_arn=role
)

pprint(fine_tuned_model)

### Create Inference Component

In [None]:
from sagemaker.core.resources import InferenceComponent
from sagemaker.core.shapes import (
    InferenceComponentSpecification,
    InferenceComponentContainerSpecification,
    InferenceComponentComputeResourceRequirements,
    InferenceComponentRuntimeConfig,
)

# Step 3: Create InferenceComponent
inference_component = InferenceComponent.create(
    inference_component_name=ic_name,
    endpoint_name=endpoint_name,
    variant_name="AllTraffic",
    specification=InferenceComponentSpecification(
        model_name=model_name,
        compute_resource_requirements=InferenceComponentComputeResourceRequirements(
            min_memory_required_in_mb=10240,
            number_of_accelerator_devices_required=1,
        )
    ),
    runtime_config=InferenceComponentRuntimeConfig(
        copy_count=1
    ),
    region=region
)

print(f"InferenceComponent created: {inference_component.inference_component_name}")
print(f"Endpoint ARN: {endpoint.endpoint_arn}")
inference_component.wait_for_status("InService")
print(f"Endpoint {ic_name} is InService")

***

### Test endpoint

### Utility interface for interacting with the model

Utility function to invoke the model and parse the answer

In [None]:
import boto3
import json
import io

def execute_inference(prompt, endpoint_name, inference_component_name, stream=True):
    sm_rt_client = boto3.client("sagemaker-runtime")

    if stream:
        result = sm_rt_client.invoke_endpoint_with_response_stream(
            EndpointName=endpoint_name,
            InferenceComponentName=inference_component_name,
            CustomAttributes='accept_eula=true',
            Body=json.dumps(
                {
                    "inputs": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{0}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".format(prompt),
                    "parameters": {"max_new_tokens": 512, "temperature": 0.1, "top_p": 0.9},
                    "stream": stream
                }
            ),
            ContentType="application/json"
        )
        return result['Body']

    else:
        result = sm_rt_client.invoke_endpoint(
            EndpointName=endpoint_name,
            InferenceComponentName=inference_component_name,
            CustomAttributes='accept_eula=true',
            Body=json.dumps(
                {
                    "inputs": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{0}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".format(prompt),
                    "parameters": {"max_new_tokens": 512, "temperature": 0.1, "top_p": 0.9}
                }
            ),
            ContentType="application/json"
        )

        return result["Body"].read().decode("utf8")

def build_div(text):
    return """
<div style="background-color: lightblue; padding: 10px; margin: 10px; border-radius: 5px;">
    {0}
</div>
""".format(text)

Utility class to parse streaming responses

In [None]:
class LineIterator:
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])


def print_stream(stream):
    for line in LineIterator(stream):
        try:
            if line != b'':
                resp = json.loads(line)
                print(resp["token"].get("text"), end='')
        except:
            print(line)

In [None]:
prompt="I've always struggled with math - can you explain how fractals work in a way that's easy to understand?"

In [None]:
stream = execute_inference(prompt, endpoint_name, ic_name, stream=True)
print_stream(stream)

***

### Delete resources

In [None]:
import boto3

sm_client = boto3.client("sagemaker")

base_model_id = base_model_id

model_name = f"{base_model_id}-sft"
endpoint_config_name = f"{base_model_id}-sft-config"
endpoint_name = f"{base_model_id}-sft-endpoint"
ic_name = f"{base_model_id}-sft-ic"

In [None]:
# Delete inference component
inference_component.delete()

In [None]:
# Delete model
fine_tuned_model.delete()

In [None]:
# Delete endpoint (optional - if you want to remove the endpoint too)
endpoint.delete()

In [None]:
# Delete endpoint config (optional)
endpoint_config.delete()