# Fine tune Flux with DreamBooth LoRA Hugging Face Diffusers

### Download the dataset

In this example, we use a dataset of dog images. You can replace this with your own dataset. Only 5 images are needed for Dreambooth fine-tuning.


In [None]:
from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir,
    repo_type="dataset",
    ignore_patterns=".gitattributes",
)

In [None]:
# training job will fail if the dog/ contains a sub directory
!rm -rf dog/.cache

Import necessary libraries. We use dot env to load api keys from environment variables so they are not hardcoded in the notebook. We use the sagemaker modeltrainer api to kick off the training job.


In [None]:
import os
from dotenv import load_dotenv
from sagemaker.modules import Session
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.distributed import Torchrun
from sagemaker.modules.configs import (
    CheckpointConfig,
    Compute,
    SourceCode,
    InputData,
    StoppingCondition,
    S3DataSource,
)

load_dotenv()

In [None]:
environment = {
    'HF_TOKEN': os.environ["HF_TOKEN"],
    'WANDB_API_KEY': os.environ["WANDB_API_KEY"],
}

The StoppingCondition max_runtime_in_seconds allows us to use warm pools for the training job. This reduces startup time for the training job if we are going to be using the same instance type for multiple training jobs.

For this example script we use a single p4de instance and we run on a single GPU. You can also train on multple GPUs by editing the script.sh in the scripts folder and uncommenting the multi GPU launch command and commenting the single GPU launch command.

This script will also run on a p5 instance.

You can adjust the training parameters in the script.sh file. The training parameters are passed to the training script as command line arguments. The training script is located in the scripts folder.


In [None]:
sess = Session()
bucket = sess.default_bucket()
base_job_name = "flux-fine-tune"

image = '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.7.1-gpu-py312-cu128-ubuntu22.04-sagemaker'
# define the script to be run
source_code = SourceCode(
    source_dir="scripts/",
    entry_script="script.sh",
)

compute = Compute(
    instance_count=1,
    instance_type="ml.p4de.24xlarge",
    keep_alive_period_in_seconds=3600,
)

stopping_condition = StoppingCondition(max_runtime_in_seconds=18000)

checkpoint_config = CheckpointConfig(
    s3_uri=f"s3://{sess.default_bucket()}/{base_job_name}/checkpoints/",
)

# define the ModelTrainer
model_trainer = ModelTrainer(
    sagemaker_session=sess,
    training_image=image,
    source_code=source_code,
    base_job_name=base_job_name,
    compute=compute,
    environment=environment,
    stopping_condition=stopping_condition,
    checkpoint_config=checkpoint_config,
)

Upload the dataset to S3. The S3 URI should not have a trailing slash.


In [None]:
from sagemaker.s3 import S3Uploader

S3Uploader.upload(
    local_path="./dog/",
    # cannot have a trailing slash in the S3 URI
    desired_s3_uri=f"s3://{bucket}/flux/dog",
)

In [None]:
data = InputData(
    channel_name="train",
    data_source=S3DataSource(
        s3_data_type="S3Prefix",
        s3_uri=f"s3://{sess.default_bucket()}/flux/dog",
        s3_data_distribution_type="FullyReplicated",
    ),
)

Kick off the training job. The training job will use the script.sh file in the scripts folder to run the training script.


In [None]:
model_trainer.train(input_data_config=[data], wait=False)