# [wip] Fine-tuning Whisper on Amazon SageMaker

In this notebook, we create a SageMaker Pipeline to fine-tune Whisper algorithm for a specific language. Whisper is a pre-trained Automatic Speech Recognition (ASR) [[paper](https://cdn.openai.com/papers/whisper.pdf)]. 

This notebook can be run using Data Science 3.0 with ml.t3.medium.
We use [CommonVoice](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) dataset to fine-tune Whisper model.

TODO:
  - Consider running on [fleurs dataset](https://huggingface.co/datasets/google/fleurs/viewer/id_id/train)

## Requirements:
  - Hugging Face hub account to get access token to download the datasets. You need to create Hungging Face account and follow this guide to create an access token: https://huggingface.co/docs/hub/security-tokens
  - Following best practice, we store the Hugging Face Hub access token in AWS Secrets Manager. See the screenshots below to store access token to AWS Secret Manager. You also need to allow Amazon SageMaker to read AWS Secrets Manager.
  
![Secret Type](img/secret_manager1.png "Secret Type")

![Secret Name](img/secret_manager2.png "Secret Name")

## Installs and Imports

In [None]:
%pip install -U sagemaker

In [None]:
import boto3
import os
import sagemaker

from datetime import datetime
from sagemaker.huggingface import HuggingFaceProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow.steps import ProcessingStep, CacheConfig

In [None]:
sess = boto3.Session()
sm = sess.client("sagemaker")

role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session(boto_session=sess)
bucket = sagemaker_session.default_bucket()
region = boto3.Session().region_name
pipeline_session = PipelineSession()

## Application Parameters

Let's define several input parameters for our Whisper fine-tuning pipeline.

  - `DATASET` refers to dataset name. In this case, we are using Common Voice dataset (`common-voice`). TODO: Try `fleurs` dataset.
  - `HFHUB_SECRET` refers to the name of AWS Secrets that contain your Hugging Face access token.

In [None]:
DATASET = "common-voice"
HFHUB_SECRET= "hfhub"

LANGUAGE_NAMES = {
    "id": "indonesian",
    "vi": "vietnamese",
    "ta": "tamil",
    "th": "thai",
}

In [None]:
source_data_s3uri_prefix = f"s3://{bucket}/whisper-data/{DATASET}-"
pipeline_name = "whisper-fine-tuning"  # SageMaker Pipeline name

print(pipeline_name)

In [None]:
from sagemaker.workflow.parameters import ParameterInteger, ParameterString, ParameterFloat

language_code = ParameterString(name="LanguageCode")
language_name = ParameterString(name="LanguageName")

# AWS Secrets Manager Secret containing your Hugging Face Hub access token:
hf_secret = ParameterString(name="HFSecretName")

# processing step parameters
fetching_instance_type = ParameterString(name="FetchingInstanceType", default_value="ml.m5.xlarge")

# training step parameters
# training_instance_type = ParameterString(name="TrainingInstanceType", default_value="ml.g5.2xlarge")
training_max_steps = ParameterString(name="TrainingMaxSteps", default_value="200")

# evaluation step parameters
evaluation_instance_type = ParameterString(name="EvaluationInstanceType", default_value="ml.g4dn.xlarge")

# name of model package in the model registry
model_package_name = ParameterString(name="ModelPackageName", default_value="fine-tuned-whisper")

# model performance step parameters
wer_threshold = ParameterFloat(name="WERThreshold", default_value=35.0)

## Pre-processing Step: Fetching Dataset from CommonVoice

In [None]:
# Need to specify explicit image_uri otherwise HuggingFaceProcessor complains about not having a
# non-GPU image when used with m*/c*/etc non-accelerated instance types:
image_uri = sagemaker.image_uris.retrieve(
    "huggingface",
    region=os.environ["AWS_REGION"],
    version="4.28",
    base_framework_version="pytorch2.0",
    py_version="py310",
    image_scope="training",
)


In [None]:
from sagemaker.workflow.functions import Join

#Initialize the HuggingFaceProcessor
fetch_processor = HuggingFaceProcessor(
    sagemaker_session=pipeline_session,
    base_job_name=f"whisper-fetch-dataset",
    role=role, 
    image_uri=image_uri,
    instance_count=1,
    instance_type=fetching_instance_type,  #"ml.g4dn.xlarge", "ml.m5.2xlarge"
    py_version="py310",
    pytorch_version="2.0", 
    transformers_version="4.28", 
    volume_size_in_gb=40,
)

fetch_processor_args = fetch_processor.run(
    code="entry_fetch.py",
    source_dir="src",
    inputs=None,
    outputs=[
        ProcessingOutput(
            output_name="output",
            source="/opt/ml/processing/output",
            destination=Join(on="", 
                             values=[
                                 source_data_s3uri_prefix,
                                 language_code
                             ]),
            s3_upload_mode="Continuous",
        )    
    ],
    arguments=[
        # Name/ARN of your AWS Secrets Manager Secret containing your Hugging Face Hub access token:
        "--hf_secret_id", hf_secret,
        "--language_code", language_code,
        # Normalize audio sample rate at pre-processing time to save time in training job:
        "--norm_sample_rate", "16000",
        # More shards to help scale later pre-processing in the training job:
        "--save_num_shards", "48",
    ],
    wait=True,
    logs=True,
)


In [None]:
step_fetch = ProcessingStep(
    name="fetch_data",
    step_args=fetch_processor_args,
    cache_config=CacheConfig(enable_caching=True, expire_after="PT24H"),  # Caching step for 24 hours
)

## Training Step: Fine-tuning Whisper

In [None]:
from sagemaker.huggingface import HuggingFace as HuggingFaceEstimator
from sagemaker.workflow.steps import TrainingStep

from src.sagemaker_whisper import notebook as util

# configuration for running training on smdistributed data parallel
#distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}

hyperparameters = {
    "model_name_or_path": "openai/whisper-small", #"openai/whisper-medium", # Not whisper-large-v2 yet: CUDA OOM
    "language": language_name,
    "per_device_train_batch_size": 32, #8, #16, #32,
    "per_device_eval_batch_size": 16, #8, #16,
    "gradient_accumulation_steps": 1, #4, #2, #1,
    "gradient_checkpointing": "true",  # Not needed if not VRAM-constrained?
    "fp16": True,
    "fp16_full_eval": True,
    
    "learning_rate": 5e-6,
    "lr_scheduler_type": "constant_with_warmup",
    "warmup_steps": 50,
    "max_steps": training_max_steps, #1600, #2400, #32, #256, #3000,
    "evaluation_strategy": "steps",
    "save_strategy": "steps",
    "save_steps": 800, #200,
    "eval_steps": 400, #200,
    "logging_steps": 25,

    "early_stopping_patience": 10,
    "load_best_model_at_end": True,
    "metric_for_best_model": "wer",
    "greater_is_better": False,

    # Early stopping implies checkpointing every evaluation, so limit the total checkpoints
    # kept to avoid filling up disk:
    "save_total_limit": 10,
    "seed": 42,

    "predict_with_generate": True,
    "generation_max_length": 255,
    # "push_to_hub": False,
}

metric_definitions = [
    {"Name": "epoch", "Regex": util.get_hf_metric_regex("epoch")},
    {"Name": "learning_rate", "Regex": util.get_hf_metric_regex("learning_rate")},
    {"Name": "train:loss", "Regex": util.get_hf_metric_regex("loss")},
    {"Name": "validation:loss", "Regex": util.get_hf_metric_regex("eval_loss")},
    {
        "Name": "validation:samples_per_sec",
        "Regex": util.get_hf_metric_regex("eval_samples_per_second"),
    },
    {"Name": "validation:wer", "Regex": util.get_hf_metric_regex("eval_wer")},
    {"Name": "validation:wer_ortho", "Regex": util.get_hf_metric_regex("eval_wer_ortho")},
]

estimator = HuggingFaceEstimator(
    sagemaker_session=pipeline_session,
    base_job_name=f"whisper-fine-tune",
    entry_point="entry_train.py",
    source_dir="src",
    output_path=f"s3://{bucket}/{pipeline_name}-trainjobs",
    role=role,
    py_version="py310",
    pytorch_version="2.0",
    transformers_version="4.28",
    instance_count=1,
    instance_type="ml.g5.2xlarge", # E.g. "ml.g5.2xlarge", "ml.p3.2xlarge"; # `training_instance_type` parameter does not work here, causing conflict with attribute in HuggingFace class. 
    # volume_size=30,
    #keep_alive_period_in_seconds=900,  # Enable warm pools for faster debugging (you need quota)

    environment={
        # "PIP_CACHE_DIR": "/opt/ml/sagemaker/warmpoolcache/pip",  # For warm pools
        # "TOKENIZERS_PARALLELISM": "false",
    },
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    #distribution = distribution,
)

train_args = estimator.fit(
    inputs={
        "dataset": step_fetch.properties.ProcessingOutputConfig.Outputs["output"].S3Output.S3Uri,
    },
)

step_train_model = TrainingStep(
    name="fine_tune_whisper",
    step_args=train_args,
)

## Evaluation Step: Evaluate the fine-tuned model

In [None]:
# A ProcessingStep is used to evaluate the performance of the trained model.
# Based on the results of the evaluation, the model is created and deployed.
from sagemaker.workflow.properties import PropertyFile


eval_processor = HuggingFaceProcessor(
    sagemaker_session=pipeline_session,
    base_job_name=f"whisper-evaluation",
    role=role, 
    image_uri=image_uri,
    instance_count=1,
    instance_type=evaluation_instance_type,
    py_version="py310",
    pytorch_version="2.0", 
    transformers_version="4.28", 
    volume_size_in_gb=40,
)


evaluation_report = PropertyFile(
    name="evaluation_report",
    output_name="evaluation",
    path="evaluation.json",
)


eval_args = eval_processor.run(
    code="entry_evaluation.py",
    source_dir="src",
    inputs=[
        ProcessingInput(
            source=step_train_model.properties.ModelArtifacts.S3ModelArtifacts,
            destination="/opt/ml/processing/model",
        ),
        ProcessingInput(
            source=step_fetch.properties.ProcessingOutputConfig.Outputs["output"].S3Output.S3Uri,
            destination="/opt/ml/processing/output",
        ),
    ],
    outputs=[
        ProcessingOutput(
            output_name="evaluation",
            source="/opt/ml/processing/evaluation",
            # destination=f"s3://{bucket}/{s3_prefix}/evaluation_report",
        ),
    ],
    arguments=[
        "--language_name", language_name,
    ],
)

step_evaluate_model = ProcessingStep(
    name="evaluate_ft_whisper_model",
    step_args=eval_args,
    property_files=[evaluation_report],
)

## Model Step: Create Model to be registered to Model Registry

In [None]:
from sagemaker.model_metrics import MetricsSource, ModelMetrics
from sagemaker.workflow.model_step import ModelStep

from sagemaker.huggingface.model import HuggingFaceModel


ft_whisper_model = HuggingFaceModel(
    sagemaker_session=pipeline_session,
    env={
        "HF_TASK": "automatic-speech-recognition",
    },
    model_data=step_train_model.properties.ModelArtifacts.S3ModelArtifacts,
    role=role,
    py_version="py310",
    pytorch_version="2.0",
    transformers_version="4.28",
)

s3_eval_path = step_evaluate_model.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
evaluation_s3_uri = f"{s3_eval_path}/evaluation.json"

model_metrics = ModelMetrics(
    model_statistics=MetricsSource(
        s3_uri=evaluation_s3_uri,
        content_type="application/json",
    )
)

register_model_step_args = ft_whisper_model.register(
    content_types=["application/json"],
    response_types=["application/json"],
    inference_instances=["ml.g4dn.xlarge"],
    model_package_group_name=model_package_name, 
    model_metrics=model_metrics,
    approval_status="Approved"
)

step_create_model = ModelStep(
    name="register_ft_whisper_model",
    step_args=register_model_step_args,
)

## Condition Step: Check WER and conditionally register the packaged model to model registry.

In [None]:
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.functions import JsonGet


cond_lte = ConditionLessThanOrEqualTo(
    left=JsonGet(
        step_name=step_evaluate_model.name,
        property_file=evaluation_report,
        json_path="metrics.wer.value",
    ),
    right=wer_threshold,
)

step_condition = ConditionStep(
    name="check_whisper_evaluation",
    conditions=[cond_lte],
    if_steps=[step_create_model],
    else_steps=[],
)

## Putting all steps together in one SageMaker Pipeline

In [None]:
from sagemaker.workflow.pipeline import Pipeline

# Create a Sagemaker Pipeline.
# Each parameter for the pipeline must be set as a parameter explicitly when the pipeline is created.
# Also pass in each of the steps created above.
# Note that the order of execution is determined from each step's dependencies on other steps,
# not on the order they are passed in below.

pipeline = Pipeline(
    name=pipeline_name,
    parameters=[
        language_code,
        language_name,
        hf_secret,
        fetching_instance_type,
        training_max_steps,
        evaluation_instance_type,
        model_package_name,
        wer_threshold,
    ],
    steps=[
        step_fetch, 
        step_train_model,
        step_evaluate_model, 
        step_condition,
    ],
    sagemaker_session=sagemaker_session,
)

In [None]:
import json

definition = json.loads(pipeline.definition())
definition

In [None]:
pipeline.upsert(role_arn=role)

## Running the pipeline

After we create the pipeline, we can start a pipeline execution with different pipeline parameters such as on `id` or `vi` language.

Note: 
  - `TrainingMaxSteps=800` on Indonesian dataset takes about ~1.5 hours on the training step.
  - `TrainingMaxSteps=32` on Vietnamese takes about ~16 minutes on the training step.
  

In [None]:
date_time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
execution = pipeline.start(
    execution_display_name=f"fine-tune-id-{date_time_str}",
    execution_description="fine-tune on id language",
    parameters=dict(
        LanguageCode="id",
        LanguageName=LANGUAGE_NAMES["id"],
        HFSecretName=HFHUB_SECRET,
        FetchingInstanceType="ml.m5.xlarge",
        #TrainingInstanceType="ml.g5.2xlarge",
        TrainingMaxSteps=800, #32, #800, #1600, #32, # Use 1600 to fine-tune id language for ~10 epochs
        EvaluationInstanceType="ml.g4dn.xlarge",
        ModelPackageName="whisper-fine-tuned-id",
        WERThreshold="30",
    )
)

In [None]:
# Example pipeline execution on Vietnamese language

date_time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
execution = pipeline.start(
    execution_display_name=f"fine-tune-vi-{date_time_str}",
    execution_description="fine-tune on vi language",
    parameters=dict(
        LanguageCode="vi",
        LanguageName=LANGUAGE_NAMES["vi"],
        HFSecretName=HFHUB_SECRET,
        FetchingInstanceType="ml.m5.xlarge",
        #TrainingInstanceType="ml.g5.2xlarge",
        TrainingMaxSteps=32,
        EvaluationInstanceType="ml.g4dn.xlarge",
        ModelPackageName="whisper-fine-tuned-vi",
        WERThreshold="30",
    )
)

In [None]:
execution.wait()

## [Not working yet, ignore this section] Deployment

After we have fine-tuned models in Model Registry, we can deploy the model as a SageMaker Endpoint.

In [None]:
from botocore.exceptions import ClientError


def get_approved_package(model_package_group_name):
    """Gets the latest approved model package for a model package group.

    Args:
        model_package_group_name: The model package group name.

    Returns:
        The SageMaker Model Package ARN.
    """
    try:
        # Get the latest approved model package
        response = sm.list_model_packages(
            ModelPackageGroupName=model_package_group_name,
            ModelApprovalStatus="Approved",
            SortBy="CreationTime",
            MaxResults=100,
        )
        approved_packages = response["ModelPackageSummaryList"]

        # Fetch more packages if none returned with continuation token
        while len(approved_packages) == 0 and "NextToken" in response:
            response = sm.list_model_packages(
                ModelPackageGroupName=model_package_group_name,
                ModelApprovalStatus="Approved",
                SortBy="CreationTime",
                MaxResults=100,
                NextToken=response["NextToken"],
            )
            approved_packages.extend(response["ModelPackageSummaryList"])

        # Return error if no packages found
        if len(approved_packages) == 0:
            error_message = (
                f"No approved ModelPackage found for ModelPackageGroup: {model_package_group_name}"
            )
            raise Exception(error_message)

        print(approved_packages)
        # Return the pmodel package arn
        model_package_arn = approved_packages[0]["ModelPackageArn"]
        return approved_packages[0]
        # return model_package_arn
    except ClientError as e:
        error_message = e.response["Error"]["Message"]
        raise Exception(error_message)

In [None]:
pck = get_approved_package(
    "whisper-fine-tuned-id"
    #model_package_group_prefix + "id" #model_package_group_name
)  
model_description = sm.describe_model_package(ModelPackageName=pck["ModelPackageArn"])

model_description

In [None]:
from sagemaker import ModelPackage
from sagemaker.serializers import DataSerializer

date_time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

audio_serializer = DataSerializer(content_type='audio/x-audio') # using x-audio to support multiple audio formats

model = ModelPackage(
    role=role, 
    model_package_arn=model_description["ModelPackageArn"], 
    sagemaker_session=sagemaker_session
)

endpoint_name = "ft-whisper-model-id-" + date_time_str
print(f"EndpointName={endpoint_name}")
model.deploy(initial_instance_count=1, 
             instance_type="ml.g4dn.xlarge", 
             serializer=audio_serializer, # serializer for our audio data
             endpoint_name=endpoint_name)

In [None]:
from sagemaker.predictor import Predictor

sample_test_audio = 'sample_data/common_voice_id_26208380.mp3'

predictor = Predictor(endpoint_name=endpoint_name)


In [None]:
predictor.content_type

In [None]:
res = predictor.predict(data=sample_test_audio)
print(res)

## Cleanup (removing unused resources)

If you no longer use the packages in model registry, endpoints, and pipeline after you finish with this notebook, you can delete them to avoid any unintended charges.
Here are the example codes for clean up, you can adjust the code to follow your variable names.

### Delete Model Packages

In [None]:
sm = boto3.client("sagemaker")

for d in sm.list_model_packages(ModelPackageGroupName=model_package_group_name)[
    "ModelPackageSummaryList"
]:
    print(d["ModelPackageArn"])
    sm.delete_model_package(ModelPackageName=d["ModelPackageArn"])

sm.delete_model_package_group(ModelPackageGroupName=model_package_group_name)

### Delete Endpoint

In [None]:
predictor.delete_endpoint()

### Delete Pipeline

In [None]:
pipeline.delete()