## Fine-tuning Stable Diffusion XL with DreamBooth and LoRA on a free-tier Colab Notebook 🧨

In this notebook, we show how to fine-tune [Stable Diffusion XL (SDXL)](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl) with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth) and [LoRA](https://huggingface.co/docs/diffusers/main/en/training/lora) on a T4 GPU.

SDXL consists of a much larger UNet and two text encoders that make the cross-attention context quite larger than the previous variants.

So, to pull this off, we will make use of several tricks such as gradient checkpointing, mixed-precision, and 8-bit Adam. So, hang tight and let's get started 🧪

## Setup 🪓

In [53]:
# Check the GPU
!nvidia-smi

In [54]:
# Install dependencies.
!pip install bitsandbytes transformers accelerate peft -q

Make sure to install `diffusers` from `main`.

In [55]:
!pip install git+https://github.com/huggingface/diffusers.git -q

Download diffusers SDXL DreamBooth training script.

In [56]:
!wget https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth_lora_sdxl.py

## Dataset 🐶

**Let's get our training data!**
For this example, we'll download some images from the hub

If you already have a dataset on the hub you wish to use, you can skip this part and go straight to: "Prep for
training 💻" section, where you'll simply specify the dataset name.

If your images are saved locally, and/or you want to add BLIP generated captions,
pick option 1 or 2 below.



In [57]:
import os
from google.colab import files
!pip install opendatasets
import opendatasets as od
dataset_url = 'https://www.kaggle.com/datasets/shyambhu/hands-and-palm-images-dataset'
od.download(dataset_url)
local_dir = "./hands/"
os.makedirs(local_dir)
os.chdir(local_dir)
image_dir = "/content/hands-and-palm-images-dataset/Hands/Hands"
# choose and upload local images into the newly created directory
# uploaded_images = files.upload()
# os.chdir("/content") # back to parent directory

 ahmadshaheer
 ··········


100%|██████████| 634M/634M [00:07<00:00, 91.5MB/s]


### Generate custom captions with BLIP
Load BLIP to auto caption your images:

In [58]:
import requests
from transformers import AutoProcessor, BlipForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# load the processor and the captioning model
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)

# captioning utility
def caption_images(input_image):
    inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
    pixel_values = inputs.pixel_values

    generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_caption

In [59]:
import glob
import os

# Specify the directory
local_dir = "/content/hands-and-palm-images-dataset/Hands/Hands/"

# Get a list of all .jpg files in the directory, sorted by name
all_files = sorted(glob.glob(f"{local_dir}/*.jpg"))

# Keep only the first 10 files
files_to_keep = all_files[:10]

# Get the files to delete
files_to_delete = [file for file in all_files if file not in files_to_keep]

# Delete files
for file in files_to_delete:
    os.remove(file)



In [60]:
import glob
from PIL import Image

# create a list of (Pil.Image, path) pairs
local_dir = "/content/hands-and-palm-images-dataset/Hands/Hands/"
imgs_and_paths = [(path,Image.open(path)) for path in glob.glob(f"{local_dir}/*.jpg")]
imgs_and_paths = imgs_and_paths

In [61]:
type(imgs_and_paths)

list

Now let's add the concept token identifier (e.g. TOK) to each caption using a caption prefix.
Feel free to change the prefix according to the concept you're training on!
- for this example we can use "a photo of TOK," other options include:
    - For styles - "In the style of TOK"
    - For faces - "photo of a TOK person"
- You can add additional identifiers to the prefix that can help steer the model in the right direction.
-- e.g. for this example, instead of "a photo of TOK" we can use "a photo of TOK dog" / "a photo of TOK corgi dog"

In [62]:
imgs_and_paths

