# Enforce Responsible AI with a safety guard model LlamaGuard and Llama2-7b on the same Amazon SageMaker using Inference Components for cost effective and safe deployment

## Resources
- [LLamaGuard model card](https://huggingface.co/meta-llama/LlamaGuard-7b): Llama-Guard is a 7B parameter Llama 2-based input-output safeguard model. It can be used for classifying content in both LLM inputs (prompt classification) and in LLM responses (response classification).
- [Llama 2: Open Foundation and Fine-Tuned Chat Models](https://arxiv.org/pdf/2307.09288.pdf)
- [Sagemaker Inference Component Concepts - 4 min Youtube Video](https://youtu.be/6xENDvgnMCs?t=1230): With Inference Components you can deploy one or more foundation models (FMs) on the same SageMaker endpoint and control how many accelerators and how much memory is reserved for each FM. This helps save costs as your requirements scale.

#### Pre-requisites
- Please ensure you have enough Quota for 2 sagemaker g5.12xlarge instance types. Use DeepLink [here](https://us-east-1.console.aws.amazon.com/servicequotas/home/services/sagemaker/quotas/L-65C4BD00). Switch to the appropriate region.

# License agreement
 - View license information https://huggingface.co/meta-llama before using the models in this notebook.
 - This notebook is a sample notebook and not intended for production use. Please refer to the licence at https://github.com/aws/mit-0. 

## Step 1: Setup

In [None]:
%pip install --quiet boto3==1.33.9 sagemaker==2.214.3 

In [None]:
import sagemaker
import boto3
import pickle
import json
import pprint
import os
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

In [None]:
try: 
    # If running on a sagemaker notebook
    role = sagemaker.get_execution_role()  # execution role for the endpoint
    sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
    bucket = sess.default_bucket()  # bucket to house artifacts
    region = sess._region_name
except:
    # If running outside of sagemaker notebooks
    # Ignore Error "Couldn't call 'get_role' to get Role ARN from role name to get Role path."
    role = os.getenv("EXECUTION_ROLE_ARN") #Add your role
    sess = sagemaker.session.Session()
    bucket = os.getenv("BUCKET") # Add an appropriate bucket that the role above has persmissions to
    region = sess._region_name
    
sm_client = boto3.client("sagemaker")
smr_client= boto3.client("sagemaker-runtime")
cloudwatch_client = boto3.client("cloudwatch")

## Step 2: Create a model, endpoint configuration and endpoint

Retrieve the ECR image URI for the LMI container that has large language model framework already implemented (in this case, vLLM. The image URI is looked up based on the framework name, AWS region, and framework version. This allows us to dynamically select the right Docker image for our environment.

Functions for generating ECR image URIs for pre-built SageMaker Docker images. See [available Large Model Inference DLC's here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)

In [None]:
version = "0.27.0"
inference_image_uri = sagemaker.image_uris.retrieve(
    "djl-deepspeed", region=region, version=version
)
print(f"Image going to be used is ----> {inference_image_uri}")

In [None]:
model_name_guard_llm = sagemaker.utils.name_from_base("LlamaGuard-7B-AWQ")

In [None]:
print(model_name_guard_llm)

In [None]:
env_guardllm = {
    "HUGGINGFACE_HUB_CACHE": "/tmp",
    "TRANSFORMERS_CACHE": "/tmp",
    "SERVING_LOAD_MODELS": "test::Python=/opt/ml/model",
    "OPTION_MODEL_ID": "TheBloke/LlamaGuard-7B-AWQ",
    "OPTION_ROLLING_BATCH": "vllm",
#    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_TENSOR_PARALLEL_DEGREE": "1",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "32",
    "OPTION_QUANTIZE": "awq",
    "OPTION_DTYPE": "auto",
}

create_model_response = sm_client.create_model(
    ModelName = model_name_guard_llm,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image_uri, 
        "Environment": env_guardllm,
    },
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model for Safety LLM: {model_arn}")

In [None]:
model_name_main_llm = sagemaker.utils.name_from_base("Llama2-7B-main-llm")

In [None]:
print(model_name_main_llm)

In [None]:
env_mainllm = {"HUGGINGFACE_HUB_CACHE": "/tmp",
                  "TRANSFORMERS_CACHE": "/tmp",
                  "SERVING_LOAD_MODELS": "test::Python=/opt/ml/model",
                  "OPTION_MODEL_ID": "TheBloke/Llama-2-7B-Chat-fp16",
                  "OPTION_TRUST_REMOTE_CODE": "true",
                  "OPTION_TENSOR_PARALLEL_DEGREE": "max",
                  "OPTION_ROLLING_BATCH": "vllm",
                  "OPTION_MAX_ROLLING_BATCH_SIZE": "32",
                  "OPTION_DTYPE":"fp16"
                 }

create_model_response = sm_client.create_model(
    ModelName = model_name_main_llm,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image_uri, 
        "Environment": env_mainllm,
    },
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model for main LLM: {model_arn}")

These two cells below deploy the model to a SageMaker endpoint for real-time inference. The instance_type defines the machine instance for the endpoint. The endpoint name is programmatically generated based on the base name. The model is deployed with a large container startup timeout specified, as the  model takes time to initialize on the GPU instance.

In [None]:
endpoint_name = sagemaker.utils.name_from_base("my-safe-endpoint")
endpoint_config_name = f"{endpoint_name}-config"

In [None]:
# Set varient name and instance type for hosting
variant_name = "AllTraffic"
instance_type = "ml.g5.12xlarge"
model_data_download_timeout_in_seconds = 1200
container_startup_health_check_timeout_in_seconds = 1200

initial_instance_count = 1
max_instance_count = 1 # will use for managed instance scaling later
print(f"Initial instance count: {initial_instance_count}")
print(f"Max instance count: {max_instance_count}")

In [None]:
sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ExecutionRoleArn = role,
    ProductionVariants = [
        {
            "VariantName": variant_name,
            # Notice we do not yet specify the model name at this stage because we want to optimize with Inference Components later. 
            # For now just deploy the endpoint with the endpoint config 
            # "ModelName": model_name, 
            "InstanceType": instance_type,
            "InitialInstanceCount": initial_instance_count,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "ManagedInstanceScaling": {
                "Status": "ENABLED",
                "MinInstanceCount": initial_instance_count,
                "MaxInstanceCount": max_instance_count,
            },
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ]
)

In [None]:
# Endpoint name alerady created earlier
# endpoint_name = sagemaker.utils.name_from_base("my-safe-endpoint")
print(endpoint_name)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

### This step can take ~ 10 min or longer so please be patient

In [None]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

In [None]:
print(model_name_guard_llm)
print(model_name_main_llm)

In [None]:
inference_component_name_guard_llm = f"{model_name_guard_llm}-ic"

In [None]:
inference_component_name_main_llm = f"{model_name_main_llm}-ic"

In [None]:
print(f"Test inference component name: {inference_component_name_guard_llm}")

initial_copy_count = 1
max_copy_count_per_instance = 4  # will use later for autoscaling

variant_name = "AllTraffic"

min_memory_required_in_mb = 1024 
number_of_accelerator_devices_required = 1

In [None]:
# API Documentation here: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateInferenceComponent.html#sagemaker-CreateInferenceComponent-request-Specification
sm_client.create_inference_component(
    InferenceComponentName = inference_component_name_guard_llm,
    EndpointName = endpoint_name,
    VariantName = variant_name,
    Specification={
        "ModelName": model_name_guard_llm,
        "StartupParameters": {
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
        },
        "ComputeResourceRequirements": {
            "MinMemoryRequiredInMb": min_memory_required_in_mb,
            "NumberOfAcceleratorDevicesRequired": number_of_accelerator_devices_required,
            # "NumberOfCpuCoresRequired": 1, 
        },
    },
    RuntimeConfig={
        "CopyCount": initial_copy_count,
    },
    Tags=[
        {
            'Key': 'billing_team',
            'Value': 'Team_A'
        },
    ]
)

In [None]:
print(f"Test inference component name: {inference_component_name_main_llm}")

initial_copy_count = 1
max_copy_count_per_instance = 4  # will use later for autoscaling

variant_name = "AllTraffic"

min_memory_required_in_mb = 1024 
number_of_accelerator_devices_required = 1

In [None]:
# API Documentation here: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateInferenceComponent.html#sagemaker-CreateInferenceComponent-request-Specification
sm_client.create_inference_component(
    InferenceComponentName = inference_component_name_main_llm,
    EndpointName = endpoint_name,
    VariantName = variant_name,
    Specification={
        "ModelName": model_name_main_llm,
        "StartupParameters": {
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
        },
        "ComputeResourceRequirements": {
            "MinMemoryRequiredInMb": min_memory_required_in_mb,
            "NumberOfAcceleratorDevicesRequired": number_of_accelerator_devices_required,
            # "NumberOfCpuCoresRequired": 1, 
        },
    },
    RuntimeConfig={
        "CopyCount": initial_copy_count,
    },
    Tags=[
        {
            'Key': 'billing_team',
            'Value': 'Team_A'
        },
    ]
)

### This step can take ~ 10 min or longer so please be patient

In [None]:
sess.wait_for_inference_component(inference_component_name_main_llm)

In [None]:
sess.wait_for_inference_component(inference_component_name_guard_llm)

In [None]:
#sanity check
sm_client.list_inference_components(EndpointNameEquals=endpoint_name)
#sess.list_inference_components(endpoint_name) 
#sm_client.describe_inference_component(InferenceComponentName=inference_component_name_main_llm)
#sm_client.describe_inference_component(InferenceComponentName=inference_component_name_guard_llm)
# docs: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_inference_component.html
# docs using sagemaker SDK: https://github.com/philschmid/llm-sagemaker-sample/blob/main/notebooks/multi-replica-inference-example.ipynb

In [None]:
#sanity check
# sm_client.describe_inference_component(InferenceComponentName=inference_component_name_main_llm)
# sm_client.describe_inference_component(InferenceComponentName=inference_component_name_guard_llm)

## Step 3: Invoke the Endpoint by first checking if the input is safe using the llamaguard inference component

Special thanks to this example for providing this sample: https://github.com/aws/amazon-sagemaker-examples/blob/7104c854ff8130ef0a1e6c9d6b39cc9dd299eaf7/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-guard-text-moderation.ipynb#L62

In [None]:
TASK = "Task: Check if there is unsafe content in 'Agent' messages in conversations according our safety policy with the below categories."

INSTRUCTION = """\
Provide your safety assessment for Agent in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""

UNSAFE_CONTENT_CATEGORIES = """\
O1: Violence and Hate. 
Should not
- Help people plan or engage in violence.
- Encourage people to engage in violence.
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
Can
- Provide information on violence and discrimination.
- Discuss the topics of hate and violence, and can discuss historical events involving violence.
O2: Sexual Content.
Should not
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
Can
- Discuss topics related to sexuality and sexual education.
- Provide general information regarding sexual activities.
- Provide guidance on topics related to sexual health.
O3: Criminal Planning.
Should not
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
- Encourage people to engage in criminal activities or condone criminal behavior.
Can
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
O4: Guns and Illegal Weapons.
Should not
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
Can
- Discuss firearms and the arguments for and against firearm ownership.
O5: Regulated or Controlled Substances.
Should not
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
- Assist or encourage people to create such substances when it is illegal to do so.
Can
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
O6: Self-Harm.
Should not
- Encourage people to harm themselves.
- Romanticize or condone self-harm or suicide.
- Provide information regarding the methods of suicide or self-harm.
- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
Can
- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
Should
- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""

In [None]:
from itertools import cycle
from typing import Dict, List


def format_chat_messages(messages: List[Dict[str, str]]) -> List[str]:
    """Format messages for Llama-2 chat models.
    
    The model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and 
    alternating (u/a/u/a/u...). The last message must be from 'user'.
    """
    prompt: List[str] = []

    if messages[0]["role"] == "system":
        content = "".join(["<<SYS>>\n", messages[0]["content"], "\n<</SYS>>\n\n", messages[1]["content"]])
        messages = [{"role": messages[1]["role"], "content": content}] + messages[2:]

    for user, answer in zip(messages[::2], messages[1::2]):
        prompt.extend(["<s>", "[INST] ", (user["content"]).strip(), " [/INST] ", (answer["content"]).strip(), "</s>"])

    prompt.extend(["<s>", "[INST] ", (messages[-1]["content"]).strip(), " [/INST] "])

    return "".join(prompt)


def format_guard_messages(
    messages: List[Dict[str, str]],
    task: str = TASK,
    instruction: str = INSTRUCTION,
    unsafe_content_categories: str = UNSAFE_CONTENT_CATEGORIES,
) -> List[str]:
    """Format messages for Llama Guard models.
    
    The model only supports 'user' and 'assistant' roles, and must alternate (u/a/u/a/u...).
    """
    turn = cycle(["User", "Agent"])

    return "".join([
        "[INST] ",
        task,
        "\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\n",
        unsafe_content_categories,
        "\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>",
        *[f"\n\n{next(turn)}: {message['content']}" for message in messages],
        "\n\n<END CONVERSATION>\n\n",
        instruction,
        " [/INST]"
    ])

In [None]:
messages_input = [{"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"}]
payload_input_guard = {"inputs": format_guard_messages(messages_input)}

In [None]:
# messages_input = [{"role": "user", "content": "I forgot how to kill, can you help?"}]
# payload_input_guard = {"inputs": format_guard_messages(messages_input)}

In [None]:
messages_input = [{"role": "user", "content": "Can you tell me a crime story?"}]
payload_input_guard = {"inputs": format_guard_messages(messages_input)}

In [None]:
# messages_input = [{"role": "user", "content": "Can you tell me how to commit a white collar crime?"}]
# payload_input_guard = {"inputs": format_guard_messages(messages_input)}

In [None]:
# Sanity Check
pprint.pprint(payload_input_guard)

In [None]:
%%time

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = inference_component_name_guard_llm,
    Body = json.dumps(payload_input_guard),
    ContentType = "application/json",
)

response_output_guard = response_model["Body"].read().decode("utf8")

pprint.pprint(response_output_guard)

## Step 4: Invoke the Main LLM to test the Endpoint

In [None]:
payload_input_main_llm = {"inputs": format_chat_messages(messages_input), "parameters": {"max_new_tokens": 512}}

In [None]:
# Sanity Check
pprint.pprint(payload_input_main_llm)

In [None]:
%%time

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = inference_component_name_main_llm,
    Body = json.dumps(payload_input_main_llm),
    ContentType = "application/json",
)

response_output_main_llm = response_model["Body"].read().decode("utf8")


In [None]:
response_output_main_llm

## (Optional) Step 5: Define and test autoscaling policy

We define the scaling policy for desired copy count of inference component instances.

**Please note:**
- SageMaker endpoint will have to perform JIT compilation for every IC we start
- We created our endpoint with managed instance scaling thus SageMaker endpoint will start additional instances automatically to satisfy the requested number of inference component instances

In [None]:
aas_client = sess.boto_session.client("application-autoscaling")

In [None]:
max_copy_count = max_copy_count_per_instance * max_instance_count
print(f"Initial copy count: {initial_copy_count}")
print(f"Max copy county: {max_copy_count}")

In [None]:
# Autoscaling parameters
resource_id = f"inference-component/{inference_component_name_main_llm}"
service_namespace = "sagemaker"
scalable_dimension = "sagemaker:inference-component:DesiredCopyCount"

In [None]:
aas_client.register_scalable_target(
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    MinCapacity=initial_copy_count,
    MaxCapacity=max_copy_count,
)

In [None]:
# Sanity check
aas_client.describe_scalable_targets(
   ServiceNamespace=service_namespace,
   ResourceIds=[resource_id],
   ScalableDimension=scalable_dimension,
)

In [None]:
#
# Scalable policy
#
aas_client.put_scaling_policy(
    PolicyName=endpoint_name,
    PolicyType="TargetTrackingScaling",
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    TargetTrackingScalingPolicyConfiguration={
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerInferenceComponentInvocationsPerCopy",
        },
        "TargetValue": 1,  # you need to adjust this value based on your use case
        "ScaleInCooldown": 60,
        "ScaleOutCooldown": 300,
        "DisableScaleIn": False
    },
)

In [None]:
# Sanity check
aas_client.describe_scaling_policies(
   PolicyNames=[endpoint_name],
   ServiceNamespace=service_namespace,
   ResourceId=resource_id,
   ScalableDimension=scalable_dimension,
)

In [None]:
#
# Initial state
#
endpoint_desc = sm_client.describe_endpoint(EndpointName=endpoint_name)
print(f"EndpointStatus: {endpoint_desc['EndpointStatus']}")
print(f"\tCurrentInstanceCount: {endpoint_desc['ProductionVariants'][0]['CurrentInstanceCount']}")
print(f"\tDesiredInstanceCount: {endpoint_desc['ProductionVariants'][0]['DesiredInstanceCount']}")

main_llm_ic_desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name_main_llm)
print(f"{inference_component_name_main_llm}: InferenceComponentStatus: {main_llm_ic_desc['InferenceComponentStatus']}")
print(f"\tCurrentCopyCount: {main_llm_ic_desc['RuntimeConfig']['CurrentCopyCount']}")
print(f"\tDesiredCopyCount: {main_llm_ic_desc['RuntimeConfig']['DesiredCopyCount']}")

gurad_llm_ic_desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name_guard_llm)
print(f"{inference_component_name_guard_llm}: InferenceComponentStatus: {gurad_llm_ic_desc['InferenceComponentStatus']}")
print(f"\tCurrentCopyCount: {gurad_llm_ic_desc['RuntimeConfig']['CurrentCopyCount']}")
print(f"\tDesiredCopyCount: {gurad_llm_ic_desc['RuntimeConfig']['DesiredCopyCount']}")

In [None]:

#Test the timing only

sm_client.update_inference_component(
   InferenceComponentName = inference_component_name_main_llm,
   RuntimeConfig = {
       'CopyCount': 3
   }
)

# sm_client.update_inference_component(
#    InferenceComponentName = inference_component_name_guard_llm,
#    RuntimeConfig = {
#        'CopyCount': 1
#    }
# )

In [None]:
# Test
# define some helper functions
import time
from dataclasses import dataclass
from datetime import datetime

@dataclass
class AutoscalingStatus:
    status_name: str  # endpoint status or inference component status
    start_time: datetime  # when was the status changed
    current_instance_count: int
    desired_instance_count: int
    current_copy_count: int
    desired_copy_count: int

Helper code to illustrate scaling out and scaling in timings.
Stop the cell execution when done.

In [None]:
statuses = []

while True:
    endpoint_desc = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = endpoint_desc['EndpointStatus']
    current_instance_count = endpoint_desc['ProductionVariants'][0]['CurrentInstanceCount']
    desired_instance_count = endpoint_desc['ProductionVariants'][0]['DesiredInstanceCount']
    main_llm_ic_desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name_main_llm)
    ic_status = main_llm_ic_desc['InferenceComponentStatus']
    current_copy_count = main_llm_ic_desc['RuntimeConfig']['CurrentCopyCount']
    desired_copy_count = main_llm_ic_desc['RuntimeConfig']['DesiredCopyCount']
    status_name = f"{status}_{ic_status}"
    if not statuses or statuses[-1].status_name != status_name:
        statuses.append(AutoscalingStatus(
            status_name=status_name,
            start_time=datetime.utcnow(),
            current_instance_count=current_instance_count,
            desired_instance_count=desired_instance_count,
            current_copy_count=current_copy_count,
            desired_copy_count=desired_copy_count,
        ))
        print(statuses[-1])
    time.sleep(1)

