<a href="https://colab.research.google.com/github/Witcape/finetune_stable_diffusion/blob/main/pokemon_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine Tune Stable Diffusion

Fine tuning Stable Diffusion on Pokemon,
for more details see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples).

We recommend using a multi-GPU machine, for example an instance from [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud). If running on Colab this notebook is likely to need a GPU with >16GB of VRAM and a runtime with high RAM, which will almost certainly need Colab Pro or Pro+. (If you get errors suchs as `Killed` or `CUDA out of memory` then one of these is not sufficient)

In [None]:
!git clone https://github.com/justinpinkney/stable-diffusion.git
%cd stable-diffusion
!pip install --upgrade pip
!pip install -r requirements.txt

!pip install --upgrade keras # on lambda stack we need to upgrade keras
!pip uninstall -y torchtext # on colab we need to remove torchtext

Cloning into 'stable-diffusion'...
remote: Enumerating objects: 1755, done.[K
remote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 1755 (delta 3), reused 1 (delta 0), pack-reused 1747[K
Receiving objects: 100% (1755/1755), 73.93 MiB | 29.51 MiB/s, done.
Resolving deltas: 100% (1082/1082), done.
/content/stable-diffusion
Collecting pip
  Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.0
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Obtaining taming-transformers from git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers (from -r requi

In [None]:
!nvidia-smi

In [None]:
# Check the dataset
from datasets import load_dataset
ds = load_dataset("logo-wizard/modern-logo-dataset", split="train")
sample = ds[0]
display(sample["image"].resize((256, 256)))
print(sample["text"])

To get the weights you need to you'll need to [go to the model card](https://huggingface.co/CompVis/stable-diffusion-v1-4-original), read the license and tick the checkbox if you agree.

In [None]:
!pip install huggingface_hub
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt", use_auth_token=True)

Set your parameters below depending on your GPU setup, the settings below were used for training on a 2xA6000 machine, (the A6000 has 48GB of VRAM). On this set up good results are achieved in around 6 hours.

You can make up for using smaller batches or fewer gpus by accumulating batches:

`total batch size = batach size * n gpus * accumulate batches`

In [11]:
# 2xA6000:
BATCH_SIZE = 4
N_GPUS = 1
ACCUMULATE_BATCHES = 1

gpu_list = ",".join((str(x) for x in range(N_GPUS)))
print(f"Using GPUs: {gpu_list}")

Using GPUs: 0


In [14]:
gpu_list

'0'

In [14]:
# Run training
!(python main.py \
    -t \
    --base configs/stable-diffusion/logo.yaml \
    --gpus "$gpu_list" \
    --scale_lr False \
    --num_nodes 1 \
    --check_val_every_n_epoch 10 \
    --finetune_from "$ckpt_path" \
    data.params.batch_size="$BATCH_SIZE" \
    lightning.trainer.accumulate_grad_batches="$ACCUMULATE_BATCHES" \
    data.params.validation.params.n_gpus="$NUM_GPUS" \
)

usage: main.py [-h] [--finetune_from [FINETUNE_FROM]] [-n [NAME]] [-r [RESUME]]
               [-b [base_config.yaml ...]] [-t [TRAIN]] [--no-test [NO_TEST]] [-p PROJECT]
               [-d [DEBUG]] [-s SEED] [-f POSTFIX] [-l LOGDIR] [--scale_lr [SCALE_LR]]
               [--logger [LOGGER]] [--checkpoint_callback [CHECKPOINT_CALLBACK]]
               [--default_root_dir DEFAULT_ROOT_DIR] [--gradient_clip_val GRADIENT_CLIP_VAL]
               [--gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM]
               [--process_position PROCESS_POSITION] [--num_nodes NUM_NODES]
               [--num_processes NUM_PROCESSES] [--devices DEVICES] [--gpus GPUS]
               [--auto_select_gpus [AUTO_SELECT_GPUS]] [--tpu_cores TPU_CORES] [--ipus IPUS]
               [--log_gpu_memory LOG_GPU_MEMORY]
               [--progress_bar_refresh_rate PROGRESS_BAR_REFRESH_RATE]
               [--overfit_batches OVERFIT_BATCHES] [--track_grad_norm TRACK_GRAD_NORM]
               [--check_val_every_n_epoch C

In [None]:
# Run the model
!(python scripts/txt2img.py \
    --prompt 'robotic cat with wings' \
    --outdir 'outputs/generated_pokemon' \
    --H 512 --W 512 \
    --n_samples 4 \
    --config 'configs/stable-diffusion/pokemon.s' \
    --ckpt 'path/to/your/checkpoint')