# Deploy MusicGen Small model on SageMaker Async Inference Endpoint

In [None]:
!pip install -Uq sagemaker

In [None]:
!mkdir model
!mkdir model/code

In [None]:
## requirements.txt https://github.com/facebookresearch/audiocraft/blob/main/README.md
'''
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
# Then proceed to one of the following
pip install -U audiocraft  # stable release
pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft  # bleeding edge
pip install -e .  # or if you cloned the repo locally (mandatory if you want to train).
'''
with open("model/code/requirements.txt", "w") as f:
    f.write("transformers==4.37.1\n")
    f.write("boto3\n")
    f.write("torch>=2.1\n")
    f.write("scipy\n")
    f.write("uuid\n")
    f.write("audiocraft\n")
    f.write("git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft\n")


In [None]:
%%writefile model/code/inference.py

import boto3
from urllib.parse import urlparse
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
import uuid

def model_fn(model_dir):
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
    return model


def process_input(model, texts):
    processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
    inputs = processor(
        text = texts, #["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
        padding=True,
        return_tensors="pt",
    )
    audio_values = model.generate(**inputs, max_new_tokens=256)
    return audio_values


def upload_to_s3(wav_on_disk, bucket_name):
    s3 = boto3.resource('s3')
    target_file = wav_on_disk.split('/')[-1]
    s3.Bucket(bucket_name).upload_file(wav_on_disk, f'musicgen/output/{target_file}')
    return f"s3://{bucket_name}/musicgen/output/{target_file}"


def write_to_s3(sampling_rate, audio_values, bucket_name):
    suffix = str(uuid.uuid1())
    wav_file = f"musicgen_out-{suffix}.wav"
    wav_on_disk = f'/tmp/{wav_file}'
    try:
        scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[0, 0].numpy())
    except:
        wav_on_disk = f'/opt/ml/output/data/{wav_file}'
        scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[0, 0].numpy())
    return upload_to_s3(wav_on_disk, bucket_name)


def predict_fn(data, model):
    texts = data.pop('texts')
    bucket_name = data.pop('bucket_name')
    audio_values = process_input(model, texts)
    sampling_rate = model.config.audio_encoder.sampling_rate
    s3_location = write_to_s3(sampling_rate, audio_values, bucket_name)
    return {
        "generated_output_s3": s3_location,
        #"audio_values": audio_values # Optionaly you can send audio_values if required
    }

In [None]:
%cd model

In [None]:
!rm model.tar.gz

In [None]:
!rm -rf code/.ipynb_checkpoints*

In [None]:
!tar zcvf model.tar.gz *

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
s3_location = f"s3://{sagemaker_session_bucket}/musicgen/model/model.tar.gz"
s3_location

In [None]:
!aws s3 cp model.tar.gz $s3_location

## Async Inference

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join
from sagemaker.utils import name_from_base

async_endpoint_name = name_from_base("musicgen-small-v1-async")

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    name=async_endpoint_name,
    model_data=s3_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.37",  # transformers version used
    pytorch_version="2.1",  # pytorch version used
    py_version="py310",  # python version used
)

# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "musicgen/async_inference/music_output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        # "SuccessTopic": "PUT YOUR SUCCESS SNS TOPIC ARN",
        # "ErrorTopic": "PUT YOUR ERROR SNS TOPIC ARN",
    },  #  Notification configuration
)

# deploy the endpoint endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
)


In [None]:
endpoint_name=async_predictor.endpoint_name

In [None]:
%store \
endpoint_name \
sagemaker_session_bucket