In [None]:
from sagemaker.hyperpod import list_clusters, set_cluster_context
list_clusters(region='us-east-2')

In [None]:
# choose the HP cluster
set_cluster_context('<my-cluster>', region='us-east-2')

In [None]:
from sagemaker.hyperpod.inference.config.hp_endpoint_config import FsxStorage, ModelSourceConfig, TlsConfig, EnvironmentVariables, ModelInvocationPort, ModelVolumeMount, Resources, Worker
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
import yaml
import time

In [None]:
tls_config=TlsConfig(tls_certificate_output_s3_uri='s3://<my-tls-bucket-name>')

model_source_config = ModelSourceConfig(
    model_source_type='fsx',
    model_location="<my-model-folder-in-fsx>",
    fsx_storage=FsxStorage(
        file_system_id='<my-fs-id>'
    ),
)

environment_variables = [
    EnvironmentVariables(name="HF_MODEL_ID", value="/opt/ml/model"),
    EnvironmentVariables(name="SAGEMAKER_PROGRAM", value="inference.py"),
    EnvironmentVariables(name="SAGEMAKER_SUBMIT_DIRECTORY", value="/opt/ml/model/code"),
    EnvironmentVariables(name="MODEL_CACHE_ROOT", value="/opt/ml/model"),
    EnvironmentVariables(name="SAGEMAKER_ENV", value="1"),
]

worker = Worker(
    image='763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04-v2.0',
    model_volume_mount=ModelVolumeMount(
        name='model-weights',
    ),
    model_invocation_port=ModelInvocationPort(container_port=8080),
    resources=Resources(
            requests={"cpu": "30000m", "nvidia.com/gpu": 1, "memory": "100Gi"},
            limits={"nvidia.com/gpu": 1}
    ),
    environment_variables=environment_variables,
)

In [None]:
fsx_endpoint = HPEndpoint(
    endpoint_name='<my-endpoint-name>',
    instance_type='ml.g5.8xlarge',
    model_name='deepseek15b-fsx-test-pysdk',
    tls_config=tls_config,
    model_source_config=model_source_config,
    worker=worker,
)

In [None]:
fsx_endpoint.create()

In [None]:
# poll status
t = 0
timeout = 600  # 600 seconds timeout  
interval = 15  # poll every 15 seconds

while t < timeout:
    # use refresh to fetch latest status
    fsx_endpoint.refresh()

    print('Refreshing instance status...')

    try:
        # deployment status will be available immediately
        deployment_status = fsx_endpoint.status.deploymentStatus.deploymentObjectOverallState
        if deployment_status== 'DeploymentFailed':
            print('Deployment failed!')
            break

        # endpoint status will appear be available from refresh() at some point
        endpoint_status = fsx_endpoint.status.endpoints.sagemaker.state
        if endpoint_status == 'CreationCompleted':
            print('Endpoint is available!')
            break
    except:
        pass

    time.sleep(interval)
    t += interval

if t >= timeout:
    print('Endpoint creation timed out!')

In [None]:
# print endpoint in yaml
def print_yaml(endpoint):
    print(yaml.dump(endpoint.model_dump(exclude_none=True)))

In [None]:
# list all endpoints
endpoint_list = HPEndpoint.list()
print_yaml(endpoint_list[0])

In [None]:
endpoint = HPEndpoint.get(name='<my-endpoint-name>')

In [None]:
# invoke
data='{"inputs": "What is the capital of Japan?"}'

# invoke
endpoint.invoke(body=data).body.read()

In [None]:
# delete endpoint
endpoint.delete()