# Deploying a pruna optimised version of Flux Kontext Dev on sagemaker

This sample showcases how to optimise and deploy Flux Kontext dev from Black Forest Lab (available on <a href="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev">HuggingFace</a>) using <a href="https://aws.amazon.com/sagemaker/">Amazon Sagemaker</a> and <a href="https://github.com/PrunaAI">Pruna open source library</a>.

NOTE: Usage of Flux Kontext Dev is subject to licensing available <a href="https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev">here</a>.

Optimisation implemented in this sample shows a 3.5x speed-up (63sec > 21sec) with similar output quality.  

## Requirements and instructions

You can run this notebook locally (laptop for eg.) as it leverages Amazon Sagemaker infrastructure for model preparation and deployment. You could also run it from an Amazon Sagemaker AI studio (ex. Code Editor).

Python version: 3.10.18.
Torch version : At time of writing, PyTorch Deep Learning Container is not available for PyTorch 2.7.0 requiring upgrade from 2.6 in training job and endpoint deployment.


In [None]:
%pip install torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/cu128
%pip install torchaudio==2.7.0
%pip install torchvision==0.22.0
%pip install diffusers==0.35.1
%pip install transformers==4.51.0
%pip install huggingface_hub>=0.34.4
%pip install boto3
%pip install sagemaker
%pip install sagemaker-huggingface-inference-toolkit>=2.6.0
%pip install peft>=0.17.0
%pip install accelerate
%pip install sentencepiece
%pip install pillow
%pip install protobuf
%pip install matplotlib
%pip install pruna
%pip install numpy

In [None]:
import os
import base64
import tarfile
from pathlib import Path
import shutil
import tarfile
from pathlib import Path
from io import BytesIO
import matplotlib.pyplot as plt
from IPython.display import display
from PIL import Image

import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from huggingface_hub import login

import boto3
import sagemaker
from sagemaker.remote_function import remote
from sagemaker.s3 import S3Uploader
from sagemaker.s3 import S3Downloader
from sagemaker.huggingface.model import HuggingFaceModel

from pruna import SmashConfig, smash


In [None]:
region = "us-east-1"
HF_TOKEN = "your_token_here"
model_subfolder = "flux-smashed"

sess = sagemaker.Session(boto_session=boto3.session.Session(region_name=region))
sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    # If you current identity is not a role, you may hardcode the role arn here
    role = "arn_of_your_sagemaker_execution_role"

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## Packaging local code to S3

Inference code and requirements are available in this repository. This code packages the 2 files to S3 which will be used during model preparation in the training job. 

In [None]:
def compress(tar_dir=None, output_file="code.tar.gz"):
    parent_dir=os.getcwd()
    os.chdir(parent_dir)
    with tarfile.open(os.path.join(parent_dir, output_file), "w:gz") as tar:
        for item in os.listdir('.'):
          tar.add(item, arcname=item)    
    os.chdir(parent_dir)

compress(str(Path("code")))

s3_model_uri=S3Uploader.upload(local_path="code.tar.gz", desired_s3_uri=f"s3://{sess.default_bucket()}/{model_subfolder}")

## Smashing model in sagemaker training job

In [None]:
settings = dict(
    sagemaker_session=sess,
    role=role,
    instance_type="ml.g6e.xlarge",
    volume_size=250,
    dependencies='./code/requirements.txt'
)

@remote(**settings)
def smash_model_hf(model_id):

    # Hugging face lib will store model artifacts here, needed due to size of model
    import os
    os.environ['HF_HOME'] = '/tmp/cache/'

    # getting a new sagemaker session on the training instance
    remote_session = sagemaker.Session(boto_session=boto3.session.Session(region_name=region))

    # S3 download of code + uncompress
    # This is quicker than using the source_dir which repackages the whole model
    S3Downloader.download(s3_uri=f"s3://{remote_session.default_bucket()}/{model_subfolder}/code.tar.gz",local_path="/tmp")
    tar = tarfile.open("/tmp/code.tar.gz", "r:gz")
    tar.extractall(path="/tmp/code")
    tar.close()

    # Moving the files to the sagemaker required output folder
    shutil.copytree("/tmp/code", "/opt/ml/model/code", dirs_exist_ok = True)
    login(token = HF_TOKEN)
    
    pipe = FluxKontextPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    # Pruna configuration
    smash_config = SmashConfig(device='cuda')
    smash_config["cacher"] = "fora"
    smash_config["fora_interval"] = 2  # 3, 4
    smash_config["compiler"] = "torch_compile"
    smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
    smash_config["quantizer"] = "torchao"
    smash_config["torchao_quant_type"] = "int8dq"
    smash_config["torchao_excluded_modules"] = "norm+embedding"
    smash_config["torch_compile_make_portable"] = True

    smashed_pipe = smash(model=pipe, smash_config=smash_config)

    # running inference post-optimisation to construct the compilation graph
    prompt = "Add a fun hat to the dog on the right and a top hat to the dog on the left"
    input_image = load_image("https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg")
    pipe(
        image=input_image,
        prompt=prompt,
        guidance_scale=2.5,
        num_inference_steps=50
    )["images"]

    smashed_pipe.save_pretrained("/opt/ml/model/flux_smashed")

    return os.environ.get('TRAINING_JOB_ARN')


# Starting the training job

In [None]:
# this can take 30-90min
training_job_arn = smash_model_hf("black-forest-labs/FLUX.1-Kontext-dev")
training_job_arn


In [None]:
# retrieving the name from the job arn
training_job_name = training_job_arn.split('/')[-1]

# Get the model artifact S3 URI
job_description = sess.describe_training_job(training_job_name)
model_data = job_description['ModelArtifacts']['S3ModelArtifacts']
model_data

## Deploy model to sagemaker real time inference

In [None]:
env={
   'SAGEMAKER_MODEL_SERVER_TIMEOUT':'1200', 
}

huggingface_model = HuggingFaceModel(
   env=env,
   model_data=model_data,      
   role=role,                    
   transformers_version="4.49.0",  
   pytorch_version="2.6.0",       # pytorch version used, which will be updated at runtime
   py_version='py312'            # python version used for deployment
)

predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g6e.xlarge",
    model_data_download_timeout=1200, 
    container_startup_health_check_timeout=1200
)

# Testing inference

In [None]:
# helper functions for base64 and images

def decode_base64_image(image_string):
  base64_image = base64.b64decode(image_string)
  buffer = BytesIO(base64_image)
  return Image.open(buffer)

def display_images(images=None,columns=3, width=100, height=100):
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.axis('off')
        plt.imshow(image)

def encode_image(image):
  buffered = BytesIO()
  image.save(buffered, format="JPEG")
  img_str = base64.b64encode(buffered.getbuffer()).decode()
  return img_str

In [None]:
# free to use image according to license in pexels - https://www.pexels.com/license/
input_image = load_image("https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg")
str_imge = encode_image(input_image)

# run prediction
prompt = "Add a fun hat to the dog on the right and a top hat to the dog on the left"

response = predictor.predict({
  "input_image": str_imge,
  "inputs": prompt,
  "guidance_scale": 2.5,
  "num_inference_steps": 50
  }
)

In [None]:
# decode images
decoded_images = [decode_base64_image(image) for image in response["generated_images"]]

# visualize generation
display_images(decoded_images)

# Cleanup resources

In [None]:
predictor.delete_endpoint()