# Fine Tune Stable Diffusion

Fine tuning Stable Diffusion on Pokemon, 
for more details see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples). 

We recommend using a multi-GPU machine, for example an instance from [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud). If running on Colab this notebook is likely to need a GPU with >16GB of VRAM and a runtime with high RAM, which will almost certainly need Colab Pro or Pro+. (If you get errors suchs as `Killed` or `CUDA out of memory` then one of these is not sufficient)

In [None]:
!pip install --upgrade pip
!pip install -r requirements.txt

# !pip install --upgrade keras # on lambda stack we need to upgrade keras
# !pip uninstall -y torchtext # on colab we need to remove torchtext

[0mObtaining taming-transformers from git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers (from -r requirements.txt (line 20))
  Updating ./src/taming-transformers clone (to revision master)
  Running command git fetch -q --tags
  Running command git reset --hard -q 24268930bf1dce879235a7fddd0b2355b84d7ea6
  Preparing metadata (setup.py) ... [?25ldone
[?25hObtaining clip from git+https://github.com/openai/CLIP.git@main#egg=clip (from -r requirements.txt (line 21))
  Updating ./src/clip clone (to revision main)
  Running command git fetch -q --tags
  Running command git reset --hard -q d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
  Preparing metadata (setup.py) ... [?25ldone
[?25hObtaining file:///root/hackathon/stable-difussion-hackathon (from -r requirements.txt (line 22))
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting albumentations==0.4.3
  Using cached albumentations-0.4.3-py3-none-any.whl
Collecting opencv-python==4.5.5.64
  Usin

In [None]:
!nvidia-smi

In [None]:
# Check the dataset (from huggingface)
from datasets import load_dataset
ds = load_dataset("lambdalabs/pokemon-blip-captions", split="train")
sample = ds[0]
display(sample["image"].resize((256, 256)))
print(sample["text"])

To get the weights you need to you'll need to [go to the model card](https://huggingface.co/CompVis/stable-diffusion-v1-4-original), read the license and tick the checkbox if you agree.

In [None]:
DOWNLOAD_FROM_HUGGING_FACE = False

if DOWNLOAD_FROM_HUGGING_FACE:
    !pip install huggingface_hub
    from huggingface_hub import notebook_login

    notebook_login()

In [None]:
if DOWNLOAD_FROM_HUGGING_FACE:
    from huggingface_hub import hf_hub_download
    ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt", use_auth_token=True)
else:
    ckpt_path = "../sd-v1-4-full-ema.ckpt"

Set your parameters below depending on your GPU setup, the settings below were used for training on a 2xA6000 machine, (the A6000 has 48GB of VRAM). On this set up good results are achieved in around 6 hours.

You can make up for using smaller batches or fewer gpus by accumulating batches:

`total batch size = batach size * n gpus * accumulate batches`

In [None]:
"""# 2xA6000:
BATCH_SIZE = 1
N_GPUS = 4
ACCUMULATE_BATCHES = 1

GPU_LIST = ",".join((str(x) for x in range(N_GPUS)))
print(f"Using GPUs: {GPU_LIST}")"""

In [56]:
# Run training
!(python main.py \
    -t \
    --base configs/stable-diffusion/glovo.yaml \
    --gpus 0 \
    --scale_lr False \
    --num_nodes 1 \
    --check_val_every_n_epoch 10 \
    --finetune_from ../sd-v1-4-full-ema.ckpt \
    data.params.batch_size=1 \
    lightning.trainer.accumulate_grad_batches=1 \
    data.params.validation.params.n_gpus=1 \
)

Global seed set to 23
Running on GPUs 0,1,2,3
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
Keeping EMAs of 688.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.15.self_attn.q_proj.weight', 'vision_model.encoder.layers.15.layer_norm2.bias', 'vision_model.encoder.layers.18.mlp.fc2.weight', 'vision_model.encoder.layers.20.self_attn.v_proj.weight', 'vision_model.encoder.layers.5.mlp.fc1.bias', 'vision_model.encoder.layers.6.self_attn.v_proj.weight', 'vision_model.encoder.layers.19.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.weight', 'vision_model.encoder.layers.16.layer_norm1.bias', 'vision_model.encoder.layers.18.self_attn.out_proj.weight', 'vision_mode

In [None]:
# Run the model
PROMT = 'robotic cat with wings'
VERSION = 'v0'
!(python scripts/txt2img.py \
    --prompt PROMPT \
    --outdir f'../outputs/{VERSION}' \
    --H 512 --W 512 \
    --n_samples 4 \
    --config 'configs/stable-diffusion/glovo.yaml' \
    --ckpt 'path/to/your/checkpoint')

In [None]:
# open image
from PIL import Image
im = Image.open(f"../outputs/{VERSION}/grid-0000.png").resize((1024, 256))
display(im)
print("robotic cat with wings")