## Transcription inference on Amazon SageMaker Inference


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/async-inference|Transcription_on_SM_endpoint.ipynb)

---

##### A near real-time inference for transcription using Whisper model

#### Table of Contents







## Background
Transcribe is the go-to service for transcription in AWS.
However, for non-supported languages, we can use other models (in our case Whisper) that will be deployed in Amazon SageMaker for inference.
For short audio files that the inference takes up to 60 seconds, we can use real-time inference.
For inference that takes longer than 60 seconds, or in the case we want to save on costs by autoscaling the instance count to zero when there are no requests to process, asynchronous inference should be used.

## Notebook scope
This notebook provides 2 deployments options for the Whisper model - real-time and asynchronous inference - including auto-scaling setup and asynchronous inference invocation example 

We used Data Science image to execute the notebook

## 1. Prepare the model for inference

In [None]:
!mkdir model
!mkdir model/code

Create a customer inference code

In [None]:
%%writefile model/code/inference.py
import whisper
import boto3
from urllib.parse import urlparse


def model_fn(model_dir):
    model = whisper.load_model("large-v2")
    return model


def transcribe_from_s3(model, s3_file, language=None):
    s3 = boto3.client("s3")
    o = urlparse(s3_file, allow_fragments=False)
    bucket = o.netloc
    key = o.path.lstrip("/")

    s3.download_file(bucket, key, "tmp.wav")
    result = model.transcribe("tmp.wav", language=language)

    return result["language"], result["text"], result["segments"]


def predict_fn(data, model):
    s3_file = data.pop("s3_file")
    language = data.pop("language", None)

    detected_language, transcription, segments = transcribe_from_s3(model, s3_file, language)

    return {
        "detected_language": detected_language,
        "transcription": transcription,
        "segments": segments,
    }

In requirements.txt file we put the libraries we will need to run the inference code

In [None]:
with open("model/code/requirements.txt", "w") as f:
    f.write("transformers==4.25.1\n")
    f.write("git+https://github.com/openai/whisper.git\n")
    f.write("boto3")

### Uploading the model to S3

In [None]:
%cd model

In [None]:
!rm model.tar.gz

In [None]:
!tar zcvf model.tar.gz *

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
s3_location = f"s3://{sagemaker_session_bucket}/whisper/model/model.tar.gz"

In [None]:
!aws s3 cp model.tar.gz $s3_location

## 2. Real-time inference

### Deploying the model to a real-time inference

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.utils import name_from_base

rt_endpoint_name = name_from_base("whisper-large-v2-custom")

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data=s3_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.17",  # transformers version used
    pytorch_version="1.10",  # pytorch version used
    py_version="py38",  # python version used
)

# deploy the endpoint endpoint
rl_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    endpoint_name=rt_endpoint_name,
)

### Execute inference

In [None]:
# Replace with a path to audio object in S3
# Comment out the language line if you want to specify the input language. Otherwise it will detect it automatically


data = {
    "s3_file": "REPLACE WITH A PATH TO AUDIO OBJECT IN S3"
    # "language": "pl"
}

res = rl_predictor.predict(data=data)
print(res)

## 3. Asynchronous inference

For inference that takes longer than 60 seconds, or in the case we want to save on costs by autoscaling the instance count to zero when there are no requests to process, asynchronous inference should be used.

### Deploying the model to an asynchronous inference

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join
from sagemaker.utils import name_from_base

async_endpoint_name = name_from_base("whisper-large-v2-custom-asyc")

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data=s3_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.17",  # transformers version used
    pytorch_version="1.10",  # pytorch version used
    py_version="py38",  # python version used
)

# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "async_inference/output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        # "SuccessTopic": "PUT YOUR SUCCESS SNS TOPIC ARN",
        # "ErrorTopic": "PUT YOUR ERROR SNS TOPIC ARN",
    },  #  Notification configuration
)

env = {"MODEL_SERVER_WORKERS": "2"}

# deploy the endpoint endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
    env=env,
)

### Execute inference

In [None]:
# Replace with a path to audio object in S3
# Comment out the language line if you want to specify the input language. Otherwise it will detect it automatically


data = {
    "s3_file": "REPLACE WITH A PATH TO AUDIO OBJECT IN S3"
    # "language": "pl"
}

res = async_predictor.predict_async(data=data)
print(res)

In [None]:
# Since it is async inference, get_results is looking for the output_path
# If the inference completed, you'll get the results from the output path. Otherwise, you'll get error that the output_path file doesn't exist
res.get_result()

In [None]:
res.output_path

