# How to deploy OpenAI's whisper ASR ( automatic speech recognition ) model for inference on Amazon SageMakerAI

In this notebook, you will learn how to deploy **OpenAI's whisper-large-v2** model.(HuggingFace model ID [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2): using Amazon SageMaker AI. The inference image will be the SageMaker-managed [LMI (Large Model Inference) v15](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-container-docs.html) Docker image. LMI images features a [DJL serving](https://github.com/deepjavalibrary/djl-serving) stack powered by the [Deep Java Library](https://djl.ai/). 

Whisper Large v2 is a version of OpenAI's robust, multilingual speech-to-text model.It performs well on benchmarks like Common Voice and Fleurs, supports transcription and translation, and is suitable for applications like podcasting, video subtitles, and lecture note-taking. For more details please read [post](https://openai.com/index/whisper/).

### Key Features

- Multilingual and Multitask:
Whisper is trained on a massive dataset of multilingual audio, allowing it to transcribe in many languages and even translate speech into English. 
- Robustness:
It's designed to be resilient to various acoustic challenges, including different accents and noisy environments. 
- Improved Performance:
Large v2 achieved better performance over large v1, with relative error reductions in English (around 5%) and other languages (around 10%). 
- End-to-End Transformer:
The model uses an encoder-decoder transformer architecture, processing audio by converting it into a log-Mel spectrogram before feeding it to the encoder. 

### Usage

We provide a reference implementation of whisper-large-v2 , as well as sampling code, in a dedicated github repository. Developers and creatives looking to build on top of whisper-large-v2 are encouraged to use this as a starting point.

### Out-of-Scope Use 
The model and its derivatives may not be used

- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.


### License agreement
* This model is open source on HuggingFace, please refer to the original [model card](https://huggingface.co/openai/whisper-large-v2)
* This notebook is a sample notebook and not intended for production use.

In [76]:
%pip install sagemaker --upgrade --quiet --no-warn-conflicts

Note: you may need to restart the kernel to use updated packages.


In [77]:
import json
import sagemaker
import boto3
from sagemaker.s3 import S3Uploader


role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints
s3_client = boto3.client("s3")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"sagemaker version: {sagemaker.__version__}")

sagemaker role arn: arn:aws:iam::992382553328:role/amazon-sagemaker-base-executionrole
sagemaker bucket: sagemaker-us-west-2-992382553328
sagemaker session region: us-west-2
sagemaker version: 2.251.0


## HF container with default handler

In [78]:
from sagemaker.huggingface import HuggingFaceModel

model_name = sagemaker.utils.name_from_base("model")
endpoint_name = model_name

# Hub Model configuration. https://huggingface.co/models
hub = {
	"HF_MODEL_ID": "openai/whisper-large-v2",
	"HF_TASK": "automatic-speech-recognition"
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
	name=model_name,
    transformers_version='4.49.0',
	pytorch_version='2.6.0',
	py_version='py312',
	env=hub,
	role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
	initial_instance_count=1, # number of instances
	instance_type='ml.g5.12xlarge', # ec2 instance type
    endpoint_name=endpoint_name
)

-----------!

### Download audio files

In [79]:
from sagemaker.jumpstart import utils

# The wav files must be sampled at 16kHz (this is required by the automatic speech recognition models), so make sure to resample them if required. The input audio file must be less than 30 seconds.
s3_bucket = utils.get_jumpstart_content_bucket()
key_prefix = "training-datasets/asr_notebook_data"
input_audio_file_name = "sample1.wav"

s3_client.download_file(s3_bucket, f"{key_prefix}/{input_audio_file_name }", input_audio_file_name)

input_audio_file_name = "sample_french1.wav"

s3_client.download_file(s3_bucket, f"{key_prefix}/{input_audio_file_name }", input_audio_file_name )

In [80]:
from sagemaker.serializers import DataSerializer
	
predictor.serializer = DataSerializer(content_type='audio/x-audio')
predictor.content_type = "audio/x-audio"

# Make sure the input file "sample1.flac" exists
with open(input_audio_file_name, "rb") as f:
	data = f.read()
predictor.predict(data)

{'text': " Bienvenue chez JPB Systèmes, ici. C'est plus de 150 collaborateurs, c'est plus de 90% de chiffre d'affaires à l'export et d'un produit, c'est une quinzaine de preuves que nous avons développées."}

In [81]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
sess.delete_model(model_name)

## HF container with custom handler

Model is deployed from HF hub. Custom handler is placed on S3