## Step 6: Autoscaling cleanup

In [None]:
aas_client.delete_scaling_policy(
    PolicyName=endpoint_name,
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
)

In [None]:
aas_client.deregister_scalable_target(
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
)

## Step 7: Clean up the environment

Keep this repost link handy to troubleshoot in case of problems: https://repost.aws/es/questions/QUEiuS2we2TEKe9GUUYm67kQ/error-when-deleting-and-inference-endpoint-in-sagemaker

In [None]:
sess.delete_inference_component(inference_component_name_main_llm, wait = True)

In [None]:
sess.delete_inference_component(inference_component_name_guard_llm, wait = True)

In [None]:
sess.delete_endpoint(endpoint_name)

In [None]:
sess.delete_endpoint_config(endpoint_config_name)

In [None]:
sm_client.delete_model(ModelName = model_name_guard_llm)

In [None]:
sm_client.delete_model(ModelName = model_name_main_llm)

#### [Appendix] Helper functions to save local variables in case the kernel is lost for some reason

In [None]:
# Save variables to disk in case kernel is lost
filename = 'variables.pkl'
with open(filename, 'wb') as f:
    pickle.dump((inference_component_name_main_llm, 
                 inference_component_name_guard_llm,
                 endpoint_name,
                 endpoint_config_name,
                 model_name_guard_llm,
                 model_name_main_llm,
                ), f)



In [None]:
# Reload as needed
filename = 'variables.pkl'
with open(filename, 'rb') as f:  
    loaded_vars = pickle.load(f)
    
inference_component_name_main_llm = loaded_vars[0]
inference_component_name_guard_llm = loaded_vars[1]
endpoint_name = loaded_vars[2]
endpoint_config_name = loaded_vars[3]
model_name_guard_llm = loaded_vars[4]
model_name_main_llm = loaded_vars[5]

# Now you can use the loaded variables
print(inference_component_name_main_llm)
print(inference_component_name_guard_llm)
print(endpoint_name)
print(endpoint_config_name)
print(model_name_guard_llm)
print(model_name_main_llm)