### 1. 安装依赖 & 变量设置

In [None]:
!pip install huggingface-hub -Uqq
!pip install --upgrade sagemaker -Uqq
!pip install packaging==21.3

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts

region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

In [10]:
from pathlib import Path

local_model_path = Path("./funasr_model")
local_model_path.mkdir(exist_ok=True)
s3_code_prefix = "aigc-asr-models"

### 2. 模型部署准备（entrypoint脚本，容器镜像，服务配置）

In [11]:
inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
)

#中国区需要替换为下面的image_uri
if region in ['cn-north-1', 'cn-northwest-1']:
    inference_image_uri = (
        f"727897471807.dkr.ecr.{region}.amazonaws.com.cn/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
    )

print(f"Image going to be used is ---- > {inference_image_uri}")

Image going to be used is ---- > 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04


In [12]:
!mkdir -p code

In [22]:
%%writefile ./code/inference.py
import os
import io
import sys
import time
import json
import logging
import torch
import boto3
import ffmpeg
import torchaudio
import requests

from urllib.parse import urlparse, unquote
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

device = "cuda:0" if torch.cuda.is_available() else "cpu"
chunk_length_s = int(os.environ.get('chunk_length_s'))
s3_client = boto3.client('s3')

def download_file_from_s3(bucket_name, s3_file_key, local_dir ='/tmp'):
    try:
        local_file_path = f"{local_dir}/{s3_file_key.split('/')[-1]}"
        s3_client.download_file(bucket_name, s3_file_key, local_file_path)
        print(f"文件成功下载到: {local_file_path}")
    except Exception as e:
        print(f"下载失败: {e}")
        return None
        
    return local_file_path

def download_file_from_s3_url(url, local_dir ='/tmp'):
    # 发送 GET 请求到预签名 URL
    response = requests.get(url)

    # 检查请求是否成功
    if response.status_code == 200:
        # 如果没有提供本地路径，尝试从 URL 或头信息中获取文件名
        parsed_url = urlparse(url)
        filename = os.path.basename(unquote(parsed_url.path))

        local_path = f"{local_dir}/{filename}"
        # 将内容写入本地文件
        with open(local_path, 'wb') as f:
            f.write(response.content)

        print(f"File successfully downloaded to {local_path}")
        return local_path
    else:
        print(f"Failed to download file. Status code: {response.status_code}")
        return None

def model_fn(model_dir):
    print(f"input_model_dir: {model_dir}")
    model_dir = "FunAudioLLM/SenseVoiceSmall"
    model = AutoModel(
        model=model_dir,
        trust_remote_code=True,
        vad_kwargs={"max_single_segment_time": chunk_length_s},
        device="cuda:0",
        hub="hf", # hub="ms" for China region
    )
    return model

def transform_fn(model, request_body, request_content_type, response_content_type="application/json"):
    request = json.loads(request_body)
    audio_s3_presign_uri = request.get("audio_s3_presign_uri")
    bucket_name = request.get("bucket_name")
    s3_key = request.get("s3_key")

    if audio_s3_presign_uri:
        local_file_path = download_file_from_s3_url(audio_s3_presign_uri)
    elif bucket_name and s3_key:
        local_file_path = download_file_from_s3(bucket_name, s3_key)
    else:
        return {"error" : "No valid input passed."}

    if not local_file_path:
        return {"error" : "No Audio downloaded."}
    
    res = model.generate(
        input=local_file_path,
        cache={},
        language="auto",  # "zn", "en", "yue", "ja", "ko", "nospeech"
        use_itn=True,
        batch_size_s=60,
        merge_vad=True,  #
        merge_length_s=15,
    )
    
    text = rich_transcription_postprocess(res[0]["text"])
    
    result = {"text" : text}
    
    os.remove(local_file_path)
    
    return json.dumps(result)

Overwriting ./code/inference.py


#### 执行下面这个cell，在requirements.txt中添加国内的pip镜像

In [23]:
%%writefile ./code/requirements.txt
-i https://pypi.tuna.tsinghua.edu.cn/simple
torch>=1.13
torchaudio
ffmpeg-python
funasr

Overwriting ./code/requirements.txt


In [24]:
!rm funasr_model.tar.gz
!touch dummy
!tar czvf model.tar.gz dummy

rm: cannot remove ‘funasr_model.tar.gz’: No such file or directory
dummy


In [25]:
model_uri = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {model_uri}")

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-east-1-687752207838/aigc-asr-models/model.tar.gz


### 3. 创建模型 & 创建endpoint

In [26]:
from sagemaker.huggingface.model import HuggingFaceModel

model_name = "FunASR-SenseVoiceSmall"

funasr_hf_model = HuggingFaceModel(
    model_data=model_uri,
    role=role,
    image_uri=inference_image_uri,
    entry_point="inference.py",
    source_dir='./code',
    name=model_name,
    env={
        "chunk_length_s" : "30"
    }
)

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

