# FLUX.1 with LoRA on Amazon SageMaker 

Welcome to this Amazon SageMaker guide on how to deploy the [FLUX.1 model](black-forest-labs/FLUX.1-dev) with multi LoRA Adapters to Amazon SageMaker. We will deploy the FLUX.1 model with multiple adeptable LoRA Adapters to Amazon SageMake for real-time inference using Hugging Faces [🧨 Diffusers library](https://huggingface.co/docs/diffusers/index).

![stable-diffusion-on-amazon-sagemaker](./imgs/sd-on-sm.png)

What we are going to do 
1. Create FLUX.1 inference script 
2. Deploy FLUX.1 with LoRA Adapters to Amazon SageMaker
3. Generate images using FLUX.1 schnell


## What is FLUX.1?

FLUX is an open-weights 12B parameter family of rectified flow transformers that generates images from text descriptions, pushing the boundaries of text-to-image generation created by Black Forest Labs. It comes in three variants, each catering to different use cases:

* FLUX.1 [pro]: Offers top-tier performance for commercial applications with high visual quality and prompt adherence, available via API and enterprise solutions.
* FLUX.1 [dev]: Open-weight model for non-commercial use, providing similar capabilities to FLUX.1 [pro], available on platforms like HuggingFace.
* FLUX.1 [schnell]: Fastest model, designed for local development and personal use, available under the Apache 2.0 license with open inference code on GitHub.


![schnell_grid](./imgs/schnell_grid.jpeg)

--- 

Before we can get started, make sure you have [Hugging Face user account](https://huggingface.co/join). The account is needed to load the [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) from the [Hugging Face Hub](https://huggingface.co/).

Create account: https://huggingface.co/join


In [None]:
!pip install "sagemaker==2.231.0" "huggingface_hub==0.24.6" --upgrade --quiet

If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## Create Multi-LoRA FLUX.1 inference script 

Amazon SageMaker allows us to customize the inference script by providing a `inference.py` file. The `inference.py` file is the entry point to our model. It is responsible for loading the model and handling the inference request. If you are used to deploying Hugging Face Transformers that might be knew to you. Usually, we just provide the `HF_MODEL_ID` and `HF_TASK` and the Hugging Face DLC takes care of the rest.

For multi-LoRA FLUX.1 we need to provide the LoRA Adapters and the base FLUX.1 model. We will use the `diffusers` library to load the model and handle the inference request. We create a custom `model_fn` and `predict_fn` to load the model and handle the inference request.


If you want to learn more about creating a custom inference script you can check out [Creating document embeddings with Hugging Face's Transformers & Amazon SageMaker](https://www.philschmid.de/custom-inference-huggingface-sagemaker)

In addition to the `inference.py` file we also have to provide a `requirements.txt` file. The `requirements.txt` file is used to install the dependencies for our `inference.py` file.

The first step is to create a `code/` directory.

In [None]:
!mkdir code

As next we create a `requirements.txt` file and add the latest `diffusers` library to it.

In [None]:
%%writefile code/requirements.txt
diffusers==0.30.2

Next we need to create the `inference.py` file. The `inference.py` file is responsible for loading the model and handling the inference request. The `model_fn` function is called when the model is loaded. The `predict_fn` function is called when we want to do inference. 

We need to update the `model_fn` that is uses the `HF_MODEL_ID` for FLUX.1-dev and `HF_ADAPTER_IDS` for each LoRA Adapter. 

_Note: The `HF_ADAPTER_IDS` is a list of LoRA Adapter IDs. The LoRA Adapter IDs are the names of the LoRA Adapters. The LoRA Adapter IDs are used to load the LoRA Adapters from the Hugging Face Hub. This environment variable is not included when you deploy the model without a custom inference script._

In the `predict_fn` we first validate which adapter id is requrested and then generate 4 image for an input prompt. The `predict_fn` function returns the `4` image as a `base64` encoded string.

In [2]:
# %%writefile code/inference.py
import base64
import torch
import os 
import json
from io import BytesIO
from diffusers import AutoPipelineForText2Image

# HF_ADAPTER_IDS="{\"ostris/yearbook-photo-flux-schnell\": \"yearbook-photo-flux-schnell-v1.safetensors\"}" python inference.py

# ADAPTER_IDS needs to be a JSON object with the adapter id as key and the adapter weight name as value
# e.g. {"ostris/yearbook-photo-flux-schnell": "yearbook-photo-flux-schnell-v1.safetensors"}
ADAPTERS = json.loads(os.getenv("HF_ADAPTER_IDS", "{}"))

MODEL_ID = os.getenv("HF_MODEL_ID", "black-forest-labs/FLUX.1-schnell")


def model_fn(model_dir):
    """Load the model from Hugging Face and apply the LoRA weights if provided."""
    pipeline = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="balanced")
    if len(ADAPTERS.keys()) > 0:
        for adapter_id in ADAPTERS.keys():
            print(f"Loading adapter: {adapter_id}")
            pipeline.load_lora_weights(adapter_id, weight_name=ADAPTERS[adapter_id], adapter_name=adapter_id)

    return pipeline


def predict_fn(data, pipe):
    """Run the model with the provided data and return the generated images."""
    # get prompt & parameters
    prompt = data.pop("inputs", data)
    
    # check if adapter id is provided
    adapter_id = data.pop("adapter_id", None)
    # if adapter id is provided, set the adapter
    if ADAPTERS.get(adapter_id, None) is not None:
        print(f"Using adapter: {adapter_id}")
        pipe.set_adapters(adapter_id)
    else:
        print(f"No valid adapter id provided, using base model")
        pipe.disable_lora()
    
    # set valid HP for stable diffusion
    num_inference_steps = data.pop("num_inference_steps", 4) # only need 4 for schnell version, dev version needs 20-30 or so               
    guidance_scale = data.pop("guidance_scale", 0)  # must be 0.0 for schnell version, dev version can be 3.5    
    num_images_per_prompt = data.pop("num_images_per_prompt", 4)

    # run generation with parameters
    generated_images = pipe(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=256,
        generator=torch.Generator("cpu").manual_seed(0)
    )["images"]

    # create response
    encoded_images = []
    for image in generated_images:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        encoded_images.append(base64.b64encode(buffered.getvalue()).decode())

    # create response
    return {"generated_images": encoded_images}


model = model_fn(".")
# payload = {"inputs": "Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style raw", "adapter_id": "ostris/yearbook-photo-flux-schnell"}
# result = predict_fn(payload, model)

# for i, image in enumerate(result["generated_images"]):
#     # save image to file
#     with open(f"image_{i}.jpg", "wb") as f:
#         f.write(base64.b64decode(image))


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Some parameters are on the meta device device because they were offloaded to the cpu.
Some parameters are on the meta device device because they were offloaded to the cpu.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
model.hf_device_map

{'transformer': 'cpu', 'text_encoder_2': 3, 'text_encoder': 2, 'vae': 1}

## Deploy FLUX.1 with LoRA Adapters to Amazon SageMaker

In this example we will deploy [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell). FlUX.1-schnell is a fast version of FLUX.1 available under the Apache 2.0. It only needs 4 inference steps to generate a high quality image. This is great for real-time inference and local development. If you plan to use FLUX.1 dev you need to switch to an asynchronous inference endpoint or need a bigger instance type (A100/H100). Why? Amazon SageMaker has a request timeout of 60s. Depending on the instance type you use the generation might take longer. In our example we will use an `ml.g5.2xlarge` instance with 1 NVIDIA A10G GPUs of 24GB memory. This is not enough memory to load FLUX.1 and parts of it will be loaded on the CPU. Offloading to CPU will slow down the generation. 

We will also use 2 LoRA Adapters, [prithivMLmods/Canopus-LoRA-Flux-FaceRealism](https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-FaceRealism) and [ostris/yearbook-photo-flux-schnell](htthttps://huggingface.co/ostris/yearbook-photo-flux-schnell).

You can find more LoRA Adapters for FLUX.1 schnell [here](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-schnell) or you can train you own using the [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit).

In [None]:
import base64
import torch
import os 
os.environ["HF_ADAPTER_IDS"] = "{\"ostris/yearbook-photo-flux-schnell\": \"yearbook-photo-flux-schnell-v1.safetensors\"}"
import json
from io import BytesIO
from diffusers import AutoPipelineForText2Image

# HF_ADAPTER_IDS="{\"ostris/yearbook-photo-flux-schnell\": \"yearbook-photo-flux-schnell-v1.safetensors\"}" python inference.py

# ADAPTER_IDS needs to be a JSON object with the adapter id as key and the adapter weight name as value
# e.g. {"ostris/yearbook-photo-flux-schnell": "yearbook-photo-flux-schnell-v1.safetensors"}
ADAPTERS = json.loads(os.getenv("HF_ADAPTER_IDS", "{}"))

MODEL_ID = os.getenv("HF_MODEL_ID", "black-forest-labs/FLUX.1-schnell")


def model_fn(model_dir):
    """Load the model from Hugging Face and apply the LoRA weights if provided."""
    pipeline = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="balanced")
    if len(ADAPTERS.keys()) > 0:
        for adapter_id in ADAPTERS.keys():
            print(f"Loading adapter: {adapter_id}")
            pipeline.load_lora_weights(adapter_id, weight_name=ADAPTERS[adapter_id], adapter_name=adapter_id)

    return pipeline


def predict_fn(data, pipe):
    """Run the model with the provided data and return the generated images."""
    # get prompt & parameters
    prompt = data.pop("inputs", data)
    
    # check if adapter id is provided
    adapter_id = data.pop("adapter_id", None)
    # if adapter id is provided, set the adapter
    if ADAPTERS.get(adapter_id, None) is not None:
        print(f"Using adapter: {adapter_id}")
        pipe.set_adapters(adapter_id)
    else:
        print(f"No valid adapter id provided, using base model")
        pipe.disable_lora()
    
    # set valid HP for stable diffusion
    num_inference_steps = data.pop("num_inference_steps", 4) # only need 4 for schnell version, dev version needs 20-30 or so               
    guidance_scale = data.pop("guidance_scale", 0)  # must be 0.0 for schnell version, dev version can be 3.5    
    num_images_per_prompt = data.pop("num_images_per_prompt", 4)

    # run generation with parameters
    generated_images = pipe(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=256,
        generator=torch.Generator("cpu").manual_seed(0)
    )["images"]

    # create response
    encoded_images = []
    for image in generated_images:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        encoded_images.append(base64.b64encode(buffered.getvalue()).decode())

    # create response
    return {"generated_images": encoded_images}


model = model_fn(".")
payload = {"inputs": "Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style raw", "adapter_id": "ostris/yearbook-photo-flux-schnell"}
result = predict_fn(payload, model)

for i, image in enumerate(result["generated_images"]):
    # save image to file
    with open(f"image_{i}.jpg", "wb") as f:
        f.write(base64.b64decode(image))


In [None]:
from sagemaker.huggingface.model import HuggingFaceModel


# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data=s3_model_uri,      # path to your model and script
   role=role,                    # iam role with permissions to create an Endpoint
   transformers_version="4.17",  # transformers version used
   pytorch_version="1.10",       # pytorch version used
   py_version='py38',            # python version used
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge"
    )

## Generate images using the deployed model

The `.deploy()` returns an `HuggingFacePredictor` object which can be used to request inference. Our endpoint expects a `json` with at least `inputs` key. The `inputs` key is the input prompt for the model, which will be used to generate the image. Additionally, we can provide `num_inference_steps`, `guidance_scale` & `num_images_per_prompt` to controll the generation.

The `predictor.predict()` function returns a `json` with the `generated_images` key. The `generated_images` key contains the `4` generated images as a `base64` encoded string. To decode our response we added a small helper function `decode_base64_to_image` which takes the `base64` encoded string and returns a `PIL.Image` object and `display_images`, which takes a list of `PIL.Image` objects and displays them.

In [None]:
from PIL import Image
from io import BytesIO
from IPython.display import display
import base64
import matplotlib.pyplot as plt

# helper decoder
def decode_base64_image(image_string):
  base64_image = base64.b64decode(image_string)
  buffer = BytesIO(base64_image)
  return Image.open(buffer)

# display PIL images as grid
def display_images(images=None,columns=3, width=100, height=100):
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.axis('off')
        plt.imshow(image)


Now, lets generate some images. As example lets generate `3` images for the prompt `A dog trying catch a flying pizza art drawn by disney concept artists`. Generating `3` images takes around `30` seconds.

In [None]:
num_images_per_prompt = 3
prompt = "A dog trying catch a flying pizza art drawn by disney concept artists, golden colour, high quality, highly detailed, elegant, sharp focus"

# run prediction
response = predictor.predict(data={
  "inputs": prompt,
  "num_images_per_prompt" : num_images_per_prompt
  }
)

# decode images
decoded_images = [decode_base64_image(image) for image in response["generated_images"]]

# visualize generation
display_images(decoded_images)

### Delete model and endpoint

To clean up, we can delete the model and endpoint.

In [None]:
predictor.delete_model()
predictor.delete_endpoint()