Here we will show you how to use stable-diffusion-webui to train image with Lora support. The stable-diffusion-webui will be hosted at Amazon SageMaker training job.

In [None]:
import sagemaker
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

import boto3
account_id = boto3.client('sts').get_caller_identity().get('Account')
region_name = boto3.session.Session().region_name

Prepare models directory and organize the structure as following.

In [None]:
!mkdir -p models
!mkdir -p models/Stable-diffusion

Logout from AWS public ECR to avoid the authentication token is expired.

In [None]:
!docker logout public.ecr.aws

Build Docker image and push to ECR.

In [None]:
!./build_and_push.sh.lite $region_name

Install Huggingface Hub toolkit and login with your Huggingface access token.

In [None]:
%pip install huggingface_hub
!huggingface-cli login --token [Your-huggingface-access-token]

Download Stable-diffuion models.

In [None]:
from huggingface_hub import hf_hub_download
hf_hub_download(
    repo_id="stabilityai/stable-diffusion-2-1", 
    filename="v2-1_768-ema-pruned.ckpt", 
    local_dir="models/Stable-diffusion/"
)
hf_hub_download(
    repo_id="runwayml/stable-diffusion-v1-5", 
    filename="v1-5-pruned.ckpt", 
    local_dir="models/Stable-diffusion/"
)
!wget "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" -O models/Stable-diffusion/v2-1_768-ema-pruned.yaml

Download s5cmd which is a very fast S3 and local filesystem execution tool and place it under directory - tools/.

In [None]:
!wget https://github.com/peak/s5cmd/releases/download/v2.0.0/s5cmd_2.0.0_Linux-64bit.tar.gz -O tools/s5cmd_2.0.0_Linux-64bit.tar.gz
!tar xzvf tools/s5cmd_2.0.0_Linux-64bit.tar.gz -C tools/

Upload Stable-diffusion models to S3.

In [3]:
model_data = "s3://{0}/stable-diffusion-webui/models".format(bucket)

In [None]:
!tools/s5cmd cp models/Stable-diffusion $model_data

In [None]:
images_s3uri = 's3://{0}/stable-diffusion-webui/images/'.format(bucket)

In [None]:
!aws s3 cp images/training/Dreambooth $images_s3uri --recursive

In [None]:
db_params = {
  "db_model_name": [],
  "db_attention": "xformers",
  "db_cache_latents": True,
  "db_center_crop": False,
  "db_freeze_clip_normalization": False,
  "db_clip_skip": 1,
  "db_concepts_path": "",
  "db_custom_model_name": "",
  "db_epochs": "",
  "db_epoch_pause_frequency": 0.0,
  "db_epoch_pause_time": 0.0,
  "db_gradient_accumulation_steps": 1,
  "db_gradient_checkpointing": True,
  "db_gradient_set_to_none": True,
  "db_graph_smoothing": 50.0,
  "db_half_model": False,
  "db_hflip": False,
  "db_learning_rate": 2e-06,
  "db_learning_rate_min": 1e-06,
  "db_lora_learning_rate": 0.0002,
  "db_lora_model_name": [],
  "db_lora_rank": 4,
  "db_lora_txt_learning_rate": 0.0002,
  "db_lora_txt_weight": 1,
  "db_lora_weight": 1,
  "db_lr_cycles": 1,
  "db_lr_factor": 0.5,
  "db_lr_power": 1,
  "db_lr_scale_pos": 0.5,
  "db_lr_scheduler": "constant_with_warmup",
  "db_lr_warmup_steps": 0,
  "db_max_token_length": 75,
  "db_mixed_precision": "fp16",
  "db_adamw_weight_decay": 0.01,
  "db_model_path": "",
  "db_num_train_epochs": 100,
  "db_pad_tokens": True,
  "db_pretrained_vae_name_or_path": "",
  "db_prior_loss_scale": False,
  "db_prior_loss_target": 100.0,
  "db_prior_loss_weight": 0.75,
  "db_prior_loss_weight_min": 0.1,
  "db_resolution": 768,
  "db_revision": "",
  "db_sample_batch_size": 1,
  "db_sanity_prompt": "",
  "db_sanity_seed": 420420.0,
  "db_save_ckpt_after": True,
  "db_save_ckpt_cancel": False,
  "db_save_ckpt_during": False,
  "db_save_embedding_every": 0,
  "db_save_lora_after": True,
  "db_save_lora_cancel": False,
  "db_save_lora_during": False,
  "db_save_preview_every": 0,
  "db_save_safetensors": False,
  "db_save_state_after": False,
  "db_save_state_cancel": False,
  "db_save_state_during": False,
  "db_scheduler": "",
  "db_src": "",
  "db_shuffle_tags": True,
  "db_snapshot": [],
  "db_train_batch_size": 1,
  "db_train_imagic_only": False,
  "db_train_unet": True,
  "db_stop_text_encoder": 1,
  "db_use_8bit_adam": True,
  "db_use_concepts": False,
  "db_train_unfrozen": False,
  "db_use_ema": False,
  "db_use_lora": True,
  "db_use_subdir": True,
  "c1_class_data_dir": "",
  "c1_class_guidance_scale": 7.5,
  "c1_class_infer_steps": 40,
  "c1_class_negative_prompt": "",
  "c1_class_prompt": "dog",
  "c1_class_token": "dog",
  "c1_instance_data_dir": "/opt/ml/input/data/concepts",
  "c1_instance_prompt": "jp-style-girl",
  "c1_instance_token": "jp-style-girl",
  "c1_n_save_sample": 1,
  "c1_num_class_images": 0,
  "c1_num_class_images_per": 10,
  "c1_sample_seed": -1,
  "c1_save_guidance_scale": 7.5,
  "c1_save_infer_steps": 40,
  "c1_save_sample_negative_prompt": "broke a finger, ugly, duplicate, morbid, mutilated, tranny, trans, trannsexual, hermaphrodite, extra fingers, fused fingers, too many fingers, long neck, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad anatomy, bad proportions, malformed limbs, extra limbs, cloned face, disfigured, gross proportions, missing arms, missing legs, extra arms, extra legs, artist name, jpeg artifacts",
  "c1_save_sample_prompt": "jp-style-girl",
  "c1_save_sample_template": "",
  "c2_class_data_dir": "",
  "c2_class_guidance_scale": 7.5,
  "c2_class_infer_steps": 40,
  "c2_class_negative_prompt": "",
  "c2_class_prompt": "",
  "c2_class_token": "",
  "c2_instance_data_dir": "",
  "c2_instance_prompt": "",
  "c2_instance_token": "",
  "c2_n_save_sample": 1,
  "c2_num_class_images": 0,
  "c2_num_class_images_per": 0,
  "c2_sample_seed": -1,
  "c2_save_guidance_scale": 7.5,
  "c2_save_infer_steps": 40,
  "c2_save_sample_negative_prompt": "",
  "c2_save_sample_prompt": "",
  "c2_save_sample_template": "",
  "c3_class_data_dir": "",
  "c3_class_guidance_scale": 7.5,
  "c3_class_infer_steps": 40,
  "c3_class_negative_prompt": "",
  "c3_class_prompt": "",
  "c3_class_token": "",
  "c3_instance_data_dir": "",
  "c3_instance_prompt": "",
  "c3_instance_token": "",
  "c3_n_save_sample": 1,
  "c3_num_class_images": 0,
  "c3_num_class_images_per": 0,
  "c3_sample_seed": -1,
  "c3_save_guidance_scale": 7.5,
  "c3_save_infer_steps": 40,
  "c3_save_sample_negative_prompt": "",
  "c3_save_sample_prompt": "",
  "c3_save_sample_template": "",
  "c4_class_data_dir": "",
  "c4_class_guidance_scale": 7.5,
  "c4_class_infer_steps": 40,
  "c4_class_negative_prompt": "",
  "c4_class_prompt": "",
  "c4_class_token": "",
  "c4_instance_data_dir": "",
  "c4_instance_prompt": "",
  "c4_instance_token": "",
  "c4_n_save_sample": 1,
  "c4_num_class_images": 0,
  "c4_num_class_images_per": 0,
  "c4_sample_seed": -1,
  "c4_save_guidance_scale": 7.5,
  "c4_save_infer_steps": 40,
  "c4_save_sample_negative_prompt": "",
  "c4_save_sample_prompt": "",
  "c4_save_sample_template": ""
}