[('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000006.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),
 ('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000007.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),
 ('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000004.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),
 ('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000008.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),
 ('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000010.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),
 ('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000009.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),
 ('/content/hands-and-palm-images-dataset/Hands/Hands/Hand_0000003.jpg',
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1600x1200>),

In [63]:
import json

caption_prefix = "a photo of a hand " #@param
with open(f'{local_dir}metadata.jsonl', 'w') as outfile:
  for img in imgs_and_paths:
      caption = caption_prefix + caption_images(img[1]).split("\n")[0]
      entry = {"file_name":img[0].split("/")[-1], "prompt": caption}
      json.dump(entry, outfile)
      outfile.write('\n')

Free some memory:

In [13]:
import gc

# delete the BLIP pipelines and free up some memory
del blip_processor, blip_model
gc.collect()
torch.cuda.empty_cache()

## Prep for training 💻

Initialize `accelerate`:

In [15]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!accelerate config default

Configuration already exists at /root/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.


### Log into your Hugging Face account
Pass [your **write** access token](https://huggingface.co/settings/tokens) so that we can push the trained checkpoints to the Hugging Face Hub:

In [17]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Train! 🔬

#### Set Hyperparameters ⚡
To ensure we can DreamBooth with LoRA on a heavy pipeline like Stable Diffusion XL, we're using:

* Gradient checkpointing (`--gradient_accumulation_steps`)
* 8-bit Adam (`--use_8bit_adam`)
* Mixed-precision training (`--mixed-precision="fp16"`)

### Launch training 🚀🚀🚀

To allow for custom captions we need to install the `datasets` library, you can skip that if you want to train solely
 with `--instance_prompt`.
In that case, specify `--instance_data_dir` instead of `--dataset_name`

In [18]:
!pip install datasets -q

 - Use `--output_dir` to specify your LoRA model repository name!
 - Use `--caption_column` to specify name of the cpation column in your dataset. In this example we used "prompt" to
 save our captions in the
 metadata file, change this according to your needs.

In [52]:
#!/usr/bin/env bash
!accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="/content/hands-and-palm-images-dataset/Hands/Hands" \
  --output_dir="mod_hand_LoRA" \
  --caption_column="prompt" \
  --mixed_precision="fp16" \
  --instance_prompt="a detailed photo of a hand" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=3 \
  --gradient_checkpointing \
  --learning_rate=1e-4 \
  --snr_gamma=5.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --use_8bit_adam \
  --max_train_steps=500 \
  --checkpointing_steps=100 \
  --seed="0"


### Save your model to the hub and check it out 🔥

In [20]:
from huggingface_hub import whoami
from pathlib import Path
#@markdown make sure the `output_dir` you specify here is the same as the one used for training
output_dir = "mod_hand_LoRA" #@param
username = whoami(token=Path("/root/.cache/huggingface/"))["name"]
repo_id = f"{username}/{output_dir}"

In [36]:
print(username)

In [21]:
# @markdown Sometimes training finishes succesfuly (i.e. a **.safetensores** file with the LoRA weights saved properly to your local `output_dir`) but there's not enough RAM in the free tier to push the model to the hub 🙁
# @markdown
# @markdown To mitigate this, run this cell with your training arguments to make sure your model is uploaded! 🤗

# push to the hub🔥
from train_dreambooth_lora_sdxl import save_model_card
from huggingface_hub import upload_folder, create_repo

repo_id = create_repo(repo_id, exist_ok=True).repo_id

# change the params below according to your training arguments
save_model_card(
    repo_id = repo_id,
    images=[],
    base_model="stabilityai/stable-diffusion-xl-base-1.0",
    train_text_encoder=False,
    instance_prompt="a photo of a hand",
    validation_prompt=None,
    repo_folder=output_dir,
    vae_path="madebyollin/sdxl-vae-fp16-fix",
    use_dora = False
)

upload_folder(
    repo_id=repo_id,
    folder_path=output_dir,
    commit_message="End of training",
    ignore_patterns=["step_*", "epoch_*"],
)

README.md:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Shaheer2413/mod_hand_LoRA/commit/66c77833186ee71649fb985facda8d59e8840a3a', commit_message='End of training', commit_description='', oid='66c77833186ee71649fb985facda8d59e8840a3a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Shaheer2413/mod_hand_LoRA', endpoint='https://huggingface.co', repo_type='model', repo_id='Shaheer2413/mod_hand_LoRA'), pr_revision=None, pr_num=None)

In [22]:
from IPython.display import display, Markdown

link_to_model = f"https://huggingface.co/{repo_id}"
display(Markdown("### Your model has finished training.\nAccess it here: {}".format(link_to_model)))

### Your model has finished training.
Access it here: https://huggingface.co/Shaheer2413/mod_hand_LoRA

Let's generate some images with it!

## Inference 🐕

In [23]:
import torch
from diffusers import DiffusionPipeline, AutoencoderKL

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
)
pipe.load_lora_weights(repo_id)
_ = pipe.to("cuda")

config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Fetching 17 files:   0%|          | 0/17 [00:00<?, ?it/s]

model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

text_encoder_2/config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

model.fp16.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_2/tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

tokenizer_2/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/5.14G [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

OSError: Unable to load weights from checkpoint file for '/root/.cache/huggingface/hub/models--Shaheer2413--mod_hand_LoRA/snapshots/66c77833186ee71649fb985facda8d59e8840a3a/pytorch_lora_weights.bin' at '/root/.cache/huggingface/hub/models--Shaheer2413--mod_hand_LoRA/snapshots/66c77833186ee71649fb985facda8d59e8840a3a/pytorch_lora_weights.bin'. 

In [24]:
prompt = "a photo of a hand in a beach" # @param

image = pipe(prompt=prompt, num_inference_steps=25).images[0]
image

  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 