## Download the model from Hugging Face and upload the model artifacts on Amazon S3
If you are deploying a model hosted on the HuggingFace Hub, you must specify the `option.model_id=<hf_hub_model_id>` configuration. When using a model directly from the hub, we recommend you also specify the model revision (commit hash or branch) via `option.revision=<commit hash/branch>`. *Here we are using the env variable during deployment instead of serving.properties file*

Since model artifacts are downloaded at runtime from the Hub, using a specific revision ensures you are using a model compatible with package versions in the runtime environment. Open Source model artifacts on the hub are subject to change at any time. These changes may cause issues when instantiating the model (updated model artifacts may require a newer version of a dependency than what is bundled in the container). If a model provides custom model (modeling.py) and/or custom tokenizer (tokenizer.py) files, you need to specify option.trust_remote_code=true to load and use the model.

In this example, we will demonstrate how to download your copy of the model from huggingface and upload it to an s3 location in your AWS account, then deploy the model with the downloaded model artifacts to an endpoint.  

**Best Practices**:
>
> **Store Models in Your Own S3 Bucket**
For production use-cases, always download and store model files in your own S3 bucket to ensure validated artifacts. This provides verified provenance, improved access control, consistent availability, protection against upstream changes, and compliance with organizational security protocols.
>
>> ⚠️ **Important**: 
> - Downloading filescan take time. Please ensure this step completes

In [83]:
from huggingface_hub import snapshot_download
from pathlib import Path
from transformers import AutoModel
import torch

model_dir = Path('model-files')
model_dir.mkdir(exist_ok=True)
HF_MODEL_ID = "openai/whisper-large-v2"
snapshot_download(HF_MODEL_ID, local_dir=model_dir)


Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

'/home/sagemaker-user/sagemaker-genai-hosting-examples/workshop/inference/lab4/Whishper/model-files'

In [84]:
base_name = HF_MODEL_ID.split('/')[-1].replace('.', '-').lower()
model_lineage = HF_MODEL_ID.split("/")[0]
base_name

'whisper-large-v2'

**Best Practices**:
>
> **Note**: When your model and configuration files are in different S3 locations, set `option.model_id=<s3_model_uri>` in your serving.properties file, where `s3_model_uri` is the S3 object prefix containing your model artifacts. SageMaker AI will automatically download the model files by looking at the S3URI in model_id
>*Here we are using the env variable during deployment instead of serving.properties file*

In [85]:
%%writefile ./model-files/requirements.txt
openai-whisper
ffmpeg
torchaudio
nvgpu
transformers>=4.46.0

Writing ./model-files/requirements.txt


In [86]:
%%writefile ./model-files/inference.py
import ast
import io
import json
import logging
from copy import deepcopy
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from sagemaker_inference import encoder
from scipy.io.wavfile import read
from transformers import WhisperForConditionalGeneration
from transformers import WhisperProcessor
from transformers.pipelines.audio_utils import ffmpeg_read

SAMPLE_RATE = 16000
AUDIO_WAV = "audio/wav"
APPLICATION_JSON = "application/json"
STR_DECODE_CODE = "utf-8"

VERBOSE_EXTENSION = ";verbose"

AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition"
TEXT = "text"

SUPPORTED_LANGUAGES = [
    "english",
    "chinese",
    "german",
    "spanish",
    "russian",
    "korean",
    "french",
    "japanese",
    "portuguese",
    "turkish",
    "polish",
    "catalan",
    "dutch",
    "arabic",
    "swedish",
    "italian",
    "indonesian",
    "hindi",
    "finnish",
    "vietnamese",
    "hebrew",
    "ukrainian",
    "greek",
    "malay",
    "czech",
    "romanian",
    "danish",
    "hungarian",
    "tamil",
    "norwegian",
    "thai",
    "urdu",
    "croatian",
    "bulgarian",
    "lithuanian",
    "latin",
    "maori",
    "malayalam",
    "welsh",
    "slovak",
    "telugu",
    "persian",
    "latvian",
    "bengali",
    "serbian",
    "azerbaijani",
    "slovenian",
    "kannada",
    "estonian",
    "macedonian",
    "breton",
    "basque",
    "icelandic",
    "armenian",
    "nepali",
    "mongolian",
    "bosnian",
    "kazakh",
    "albanian",
    "swahili",
    "galician",
    "marathi",
    "punjabi",
    "sinhala",
    "khmer",
    "shona",
    "yoruba",
    "somali",
    "afrikaans",
    "occitan",
    "georgian",
    "belarusian",
    "tajik",
    "sindhi",
    "gujarati",
    "amharic",
    "yiddish",
    "lao",
    "uzbek",
    "faroese",
    "haitian creole",
    "pashto",
    "turkmen",
    "nynorsk",
    "maltese",
    "sanskrit",
    "luxembourgish",
    "myanmar",
    "tibetan",
    "tagalog",
    "malagasy",
    "assamese",
    "tatar",
    "hawaiian",
    "lingala",
    "hausa",
    "bashkir",
    "javanese",
    "sundanese",
    "burmese",
    "valencian",
    "flemish",
    "haitian",
    "letzeburgesch",
    "pushto",
    "panjabi",
    "moldavian",
    "moldovan",
    "sinhalese",
    "castilian",
]
SUPPORTED_TASKS = ["translate", "transcribe"]

