# Deploy Wan2.1-T2V-1.3B-Diffusers on SageMaker using HuggingFace Inference Containers

This notebook demonstrates how to deploy [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) text to video model on Amazon SageMaker AI endpoint.

The notebook was tested in Amazon SageMaker AI Studio environment and it's recommended to run it in the Studio.

## 1. Environment preparation

In [None]:
%pip install sagemaker==2.254.1 --upgrade --quiet --no-warn-conflicts

In [None]:
import sys
import os
import time
import json
import boto3
import sagemaker

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints
s3 = boto3.client("s3")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {bucket}")
print(f"sagemaker session region: {region}")
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

In [None]:
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
s3_key = f"model/{model_id}"
s3_code_key = f"{s3_key}/code"

## 2. Model and inference logic preparation

Although it's possible to deploy a model on SageMaker AI endpoint directly from the HuggingFace Hub in production scenario the model is usually deployed from Amazon S3 bucket.
To simulate the production deployment, we are going to download the model from the HF hub and upload to the S3 bucket.

Also, we are going to prepare `requirements.txt` with a list of additional Python libraries and `inference.py` which contains custom model loading and inference code.

***P.S. you can skip this step if model and required files were uploaded to the S3 previously***

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

local_model_path = Path("./data")
local_model_path.mkdir(exist_ok=True)

snapshot_download(repo_id=model_id, local_dir=local_model_path)

In [None]:
# enumerate local files recursively
for root, dirs, files in os.walk(local_model_path):
    for filename in files:
        local_path = os.path.join(root, filename)

        relative_path = os.path.relpath(local_path, local_model_path)
        s3_path = os.path.join(s3_key, relative_path)

        print("Uploading %s..." % s3_path)
        s3.upload_file(local_path, bucket, s3_path)

In [None]:
requirements = """torchvision==0.21.0
opencv-python==4.11.0.86
diffusers==0.35.2
transformers==4.49.0
tokenizers==0.21.1
accelerate==1.4.0
peft==0.17.1
ftfy==6.3.1
ffmpeg==1.4
imageio==2.37.2
imageio-ffmpeg==0.6.0
"""
file_name = "requirements.txt"
with open(file_name, 'w') as f:
    f.write(requirements)

key = f"{s3_code_key}/{file_name}"
s3.upload_file(file_name, bucket, key)

In [None]:
inference = """import os
import time
import boto3
import torch
from botocore.exceptions import ClientError
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video


def upload_file(file_name, bucket, object_name=None):
    # If S3 object_name was not specified, use file_name
    if object_name is None:
        object_name = os.path.basename(file_name)

    # Upload the file
    s3 = boto3.client('s3')
    try:
        s3.upload_file(file_name, bucket, object_name)
    except ClientError as e:
        print(e)
        return False
    return True


def model_fn(model_dir):
    vae = AutoencoderKLWan.from_pretrained(model_dir, subfolder="vae", torch_dtype=torch.float32)
    pipe = WanPipeline.from_pretrained(model_dir, vae=vae, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    return pipe


def predict_fn(data, pipe):
    print("inference started")

    bucket = data.pop("bucket")
    file_name = data.pop("file_name", "model_output.mp4")
    prompt = data.pop("prompt", "A curious raccoon")
    negative_prompt = data.pop("negative_prompt", "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards")
    height = int(data.pop("height", 480))
    width = int(data.pop("width", 832))
    num_frames = int(data.pop("num_frames", 17))
    guidance_scale = float(data.pop("guidance_scale", 5.0))
    fps = int(data.pop("fps", 15))

    start_time = time.perf_counter()

    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_frames=num_frames,
        guidance_scale=guidance_scale
    ).frames[0]

    end_time = time.perf_counter()
    elapsed_time = end_time - start_time

    print(f"Execution time - in PREDICT : {elapsed_time:.6f} seconds")

    file_path = f"/tmp/{os.path.basename(file_name)}"
    export_to_video(output, file_path, fps)

    upload_file(file_path, bucket, file_name)

    try:
        os.remove(file_path)
    except FileNotFoundError:
        print(f"Error: File '{file_path}' not found.")
    except Exception as e:
        print(f"An error occurred: {e}")

    return {"generated_video": f"s3://{bucket}/{file_name}"}
"""
file_name = "inference.py"
with open(file_name, 'w') as f:
    f.write(inference)

key = f"{s3_code_key}/{file_name}"
s3.upload_file(file_name, bucket, key)

## 3. Model deployment

