### Deploy SageMaker Endpoint
This notebook deploys the text-to-speech (TTS) and retalking SageMaker endpoints.
Note: This will create 2 x g5.xlarge instances and there will be costs incurred.

The steps include:

1. Create a bucket to upload TTS and retalking model files to
2. Download and upload the TTS and retalking pretrained model files
3. Deploy the TTS endpoint
4. Build the Retalking endpoint container and deploy the retalking endpoint

To build the retalking endpoint container:
1. Ensure that you have a Elastic Container Registry (ECR) for the container in the same region to where you're deploying the endpoints
2. Build the container for the retalking endpoint

Navigate to src/retalking
```
docker build -f Dockerfile.retalking -t <account_id>.dkr.ecr.<region>.amazonaws.com/retalking:latest .
```

3. Push the container to ECR
```
aws ecr get-login-password --region <region> | docker login --username AWS --password-stdin <aws_account_id>.dkr.ecr.<region>.amazonaws.com
docker push <account_id>.dkr.ecr.<region>.amazonaws.com/retalking:latest 
```

4. Execute the following cells to deploy

In [None]:

import os
import sys
sys.path.append('../src')

from utils import download_models

from sagemaker import get_execution_role, image_uris
from sagemaker.pytorch import PyTorchModel  
from sagemaker import Model
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.async_inference import AsyncInferenceConfig

from sagemaker.session import Session

import boto3
from botocore.exceptions import ClientError

### Parameters
Set the parameters below

In [None]:
# You must set these parameters

# [REQUIRED] Set region for deployment
region_name = "<region>"

# [REQUIRED] Sagemaker bucket for model files, specify the bucket name
sm_bucket_name = "<bucket name>"


# [REQUIRED] Specify SageMaker Role to use for the endpoints
# The role must have SageMakerFullAccess and S3 Read/Write access to your bucket to read/create files
sagemaker_role_arn = "arn:aws:iam::<account_id>:role/<Execution Role Name>"

# [REQUIRED] Create an ECR repo and put the URI here along with the image tag
retalking_ecr_uri = '<account_id>.dkr.ecr.<region>.amazonaws.com/retalking:latest'


# You can leave the following as is or change it

# Prefix for model files to be uploaded to
tts_model_prefix = "tts/model"
retalking_model_prefix = "retalking/model"

# Unique name for the TTS endpoint
tts_model_name = "tts-model"
tts_endpoint_name = "tts-endpoint-async"
tts_instances_count = 1     # Increasing instance counts enable parallel generation of TTS

# Retalking
retalking_model_name = 'retalker-model'
retalking_endpoint_name = "retalking-endpoint-async"
retalking_instances_count = 1   # Increasing instance counts enable parallel retalking jobs

### Pre-flight Checks

In [None]:
# Checks if bucket exist, if not creates it.
s3_client = boto3.client('s3', region_name=region_name)

try:
    s3_client.head_bucket(Bucket=sm_bucket_name)
    print(f"Bucket '{sm_bucket_name}' already exists.")
except ClientError as e:
    error_code = e.response['Error']['Code']
    if error_code == '404':
        location = {'LocationConstraint': region_name}
        s3_client.create_bucket(Bucket=sm_bucket_name, CreateBucketConfiguration=location)
        print(f"Bucket '{sm_bucket_name}' created successfully.")
    else:
        print(f"Error occurred while checking bucket existence: {e}")

In [None]:
# Check if ECR repository exists, if not, create the ECR repository
ecr = boto3.client('ecr', region_name=region_name)

try:
    retalking_ecr_repo = retalking_ecr_uri.split(":")[0].split("/")[-1]
    response = ecr.describe_repositories(repositoryNames=[retalking_ecr_repo])
    print(f"Found {response['repositories'][0]['repositoryArn']}")
except ClientError as e:
    if e.response['Error']['Code'] == 'RepositoryNotFoundException':
        print(f"Repository {retalking_ecr_repo} not found. Creating...")
        
        try: 
            response = ecr.create_repository(repositoryName=retalking_ecr_repo)
            print("Successfully created ECR. You must build and push the retalking container using the instructions above.")
        except ClientError as e:
            print("Error creating repository")
            print(e)
    else:
        print(f"Error in retrieving ECR repo {retalking_ecr_repo}")

