# Inference MusicGen Small model deployed on SageMaker Async Inference Endpoint

In this notebook, we will learn how to inference an async inference endpoint with musicgen model. We first start with preparing code for inferencing the musicgen deployed on SageMaker Async inference endpoint. We later cover steps to invoke Amazon SageMaker Async Inference Endpoint for Musicgen by prompting the mood in natural language in English. We then understand how to download and display the wav files to play music generated from the user prompt. We finally clean up the resources created as a part of this deployment.

## Prepare code for inferencing the musicgen deployed on SageMaker Async Inference Endpoint

Let us restore the variables from the deployment notebook that are required for inferencing the musicgen model.

In [None]:
%store -r \
endpoint_name \
sagemaker_session_bucket

In [None]:
endpoint_name, sagemaker_session_bucket

In [None]:
import sagemaker
sm_session = sagemaker.session.Session()

In [None]:
%cd ..

In [None]:
import sys, os
# https://stackoverflow.com/a/8015152
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))

from utils.inference_utils import generate_json, upload_input_json, delete_file_on_disk, get_output, download_from_s3, play_output_audios
del sys.path[0]

## Prepare and upload inference data to Amazon S3

In [None]:
default_generation_params = { 'guidance_scale': 5, 'max_new_tokens': 1300, 'do_sample': True, 'temperature': 0.9 }
data = {
    "texts": [
        """Flute with hip hop beats on a sunny day and happy vibes"""
    ],
    "bucket_name": sagemaker_session_bucket,
    "generation_params": default_generation_params
}

From the preceding code, let's understand the generation parameters for `default_generation_params`.
- `guidance_scale`: The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits (which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or 'null' prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results, use a `guidance_scale=3` (default) for text and audio-conditional generation.
- `max_new_tokens`: The `max_new_tokens` parameter specifies the number of new tokens to generate.
- `do_sample`: The model can generate an audio sample conditioned on a text prompt through use of the MusicgenProcessor to pre-process the inputs. The pre-processed inputs can then be passed to the .generate method to generate text-conditional audio samples.
- `temperature`: Softmax temperature parameter.

In [None]:
filename = generate_json(data)
input_s3_location = upload_input_json(sm_session, filename)
delete_file_on_disk(filename)

## Invoke Amazon SageMaker Async Inference Endpoint for Musicgen

In [None]:
import boto3
sagemaker_runtime = boto3.client('sagemaker-runtime')


For InvocationTimeoutSeconds, you can set the max timeout for the requests. You can set this value to a maximum of 3600 seconds (one hour) on a per-request basis. If you don't specify this field in your request, by default the request times out at 15 minutes.

Ref: [Invoke an Asynchronous Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-invoke-endpoint.html)


In [None]:
response = sagemaker_runtime.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=input_s3_location,
    ContentType="application/json",
    InvocationTimeoutSeconds=3600
)

In [None]:
response

Wait for Musicgen to generate music

In [None]:
%%time
output = get_output(sm_session, response.get('OutputLocation'), response.get('FailureLocation'))

In [None]:
import json
output = json.loads(output)
output.keys()

In [None]:
output.get('generated_outputs_s3')

## Download and Display the wav files to play music

In [None]:
music_files = []
for s3_url in output.get('generated_outputs_s3'):
    if s3_url is not None:
        music_files.append(download_from_s3(s3_url))

In [None]:
play_output_audios(music_files, data.get('texts'))

Clean up the files downloaded in studio.

In [None]:
for music in music_files:
    delete_file_on_disk(music)

## Cleanup

Programatically obtain the Endpoint, Endpoint Config, Model associated with the `endpoint_name` and delete the resources by setting the value of `cleanup` variable to `True`.

In [None]:
cleanup = False

In [None]:
sm_client = boto3.client('sagemaker')
endpoint = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_config_name = endpoint['EndpointConfigName']
endpoint_config = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
model_name = endpoint_config['ProductionVariants'][0]['ModelName']
notification_config = endpoint_config['AsyncInferenceConfig']['OutputConfig'].get('NotificationConfig', None)
print(f"""
About to delete the following sagemaker resources:
Endpoint: {endpoint_name}
Endpoint Config: {endpoint_config_name}
Model: {model_name}
""")
for k,v in notification_config.items():
    print(f'About to delete SNS topics for {k} with ARN: {v}')

In [None]:
if cleanup:
    # delete endpoint
    sm_client.delete_endpoint(EndpointName=endpoint_name)
    # delete endpoint config
    sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
    # delete model
    sm_client.delete_model(ModelName=model_name)
    print('deleted model, config and endpoint')

In [None]:
import sys, os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))

from utils.sns_client import SnsClient
del sys.path[0]

In [None]:
from utils.sns_client import SnsClient
import boto3
if cleanup:
    sns_client = SnsClient(boto3.client("sns"))
    for k,v in notification_config.items():
        sns_client.delete_topic(v)
    print('deleted SNS topics associated with Async Endpoint')