# Inference MusicGen Small model deployed on SageMaker Async Inference Endpoint

## Prepare code for inferencing the musicgen deployed on SageMaker Endpoint

In [2]:
%store -r \
endpoint_name \
sagemaker_session_bucket

In [3]:
endpoint_name, sagemaker_session_bucket

('musicgen-small-v1-async-2024-03-27-10-28-03-685',
 'sagemaker-us-west-2-920487201358')

In [4]:
import sagemaker
sm_session = sagemaker.session.Session()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [5]:
import os
import uuid
import json


def generate_json(data):
    suffix = str(uuid.uuid1())
    filename = f'payload_{suffix}.json'
    with open(filename, 'w') as fp:
        json.dump(data, fp)
    return filename


def upload_input_json(sm_session, filename):
    return sm_session.upload_data(
        filename,
        bucket=sm_session.default_bucket(),
        key_prefix='musicgen_large/input_payload',
        extra_args={"ContentType": "application/json"},
    )


def delete_file_on_disk(filename):
    if os.path.isfile(filename):
        os.remove(filename)

In [6]:
import urllib, time
from botocore.exceptions import ClientError
import random

def get_output(output_location, failure_location):
    output_url = urllib.parse.urlparse(output_location)
    failure_url = urllib.parse.urlparse(failure_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    icons = ["ü™ò","ü™á","üé∑","üé∏","üé∫","üéª","ü•Å", "ü™ó", "ü™ï"]
    print("generating music")
    while True:
        try:
            if len(sm_session.list_s3_files(bucket, failure_url.path[1:])):
                print('üîï Error generating music')
                res = sm_session.read_s3_file(bucket=failure_url.netloc, key_prefix=failure_url.path[1:])
                print(res)
                break
            res = sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
            print("\nMusic is ready!üéâ")
            return res
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                
                print(random.choice(icons), end = '')
                time.sleep(2)
                continue
            raise
    

import botocore
def download_from_s3(url):
    """ex: url = s3://bucketname/prefix1/music.wav"""
    url_parts = url.split("/")  # => ['s3:', '', 'sagemakerbucketname', 'data', ...
    bucket_name = url_parts[2]
    key = os.path.join(*url_parts[3:])
    filename = url_parts[-1]
    if not os.path.exists(filename):
        try:
            # Create an S3 client
            s3 = boto3.resource('s3')
            print('Downloading {} to {}'.format(url, filename))
            s3.Bucket(bucket_name).download_file(key, filename)
            return filename
        except botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] == "404":
                print('The object {} does not exist in bucket {}'.format(
                    key, bucket_name))
            else:
                raise


from IPython.display import Audio
import IPython
def play_output_audios(filenames, texts):
    for filename, text in zip(filenames, texts):
        # Create an Audio object
        if not filename:
            continue
        audio = Audio(filename=filename)
        # Display the audio
        print(f"{text}:\n{filename}")
        print()
        display(audio)
        print()

## Prepare and upload inference data to Amazon S3

In [7]:
default_generation_params = { 'guidance_scale': 3, 'max_new_tokens': 128, 'do_sample': True, 'temperature': 0.9 }
data = {
    "texts": ['Warm and vibrant weather on a sunny day, hip hop and synth'],
    "bucket_name": sagemaker_session_bucket,
    "generation_params": default_generation_params
}

In [8]:
filename = generate_json(data)
input_s3_location = upload_input_json(sm_session, filename)
delete_file_on_disk(filename)

In [9]:
input_s3_location

's3://sagemaker-us-west-2-920487201358/musicgen_large/input_payload/payload_2dbcb76e-ec45-11ee-8718-d5b273152fcb.json'

## Invoke Amazon SageMaker Async Inference Endpoint for Musicgen

In [10]:
import boto3
sagemaker_runtime = boto3.client('sagemaker-runtime')


For InvocationTimeoutSeconds, you can set the max timeout for the requests. You can set this value to a maximum of 3600 seconds (one hour) on a per-request basis. If you don't specify this field in your request, by default the request times out at 15 minutes.

Ref: [Invoke an Asynchronous Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-invoke-endpoint.html)


In [11]:
response = sagemaker_runtime.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=input_s3_location,
    ContentType="application/json",
    InvocationTimeoutSeconds=3600
)

In [12]:
response

