## SDXL Fine Tuning

This is tested on SageMaker notebook instance using `conda_pytorch_p310` kernel
### Setup

In [3]:
%%sh

export PIP_ROOT_USER_ACTION=ignore

pip install -Uq pip
pip install autotrain-advanced==0.6.58
pip install diffusers==0.21.4
pip install autocrop

Collecting autotrain-advanced==0.6.58
  Downloading autotrain_advanced-0.6.58-py3-none-any.whl.metadata (10 kB)
Collecting albumentations==1.3.1 (from autotrain-advanced==0.6.58)
  Downloading albumentations-1.3.1-py3-none-any.whl.metadata (34 kB)
Collecting codecarbon==2.2.3 (from autotrain-advanced==0.6.58)
  Downloading codecarbon-2.2.3-py3-none-any.whl.metadata (6.5 kB)
Collecting datasets~=2.14.0 (from datasets[vision]~=2.14.0->autotrain-advanced==0.6.58)
  Downloading datasets-2.14.7-py3-none-any.whl.metadata (19 kB)
Collecting evaluate==0.3.0 (from autotrain-advanced==0.6.58)
  Downloading evaluate-0.3.0-py3-none-any.whl.metadata (9.1 kB)
Collecting ipadic==1.0.0 (from autotrain-advanced==0.6.58)
  Downloading ipadic-1.0.0.tar.gz (13.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.4/13.4 MB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'do

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
flask 3.0.2 requires Werkzeug>=3.0.0, but you have werkzeug 2.3.6 which is incompatible.
pathos 0.3.2 requires dill>=0.3.8, but you have dill 0.3.7 which is incompatible.
pathos 0.3.2 requires multiprocess>=0.70.16, but you have multiprocess 0.70.15 which is incompatible.[0m[31m


Successfully installed Mako-1.3.3 Pillow-10.0.0 PyWavelets-1.6.0 absl-py-2.1.0 accelerate-0.25.0 aiofiles-23.2.1 aiohttp-3.9.5 aiosignal-1.3.1 albumentations-1.3.1 alembic-1.13.1 altair-5.3.0 anyio-3.7.1 async-timeout-4.0.3 autotrain-advanced-0.6.58 bitsandbytes-0.41.0 cmaes-0.10.0 codecarbon-2.2.3 colorlog-6.8.2 datasets-2.14.7 diffusers-0.21.4 dill-0.3.7 docstring-parser-0.16 einops-0.6.1 evaluate-0.3.0 fastapi-0.104.1 ffmpy-0.3.2 frozenlist-1.4.1 fsspec-2023.10.0 fuzzywuzzy-0.18.0 gradio-3.41.0 gradio-client-0.5.0 greenlet-3.0.3 grpcio-1.62.2 hf-transfer-0.1.6 huggingface-hub-0.22.2 inflate64-1.0.0 invisible-watermark-0.2.0 ipadic-1.0.0 jiwer-3.0.2 joblib-1.3.1 lazy-loader-0.4 loguru-0.7.0 markdown-3.6 markdown-it-py-3.0.0 mdurl-0.1.2 multidict-6.0.5 multiprocess-0.70.15 multivolumefile-0.2.3 nltk-3.8.1 opencv-python-headless-4.9.0.80 optuna-3.3.0 orjson-3.10.1 packaging-23.1 peft-0.7.1 protobuf-4.23.4 py-cpuinfo-9.0.0 py7zr-0.20.6 pyarrow-hotfix-0.6 pybcj-1.0.2 pycryptodomex-3.20.0

[0m

### > Check version

In [1]:
import torch

print(torch.__version__) # e.g., 2.0.0 at time of post

print(torch.cuda.get_device_name(0)) # e.g., NVIDIA A10G

device_count = torch.cuda.device_count()
assert device_count > 0, "No GPU devices detected."

print("Number of available GPU devices:", device_count)

device = torch.device("cuda")

2.1.0
NVIDIA A10G
Number of available GPU devices: 1


### > Prepare the images. The picture needs to be 1024 x 1024

In [3]:
from tqdm.notebook import tqdm
import os
from pathlib import Path
from itertools import chain
import utils
import shutil

imag_dir=Path("data") #source directory to place your image
dest_dir = Path("cropped") # destination directory after image processing
if dest_dir.exists():
    shutil.rmtree(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)

for n,img_path in enumerate(chain(imag_dir.glob("*.[jJ][pP]*[Gg]"),imag_dir.glob("*.[Pp][Nn][Gg]"))):
    try:
        cropped = utils.resize_and_center_crop(img_path.as_posix(), 1024)
        cropped.save(dest_dir / f"image_{n}.png")
    except ValueError:
        print(f"Could not detect face in {img_path}. Skipping.")
        continue

print("Here are the preprocessed images ==========")
[x.as_posix() for x in dest_dir.iterdir() if x.is_file()]



['cropped/image_9.png',
 'cropped/image_0.png',
 'cropped/image_3.png',
 'cropped/image_8.png',
 'cropped/image_7.png',
 'cropped/image_6.png',
 'cropped/image_4.png',
 'cropped/image_1.png',
 'cropped/image_2.png',
 'cropped/image_5.png']

- 8bit adam gobbles the images
- prior-preservation exceeds A10G GPU memory
- xformers gives package error

### > Initialize fine tuning parameters

In [4]:
!rm -rf `find -type d -name .ipynb_checkpoints`

In [5]:
# project configuration
project_name = "finetune_sttirum"
%store project_name

model_name_base = "stabilityai/stable-diffusion-xl-base-1.0"

# fine-tuning prompts
instance_prompt = "photo of <<TOK>>"
class_prompt = "photo of a person"

# fine-tuning hyperparameters
learning_rate = 1e-4
num_steps = 500
batch_size = 1
gradient_accumulation = 4
resolution = 1024
num_class_image = 50

class_image_path=Path(f"/tmp/priors")

# environment variables for autotrain command
os.environ["PROJECT_NAME"] = project_name
os.environ["MODEL_NAME"] = model_name_base
os.environ["INSTANCE_PROMPT"] = instance_prompt
os.environ["CLASS_PROMPT"] = class_prompt
os.environ["IMAGE_PATH"] = dest_dir.as_posix()
os.environ["LEARNING_RATE"] = str(learning_rate)
os.environ["NUM_STEPS"] = str(num_steps)
os.environ["BATCH_SIZE"] = str(batch_size)
os.environ["GRADIENT_ACCUMULATION"] = str(gradient_accumulation)
os.environ["RESOLUTION"] = str(resolution)
os.environ["CLASS_IMAGE_PATH"] = class_image_path.as_posix()
os.environ["NUM_CLASS_IMAGE"] = str(num_class_image)

Stored 'project_name' (str)


### > use autotrain to fine tune

help command will show all the available parameters

```
!autotrain dreambooth --help
```

In [6]:
!autotrain dreambooth \
    --model ${MODEL_NAME} \
    --project-name ${PROJECT_NAME} \
    --image-path "${IMAGE_PATH}" \
    --prompt "${INSTANCE_PROMPT}" \
    --class-prompt "${CLASS_PROMPT}" \
    --resolution ${RESOLUTION} \
    --batch-size ${BATCH_SIZE} \
    --num-steps ${NUM_STEPS} \
    --gradient-accumulation ${GRADIENT_ACCUMULATION} \
    --lr ${LEARNING_RATE} \
    --fp16 \
    --gradient-checkpointing

> [1mINFO    Namespace(version=False, model='stabilityai/stable-diffusion-xl-base-1.0', revision=None, tokenizer=None, image_path='cropped', class_image_path=None, prompt='photo of sttirum', class_prompt='photo of a person', num_class_images=100, class_labels_conditioning=None, prior_preservation=None, prior_loss_weight=1.0, project_name='finetune_sttirum', seed=42, resolution=1024, center_crop=None, train_text_encoder=None, batch_size=1, sample_batch_size=4, epochs=1, num_steps=1000, checkpointing_steps=100000, resume_from_checkpoint=None, gradient_accumulation=4, gradient_checkpointing=True, lr=0.0001, scale_lr=None, scheduler='constant', warmup_steps=0, num_cycles=1, lr_power=1.0, dataloader_num_workers=0, use_8bit_adam=None, adam_beta1=0.9, adam_beta2=0.999, adam_weight_decay=0.01, adam_epsilon=1e-08, max_grad_norm=1.0, allow_tf32=None, prior_generation_precision=None, local_rank=-1, xformers=None, pre_compute_text_embeddings=None, tokenizer_max_length=None, text_encoder_use_atten

### > Load the fine tune model

In [5]:
model_name_base = "stabilityai/stable-diffusion-xl-base-1.0"
project_name = "finetune_sttirum"


In [None]:
from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline

pipeline = DiffusionPipeline.from_pretrained(
    model_name_base,
    torch_dtype=torch.float16,
).to(device)

pipeline.load_lora_weights(
    project_name, 
    weight_name="pytorch_lora_weights.safetensors",
    adapter_name="sttirum"
)

In [54]:
prompt = """photo of <<TOK>>, Pixar 3d portrait, ultra detailed, gorgeous, 3d zbrush, trending on dribbble, 8k render"""
negative_prompt = """ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, 
watermark, grainy, signature, cut off, draft, amateur, multiple, gross, weird, uneven, furnishing, decorating, decoration, furniture, text, poor, low, basic, worst, juvenile, 
unprofessional, failure, crayon, oil, label, thousand hands"""

In [None]:
import random

seed = random.randint(0, 100000)
generator = torch.Generator(device).manual_seed(seed)
base_image = pipeline(
    prompt=prompt, 
    negative_prompt=negative_prompt,
    num_inference_steps=50,
    generator=generator,
    height=1024,
    width=1024,
    output_type="pil",
).images[0]
base_image