# xTTS on Sagemaker - for Astra

This notebook should be runing in `conda_python3` env, and is designed for Astra Demo only.

## Build inference image

In this workshop, we'll use the non-deepspeed docker image for inference. If you'd like to enable deepspeed, modify `build_and_push.sh` to build the `Dockerfile-sagemaker-ds` instead.

In [None]:
!chmod +x ./*.sh && ./build_and_push.sh

## SageMaker endpoint deployment

In [2]:
import boto3
import sagemaker

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
bucket = sess.default_bucket()
image = "xtts-inference"
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

full_image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{image}:no-ds-latest"
print(full_image_uri)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
022346938362.dkr.ecr.us-east-1.amazonaws.com/xtts-inference:no-ds-latest


### Create sagemaker model

#### 选项1: 使用公开模型

The default model used is `tts_models/multilingual/multi-dataset/xtts_v2`.

In [3]:
from time import gmtime, strftime
## for debug only
sm_client = boto3.client(service_name='sagemaker')


def create_model():
    image = full_image_uri
    model_name = "xtts-sagemaker-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
    create_model_response = sm_client.create_model(
        ModelName=model_name,
        ExecutionRoleArn=role,
        Containers=[
            {
                "Image": image,
            }
        ],
    )
    print(create_model_response)
    return model_name

#### 选项2: 使用自有模型

Put the model path info into `CUSTOM_MODEL_PATH` environment variable, the model will be downloaded while launching SageMaker Endpoint.

In [4]:
# upload model data into s3
CUSTOM_MODEL_PATH = f"s3://{bucket}/xtts/models/tts_models--multilingual--multi-dataset--xtts_v2/"
print(CUSTOM_MODEL_PATH)

# replace local folder
!aws s3 sync .local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2 $CUSTOM_MODEL_PATH

s3://sagemaker-us-east-1-022346938362/xtts/models/tts_models--multilingual--multi-dataset--xtts_v2/


In [5]:
from time import gmtime, strftime
## for debug only
sm_client = boto3.client(service_name='sagemaker')


def create_model():
    image = full_image_uri
    model_name = "xtts-sagemaker-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
    create_model_response = sm_client.create_model(
        ModelName=model_name,
        ExecutionRoleArn=role,
        Containers=[
            {
                "Image": image,
                "Environment": {
                    "CUSTOM_MODEL_PATH": f"s3://{bucket}/xtts/models/tts_models--multilingual--multi-dataset--xtts_v2",
                    # "CUSTOM_MODEL_PATH": "tts_models/multilingual/multi-dataset/xtts_v2",
                }
            }
        ],
    )
    print(create_model_response)
    return model_name

In [6]:
model_name = create_model()

{'ModelArn': 'arn:aws:sagemaker:us-east-1:022346938362:model/xtts-sagemaker-2024-12-19-05-46-11', 'ResponseMetadata': {'RequestId': '56f0f5d4-25b2-44cc-950f-099acc8f9b90', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '56f0f5d4-25b2-44cc-950f-099acc8f9b90', 'content-type': 'application/x-amz-json-1.1', 'content-length': '96', 'date': 'Thu, 19 Dec 2024 05:46:11 GMT'}, 'RetryAttempts': 0}}


### Create endpoint configuration

In [7]:
endpointConfigName = "xtts-sagemaker-configuration-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())


def create_endpoint_configuration():
    create_endpoint_config_response = sm_client.create_endpoint_config(
        EndpointConfigName=endpointConfigName,
        ProductionVariants=[
            {
                "ModelName": model_name,
                "VariantName": "xtts-sagemaker"+"-variant",
                "InstanceType": "ml.g5.xlarge",  # 指定 g5.xlarge 机器
                "InitialInstanceCount": 1,
                "ModelDataDownloadTimeoutInSeconds": 1200,
                "ContainerStartupHealthCheckTimeoutInSeconds": 1200
            }
        ],
    )
    print(create_endpoint_config_response)
    return endpointConfigName

In [8]:
create_endpoint_configuration()

{'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:022346938362:endpoint-config/xtts-sagemaker-configuration-2024-12-19-05-46-28', 'ResponseMetadata': {'RequestId': '4692bd02-74c0-4f2d-92a3-56c6df196033', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '4692bd02-74c0-4f2d-92a3-56c6df196033', 'content-type': 'application/x-amz-json-1.1', 'content-length': '129', 'date': 'Thu, 19 Dec 2024 05:46:29 GMT'}, 'RetryAttempts': 0}}


'xtts-sagemaker-configuration-2024-12-19-05-46-28'

### Create endpoint

In [9]:
endpointName = "xtts-sagemaker-endpoint"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())


def create_endpoint():
    create_endpoint_response = sm_client.create_endpoint(
        EndpointName=endpointName,
        EndpointConfigName=endpointConfigName
    )
    print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])
    resp = sm_client.describe_endpoint(EndpointName=endpointName)
    print("Endpoint Status: " + resp["EndpointStatus"])
    print("Waiting for {} endpoint to be in service".format("xtts-sagemaker-endpoint"))
    waiter = sm_client.get_waiter("endpoint_in_service")
    waiter.wait(EndpointName=endpointName)

In [10]:
create_endpoint()

Endpoint Arn: arn:aws:sagemaker:us-east-1:022346938362:endpoint/xtts-sagemaker-endpoint2024-12-19-05-46-35
Endpoint Status: Creating
Waiting for xtts-sagemaker-endpoint endpoint to be in service


## Endpoint Test - Stream mode

In [11]:
import time
import wave
import json
from collections import defaultdict

# endpointName = "xtts-sagemaker-endpoint2024-12-19-00-54-34"  # modify this


def invoke_streams_endpoint(smr_client, endpointName, request):
    content_type = "application/json"
    payload = json.dumps(request, ensure_ascii=False)
    start = time.time()

    response_model = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )

    print(response_model['ResponseMetadata'])
    event_stream = response_model['Body']

    result = defaultdict(dict)
    for index, event in enumerate(event_stream):
        chunk = event['PayloadPart']['Bytes']

        result[index] = {
            'first_chunk': index == 0,
            'bytes': chunk,
            'last_chunk': False,
            'index': index
        }
        if index == 0:
            print('first chunk latency: ', time.time() - start)

        if index < 5:
            print(f"chunk {index} len:", len(chunk))

    # Update the last chunk
    last_index = max(result.keys())
    result[last_index]['last_chunk'] = True

    print(f"total chunks:", last_index)

    # print("result", dict(result))
    return list(result.values())


def audio_chunks_to_wav(audio_chunks, output_filename, channels=1, sample_width=2, sample_rate=24000):
    with wave.open(output_filename, 'wb') as wav_file:
        wav_file.setnchannels(channels)
        wav_file.setsampwidth(sample_width)
        wav_file.setframerate(sample_rate)

        for chunk in audio_chunks:
            wav_file.writeframes(chunk['bytes'])

    print(f"WAV file '{output_filename}' has been created.")

In [12]:
# Copy reference wav into S3
! aws s3 cp ./ref_dayu_2s.wav s3://$bucket/xtts/wav_ref/ref_dayu_2s.wav

upload: ./ref_dayu_2s.wav to s3://sagemaker-us-east-1-022346938362/xtts/wav_ref/ref_dayu_2s.wav


In [13]:
runtime_sm_client = boto3.client(service_name="sagemaker-runtime", region_name=region)

request = {
    # "speaker_name": "",
    "speaker_wav": [f"s3://{bucket}/xtts/wav_ref/ref_dayu_2s.wav"],
    "temperature": 0.75,
    "top_k": 50,
    "top_p": 0.85,
    "speed": 1,
    "language_id": "en",
    "text": "You can contribute not only with code but with bug reports, comments, questions, answers, or just a simple tweet to spread the word.",
    # "language_id": "zh",
    # "text": "你好，今天天气不错",
}

In [16]:
response = invoke_streams_endpoint(runtime_sm_client, endpointName, request)

{'RequestId': 'b9b621c6-7d4c-4e8b-89e4-d82438241b2e', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'b9b621c6-7d4c-4e8b-89e4-d82438241b2e', 'x-amzn-invoked-production-variant': 'xtts-sagemaker-variant', 'x-amzn-sagemaker-content-type': 'audio/wav', 'date': 'Thu, 19 Dec 2024 06:40:03 GMT', 'content-type': 'application/vnd.amazon.eventstream', 'transfer-encoding': 'chunked', 'connection': 'keep-alive'}, 'RetryAttempts': 0}
first chunk latency:  0.677983283996582
chunk 0 len: 480
chunk 1 len: 480
chunk 2 len: 480
chunk 3 len: 480
chunk 4 len: 480
total chunks: 1016


In [17]:
audio_chunks_to_wav(response, "output.wav", sample_rate=24000)

WAV file 'output.wav' has been created.
