# How to deploy Black Forest Labs's FLUX.1-schnell for inference on Amazon SageMakerAI

In this notebook, you will learn how to deploy **Black Forest Labs's FLUX.1-schnell** model (HuggingFace model ID: [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)) using Amazon SageMaker AI. The inference image will be the SageMaker-managed [LMI (Large Model Inference) v15](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-container-docs.html) Docker image. LMI images features a [DJL serving](https://github.com/deepjavalibrary/djl-serving) stack powered by the [Deep Java Library](https://djl.ai/). 

FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).

### Key Features

- Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
- Trained using latent adversarial diffusion distillation, FLUX.1 [schnell] can generate high-quality images in only 1 to 4 steps.
- Released under the apache-2.0 licence, the model can be used for personal, scientific, and commercial purposes.

### Usage

We provide a reference implementation of FLUX.1 [schnell], as well as sampling code, in a dedicated github repository. Developers and creatives looking to build on top of FLUX.1 [schnell] are encouraged to use this as a starting point.

### Out-of-Scope Use 
The model and its derivatives may not be used

- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.


### License agreement
* This model is gated on HuggingFace, please refer to the original [model card](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for license.
* Tailored for local development and personal use. [FLUX.1 schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) is openly available under an Apache2.0 license
* This notebook is a sample notebook and not intended for production use.

## Payload format to invoke the model


Below you can find an examples for a request. 

```python
{
    "prompt": "A cat holding a sign that says hello world",
    "guidance_scale": 0.0,
    "num_inference_steps": 4,
    "max_sequence_length": 256,
    "seed": 42
}
```



In [None]:
%pip install -Uq sagemaker

In [None]:
import sagemaker
import boto3
import logging
import time
from sagemaker.session import Session
from sagemaker.s3 import S3Uploader

print(sagemaker.__version__)

In [None]:
try:
    boto_region = boto3.Session().region_name
    sm_session = sagemaker.session.Session(boto_session=boto3.Session(region_name=boto_region))
    role = sagemaker.get_execution_role()
    sagemaker_default_bucket = sm_session.default_bucket()
    
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

print(f"sagemaker role arn: {role}")
print(f"sagemaker default bucket: {sagemaker_default_bucket}")

In [None]:
HF_MODEL_ID = "black-forest-labs/FLUX.1-schnell"

base_name = HF_MODEL_ID.split('/')[-1].replace('.', '-').lower()
model_lineage = HF_MODEL_ID.split("/")[0]
base_name

## Download the model from Hugging Face and upload the model artifacts on Amazon S3
If you are deploying a model hosted on the HuggingFace Hub, you must specify the `option.model_id=<hf_hub_model_id>` configuration. When using a model directly from the hub, we recommend you also specify the model revision (commit hash or branch) via `option.revision=<commit hash/branch>`. 

Since model artifacts are downloaded at runtime from the Hub, using a specific revision ensures you are using a model compatible with package versions in the runtime environment. Open Source model artifacts on the hub are subject to change at any time. These changes may cause issues when instantiating the model (updated model artifacts may require a newer version of a dependency than what is bundled in the container). If a model provides custom model (modeling.py) and/or custom tokenizer (tokenizer.py) files, you need to specify option.trust_remote_code=true to load and use the model.

In this example, we will demonstrate how to download your copy of the model from huggingface and upload it to an s3 location in your AWS account, then deploy the model with the downloaded model artifacts to an endpoint.  

**Best Practices**:
>
> **Store Models in Your Own S3 Bucket**
For production use-cases, always download and store model files in your own S3 bucket to ensure validated artifacts. This provides verified provenance, improved access control, consistent availability, protection against upstream changes, and compliance with organizational security protocols.
>

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

model_dir = Path('model-files')
model_dir.mkdir(exist_ok=True)

snapshot_download(HF_MODEL_ID, local_dir=model_dir)

### Upload model files to S3 in uncompress format for SageMaker AI
SageMaker AI allows us to provide [uncompressed](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html) files. Thus, we directly upload the folder that contains model files to s3
> **Note**: The default SageMaker bucket follows the naming pattern: `sagemaker-{region}-{account-id}`

In [None]:
# upload uncompress model files to s3
model_artifact_uri = S3Uploader.upload(
    local_path="./model-files",
    desired_s3_uri=f"s3://{sagemaker_default_bucket}/lmi/{base_name}"
)
print(f"Model files are uploaded to --- >: {model_artifact_uri}")

### Configure Model Serving Properties and model.py that will be used to load the model

In [None]:
# Create the directory that will contain the configuration files
from pathlib import Path

model_dir = Path('config')
model_dir.mkdir(exist_ok=True)

**Best Practices**:
>
>**Separate Configuration from Model Artifacts**
> The LMI container supports separating configuration files from model artifacts. While you can store serving.properties with your model files, placing configurations in a distinct S3 location allows for better management of all your configurations files.
>
> **Note**: When your model and configuration files are in different S3 locations, set `option.model_id=<s3_model_uri>` in your serving.properties file, where `s3_model_uri` is the S3 object prefix containing your model artifacts. SageMaker AI will automatically download the model files by looking at the S3URI in model_id

In [None]:
%%writefile ./config/model.py
import torch
from diffusers import FluxPipeline
import base64
from io import BytesIO
from djl_python import Input, Output
import os

class FluxModelHandler(object):
    def __init__(self):
        self.pipe = None
        self.device = None
        # Initialize the model immediately when the class is instantiated
        self._load_model()

    def _load_model(self):
        """Load the model once during container startup"""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
    
        # Use the DJL download location
        djl_base_path = "/tmp/.djl.ai/download"
        model_path = None
    
        if os.path.exists(djl_base_path):
            # Find the first directory that contains model files
            for item in os.listdir(djl_base_path):
                potential_path = os.path.join(djl_base_path, item)
                if os.path.isdir(potential_path) and os.path.exists(os.path.join(potential_path, "model_index.json")):
                    model_path = potential_path
                    break
    
        if model_path:
            print(f"Loading model from DJL download location: {model_path}")
            self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
        else:
            raise FileNotFoundError("Model not found, please make sure the model files are downloaded from s3")
    
        if self.device == "cuda":
            self.pipe.enable_model_cpu_offload()
        
        print("Model loaded successfully and ready for inference!")

    def handle(self, inputs: Input) -> Output:
        # Model is already loaded, no need to check initialization
        try:
            input_data = inputs.get_as_json()
            prompt = input_data.get("prompt", "A cat holding a sign that says hello world")
            guidance_scale = float(input_data.get("guidance_scale", 0.0))
            num_inference_steps = int(input_data.get("num_inference_steps", 4))
            max_sequence_length = int(input_data.get("max_sequence_length", 256))
            seed = int(input_data.get("seed", 0))

            generator = torch.Generator(self.device).manual_seed(seed)

            image = self.pipe(
                prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                max_sequence_length=max_sequence_length,
                generator=generator
            ).images[0]

            # Convert to base64
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode()

            output = Output()
            output.add_as_json({"generated_image": img_str})
            return output

        except Exception as e:
            error_output = Output()
            error_output.add_as_json({"error": str(e)})
            return error_output


# Create the service instance once when the module is imported
_service = FluxModelHandler()


def handle(inputs: Input) -> Output:
    return _service.handle(inputs)

In [None]:
config = f"""engine=Python
option.tensor_parallel_degree=max
option.model_loading_timeout=1500
option.async_mode=false
option.entryPoint=model.py
option.model_id={model_artifact_uri}
option.trust_remote_code=false
option.dtype=bfloat16
fail_fast=true
"""
with open("config/serving.properties", "w") as f:
    f.write(config)

#### Optional configuration files

(Optional) You can also specify a `requirements.txt` to install additional libraries.

In [None]:
%%writefile config/requirements.txt
peft==0.15.1
diffusers==0.34.0
transformers==4.51.3
accelerate==1.0.1
pillow==11.2.1
torch==2.6.0

### Upload config files to S3
Here we will upload our config files to a different path to keep model files and config separate.

In [None]:
# upload the code and config to s3
s3_config_prefix = f"large-model-lmi/code-files-{model_lineage}-{base_name}"

configuration_files = S3Uploader.upload(
    local_path="config",
    desired_s3_uri=f"s3://{sagemaker_default_bucket}/{s3_config_prefix}"
)

print(f"Configuration files are uploaded to: {configuration_files}")

## Configure Model Container and Instance

For deploying Flux-1-Schnell, we'll use:
- **LMI (Deep Java Library) Inference Container**: A container optimized for large language model inference
- **[G6e Instance](https://aws.amazon.com/ec2/instance-types/g6e/)**: AWS's GPU instance type powered by NVIDIA L40S Tensor Core GPUs 

Key configurations:
- The container URI points to the DJL inference container in ECR (Elastic Container Registry)
- We use `ml.g6e.4xlarge` instance
> **Note**: The region in the container URI should match your AWS region.

In [None]:
CONTAINER_VERSION = '0.33.0-lmi15.0.0-cu128'
image_uri = "763104351884.dkr.ecr.{}.amazonaws.com/djl-inference:{}".format(sm_session.boto_session.region_name, CONTAINER_VERSION)
print(image_uri)

In [None]:
gpu_instance_type = "ml.g6e.4xlarge"

## Create SageMaker Model

Now we'll create a SageMaker Model object that combines our:
- Container image (LMI)
- code artifacts (configuration files)
- IAM role (for permissions)

In [None]:
# Specify the S3 URI for your uncompressed config files
config_data = {
    "S3DataSource": {
        "S3Uri": f"{configuration_files}/",
        "S3DataType": "S3Prefix",
        "CompressionType": "None"
    }
}

> **Note**: Here S3 URI points to the configuration files S3 location

In [None]:
from sagemaker.utils import name_from_base
from sagemaker.model import Model

model_name = name_from_base(base_name, short=True)

# Create model
black_forest_model = Model(
    name=model_name,
    image_uri=image_uri,
    model_data=config_data,  # Path to your model files
    role=role,
    env={
        'HF_TASK':'text-to-image',
    },
    sagemaker_session=sm_session
)

## Deploy Model to SageMaker Endpoint

Now we'll deploy our model to a SageMaker endpoint for real-time inference. 
> ⚠️ **Important**: 
> - Deployment can take up to 15 minutes
> - Monitor the CloudWatch logs for progress

In [None]:
%%time

from sagemaker.serializers import JSONSerializer, IdentitySerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(base_name, short=True)
instance_type = gpu_instance_type

black_forest_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=instance_type,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

### Create a predictor from our existing endpoint and make inference

In [None]:
%%time

from sagemaker.serializers import JSONSerializer, IdentitySerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor

endpoint_name = "flux-1-schnell-250630-1656"

predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    sagemaker_session=sm_session
)

In [None]:
import json

# Make a prediction
payload = {
    "prompt": "whimsical and ethereal soft-shaded story illustration: A woman in a large hat stands at the ship's railing looking out across the ocean",
    "guidance_scale": 0.0,
    "num_inference_steps": 4,
    "max_sequence_length": 256,
    "seed": 42
}

response = predictor.predict(payload)

In [None]:
# If you want to convert the base64 image back to a PIL Image:
import base64
from PIL import Image
import io

# Extract the base64-encoded image data from the response
base64_image = response['generated_image']

# Decode the base64 string to bytes
image_bytes = base64.b64decode(base64_image)

# Create a PIL Image from the bytes
generated_image = Image.open(io.BytesIO(image_bytes))

# Display the image (this will open in your default image viewer)
generated_image.show()

# Save the image to a file
generated_image.save('generated_image.png')

### Inference using boto3

In [None]:
import boto3
import json
import base64
from PIL import Image
import io

endpoint_name ="flux-1-schnell-250630-1656"
prompt = "A cat holding a sign that says hello world"

def test_endpoint(endpoint_name, prompt):
    runtime = boto3.client('sagemaker-runtime', region_name="eu-west-3")
    
    # Prepare the payload
    payload = {
        "prompt": prompt,
        "guidance_scale": 0.0,
        "num_inference_steps": 4,
        "max_sequence_length": 256,
        "seed": 42
    }
    
    try:
        response = runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType='application/json',
            Body=json.dumps(payload)
        )
        
        result = json.loads(response['Body'].read())
        print("Success!")
        return result
    except Exception as e:
        print(f"Error: {str(e)}")
        raise


In [None]:
# Test the endpoint
response = test_endpoint(endpoint_name, prompt)

In [None]:
import base64
from PIL import Image
from IPython.display import display
import io

# Extract the base64-encoded image data from the response
base64_image = response['generated_image']

# Decode the base64 string to bytes
image_bytes = base64.b64decode(base64_image)

# Create a PIL Image from the bytes
generated_image = Image.open(io.BytesIO(image_bytes))

# After creating the PIL image
max_size = (800, 800) 
generated_image.thumbnail(max_size, Image.Resampling.LANCZOS)
display(generated_image)

# Save the resized image to a file
generated_image.save('generated_image.png')

In [None]:
# delete endpoint
predictor.delete_model()
predictor.delete_endpoint(delete_endpoint_config=True)