# Inference musicgen large deployed on SageMaker Endpoint

## Prepare code for inferencing the musicgen deployed on SageMaker Endpoint

In [None]:
%store -r \
endpoint_name \
sagemaker_session_bucket

In [None]:
endpoint_name, sagemaker_session_bucket

In [None]:
import sagemaker
sm_session = sagemaker.session.Session()

In [None]:
import os
import uuid
import json


def generate_json(data):
    suffix = str(uuid.uuid1())
    filename = f'payload_{suffix}.json'
    with open(filename, 'w') as fp:
        json.dump(data, fp)
    return filename


def upload_input_json(sm_session, filename):
    return sm_session.upload_data(
        filename,
        bucket=sm_session.default_bucket(),
        key_prefix='musicgen_large/input_payload',
        extra_args={"ContentType": "application/json"},
    )


def delete_file_on_disk(filename):
    if os.path.isfile(filename):
        os.remove(filename)

In [None]:
import urllib, time
from botocore.exceptions import ClientError
import random

def get_output(output_location, failure_location):
    output_url = urllib.parse.urlparse(output_location)
    failure_url = urllib.parse.urlparse(failure_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    icons = ["🪘","🪇","🎷","🎸","🎺","🎻","🥁", "🪗", "🪕"]
    print("generating music")
    while True:
        try:
            if len(sm_session.list_s3_files(bucket, failure_url.path[1:])):
                print('🔕 Error generating music')
                res = sm_session.read_s3_file(bucket=failure_url.netloc, key_prefix=failure_url.path[1:])
                print(res)
                break
            res = sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
            print("\nMusic is ready!🎉")
            return res
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                
                print(random.choice(icons), end = '')
                time.sleep(2)
                continue
            raise
    

import botocore
def download_from_s3(url):
    """ex: url = s3://bucketname/prefix1/music.wav"""
    url_parts = url.split("/")  # => ['s3:', '', 'sagemakerbucketname', 'data', ...
    bucket_name = url_parts[2]
    key = os.path.join(*url_parts[3:])
    filename = url_parts[-1]
    if not os.path.exists(filename):
        try:
            # Create an S3 client
            s3 = boto3.resource('s3')
            print('Downloading {} to {}'.format(url, filename))
            s3.Bucket(bucket_name).download_file(key, filename)
            return filename
        except botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] == "404":
                print('The object {} does not exist in bucket {}'.format(
                    key, bucket_name))
            else:
                raise


from IPython.display import Audio
import IPython
def play_output_audios(filenames, texts):
    for filename, text in zip(filenames, texts):
        # Create an Audio object
        if not filename:
            continue
        audio = Audio(filename=filename)
        # Display the audio
        print(f"{text}:\n{filename}")
        print()
        display(audio)
        print()

## Prepare and upload inference data to Amazon S3

In [None]:
default_generation_params = { 'guidance_scale': 3, 'max_new_tokens': 1260, 'do_sample': True, 'temperature': 0.9 }
data = {
    "texts": [
        """Compose a melancholic and introspective background music piece that captures the happyness of a boy. Use a combination of strings, woodwinds, and soft percussion to depict the sportive nature the boy, the brightness of the morning, and the loads of hope that may lie ahead for him.""",
        """Compose an upbeat, rhythmic background track that captures the lively atmosphere of a bustling marketplace. Utilize a mix of percussion instruments such as hand drums, tambourines, and cymbals, along with energetic string and wind sections, to create a vibrant, engaging musical landscape."""
    ],
    "bucket_name": sagemaker_session_bucket,
    "generation_params": default_generation_params
}

In [None]:
filename = generate_json(data)
input_s3_location = upload_input_json(sm_session, filename)
delete_file_on_disk(filename)

In [None]:
input_s3_location

## Invoke Amazon SageMaker Async Inference Endpoint for Musicgen

In [None]:
import boto3
sagemaker_runtime = boto3.client('sagemaker-runtime')


For InvocationTimeoutSeconds, you can set the max timeout for the requests. You can set this value to a maximum of 3600 seconds (one hour) on a per-request basis. If you don't specify this field in your request, by default the request times out at 15 minutes.

Ref: [Invoke an Asynchronous Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-invoke-endpoint.html)


In [None]:
response = sagemaker_runtime.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=input_s3_location,
    ContentType="application/json",
    InvocationTimeoutSeconds=3600
)

In [None]:
response

Wait for Musicgen to generate music

In [None]:
%%time
output = get_output(response.get('OutputLocation'), response.get('FailureLocation'))

In [None]:
output = json.loads(output)
output.keys()

In [None]:
output.get('generated_outputs_s3')

## Download and Display the wav files to play music

In [None]:
music_files = []
for s3_url in output.get('generated_outputs_s3'):
    if s3_url is not None:
        music_files.append(download_from_s3(s3_url))

In [None]:
play_output_audios(music_files, data.get('texts'))

Clean up the files downloaded in studio.

In [None]:
for music in music_files:
    delete_file_on_disk(music)

## Cleanup

Programatically obtain the Endpoint, Endpoint Config, Model associated with the `endpoint_name` and delete the resources by setting the value of `cleanup` variable to `True`.

In [None]:
cleanup = True

In [None]:
sm_client = boto3.client('sagemaker')
endpoint = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_config_name = endpoint['EndpointConfigName']
endpoint_config = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
model_name = endpoint_config['ProductionVariants'][0]['ModelName']
notification_config = endpoint_config['AsyncInferenceConfig']['OutputConfig'].get('NotificationConfig', None)
print(f"""
About to delete the following sagemaker resources:
Endpoint: {endpoint_name}
Endpoint Config: {endpoint_config_name}
Model: {model_name}
""")
for k,v in notification_config.items():
    print(f'About to delete SNS topics for {k} with ARN: {v}')

In [None]:
if cleanup:
    # delete endpoint
    sm_client.delete_endpoint(EndpointName=endpoint_name)
    # delete endpoint config
    sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
    # delete model
    sm_client.delete_model(ModelName=model_name)
    print('deleted model, config and endpoint')

In [None]:
import sys, os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))

from utils.sns_client import SnsClient
del sys.path[0]

In [None]:
from utils.sns_client import SnsClient
import boto3
if cleanup:
    sns_client = SnsClient(boto3.client("sns"))
    for k,v in notification_config.items():
        sns_client.delete_topic(v)
    print('deleted SNS topics associated with Async Endpoint')