{'ResponseMetadata': {'RequestId': 'c6292f4a-6d36-4731-9f9d-fa676fc33ebe',
  'HTTPStatusCode': 202,
  'HTTPHeaders': {'x-amzn-requestid': 'c6292f4a-6d36-4731-9f9d-fa676fc33ebe',
   'x-amzn-sagemaker-outputlocation': 's3://sagemaker-us-west-2-920487201358/musicgen/async_inference/music_output/a83bef50-1194-48f9-87eb-2aa1587e82cb.out',
   'x-amzn-sagemaker-failurelocation': 's3://sagemaker-us-west-2-920487201358/async-endpoint-failures/musicgen-small-v1-async-2024-03-27-10-28-03-685-1711535285-6bc7/a83bef50-1194-48f9-87eb-2aa1587e82cb-error.out',
   'date': 'Wed, 27 Mar 2024 14:20:36 GMT',
   'content-type': 'application/json',
   'content-length': '54',
   'connection': 'keep-alive'},
  'RetryAttempts': 0},
 'OutputLocation': 's3://sagemaker-us-west-2-920487201358/musicgen/async_inference/music_output/a83bef50-1194-48f9-87eb-2aa1587e82cb.out',
 'FailureLocation': 's3://sagemaker-us-west-2-920487201358/async-endpoint-failures/musicgen-small-v1-async-2024-03-27-10-28-03-685-1711535285-6bc

Wait for Musicgen to generate music

In [13]:
%%time
output = get_output(response.get('OutputLocation'), response.get('FailureLocation'))

generating music
ü•Åü™òü•Åüé∑üé∫üé∏ü™áü™òü™ïü™òüé∏ü™ïü™óü™òü•Åü™óü™óü™áü™óüé∏ü™ó
Music is ready!üéâ
CPU times: user 210 ms, sys: 10.9 ms, total: 221 ms
Wall time: 42.9 s


In [14]:
output = json.loads(output)
output.keys()

dict_keys(['generated_output_s3'])

In [15]:
output.get('generated_output_s3')

's3://sagemaker-us-west-2-920487201358/musicgen/output/musicgen_out-47126f42-ec45-11ee-a86b-aa560ced12eb.wav'

## Download and Display the wav files to play music

In [16]:
music_files = []
music_files.append(download_from_s3(output.get('generated_output_s3')))

Downloading s3://sagemaker-us-west-2-920487201358/musicgen/output/musicgen_out-47126f42-ec45-11ee-a86b-aa560ced12eb.wav to musicgen_out-47126f42-ec45-11ee-a86b-aa560ced12eb.wav


In [17]:
play_output_audios(music_files, data.get('texts'))

Warm and vibrant weather on a sunny day, hip hop and synth:
musicgen_out-47126f42-ec45-11ee-a86b-aa560ced12eb.wav






Clean up the files downloaded in studio.

In [18]:
for music in music_files:
    delete_file_on_disk(music)

## Cleanup

Programatically obtain the Endpoint, Endpoint Config, Model associated with the `endpoint_name` and delete the resources by setting the value of `cleanup` variable to `True`.

In [19]:
cleanup = False

In [20]:
sm_client = boto3.client('sagemaker')
endpoint = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_config_name = endpoint['EndpointConfigName']
endpoint_config = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
model_name = endpoint_config['ProductionVariants'][0]['ModelName']
notification_config = endpoint_config['AsyncInferenceConfig']['OutputConfig'].get('NotificationConfig', None)
print(f"""
About to delete the following sagemaker resources:
Endpoint: {endpoint_name}
Endpoint Config: {endpoint_config_name}
Model: {model_name}
""")
if notification_config:
    for k,v in notification_config.items():
        print(f'About to delete SNS topics for {k} with ARN: {v}')


About to delete the following sagemaker resources:
Endpoint: musicgen-small-v1-async-2024-03-27-10-28-03-685
Endpoint Config: musicgen-small-v1-async-2024-03-27-10-28-03-685
Model: musicgen-small-v1-async-2024-03-27-10-28-03-685



In [21]:
if cleanup:
    # delete endpoint
    sm_client.delete_endpoint(EndpointName=endpoint_name)
    # delete endpoint config
    sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
    # delete model
    sm_client.delete_model(ModelName=model_name)
    print('deleted model, config and endpoint')

deleted model, config and endpoint


In [22]:
import sys, os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))

from utils.sns_client import SnsClient
del sys.path[0]

In [23]:
from utils.sns_client import SnsClient
import boto3
if cleanup and notification_config:
    sns_client = SnsClient(boto3.client("sns"))
    for k,v in notification_config.items():
        sns_client.delete_topic(v)
    print('deleted SNS topics associated with Async Endpoint')