# Audio Parameters
AUDIO_INPUT = "audio_input"
LANGUAGE = "language"
TASK = "task"
FORCED_DECODER_IDS = "forced_decoder_ids"

# Text Generation parameters
MAX_LENGTH = "max_length"
NUM_RETURN_SEQUENCES = "num_return_sequences"
NUM_BEAMS = "num_beams"
TOP_P = "top_p"
EARLY_STOPPING = "early_stopping"
DO_SAMPLE = "do_sample"
NO_REPEAT_NGRAM_SIZE = "no_repeat_ngram_size"
TOP_K = "top_k"
TEMPERATURE = "temperature"
MIN_LENGTH = "min_length"
MIN_NEW_TOKENS = "min_new_tokens"
MAX_NEW_TOKENS = "max_new_tokens"
LENGTH_PENALTY = "length_penalty"
MAX_TIME = "max_time"


ALL_PARAM_NAMES = [
    AUDIO_INPUT,
    LANGUAGE,
    TASK,
    FORCED_DECODER_IDS,
    MAX_LENGTH,
    NUM_RETURN_SEQUENCES,
    NUM_BEAMS,
    TOP_P,
    EARLY_STOPPING,
    DO_SAMPLE,
    NO_REPEAT_NGRAM_SIZE,
    TOP_K,
    TEMPERATURE,
    MIN_LENGTH,
    MAX_NEW_TOKENS,
    MIN_NEW_TOKENS,
    LENGTH_PENALTY,
    MAX_TIME,
]

# Model parameter ranges
LENGTH_MIN = 1
NUM_RETURN_SEQUENCE_MIN = 1
NUM_BEAMS_MIN = 1
TOP_P_MIN = 0
TOP_P_MAX = 1
NO_REPEAT_NGRAM_SIZE_MIN = 1
TOP_K_MIN = 0
TEMPERATURE_MIN = 0
NEW_TOKENS_MIN = 0


def is_list_of_strings(parameter: Any) -> bool:
    """Return True if the parameter is a list of strings."""
    if parameter and isinstance(parameter, list):
        return all(isinstance(elem, str) for elem in parameter)
    else:
        return False