Model deployment on Amazon SageMaker AI consist of 3 steps:
- [create model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_model.html) object (serving container and location of model artifacts)
- [create endpoint configuration](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_endpoint_config.html) (endpoint type: real-time/async, instance type and count)
- [create endpoint](https://boto3.amazonaws.com/v1/documentation/api/1.40.48/reference/services/sagemaker-runtime/client/invoke_endpoint_async.html)

In [None]:
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/huggingface-pytorch-inference:2.6.0-transformers4.49.0-gpu-py312-cu124-ubuntu22.04"

instance = {"type": "ml.g6e.2xlarge", "num_gpu": 1}
model_name = f"model-{time.strftime('%y%m%d-%H%M%S')}"
endpoint_name = model_name
endpoint_config_name = model_name
timeout = 600
variant_name = "main"

model_data_source = {
    'S3DataSource': {
        'S3Uri': f"s3://{bucket}/{s3_key}/",
        'S3DataType': 'S3Prefix',
        'CompressionType': 'None',
    }
}

In [None]:
model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image,
        "ModelDataSource": model_data_source
    },
)


**We are going to use ASYNC endpoint because video generation can exceed real-time endpoint timeout (12 mins).**

The `NotificationConfig` is optional. Please remove this entry if you don't want to be notified when inference request has been processed


In [None]:
sns_topic = "<YOUR_SNS_TOPIC>"

async_config = {
    'ClientConfig': {
        'MaxConcurrentInvocationsPerInstance': 5
    },
    'OutputConfig': {
        'S3OutputPath': f"s3://{bucket}/async/out",
        'NotificationConfig': {
            'SuccessTopic': sns_topic,
            'ErrorTopic': sns_topic,
            'IncludeInferenceResponseIn': ['SUCCESS_NOTIFICATION_TOPIC']
        },
        'S3FailurePath': f"s3://{bucket}/async/err"
    }
}

config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": variant_name,
            "ModelName": model_name,
            "InstanceType": instance["type"],
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": timeout,
        },
    ],
    AsyncInferenceConfig=async_config,
)

endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

_ = sess.wait_for_endpoint(endpoint_name)

## 4. Inference examples

**Please note that the generated video file will be placed in `{bucket}/{key}` specified in the invocation request**

In [None]:
payload = f"""
{{
    "bucket": "{bucket}",
    "file_name": "test_video1.mp4",
    "prompt": "A cat walks on the grass, realistic",
    "num_frames": 161
}}
"""
file_name = "request1.txt"
with open(file_name, 'w') as f:
    f.write(payload)

key = f"async/in/{file_name}"
s3.upload_file(file_name, bucket, key)

res = smr_client.invoke_endpoint_async(
    EndpointName = endpoint_name,
    ContentType = "application/json",
    InputLocation = f"s3://{bucket}/{key}",
)
print(json.dumps(res, indent=2))

In [None]:
payload = f"""
{{
    "bucket": "{bucket}",
    "file_name": "test_video2.mp4",
    "prompt": "A curious racoon standing and looking directly at the camera near a garbage bin",
    "num_frames": 161
}}
"""
file_name = "request2.txt"
with open(file_name, 'w') as f:
    f.write(payload)

key = f"async/in/{file_name}"
s3.upload_file(file_name, bucket, key)

res = smr_client.invoke_endpoint_async(
    EndpointName = endpoint_name,
    ContentType = "application/json",
    InputLocation = f"s3://{bucket}/{key}",
)
print(json.dumps(res, indent=2))

In [None]:
payload = f"""
{{
    "bucket": "{bucket}",
    "file_name": "test_video3.mp4",
    "prompt": "A skilled archery champion, a lean and muscular woman in her late 20s, with sharp, focused hazel eyes, high cheekbones, and sun-kissed olive skin. Her dark brown hair is tied back in a tight braid, a few loose strands framing her face. She wears a fitted leather tunic in deep forest green, reinforced with subtle stitching, over a long-sleeved linen undershirt. Her hands move with precision as she fletches arrows, carefully binding hawk feathers to the shafts with sinew. Her expression is calm but intense, lips slightly pursed in concentration. The workshop around her is cluttered with wooden shafts, fletching tools, and a quiver of finished arrows. Warm torchlight casts flickering shadows on the rough-hewn wooden walls. She is wearing a blue jade bead necklace.",
    "num_frames": 161
}}
"""
file_name = "request3.txt"
with open(file_name, 'w') as f:
    f.write(payload)

key = f"async/in/{file_name}"
s3.upload_file(file_name, bucket, key)

res = smr_client.invoke_endpoint_async(
    EndpointName = endpoint_name,
    ContentType = "application/json",
    InputLocation = f"s3://{bucket}/{key}",
)
print(json.dumps(res, indent=2))

In [None]:
payload = f"""
{{
    "bucket": "{bucket}",
    "file_name": "test_video4.mp4",
    "prompt": "A traditional Christmas dinner table with candles and presents",
    "num_frames": 161
}}
"""
file_name = "request4.txt"
with open(file_name, 'w') as f:
    f.write(payload)

key = f"async/in/{file_name}"
s3.upload_file(file_name, bucket, key)

res = smr_client.invoke_endpoint_async(
    EndpointName = endpoint_name,
    ContentType = "application/json",
    InputLocation = f"s3://{bucket}/{key}",
)
print(json.dumps(res, indent=2))

## 5. Cleanup

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)