### Download and prepare model files
TortoiseTTS and Retalking contain pre-trained models that will be downloaded.
Then model files (.tar.gz) are created for SageMaker use.

In [None]:
# Download and prepare model files 

tts_dir = "../src/tts"
tts_model_dir = "../src/tts/model"
retalking_dir = "../src/retalking"
retalking_checkpoints_dir = "../src/retalking/code/checkpoints"

# Final outputs
tts_model_file = "../src/tts/archive/model-tts.tar.gz"
retalking_model_file = "../src/retalking/archive/model-retalking.tar.gz"

download_models(tts_dir=tts_dir,
                tts_model_dir=tts_model_dir, 
                retalking_dir=retalking_dir, 
                retalking_checkpoints_dir=retalking_checkpoints_dir,
                tts_model_dest=tts_model_file,
                retalking_model_dest=retalking_model_file,
                create_archives=True,
                override_archives=False)

In [None]:
#Upload models
s3_client.upload_file(tts_model_file, sm_bucket_name, f"{tts_model_prefix}/model-tts.tar.gz")
s3_client.upload_file(retalking_model_file, sm_bucket_name, f"{retalking_model_prefix}/model-retalking.tar.gz")

#### Deploys TortoiseTTS Endpoint

In [None]:
# Model code and model paths
tts_source_dir = f"{os.path.abspath(tts_dir)}/code"
tts_model_data = f"s3://{sm_bucket_name}/{tts_model_prefix}/model-tts.tar.gz"

print(tts_source_dir)
print(tts_model_data)

In [None]:
# Retrieve the container image required for the TTS endpoint
image_uri = image_uris.retrieve(
    framework="pytorch",
    version="2.1",
    py_version="py310",
    instance_type="ml.g5.xlarge",
    region=region_name,
    image_scope="inference"
)

print(image_uri)

In [None]:
# Create the model, endpoint configuration, and deploys the endpoint

async_config = AsyncInferenceConfig()
model = Model(
    image_uri=image_uri,
    model_data=tts_model_data,
    role=sagemaker_role_arn,
    env={'SAGEMAKER_TS_RESPONSE_TIMEOUT': '900'},
    name=tts_model_name
)

predictor = model.deploy(initial_instance_count=tts_instances_count,
                         instance_type='ml.g5.xlarge',
                         endpoint_name=tts_endpoint_name,
                         serializer=JSONSerializer(),
                         deserializer=JSONDeserializer(),
                         async_inference_config=async_config,
                         model_data_download_timeout=1800,
                         wait=False)

#### Deploy Retalking Endpoint

In [None]:
# Creates the model, endpoint configuration, and deploys the endpoint

retalking_model_data = f"s3://{sm_bucket_name}/{retalking_model_prefix}/model-retalking.tar.gz"

retalker_async_config = AsyncInferenceConfig()
model = Model(
    image_uri=retalking_ecr_uri,
    model_data=retalking_model_data,
    role=sagemaker_role_arn,
    env={'SAGEMAKER_TS_RESPONSE_TIMEOUT': '900', 
         "TS_DEFAULT_RESPONSE_TIMEOUT": "1000",
         "MMS_DEFAULT_RESPONSE_TIMEOUT": "900"},
    name=retalking_model_name,
)

retalking_predictor = model.deploy(initial_instance_count=retalking_instances_count,
                         instance_type='ml.g5.xlarge',
                         endpoint_name=retalking_endpoint_name,
                         serializer=JSONSerializer(),
                         deserializer=JSONDeserializer(),
                         async_inference_config=retalker_async_config,
                         model_data_download_timeout=1800,
                         container_startup_health_check_timeout=1800,
                         env={'SAGEMAKER_TS_RESPONSE_TIMEOUT': '900', 
                                "TS_DEFAULT_RESPONSE_TIMEOUT": "1000",
                                "MMS_DEFAULT_RESPONSE_TIMEOUT": "900"},
                         wait=False)

### Cleanup Endpoints

Delete the endpoints after you're done. Make sure to delete any buckets or roles created.


In [None]:
predictor.delete_endpoint()
retalking_predictor.delete_endpoint()