# Fine Tune Stable Diffusion

Fine tuning Stable Diffusion on Pokemon, 
for more details see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples). 

In [None]:
!git clone https://github.com/justinpinkney/stable-diffusion.git
%cd stable-diffusion
!pip install --upgrade pip
!pip install -r requirements.txt
# on lambda stack we need to upgrade keras
!pip install --upgrade keras

In [None]:
# Check the dataset
from datasets import load_dataset
ds = load_dataset("lambdalabs/pokemon-blip-captions", split="train")
sample = ds[0]
display(sample["image"].resize((256, 256)))
print(sample["text"])

To get the weights you need to you'll need to [go to the model card](https://huggingface.co/CompVis/stable-diffusion-v1-4-original), read the license and tick the checkbox if you agree.

In [None]:
!pip install huggingface_hub
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt", use_auth_token=True)

Set your parameters below depending on your GPU setup, the settings below were used for training on a 2xA6000 machine, the A6000 has 48GB of VRAM

You can make up for using smaller batches or fewer gpus by accumulating batches:

`total batch size = batach size * n gpus * accumulate batches`

In [None]:
BATCH_SIZE = 4
N_GPUS = 2
ACCUMULATE_BATCHES = 1 
gpu_list = ",".join((str(x) for x in range(N_GPUS)))

In [None]:
# Run training
!(python main.py \
    -t \
    --base configs/stable-diffusion/pokemon.yaml \
    --gpus "$gpu_list" \
    --scale_lr False \
    --num_nodes 1 \
    --check_val_every_n_epoch 10 \
    --finetune_from "$ckpt_path" \
    data.params.batch_size="$BATCH_SIZE" \
    lightning.trainer.accumulate_grad_batches="$ACCUMULATE_BATCHES" \
)

In [None]:
# Run the model
!(python scripts/txt2img.py \
    --prompt 'robotic cat with wings' \
    --outdir 'outputs/generated_pokemon' \
    --H 512 --W 512 \
    --n_samples 4 \
    --config 'configs/stable-diffusion/pokemon.yaml' \
    --ckpt 'path/to/your/checkpoint')