This directory provides a training code for Stable Cascade, as well as guides to download the models you need. Specifically, you can find training scripts for the following use-cases:
- Text-to-Image
- ControlNet
- LoRA
- Image Reconstruction
A quick clarification, Stable Cascade uses Stage A & B to compress images and Stage C is used for the text-conditional learning. Therefore, it makes sense to train a LoRA or ControlNet only for Stage C. You also don't train a LoRA or ControlNet for the Stable Diffusion VAE right?
In the training configs folder we provide config files for all trainings. All config files follow a similar structure and only contain the most essential parameters you need to set. Let's take a look at the structure each config follows:
At first, you will set the run name, checkpoint-, & output-folder and which version you want to train.
experiment_id: stage_c_3b_controlnet_base
checkpoint_path: /path/to/checkpoint
output_path: /path/to/output
model_version: 3.6B
Next, you can set your Weights & Biases information if you want to use it for logging.
wandb_project: StableCascade
wandb_entity: wandb_username
Afterwards, you define the training parameters.
lr: 1.0e-4
batch_size: 256
image_size: 768
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
grad_accum_steps: 1
updates: 500000
backup_every: 50000
save_every: 2000
warmup_updates: 1
use_fsdp: False
Most, of them will be quite familiar to you probably already. A few clarification tho: updates
refers to the number of
training steps, backup_every
creates additional checkpoints, so you can revert to earlier ones if you want,
save_every
concerns how often models will be saved and sampling will be done. Furthermore, since distributed training
is essential when training large models from scratch or doing large finetunes, we have an option to use PyTorch's
Fully Shared Data Parallel (FSDP). You
can use it by setting use_fsdp: True
. Note, that you will need multiple GPUs for FSDP. However, this as mentioned
above, this is only needed for large runs. You can still train and finetune our largest models on a powerful single
machine.
Another thing we provide is training with Multi-Aspect-Ratio. You can set the aspect ratios you want in the list
for multi_aspect_ratio
.
For diffusion models, having an EMA (Exponential Moving Average) model, can drastically improve the performance of your model. To include an EMA model in your training you can set the following parameters, otherwise you can just leave them away.
ema_start_iters: 5000
ema_iters: 100
ema_beta: 0.9
Next, you can define the dataset that you want to use. Note, that the code uses webdataset for this.
webdataset_path:
- s3://path/to/your/first/dataset/on/s3
- file:/path/to/your/local/dataset.tar
You can set as many dataset paths as you want, and they can either be on
Amazon S3 storage or just local.
There are a few more specifics to each kind of training and to datasets in general. These will be discussed below.
You can start an actual training very easily by first moving to the root directory of this repository (so here). Next, the python command looks like the following:
python3 training_file training_config
For example, if you want to train a LoRA model, the command would look like this:
python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml
Moreover, we also provide a bash script for working with slurm. Note, this assumes you have access to a cluster that runs slurm as the cluster manager.
As mentioned above, the code uses webdataset for working with datasets, because this library supports working with large amounts of data very easily. In case you want to finetune a model, train a LoRA or train a ControlNet, you might not have them in a webdataset format. Therefore, here follows a simple example how you can convert your dataset into the appropriate format.
- Put all your images and captions into a folder
- Rename them to have the same number / id as the name. For example:
0000.jpg, 0000.txt, 0001.jpg, 0001.txt, 0002.jpg, 0002.txt, 0003.jpg, 0003.txt
- Run the following command:
tar --sort=name -cf dataset.tar dataset/
or manually create a tar file from the folder - Set the
webdataset_path: file:/path/to/your/local/dataset.tar
in the config file
Next, there are a few more settings that might be helpful to you, especially when working with large datasets that might contain more information about images, like some kind of variables that you want to filter for. You can apply dataset filters like the following in the config file:
dataset_filters:
- ['aesthetic_score', 'lambda s: s > 4.5']
- ['nsfw_probability', 'lambda s: s < 0.01']
In this case, you would have 0000.json, 0001.json, 0002.json, 0003.json
in your dataset as well, with keys for
aesthetic_score
and nsfw_probability
.
If you want to finetune any model you need the pretrained models. You can find details on how to download them in the models section. After downloading them, you need to modify the checkpoint paths in the config file too. See below for example config files.
You can use the following configs for finetuning Stage C on your own datasets. All necessary parameters were already explained above. So there is nothing new here. Take a look at the config for finetuning the 3.6B Stage C and the 1B Stage C.
Training a ControlNet requires setting some extra parameters as well as adding the specific ControlNet Filter you want. With filter, we simply mean a class that for example performs Canny Edge Detection, Human Pose Detection, etc.
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
controlnet_filter: CannyFilter
controlnet_filter_params:
resize: 224
Here we need to give a little more detail on how Stage C's architecture looks like. It basically is just a stack of
residual blocks (convolutional and attention) that all work at the same latent resolution. We do not use a UNet.
And this is where controlnet_blocks
comes in. It determines at which blocks you want to inject the controlling
information. This way, the ControlNet architecture differs from the common one used in Stable Diffusion where you
create an entire copy of the encoder of the UNet. With Stable Cascade it is a bit simpler and comes with the great
benefit of using much fewer parameters.
Next you define the class that filters the images and extracts the information you want to condition Stage C on
(Canny Edge Detection, Human Pose Detection, etc.) with the controlnet_filter
parameter. In the example, we use the
CannyFilter defined in the controlnet.py file. This is the place where you can add your own
ControlNet Filters. Lastly, controlnet_filter_params
simply sets additional parameters to your controlnet_filter
class. That's it. You can view the example ControlNet configs for
Inpainting / Outpainting,
Face Identity,
Canny and
Super Resolution.
To train a LoRA on Stage C, you have a few more parameters available to set for the training.
module_filters: ['.attn']
rank: 4
train_tokens:
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
- ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
These include the module_filters
, which simply determines on what modules you want to train LoRA-layers. In the
example above, it is using the attention layers (.attn
). Currently, only linear layers can be lora'd.
However, adding different layers (like convolutions) is possible as well.
You can also set the rank
and if you want to learn a specific token for your training. The latter can be done by
setting train_tokens
which expects a list of two things for each element: the token you want to train and a regex for
the token / tokens that you want to use for initializing the token. In the example above, a token [fernando]
is
created and is initialized with the average of all tokens that include the word dog
. Note, in order to add a new
token, it has to start with [
and end with ]
. There is also the option of using existing tokens which will be
trained. For this, you just enter the token, without placing [ ]
around it, like in the commented example above
for the token sanil
. The second element is null
, because we don't initialize this token and just finetune the
snail
token.
You can find an example config for training a LoRA here.
Additionally, you can also download an
example dataset for a cute little good boy dog.
Simply download it and set the path in the config file to your destination path.
Here we mainly focus on training Stage B, because it is doing most of the heavy lifting for the compression, while
Stage A only applies a very small compression and thus the results are near perfect. Why do we use Stage A even? The
reason is just to make the training and inference of Stage B cheaper and faster. With Stage A in place, Stage B works
at a 4x smaller space (for example 1 x 4 x 256 x 256
instead of 1 x 3 x 1024 x 1024
). Furthermore, we observed that
Stage B learns faster when using Stage A compared to learning Stage B directly at pixel space. Anyway, why would you
even want to train Stage B? Either you want to try to create an even higher compression or finetune on something
very specific. But this probably is a rare occasion. If you do want to, you can take a look at the training config
for the large Stage B here or for the small Stage B
here.
The codebase is in early development. You might encounter unexpected errors or not perfectly optimized training and inference code. We apologize for that in advance. If there is interest, we will continue releasing updates to it, aiming to bring in the latest improvements and optimizations. Moreover, we would be more than happy to receive ideas, feedback or even updates from people that would like to contribute. Cheers.