### 1. 安装依赖 & 变量设置

In [None]:
# Image: PyTorch 2.0.0 Python 3.10 CPU Optimized
# Kernel: Python3

In [None]:
!pip install huggingface-hub -Uqq
!pip install -Uqq sagemaker 
!pip install -Uqq soundfile -i https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
!pip install -Uqq datasets urlparse -i https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts

region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

In [None]:
from pathlib import Path

local_model_path = Path("./funasr_model")
local_model_path.mkdir(exist_ok=True)
s3_code_prefix = "aigc-asr-models"

### 2. 模型部署准备（entrypoint脚本，容器镜像，服务配置）

In [None]:
inference_image_uri = (
        f"727897471807.dkr.ecr.{region}.amazonaws.com.cn/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
    )

print(f"Image going to be used is ---- > {inference_image_uri}")

In [None]:
!mkdir -p code

In [None]:
%%writefile ./code/inference.py
import os
import io
import sys
import time
import json
import logging
import torch
import boto3
import ffmpeg
import torchaudio
import requests

from urllib.parse import urlparse, unquote
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

device = "cuda:0" if torch.cuda.is_available() else "cpu"
chunk_length_s = int(os.environ.get('chunk_length_s'))

def download_file_from_s3_url(url, local_dir ='/tmp'):
    # 发送 GET 请求到预签名 URL
    response = requests.get(url)

    # 检查请求是否成功
    if response.status_code == 200:
        # 如果没有提供本地路径，尝试从 URL 或头信息中获取文件名
        parsed_url = urlparse(url)
        filename = os.path.basename(unquote(parsed_url.path))

        local_path = f"{local_dir}/{filename}"
        # 将内容写入本地文件
        with open(local_path, 'wb') as f:
            f.write(response.content)

        print(f"File successfully downloaded to {local_path}")
        return local_path
    else:
        print(f"Failed to download file. Status code: {response.status_code}")
        return None

def model_fn(model_dir,context=None):
    print(f"input_model_dir: {model_dir}")
    model = AutoModel(
        model=model_dir,
        trust_remote_code=True,
        vad_kwargs={"max_single_segment_time": chunk_length_s},
        device="cuda:0",
        hub="ms", # hub="ms" for China region
    )
    return model

def transform_fn(model, request_body, request_content_type, response_content_type="application/json"):
    request = json.loads(request_body)
    audio_s3_presign_uri = request.get("audio_s3_presign_uri")
    
    if not audio_s3_presign_uri:
        return {"error" : "No input passed."}
    local_file_path = download_file_from_s3_url(audio_s3_presign_uri)
    
    if not local_file_path:
        return {"error" : "No Audio downloaded."}
    
    res = model.generate(
        input=local_file_path,
        cache={},
        language="auto",  # "zn", "en", "yue", "ja", "ko", "nospeech"
        use_itn=True,
        batch_size_s=60,
        merge_vad=True,  #
        merge_length_s=15,
    )
    
    text = rich_transcription_postprocess(res[0]["text"])
    
    result = {"text" : text}
    
    os.remove(local_file_path)
    
    return json.dumps(result)

#### 执行下面这个cell，在requirements.txt中添加国内的pip镜像

In [None]:
%%writefile ./code/requirements.txt
-i https://pypi.tuna.tsinghua.edu.cn/simple
torch>=1.13
torchaudio
ffmpeg-python
funasr

In [None]:
# 1. 首先安装必要的库
!pip install -U funasr modelscope -i https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
# 2. 下载模型文件
from modelscope import snapshot_download
model_id = "iic/SenseVoiceSmall"
local_model_path = "./funasr_model"

# 下载模型文件
snapshot_download(
    model_id=model_id,
    local_dir=local_model_path,
    ignore_patterns=["*.md", ".git*"]
)

In [None]:
# 3. 打包模型文件
!tar -czf model.tar.gz -C {local_model_path} .

# 4. 检查打包的文件大小
!ls -lh model.tar.gz

In [None]:
# !rm funasr_model.tar.gz
# !touch dummy
# !tar czvf model.tar.gz dummy

In [None]:
model_uri = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {model_uri}")

