Skip to content

Commit

Permalink
update flax controlnet training script (huggingface#2951)
Browse files Browse the repository at this point in the history
* load_from_disk + checkpointing_steps

* apply feedback
  • Loading branch information
yiyixuxu authored and Jimmy committed Apr 26, 2024
1 parent c55de4a commit cdd23c5
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
import torch
import torch.utils.checkpoint
import transformers
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from flax import jax_utils
from flax.core.frozen_dict import unfreeze
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import create_repo, upload_folder
from PIL import Image
from PIL import Image, PngImagePlugin
from torch.utils.data import IterableDataset
from torchvision import transforms
from tqdm.auto import tqdm
Expand All @@ -49,6 +49,11 @@
from diffusers.utils import check_min_version, is_wandb_available


# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
# see more https://github.com/python-pillow/Pillow/issues/5610
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

if is_wandb_available():
import wandb

Expand Down Expand Up @@ -246,6 +251,12 @@ def parse_args():
default=None,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=5000,
help=("Save a checkpoint of the training state every X updates."),
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -344,9 +355,17 @@ def parse_args():
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--load_from_disk",
action="store_true",
help=(
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
),
)
parser.add_argument(
Expand Down Expand Up @@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
)
else:
if args.train_data_dir is not None:
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
)
if args.load_from_disk:
dataset = load_from_disk(
args.train_data_dir,
)
else:
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

Expand Down Expand Up @@ -545,6 +569,7 @@ def tokenize_captions(examples, is_train=True):
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
Expand All @@ -553,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
)
Expand Down Expand Up @@ -981,6 +1007,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
}
)
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained(
f"{args.output_dir}/{global_step}",
params=get_params_to_save(state.params),
)

train_metric = jax_utils.unreplicate(train_metric)
train_step_progress_bar.close()
Expand Down

0 comments on commit cdd23c5

Please sign in to comment.