endpoint_name = f'{account_id}-funasr-hf-real-time-endpoint'

real_time_predictor = funasr_hf_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

-----

### 4. 模型测试

##### 4.1 下载一个音频文件，并上传到S3

In [None]:
# 下载一个Audio
import soundfile as sf
from datasets import load_dataset
dataset = load_dataset('MLCommons/peoples_speech', split='train', streaming = True)
sample = next(iter(dataset))
audio_data = sample['audio']['array']
output_path = 'sample_audio.wav'
sf.write(output_path, audio_data, sample['audio']['sampling_rate'])

print(f"Audio sample saved to '{output_path}'.")

import json
# Perform real-time inference
audio_path = "sample_audio.wav"

print(response[0])

In [75]:
!aws s3 cp {audio_path} s3://sagemaker-us-east-1-687752207838/aigc-asr-models/

upload: ./sample_audio.wav to s3://sagemaker-us-east-1-687752207838/aigc-asr-models/sample_audio.wav


##### 4.2 通过bucket and s3_key进行测试

In [28]:
jsondata = { "bucket_name" : "sagemaker-us-east-1-687752207838", "s3_key" : "aigc-asr-models/sample_audio.wav"}
real_time_predictor.predict(data=jsondata)

{'text': "I wanted to share a few things, but I'm going to not share as much as I wanted to share because we are starting late, I'd like to get this thing going so we all get home at a decent hour this election is very important too,"}

##### 4.3 生成S3 Presign URL，并发送请求

In [121]:
def generate_presigned_url(s3_uri, expiration=3600):
    """
    Generate a presigned URL for the S3 object

    :param s3_uri: The S3 URI of the object
    :param expiration: Time in seconds for the presigned URL to remain valid
    :return: Presigned URL as string. If error, returns None.
    """
    # Parse the S3 URI
    parsed_uri = urlparse(s3_uri)
    bucket_name = parsed_uri.netloc
    object_key = parsed_uri.path.lstrip('/')

    # Generate the presigned URL
    try:
        s3_client = boto3.client('s3')
        response = s3_client.generate_presigned_url('get_object',
                                                    Params={'Bucket': bucket_name, 'Key': object_key},
                                                    ExpiresIn=expiration)
    except Exception as e:
        print(f"Error generating presigned URL: {e}")
        return None

    return response

In [144]:
audio_s3_presign_uri = generate_presigned_url('s3://sagemaker-us-east-1-687752207838/aigc-asr-models/sample_audio.wav')
audio_s3_presign_uri

'https://sagemaker-us-east-1-687752207838.s3.amazonaws.com/aigc-asr-models/sample_audio.wav?AWSAccessKeyId=ASIA2AIJZ3XPB232SIXR&Signature=Fzx3rK7WK%2BefVdRDGx8LzQfvVks%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEP7%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQC7f2ip%2FiM7GwIssJRlbz5M9ZdPKeO49HkLYmd9kIPikwIgF%2BUNeuvGEl6ry2BaETQIhEBs%2FnfhP4otnuy8UeYhNNAqvQII5%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAAGgw2ODc3NTIyMDc4MzgiDAKf76y0bd3t9%2FWBciqRAsT3yl5N0u3mOCSxABwkdAtWleCL8bVywk%2FjiuiNi19uAiUSVnbwHH0%2BYl9dl0D2Ct1s7U86e%2FBiCmJmG%2BoipBU9O7Dbv0DvLDsr4P9%2F%2FK76NKDksfK4n7Jb0KIfX7I%2B8tEVh%2BhYQJsobJQeEDghjnRXfWK9sDLjqJGE548d3AlU51tVkKuI9GNeRsNyMd98eFozxmRFvRpD8524PDBJFvqfIZBhQ93fwxKabYYy6R9Qd0DYUjflMqclQ6iV1hFt2eFK2FBEYu6IbAW1XdKBvZ9Si26AQufH3DopmMeFHDisyZTe0%2BCXrQPiuArznDbp8saDWEcaPkjApZytvgygKCGJNngPrOpD58O9cr7EWZNacTDB68a1BjqTATbpjkK463nhouw2uxJ9XX%2FYJMcMnK98yajEOFSdZWk1PYqHY70MRMr2mzuytsZp28xyGgX5RnRwec1fciIlnjbtryPI%2FoX7oRlYQJY3RNA3pHse2NSaCeR2Mi2Rt%2BCVaMsA1BmCZUntJcO8KlDyHJEy76

In [145]:
jsondata = { "audio_s3_presign_uri" : audio_s3_presign_uri }
real_time_predictor.predict(data=jsondata)

['{"text": "I wanted to share a few things, but I\'m going to not share as much as I wanted to share because we are starting late, I\'d like to get this thing going so we all get home at a decent hour this election is very important too,"}',
 'application/json']

### 5. 清理模型端点

In [146]:
real_time_predictor.delete_endpoint()
real_time_predictor.delete_model()