# Regiser the trained Adapter to Model Registry


In [None]:
!pip install "sagemaker>=2.190.0"

In [None]:
import os, boto3, sagemaker
from sagemaker.model import Model
from sagemaker import ModelPackage
import time
from botocore.exceptions import ClientError

sm_client = boto3.client('sagemaker')
sess = sagemaker.Session()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

In [None]:
# Base model and trained adapter path will be retrieved from training notebook. However, please feel free to update it as needed. 
%store -r adapter_serving_dir_path
%store -r base_model_s3_path

os.environ['base_model_s3_path']=base_model_s3_path

print(f'\nAdapter Serving S3 Dir path: {adapter_serving_dir_path} \n')

print(f'\n Base Model Serving S3 Dir path: {base_model_s3_path} \n')

## Create Adapter Package

In [None]:
# Remove existing adapter_artifact directory
!rm -rf ./adapter_artifact 2>/dev/null || true

# Run AWS S3 cp command with error output
!aws s3 cp --recursive {adapter_serving_dir_path} ./adapter_artifact

In [None]:
%%bash

cat > ./adapter_artifact/serving.properties <<EOF

engine=Python
option.model_id=${base_model_s3_path}
option.adapters=adapters
option.dtype=fp16

option.tensor_parallel_degree=max
option.rolling_batch=lmi-dist
option.use_custom_all_reduce=true
option.output_formatter=json
option.max_rolling_batch_size=64
option.model_loading_timeout=3600
option.max_model_len=5000
option.gpu_memory_utilization=0.9
option.enable_lora:true
load_on_devices=0

EOF

In [None]:
!tar czvf adapter.tar.gz --exclude='checkpoint-20' -C ./adapter_artifact/ .

In [None]:
!aws s3 cp adapter.tar.gz {adapter_serving_dir_path}

In [None]:
LMI_VERSION = "0.29.0"
LMI_FRAMEWORK = 'djl-lmi'

inference_image_uri = sagemaker.image_uris.retrieve(
    framework=LMI_FRAMEWORK, region=boto3.Session().region_name, version=LMI_VERSION
)
print(f"LMI Container Image: {inference_image_uri}")

## Create Model Package Group

In [None]:
def check_model_package_group_exists(model_package_group_name):
    try:
        response = sm_client.describe_model_package_group(
            ModelPackageGroupName=model_package_group_name
        )
        # If the call succeeds, the model package group exists
        return True
    except ClientError as e:
        if e.response['Error']['Code'] == 'ResourceNotFound':
            # The model package group does not exist
            return False
        else:
            # Some other error occurred
            raise

model_package_group_name = "adapter-model-group-v1"

if(not check_model_package_group_exists(model_package_group_name)):
    # Create Model Package Group
    try:
        sm_client.create_model_package_group(
            ModelPackageGroupName=model_package_group_name,
            ModelPackageGroupDescription="Description of your model package group"
        )
        print(f"Model Package Group '{model_package_group_name}' created successfully.")
    except sm_client.exceptions.ResourceInUseException:
        print(f"Model Package Group '{model_package_group_name}' already exists.")

## Create Model Package

In [None]:
create_model_package_input_dict = {
    "ModelPackageGroupName": model_package_group_name,
    "ModelPackageDescription": "Model description",
    "ModelApprovalStatus": "Approved"
}

# Call the API
create_model_package_response = sm_client.create_model_package(**create_model_package_input_dict)
model_package_arn = create_model_package_response["ModelPackageArn"]

model_package_arn

In [None]:
!aws s3 ls {adapter_serving_dir_path}adapter.tar.gz

## Update Model Package with LMI image and adapter package 

In [None]:
latest_model_package_arn=model_package_arn

response = sm_client.update_model_package(
    ModelPackageArn=latest_model_package_arn,
    InferenceSpecification={
        'Containers': [
            {
                'Image': inference_image_uri,
                'ModelDataUrl': f"{adapter_serving_dir_path}adapter.tar.gz"
            },
        ],
        'SupportedTransformInstanceTypes': ['ml.g5.12xlarge'],
        'SupportedRealtimeInferenceInstanceTypes': ['ml.g5.12xlarge'],
        'SupportedContentTypes': ['application/json'],
        'SupportedResponseMIMETypes': ['application/json']
    }
)
model_package_arn = response["ModelPackageArn"]
print(f"update registered model's inference spec: {model_package_arn}")

%store model_package_arn