# NVIDIA Parakeet on SageMaker Asynchronous Endpoint with LMI container

We can create a Python file that defines the logic for hosting NVIDIA Parakeet on the DJLServing container. We then deploy this to a SageMaker Asynchronous Endpoint allowing payloads up to 1GB.

Since our code relies on ffmpeg, we need to adapt the LMI container to install `ffmpeg`. We can see this below:

In [None]:
%%writefile Dockerfile
FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.33.0-lmi15.0.0-cu128
RUN apt-get update && apt-get install -y ffmpeg

We then push the container to ECR. 

For example:

```
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-1.amazonaws.com

aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin <ACCOUNT ID>.dkr.ecr.us-east-1.amazonaws.com

docker build -t nemo-asr .

docker tag nemo-asr <ACCOUNT ID>.dkr.ecr.us-east-1.amazonaws.com/nemo-asr:latest

docker push <ACCOUNT ID>.dkr.ecr.us-east-1.amazonaws.com/nemo-asr:latest
```

## Custom code setup
We then write our code to a folder which we upload to S3. When the endpoint starts, it will download and use these files

In [None]:
!mkdir code

The `NemoASRService` defines the model loading and inference logic. It extends the built-in [HuggingFaceService](https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/huggingface.py). 

In [None]:
%%writefile code/NemoService.py

import nemo.collections.asr as nemo_asr
import numpy as np

from djl_python import Input, Output
from djl_python.huggingface import HuggingFaceService, get_rolling_batch_class_from_str
from djl_python.properties_manager.properties import is_rolling_batch_enabled
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from transformers.pipelines.audio_utils import ffmpeg_read


class NemoASRService(HuggingFaceService):
    def __init__(self):
        super().__init__()

    def initialize(self, properties: dict):
        self.hf_configs = HuggingFaceProperties(**properties)
        self.model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2").to('cuda')
        self.model.change_attention_model("rel_pos_local_attn", [128, 128])  # local attn
        self.model.change_subsampling_conv_chunking_factor(1)  # 1 = auto select
        self.tokenizer = self.model.tokenizer
        self.input_format_args = self.get_input_format_args()
        self.hf_pipeline = self.call_model
        self.initialized=True
        

    def call_model(self, input_data, **kwargs):
        if type(input_data) is list:
            if type(input_data[0]) is list:
                input_data = np.array(input_data)
            else:
                input_list = []
                for data in input_data:
                    input_list.append(ffmpeg_read(data, 16000))
                input_data = input_list
        output = self.model.transcribe(input_data, **kwargs)

        texts = []
        for out in output:
            texts.append(out.text)
        return texts

    def _read_model_config(self):
        return self.model.config

_service = NemoASRService()

def handle(inputs: Input):
    if not _service.initialized:
        _service.initialize(inputs.get_properties())

    if inputs.is_empty():
        return None
    return _service.inference(inputs)

We disable rolling batch and set our entrypoint

In [None]:
%%writefile code/serving.properties
option.rolling_batch=disable
engine=Python
option.entryPoint=NemoService.py

In [None]:
%%writefile code/requirements.txt
nemo_toolkit[asr]
cuda-python
ffmpeg-python

## Prepare and deploy the model artifacts

Once our files are prepared, we can prepare the artifacts to be deployed. This includes uploading our files to S3 then configuring and deploying the async endpoint.

We configure our Model to use uncompressed artifacts as this allows for faster startup time.

In [None]:
import sagemaker
import time
import json
import boto3
from sagemaker.model import Model
from sagemaker.serializers import DataSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.utils import name_from_base

# Define serializers and deserializer
audio_serializer = DataSerializer(content_type="audio/x-audio")
deserializer = JSONDeserializer()
# Basic configurations
sess = sagemaker.session.Session()
bucket = sess.default_bucket()
prefix = 'parakeet-asr'
role = sagemaker.get_execution_role()
s3_model_prefix = (
    "hf-asr-models/nvidia-asr"  # folder within bucket where code artifact will go
)
# below boto3 clients are for invoking asynchronous endpoint 
sm_runtime = boto3.client("sagemaker-runtime")

In [None]:
!aws s3 cp code/ s3://$bucket/$prefix --recursive

In [None]:
image=f"{sess.account_id()}.dkr.ecr.{sess.boto_region_name}.amazonaws.com/nemo-asr:latest"

parakeet_model = Model(
    model_data={
        "S3DataSource": { 
               "CompressionType": "None",
               "S3DataType": "S3Prefix",
               "S3Uri": f"s3://{bucket}/{prefix}/"
        }
    },
    image_uri=image,
    role=role,
    name=name_from_base('parakeet-asr-model')
)

## Deploy Realtime Endpoint

Deploy the model as a real-time inference endpoint. Note that if you choose to use the local SageMaker Session when creating the model object, change the instance_type to local_gpu to be able to quickly test the endpoint from local SageMaker notebook instance for fast testing. If you are going to deploy the model to an async endpoint, please make sure you create the Model object with the actual sagemaker session. In this case, it will be sess.