### Setting up Autoscale asynchronous endpoint

In [None]:
client = boto3.client(
    "application-autoscaling"
)  # Common class representing Application Auto Scaling for SageMaker amongst other services

resource_id = (
    "endpoint/" + async_endpoint_name + "/variant/" + "AllTraffic"
)  # This is the format in which application autoscaling references the endpoint

response = client.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    MinCapacity=1,  # async endpoint can scale in to 0 if setting the MinCapacity=0
    MaxCapacity=5,
)

response = client.put_scaling_policy(
    PolicyName="Invocations-ScalingPolicy",
    ServiceNamespace="sagemaker",  # The namespace of the AWS service that provides the resource.
    ResourceId=resource_id,  # Endpoint name
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",  # SageMaker supports only Instance Count
    PolicyType="TargetTrackingScaling",  # 'StepScaling'|'TargetTrackingScaling'
    TargetTrackingScalingPolicyConfiguration={
        "TargetValue": 5.0,  # The target value for the metric.
        "CustomizedMetricSpecification": {
            "MetricName": "ApproximateBacklogSizePerInstance",
            "Namespace": "AWS/SageMaker",
            "Dimensions": [{"Name": "EndpointName", "Value": async_endpoint_name}],
            "Statistic": "Average",
        },
        "ScaleInCooldown": 300,  # ScaleInCooldown - The amount of time, in seconds, after a scale-in activity completes before another scale in activity can start.
        "ScaleOutCooldown": 300  # ScaleOutCooldown - The amount of time, in seconds, after a scale-out activity completes before another scale out activity can start.
        # 'DisableScaleIn': True|False - indicates whether scale in by the target tracking policy is disabled.
        # If the value is true, scale-in is disabled and the target tracking policy won't remove capacity from the scalable resource.
    },
)

## 4. Invoke Whisper on SageMaker Endpoint for Asynchronous inference
In this section we will demonstrate invocation of an Asynchronous inference endpoint by using the Asynchronous endpoint deployed in section #3 


In [None]:
import boto3

# Create a low-level client representing Amazon SageMaker Runtime
# Update the relevant region
sagemaker_runtime = boto3.client("sagemaker-runtime")

# Specify the location of the input. Should be JSON with the input audion file (example in 02_deploy_whisper-Async.ipynb notebook)
input_location = "REPLACE WITH A PATH TO AUDIO OBJECT IN S3"

# The name of the endpoint. The name must be unique within an AWS Region in your AWS account.
async_endpoint_name = async_endpoint_name

# After you deploy a model using SageMaker hosting
# services, your client applications use this API to get inferences
# from the model hosted at the specified endpoint.
response = sagemaker_runtime.invoke_endpoint_async(
    EndpointName=async_endpoint_name,
    # ContentType='audio/mpeg',
    InputLocation=input_location,
)

In [None]:
# View invocation response
response

#### Check Output Location

In [None]:
# A function that waiting for the async response

import urllib, time
from botocore.exceptions import ClientError


def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                print("waiting for output...")
                time.sleep(2)
                continue
            raise

In [None]:
import sagemaker

sm_session = sagemaker.session.Session()

#### Get Result

In [None]:
import json

output = get_output(response["OutputLocation"])
result = json.loads(output)

print(f"Output: {output}")
print("transcription: ", result.get("transcription"))

### Example for multiple invocations (can be used to test the autoscaling)

In [None]:
sm_runtime = boto3.client("sagemaker-runtime", region_name="eu-west-1")
inferences = []
for i in range(10):
    response = sm_runtime.invoke_endpoint_async(
        EndpointName=async_endpoint_name, InputLocation=input_location
    )
    output_location = response["OutputLocation"]
    inferences += [(input_location, output_location)]
    time.sleep(0.5)

for input_location, output_location in inferences:
    output = get_output(output_location)
    print(f"Input File: {input_location}, Output: {output}")

## 5. Clean up

Remember to delete your endpoints after use as you will be charged for the instances used in this Demo.

In [None]:
rl_predictor.delete_endpoint(rt_endpoint_name)
async_predictor.delete_endpoint(async_endpoint_name)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/async-inference|Transcription_on_SM_endpoint.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/async-inference|Transcription_on_SM_endpoint.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/async-inference|Transcription_on_SM_endpoint.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/async-inference|Transcription_on_SM_endpoint.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/async-inference|Transcription_on_SM_endpoint.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/async-inference|Transcription_on_SM_endpoint.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/async-inference|Transcription_on_SM_endpoint.ipynb)