### 3. 创建模型 & 创建endpoint

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel

model_name = "FunASR-SenseVoiceSmall"

funasr_hf_model = HuggingFaceModel(
    model_data=model_uri,
    role=role,
    image_uri=inference_image_uri,
    entry_point="inference.py",
    source_dir='./code',
    name=model_name,
    env={
        "chunk_length_s" : "30",
        "MMS_DEFAULT_RESPONSE_TIMEOUT": "500",  # 设置模型服务器超时（秒）
        "SAGEMAKER_MODEL_SERVER_TIMEOUT": "500"  # 设置SageMaker模型服务器超时
    }
)

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

endpoint_name = f'{account_id}-funasr-real-time-endpoint'

real_time_predictor = funasr_hf_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

### 4. 模型测试

##### 4.1 下载一个音频文件，并上传到S3

In [None]:
# 下载一个Audio
import soundfile as sf
from datasets import load_dataset
dataset = load_dataset('MLCommons/peoples_speech', split='train', streaming = True)
sample = next(iter(dataset))
audio_data = sample['audio']['array']
output_path = 'sample_audio.wav'
sf.write(output_path, audio_data, sample['audio']['sampling_rate'])

print(f"Audio sample saved to '{output_path}'.")

import json
# Perform real-time inference
audio_path = "sample_audio.wav"

print(response[0])

In [None]:
# !aws s3 cp ./99aaadae-7057-46d5-9802-9b578bef10ab.mp3 s3://sagemaker-cn-northwest-1-284567523170/aigc-asr-models/
s3_audio_url = sess.upload_data(audio_path, bucket, s3_code_prefix)
print(s3_audio_url)

##### 4.2 生成S3 Presign URL，并发送请求

In [None]:
def generate_presigned_url(s3_uri, expiration=3600):
    """
    Generate a presigned URL for the S3 object

    :param s3_uri: The S3 URI of the object
    :param expiration: Time in seconds for the presigned URL to remain valid
    :return: Presigned URL as string. If error, returns None.
    """
    # Parse the S3 URI
    parsed_uri = urlparse(s3_uri)
    bucket_name = parsed_uri.netloc
    object_key = parsed_uri.path.lstrip('/')

    # Generate the presigned URL
    try:
        s3_client = boto3.client('s3',region_name='cn-northwest-1')
        response = s3_client.generate_presigned_url('get_object',
                                                    Params={'Bucket': bucket_name, 'Key': object_key},
                                                    ExpiresIn=expiration)
    except Exception as e:
        print(f"Error generating presigned URL: {e}")
        return None

    return response

In [None]:
from urllib.parse import urlparse
print(s3_audio_url)
audio_s3_presign_uri = generate_presigned_url(s3_audio_url)
audio_s3_presign_uri

In [None]:
jsondata = { "audio_s3_presign_uri" : audio_s3_presign_uri }
real_time_predictor.predict(data=jsondata)

In [None]:
# 检测下载音频文件
import os
import io
import sys
import time
import json
import logging

import requests

from urllib.parse import urlparse, unquote
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
def download_file_from_s3_url(url, local_dir ='/tmp'):
    # 发送 GET 请求到预签名 URL
    response = requests.get(url)

    # 检查请求是否成功
    if response.status_code == 200:
        # 如果没有提供本地路径，尝试从 URL 或头信息中获取文件名
        parsed_url = urlparse(url)
        filename = os.path.basename(unquote(parsed_url.path))

        local_path = f"{local_dir}/{filename}"
        # 将内容写入本地文件
        with open(local_path, 'wb') as f:
            f.write(response.content)

        print(f"File successfully downloaded to {local_path}")
        return local_path
    else:
        print(f"Failed to download file. Status code: {response.status_code}")
        return None
request = jsondata
audio_s3_presign_uri = request.get("audio_s3_presign_uri")

if not audio_s3_presign_uri:
    print("No input passed.")
local_file_path = download_file_from_s3_url(audio_s3_presign_uri)

if not local_file_path:
    print("No Audio downloaded.")

### 5. 清理模型端点

In [None]:
real_time_predictor.delete_endpoint()