In [None]:
from sagemaker.predictor import Predictor


endpoint_name=name_from_base('parakeet-asr-endpoint')

parakeet_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type ='ml.g5.xlarge', # instance type
    endpoint_name = endpoint_name
)

predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=audio_serializer,
    deserializer=deserializer
)

## Invoke the realtime endpoint

Test the deployed real-time endpoint with a sample audio file:

    Input: Audio file path (automatically serialized)
    Processing: Synchronous transcription
    Output: JSON response with transcription results


In [None]:
import json
# Perform real-time inference
audio_path = "../data/test.wav"
response = predictor.predict(data=audio_path)
print(response[0])

## Delete the endpoint


In [None]:
predictor.delete_endpoint()

## Asynchronous Inference Deployment

Set up asynchronous inference config. This includes:
- Output path in S3 for transcription results
- Concurrency
- SNS notifications for completion
- Failure path

In [None]:
sns_client = boto3.client('sns')

def create_sns_topic_if_not_exists(topic_name, description):
    """Create SNS topic if it doesn't exist, return the ARN"""
    try:
        # Try to create the topic (idempotent operation)
        response = sns_client.create_topic(Name=topic_name)
        topic_arn = response['TopicArn']
        
        # Set topic attributes for better identification
        sns_client.set_topic_attributes(
            TopicArn=topic_arn,
            AttributeName='DisplayName',
            AttributeValue=description
        )
        
        print(f"✅ Topic '{topic_name}' ready: {topic_arn}")
        return topic_arn
        
    except ClientError as e:
        print(f"❌ Error creating topic '{topic_name}': {e}")
        raise

# Create success topic
success_topic_name = "async-success"
success_description = "SageMaker Async Inference Success Notifications"
success_topic_arn = create_sns_topic_if_not_exists(success_topic_name, success_description)

# Create error topic  
error_topic_name = "async-failed"
error_description = "SageMaker Async Inference Error Notifications"
error_topic_arn = create_sns_topic_if_not_exists(error_topic_name, error_description)

print(f"\n📧 SNS Topics Created Successfully:")
print(f"Success Topic ARN: {success_topic_arn}")
print(f"Error Topic ARN: {error_topic_arn}")

print(f"\n🔧 Topics are ready for AsyncInferenceConfig!")

In [None]:
from sagemaker.async_inference import AsyncInferenceConfig

endpoint_name=name_from_base('parakeet-asr-async-endpoint')

# Create an AsyncInferenceConfig object
async_config = AsyncInferenceConfig(
    output_path=f"s3://{bucket}/{prefix}/output", 
    max_concurrent_invocations_per_instance = 4,
    failure_path=f"s3://{bucket}/{prefix}/failed",
    notification_config={
        'SuccessTopic' :f"arn:aws:sns:{sess.boto_region_name}:{sess.account_id()}:async-success",
      "ErrorTopic": f"arn:aws:sns:{sess.boto_region_name}:{sess.account_id()}:async-failed",
    }
)

# Deploy the model for async inference

async_predictor = parakeet_model.deploy(
    async_inference_config=async_config,
    initial_instance_count=1, # number of instances
    instance_type ='ml.g5.xlarge', # instance type
    endpoint_name = endpoint_name
)

## Invoke the endpoint

To invoke our endpoint asynchronously, we need to upload our data to S3 first. This is then passed into the `predict_async` function which sends the request to invoke the model.

In [None]:
def upload_to_s3(s3_client, file_path, bucket_name, s3_key):
    """Upload file to S3"""
    try:
        s3_client.upload_file(file_path, bucket_name, s3_key)
        return True
    except Exception as e:
        print(f"Error uploading {s3_key}: {e}")
        return False
s3_client = boto3.client('s3')
audio_path = "../data/test_audio.wav" 
s3_key = prefix+f"/data/{audio_path}"
upload_to_s3(s3_client, audio_path, bucket, s3_key)

In [None]:
input_path = f"s3://{bucket}/{s3_key}"

In [None]:
input_path

In [None]:
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.predictor import Predictor

async_predictor=AsyncPredictor(
    Predictor(
        endpoint_name=endpoint_name
    )
)

In [None]:
# Perform async inference
initial_args = {'ContentType':"audio/x-audio"}
response = async_predictor.predict_async(initial_args = initial_args, input_path=input_path)
response.output_path

Once the data is sent, we wait for the output to be available.

In [None]:
import urllib, time
from botocore.exceptions import ClientError

def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sess.read_s3_file(
                        bucket=output_url.netloc, 
                        key_prefix=output_url.path[1:])
        except ClientError as e:
            if e.response['Error']['Code'] == 'NoSuchKey':
                print("waiting for output...")
                time.sleep(2)
                continue
            raise
            
output = get_output(response.output_path)
print(f"Output: {output}")

## Cleanup

Once we are done, we can delete the endpoint

In [None]:
async_predictor.delete_endpoint()