def _validate_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
    """Validate the parameters in the input loads.

    Checks if max_length, num_return_sequences, num_beams, top_p and temprature are in bounds.
    Checks if do_sample is boolean.
    Checks max_length, num_return_sequences and num_beams integers.

    Args:
        payload: a decoded input payload (dictionary of input parameter and values)

    Raises: ValueError is any of the check fails.
    """
    # For all parameters used in generation task, please see
    # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
    for param_name in payload:
        if param_name not in ALL_PARAM_NAMES:
            raise ValueError(f"Input payload contains an invalid key '{param_name}'. Valid keys are {ALL_PARAM_NAMES}.")

    if AUDIO_INPUT not in payload:
        raise ValueError(f"Input payload must contain {AUDIO_INPUT} key.")

    if LANGUAGE in payload:
        value = payload[LANGUAGE]
        if type(value) != str:
            raise ValueError(f"{LANGUAGE} must be a string, got {value}.")
        value = value.lower()
        payload[LANGUAGE] = value
        if value not in SUPPORTED_LANGUAGES:
            raise ValueError(
                f"Input payload contains an invalid language {value}. "
                f"Valid languages are {SUPPORTED_LANGUAGES}."
            )
        if TASK not in payload:
            raise ValueError("Input payload should contain both language and task")

    if TASK in payload:
        value = payload[TASK]
        if type(value) != str:
            raise ValueError(f"{TASK} must be a string, got {value}.")
        value = value.lower()
        payload[TASK] = value
        if value not in SUPPORTED_TASKS:
            raise ValueError(
                f"Input payload contains an invalid task {value}. Valid tasks are {SUPPORTED_TASKS}."
            )
        if LANGUAGE not in payload:
            raise ValueError("Input payload should contain both language and task")

    for param_name in [MAX_LENGTH, NUM_RETURN_SEQUENCES, NUM_BEAMS]:
        if param_name in payload:
            if type(payload[param_name]) != int:
                raise ValueError(f"{param_name} must be an integer, got {payload[param_name]}.")

    if MAX_LENGTH in payload:
        if payload[MAX_LENGTH] < LENGTH_MIN:
            raise ValueError(f"{MAX_LENGTH} must be at least {LENGTH_MIN}, got {payload[MAX_LENGTH]}.")

    if MIN_LENGTH in payload:
        if payload[MIN_LENGTH] < LENGTH_MIN:
            raise ValueError(f"{MIN_LENGTH} must be at least {LENGTH_MIN}, got {payload[MIN_LENGTH]}.")

    if MAX_NEW_TOKENS in payload:
        if payload[MAX_NEW_TOKENS] < NEW_TOKENS_MIN:
            raise ValueError(f"{MAX_NEW_TOKENS} must be at least {NEW_TOKENS_MIN}, got {payload[MAX_NEW_TOKENS]}.")

    if MIN_NEW_TOKENS in payload:
        if payload[MIN_NEW_TOKENS] < NEW_TOKENS_MIN:
            raise ValueError(f"{MIN_NEW_TOKENS} must be at least {NEW_TOKENS_MIN}, got {payload[MIN_NEW_TOKENS]}.")

    if NUM_RETURN_SEQUENCES in payload:
        if payload[NUM_RETURN_SEQUENCES] < NUM_RETURN_SEQUENCE_MIN:
            raise ValueError(
                f"{NUM_RETURN_SEQUENCES} must be at least {NUM_RETURN_SEQUENCE_MIN}, "
                f"got {payload[NUM_RETURN_SEQUENCES]}."
            )

    if NUM_BEAMS in payload:
        if payload[NUM_BEAMS] < NUM_BEAMS_MIN:
            raise ValueError(f"{NUM_BEAMS} must be at least {NUM_BEAMS_MIN}, got {payload[NUM_BEAMS]}.")

    if NUM_RETURN_SEQUENCES in payload and NUM_BEAMS in payload:
        if payload[NUM_RETURN_SEQUENCES] > payload[NUM_BEAMS]:
            raise ValueError(
                f"{NUM_BEAMS} must be at least {NUM_RETURN_SEQUENCES}. Instead got "
                f"{NUM_BEAMS}={payload[NUM_BEAMS]} and {NUM_RETURN_SEQUENCES}="
                f"{payload[NUM_RETURN_SEQUENCES]}."
            )

    if TOP_P in payload:
        if payload[TOP_P] < TOP_P_MIN or payload[TOP_P] > TOP_P_MAX:
            raise ValueError(f"{TOP_K} must be in range [{TOP_P_MIN},{TOP_P_MAX}], got " f"{payload[TOP_P]}")

    if TEMPERATURE in payload:
        if payload[TEMPERATURE] < TEMPERATURE_MIN:
            raise ValueError(
                f"{TEMPERATURE} must be a float with value at least {TEMPERATURE_MIN}, got " f"{payload[TEMPERATURE]}."
            )

    if DO_SAMPLE in payload:
        if type(payload[DO_SAMPLE]) != bool:
            raise ValueError(f"{DO_SAMPLE} must be a boolean, got {payload[DO_SAMPLE]}.")

    return payload


def _update_num_beams(payload: Dict[str, Union[str, float, int]]) -> Dict[str, Union[str, float, int]]:
    """Add num_beans to the payload if missing and num_return_sequences is present.

    Args:
        payload (Dict): dictionary of input text and parameters
    Returns:
        payload (Dict): payload with number of beams updated
    """

    if NUM_RETURN_SEQUENCES in payload and NUM_BEAMS not in payload:
        payload[NUM_BEAMS] = payload[NUM_RETURN_SEQUENCES]
    return payload



