# GPT-SoVITS on Sagemaker - for Astra

## Build image

In [None]:
!chmod +x ./*.sh && ./build_and_push.sh 

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
bucket = sess.default_bucket()
image="gpt-sovits-inference"
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

full_image_uri=f"{account_id}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
print(full_image_uri)


## SageMaker endpoint Deployment

### Create sagemaker model

In [None]:
import boto3
import re
import os
import json
import boto3
import sagemaker
from time import gmtime, strftime
## for debug only
sm_client = boto3.client(service_name='sagemaker')

def create_model():
    image=full_image_uri
    model_name="gpt-sovits-sagemaker-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
    create_model_response = sm_client.create_model(
        ModelName=model_name,
        ExecutionRoleArn=role,
        Containers=[{"Image": image}],
    )
    print(create_model_response)
    return model_name

In [None]:
model_name=create_model()

### Create endpoint configuration

In [None]:
endpointConfigName = "gpt-sovits-sagemaker-configuration-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
def create_endpoint_configuration():
    create_endpoint_config_response = sm_client.create_endpoint_config(     
        EndpointConfigName=endpointConfigName,
        ProductionVariants=[
            {
                "ModelName":model_name,
                "VariantName": "gpt-sovits-sagemaker"+"-variant",
                "InstanceType": "ml.g5.xlarge",  # 指定 g5.xlarge 机器
                "InitialInstanceCount": 1,
                "ModelDataDownloadTimeoutInSeconds": 1200,
                "ContainerStartupHealthCheckTimeoutInSeconds": 1200
            }
        ],
    )
    print(create_endpoint_config_response)
    return endpointConfigName


In [None]:
create_endpoint_configuration()


### Create endpoint

In [None]:
endpointName="gpt-sovits-sagemaker-endpoint"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
def create_endpoint():
    create_endpoint_response = sm_client.create_endpoint(
        EndpointName=endpointName,
        EndpointConfigName=endpointConfigName
    )
    print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])
    resp = sm_client.describe_endpoint(EndpointName=endpointName)
    print("Endpoint Status: " + resp["EndpointStatus"])
    print("Waiting for {} endpoint to be in service".format("gpt-sovits-sagemaker-endpoint"))
    waiter = sm_client.get_waiter("endpoint_in_service")
    waiter.wait(EndpointName=endpointName)

In [None]:
create_endpoint()

## Endpoint Test - Stream mode

In [18]:
import wave
import json
from collections import defaultdict

def invoke_streams_endpoint(smr_client, endpointName, request):
    content_type = "application/json"
    payload = json.dumps(request, ensure_ascii=False)

    response_model = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )

    print(response_model['ResponseMetadata'])
    event_stream = response_model['Body']
    
    result = defaultdict(dict)
    for index, event in enumerate(event_stream):
        chunk = event['PayloadPart']['Bytes']
        
        result[index] = {
            'first_chunk': index == 0,
            'bytes': chunk,
            'last_chunk': False,
            'index': index
        }
        
        if index < 5:
            print(f"chunk {index} len:", len(chunk))

    # Update the last chunk
    last_index = max(result.keys())
    result[last_index]['last_chunk'] = True

    print("result", dict(result))
    return list(result.values())

def audio_chunks_to_wav(audio_chunks, output_filename, channels=1, sample_width=2, sample_rate=16000):
    with wave.open(output_filename, 'wb') as wav_file:
        wav_file.setnchannels(channels)
        wav_file.setsampwidth(sample_width)
        wav_file.setframerate(sample_rate)
        
        for chunk in audio_chunks:
            wav_file.writeframes(chunk['bytes'])

    print(f"WAV file '{output_filename}' has been created.")


In [None]:
# Copy reference wav into S3
! aws s3 cp ./ref_dayu_2s.wav s3://$bucket/gpt-sovits/wav_ref/ref_dayu_2s.wav

In [13]:
import json
import boto3
runtime_sm_client = boto3.client(service_name="sagemaker-runtime", region_name=region)

request = {"refer_wav_path": f"s3://{bucket}/gpt-sovits/wav_ref/ref_dayu_2s.wav",
    "prompt_text": "第二次看完大鱼海棠",
    "prompt_language":"zh",
    "text":"作为SAP基础架构专家,我来解释一下SAP Basis的含义:SAP Basis是指SAP系统的基础设施层,负责管理和维护整个SAP系统环境的运行。它包括以下几个主要方面:SAP系统管理包括SAP系统实例的安装、启动、监控、备份、升级等日常管理任务。Basis团队负责保证系统的正常运行。",
    "text_language" :"zh",
    "output_s3uri":"",
    "cut_punc":",，.。:：!！\"”'‘"
}


In [None]:
response=invoke_streams_endpoint(runtime_sm_client,endpointName,request)

In [None]:
audio_chunks_to_wav(response, "output.wav", sample_rate=32000)