In [None]:
import uuid, json
db_config_id = str(uuid.uuid4())
db_config_file =f'{db_config_id}.json'
json.dump(db_params, open(db_config_file,'w'), indent=6)

In [None]:
image_uri = '{0}.dkr.ecr.{1}.amazonaws.com/all-in-one-ai-stable-diffusion-webui-training-api'.format(account_id, region_name)
sd_models_s3uri = 's3://{0}/stable-diffusion-webui/models/768-v-ema'.format(bucket)
db_models_s3uri = 's3://{0}/stable-diffusion-webui/dreambooth/'.format(bucket)
lora_models_s3uri = 's3://{0}/stable-diffusion-webui/lora/'.format(bucket)
db_config_s3uri = 's3://{0}/stable-diffusion-webui/dreambooth-config/'.format(bucket)

In [None]:
print(db_config_file)
print(db_config_s3uri)
!aws s3 cp $db_config_file $db_config_s3uri

In [None]:
def json_encode_hyperparameters(hyperparameters):
    for (k, v) in hyperparameters.items():
        print(k, v)
    
    return {str(k): json.dumps(v) for (k, v) in hyperparameters.items()}

train_args = {
    'train_dreambooth_settings': {
        'db_create_new_db_model': True, 
        'db_use_txt2img': True,
        'db_new_model_name': 'new-dreambooth-model-001', 
        'db_new_model_src': '768-v-ema.ckpt', 
        'db_new_model_scheduler': 'ddim', 
        'db_create_from_hub': False, 
        'db_new_model_url': '', 
        'db_new_model_token': '', 
        'db_new_model_extract_ema': False, 
        'db_train_unfrozen': False,
        'db_512_model': False,
        'db_model_name': [], 
        'db_train_wizard_person': False,
        'db_train_wizard_object': False,
        'db_performance_wizard': False,
        'db_lora_model_name': [],
        'db_save_safetensors': False
    }
}

In [None]:
hyperparameters = {
    'train-args': json.dumps(train_args),
    'sd-models-s3uri': sd_models_s3uri,
    'db-models-s3uri': db_models_s3uri,
    'lora-models-s3uri': lora_models_s3uri,
    'dreambooth-config-id': db_config_id
}

hyperparameters = json_encode_hyperparameters(hyperparameters)

instance_type = 'ml.g4dn.2xlarge'

In [None]:
inputs = {
    'concepts': images_s3uri,
    'models': sd_models_s3uri,
    'config': db_config_s3uri
}

In [None]:
from sagemaker.estimator import Estimator

estimator = Estimator(
    role = role,
    instance_count=1,
    instance_type = instance_type,
    image_uri = image_uri,
    hyperparameters = hyperparameters
)

estimator.fit(inputs)