class ModelAndProcessor:
    """An ASR model with explicit model and tokenizer objects."""

    def __init__(self, model_dir: str) -> None:
        """Initialize model with provided model kwargs and processor objects."""
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.model = WhisperForConditionalGeneration.from_pretrained(model_dir)
        self.model = self.model.to(device)
        self.model.config.forced_decoder_ids = None
        self.model.eval()
        logging.info("Loaded model")

        self.processor = WhisperProcessor.from_pretrained(model_dir)
        logging.info("Loaded processor")

    def __call__(self, audio_input: Dict, **kwargs: Any) -> List:
        """Perform inference via calls to processor and model's generate method.

        If the model is loaded on the GPU, input_ids are placed on the GPU device context.
        """
        input_features = self.processor(
            audio_input["raw"], sampling_rate=audio_input["sampling_rate"], return_tensors="pt"
        ).input_features

        if next(self.model.parameters()).is_cuda:
            input_ids_device = input_features.cuda()
        else:
            input_ids_device = input_features

        if kwargs:
            if LANGUAGE in kwargs:
                language = kwargs.pop(LANGUAGE)
                task = kwargs.pop(TASK)
                kwargs[FORCED_DECODER_IDS] = self.processor.get_decoder_prompt_ids(
                    language=language, task=task
                )

        predicted_ids = self.model.generate(input_ids_device, **kwargs)

        outputs = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)

        return {TEXT: outputs}


def model_fn(model_dir: str) -> Tuple[WhisperForConditionalGeneration, WhisperProcessor]:
    """Create our inference task as a delegate to the model.

    This runs only once per one worker.

    Args:
        model_dir (str): directory where the model files are stored.
    Returns:
        WhisperForConditionalGeneration: a huggingface model for Automatic Speech Recognition.
        WhisperProcessor: a huggingface processor for pre-process the audio inputs and post-process the model outputs.

    Raises:
        ValueError if the model file cannot be found.
    """
    try:
        return ModelAndProcessor(model_dir)
    except Exception:
        logging.exception(f"Failed to load model from: {model_dir}")
        raise


def transform_fn(
    audio_generator_processor: ModelAndProcessor,
    input_data: bytes,
    content_type: str,
    accept: str,
) -> bytes:
    """Make predictions against the model and return a serialized response.

    The function signature conforms to the SM contract.

    Args:
        audio_generator_processor: a huggingface pipeline
        input_data (obj): the request data.
        content_type (str): the request content type.
        accept (str): accept header expected by the client.
    Returns:
        obj: a byte string of the prediction.
    """

    if content_type == AUDIO_WAV:
        try:
            data = ffmpeg_read(input_data, SAMPLE_RATE)
            audio_input = {"sampling_rate": SAMPLE_RATE, "raw": data}
        except Exception:
            logging.exception(
                f"Failed to parse input payload. For content_type= {AUDIO_WAV}, input "
                f"payload must be a bytearray"
            )
            raise
        try:
            output = audio_generator_processor(deepcopy(audio_input))
        except Exception:
            logging.exception("Failed to do inference")
            raise

    elif content_type == APPLICATION_JSON:
        try:
            payload = json.loads(input_data)
        except Exception:
            logging.exception(
                f"Failed to parse input payload. For content_type={APPLICATION_JSON}, input "
                f"payload must be a json encoded dictionary with keys {ALL_PARAM_NAMES}."
            )
            raise
        payload = _validate_payload(payload)
        payload = _update_num_beams(payload)
        audio_input = payload.pop(AUDIO_INPUT)
        audio_input = ffmpeg_read(bytes.fromhex(audio_input), SAMPLE_RATE)

        audio_input = {"sampling_rate": SAMPLE_RATE, "raw": audio_input}

        try:
            output = audio_generator_processor(deepcopy(audio_input), **payload)
        except Exception:
            logging.exception("Failed to do inference")
            raise
    else:
        raise ValueError('{{"error": "unsupported content type {}"}}'.format(content_type or "unknown"))
    if accept.endswith(VERBOSE_EXTENSION):
        accept = accept.rstrip(VERBOSE_EXTENSION)  # Verbose and non-verbose response are identical
    return encoder.encode(output, accept)


Writing ./model-files/inference.py


### Upload model files to S3 in uncompress format for SageMaker AI
SageMaker AI allows us to provide [uncompressed](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html) files. Thus, we directly upload the folder that contains model files to s3
> **Note**: The default SageMaker bucket follows the naming pattern: `sagemaker-{region}-{account-id}`

> ⚠️ **Important**: 
> - Uploading to s3 can take approximately 5 minutes. Please ensure this step completes


In [87]:
# upload uncompress model files to s3
model_artifact_uri = S3Uploader.upload(
    local_path="./model-files",
    desired_s3_uri=f"s3://{bucket}/lmi/{base_name}"
)
print(f"Model files are uploaded to --- >: {model_artifact_uri}")

