# Deploying WhisperX on Amazon SageMaker Utilizing LMIv16
In this example we take a look at deploying WhisperX on Amazon SageMaker Real-Time Endpoints using the LMI v16 container. For models that aren't natively supported by a backend such as vLLM yet, you can use the Python engine via the container to load the model onto a GPU instance like we do in this case. 

At the moment of this notebook, WhisperX does not have native vLLM support (that I'm aware of), but whisper large v3 does: https://docs.vllm.ai/en/v0.7.0/getting_started/examples/whisper.html. That you can use natively via a vLLM backend.

### Other Hosting Options/Routes
In this example we serialize the audio file into a numpy array using WhisperX's library. You can alternatively pass in an S3 URI as the payload and do this serialization within the container. Another option for more hybrid inference workloads is Async Inference: https://dev.to/makawtharani/deploying-whisperx-on-aws-sagemaker-as-asynchronous-endpoint-17g6. Here you can enable scale-down to zero and have managed queuing as part of the solution.

### Docker Testing
We always recommend testing with Docker on an EC2/EKS node of the instance you want to deploy on, this allows for quick debugging of your model.py and serving configuration. Here's a sample command:

```
#Pull image
docker pull 763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.34.0-lmi16.0.0-cu128

#Start container, adjust for path of artifacts
docker run \
  --gpus all \
  -v /home/ubuntu:/opt/ml/model \
  -p 8080:8080 \
  763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.34.0-lmi16.0.0-cu128 \
  serve

```

### Additional Resources
- Docker Debug Tutorial: https://www.youtube.com/watch?v=UQHufr-DToE
- Large Model Inference Container Intro: https://www.youtube.com/watch?v=Q-Kz5Yi0QiQ

## Setup
Also install WhisperX if trying to invoke the endpoint, there might be some dependency clashes if using SM Notebook Instances. We ran this notebook in a ml.g5.4xlarge to test with Docker on a SM Classic Notebook Instance.

In [None]:
import sagemaker, boto3
from sagemaker import image_uris
import subprocess
import time
from time import gmtime, strftime

boto_session = boto3.session.Session()
s3 = boto_session.resource('s3')
client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")

instance_type = "ml.g5.4xlarge"
role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = session._region_name
bucket = session.default_bucket()  # bucket to house artifacts

CONTAINER_VERSION = "0.34.0-lmi16.0.0-cu128"
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"
print(f"Using image URI: {inference_image}")

## Prepare Model Artifacts
For LMI container we prepare a
- model.py: Custom inference loading and pre/post processing code
- requirements.txt: Dependencies to install, whisperx in this case
- serving.properties: Specify engine, Python in this case. For vLLM supported models use that backend.

In [None]:
%%writefile model.py
import logging
import time
import os
from djl_python import Input
from djl_python import Output
import torch
import whisperx
import gc
from whisperx.diarize import DiarizationPipeline
import numpy as np

# Set HF Token
HF_TOKEN = "Add your HF Token here"
os.environ["HUGGINGFACE_TOKEN"] = HF_TOKEN

# set instance and compute types
device = "cuda" if torch.cuda.is_available() else "cpu"
# suggest GPU instance for WhisperX
compute_type = "float16" if device == "cuda" else "float32"

class WhisperXModel(object):
    """
    Deploying WhisperX with DJL Serving
    """

    def __init__(self):
        self.initialized = False

    def initialize(self, properties: dict):
        """
        Initialize model.
        """
        logging.info(os.listdir())
        logging.info("-----------------")
        logging.info(properties)

        self.model = whisperx.load_model("small", device, compute_type=compute_type)
        self.model_a, self.metadata = whisperx.load_align_model(language_code="en", device=device)
        self.diarization_model = DiarizationPipeline(use_auth_token=os.getenv("HUGGINGFACE_TOKEN"), device=device)
        self.initialized = True

    def inference(self, inputs):
        """
        Custom service entry point function.

        :param inputs: the Input object holds the text for the WhisperX model to infer upon
        :return: the Output object to be send back
        """
        
        try:
            # Sample input: {"audio_array": numpy serialization of audio file}
            data = inputs.get_as_json()
            logging.info("-----------------")
            logging.info(data)
            logging.info(type(data))
            logging.info("-----------------")
            
            #parse input
            audio = data["audio_array"]

            # Cast list -> np.ndarray (float32), as required by whisperx
            if isinstance(audio, list):
                audio = np.asarray(audio, dtype=np.float32)
            elif not isinstance(audio, np.ndarray):
                return Output().error("audio_array must be a list or numpy array")

            # transcription model inference
            result = self.model.transcribe(audio)
            output = result["segments"]

            # alignment model
            result = whisperx.align(result["segments"], self.model_a, self.metadata, audio, device, return_char_alignments=False)

            # diarization
            diarize_segments = self.diarization_model(audio)
            diarization_result = whisperx.assign_word_speakers(diarize_segments, result)

            # parse final output and return as a JSON
            final_output = result["segments"] # segments are now assigned speaker IDs
            result = {"outputs": final_output}
            outputs = Output()
            outputs.add_as_json(result)
        except Exception as e:
            logging.exception("inference failed")
            # error handling
            outputs = Output().error(str(e))
        
        logging.info(outputs)
        logging.info(type(outputs))
        logging.info("Returning inference---------")
        return outputs


_service = WhisperXModel()


def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # stateful model
        _service.initialize(inputs.get_properties())
    
    if inputs.is_empty():
        return None

    return _service.inference(inputs)

In [None]:
%%writefile serving.properties
engine=Python

In [None]:
%%writefile requirements.txt
whisperx

## Create SM Constructs
- Model: Container & Model Data & Scripts/Serving Properties
- EndpointConfig: Instance specs and variants
- Endpoint: REST Endpoint to invoke

In [None]:
#Build tar file with model data + inference code
bashCommand = "tar -cvpzf model.tar.gz model.py requirements.txt serving.properties"
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

In [None]:
#Upload tar.gz to bucket
model_artifacts = f"s3://{bucket}/whisperx/model.tar.gz"
response = s3.meta.client.upload_file('model.tar.gz', bucket, 'whisperx/model.tar.gz')

In [None]:
!aws s3 ls {model_artifacts}

In [None]:
model_name = "djl-whisperx" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + model_name)
create_model_response = client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image, "ModelDataUrl": model_artifacts},
)
print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = "djl-whisperx" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

production_variants = [
    {
        "VariantName": "AllTraffic",
        "ModelName": model_name,
        "InitialInstanceCount": 1,
        "InstanceType": instance_type,
        "ModelDataDownloadTimeoutInSeconds": 1800,
        "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
    }
]

endpoint_config = {
    "EndpointConfigName": endpoint_config_name,
    "ProductionVariants": production_variants,
}

endpoint_config_response = client.create_endpoint_config(**endpoint_config)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = "djl-whisperx" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(45)
print(describe_endpoint_response)

## Sample Inference

In [None]:
!pip install whisperx

In [None]:
import json
import whisperx
import torch

# ---- prepare audio ----
audio_path = "test_audio.mp4" # replace with your Audio file
audio = whisperx.load_audio(audio_path)   
sample_rate = 16000

payload = {
    "audio_array": audio.tolist(),  # JSON-serializable
    "sample_rate": sample_rate
}

# ---- invoke sagemaker endpoint ----
runtime = boto3.client("sagemaker-runtime", region_name = "us-east-1")
response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,             
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(payload)
)

# deserialize result
result = json.loads(response["Body"].read())
print(json.dumps(result, indent=2))