Model files are uploaded to --- >: s3://sagemaker-us-west-2-992382553328/lmi/whisper-large-v2


> **Note**: Here S3 URI points to the configuration files S3 location

## Configure Model Container and Instance

For deploying whisper-large-v2, we'll use:
- **DLC Hugging Face Pytorch  Inference Container**: A container optimized for large language model inference
- **[G5 Instance](https://aws.amazon.com/ec2/instance-types/g5/)**: High performance GPU-based instances for graphics-intensive applications and machine learning inference

Key configurations:
- The container URI points to the DJL inference container in ECR (Elastic Container Registry)
- We use `ml.g5.4xlarge` instance
> **Note**: The region in the container URI should match your AWS region.
>
## Create SageMaker Model

Now we'll create a SageMaker Model object that combines our:
- Container image (LMI)
- code artifacts (configuration files)
- IAM role (for permissions)

In [88]:
# OVERRIDE:
from sagemaker.utils import name_from_base

image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/huggingface-pytorch-inference:2.6.0-transformers4.49.0-gpu-py312-cu124-ubuntu22.04"
model_name = name_from_base(base_name, short=True)
endpoint_name = model_name

# sagemaker config
instance_type = "ml.g5.4xlarge"
health_check_timeout = 900

model = sagemaker.Model(
	role=role, 
    name=model_name,
    image_uri=image_uri,
    model_data={
        'S3DataSource': {
            'S3Uri': f"{model_artifact_uri}/",
            'S3DataType': 'S3Prefix',
            'CompressionType': 'None'
        }
    },
    env={
        "MMS_MAX_REQUEST_SIZE": '2000000000',
        "MMS_MAX_RESPONSE_SIZE": '2000000000',
        "MMS_DEFAULT_RESPONSE_TIMEOUT": '900',
        "HF_TASK": "automatic-speech-recognition",
        "SERVING_FAIL_FAST": "true",
        "OPTION_MODEL_ID": "/opt/ml/model",
        "OPTION_ASYNC_MODE": "false",
        "OPTION_ROLLING_BATCH": "disable",
        "OPTION_TENSOR_PARALLEL_DEGREE": "max",
        "OPTION_TRUST_REMOTE_CODE": "true",
        "OPTION_ENTRYPOINT": "inference.py"
    },
    sagemaker_session=sess,
)

# Deploy model to an endpoint
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    endpoint_name=endpoint_name,
)

----------!

In [89]:
from sagemaker.serializers import DataSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor

# Define serializers and deserializer
audio_serializer = DataSerializer(content_type="audio/x-audio")
deserializer = JSONDeserializer()

### Create a predictor from our existing endpoint and make inference

In [90]:
predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=audio_serializer,
    deserializer=deserializer,
    sagemaker_session=sess
)

In [91]:
def query_endpoint(body, content_type):
    response = smr_client.invoke_endpoint(EndpointName=endpoint_name, ContentType=content_type, Body=body)
    model_predictions = json.loads(response['Body'].read())
    print(json.dumps(model_predictions, indent=2))

### Speech to transcribed text

In [92]:
input_audio_file_name = "sample1.wav"

with open(input_audio_file_name, "rb") as file:
    wav_file_read = file.read()

query_endpoint(wav_file_read, "audio/wav")

{
  "text": [
    " We are living in very exciting times with machine learning. The speed of ML model development will really actually increase. But you won't get to that end state that we want in the next coming years unless we actually make these models more accessible to everybody."
  ]
}


### Speech to transcribed text in original language(French) and translated to English

In [93]:
input_audio_file_name = "sample_french1.wav"

with open(input_audio_file_name, "rb") as file:
    wav_file_read = file.read()

payload = {"audio_input": wav_file_read.hex()}
query_endpoint(json.dumps(payload).encode('utf-8'), "application/json")

payload = {"audio_input": wav_file_read.hex(),
           "language": "french",
           "task": "translate"}

query_endpoint(json.dumps(payload).encode('utf-8'), "application/json")

{
  "text": [
    " Bienvenue chez JPB Syst\u00e8mes, ici. C'est plus de 150 collaborateurs, c'est plus de 90% de chiffre d'affaires \u00e0 l'export et d'un produit, c'est une quinzaine de preuves que nous avons d\u00e9velopp\u00e9es."
  ]
}
{
  "text": [
    " Welcome to JPBSystem. We have more than 150 employees and 90% of sales. We have developed about 15 patents."
  ]
}


In [94]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
sess.delete_model(model_name)