Skip to content

Commit

Permalink
Enabling gradient checkpointing for VAE (open-mmlab#2536)
Browse files Browse the repository at this point in the history
* updated black format

* update black format

* make style format

* updated line endings

* update code formatting

* Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* added vae gradient checkpointing test

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
  • Loading branch information
3 people committed Mar 17, 2023
1 parent a169571 commit 116f70c
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def main():

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
vae.enable_gradient_checkpointing()

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""

_supports_gradient_checkpointing = True

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -121,6 +123,10 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
SlicedAttnProcessor,
XFormersAttnProcessor,
)
from .attention_processor import ( # noqa: F401
AttnProcessor as AttnProcessorRename,
)
from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401


deprecate(
Expand Down
71 changes: 59 additions & 12 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.conv_in = torch.nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)

self.mid_block = None
self.down_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -96,16 +102,34 @@ def __init__(
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)

self.gradient_checkpointing = False

def forward(self, x):
sample = x
sample = self.conv_in(sample)

# down
for down_block in self.down_blocks:
sample = down_block(sample)
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# down
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)

else:
# down
for down_block in self.down_blocks:
sample = down_block(sample)

# middle
sample = self.mid_block(sample)
# middle
sample = self.mid_block(sample)

# post-process
sample = self.conv_norm_out(sample)
Expand All @@ -129,7 +153,13 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)

self.mid_block = None
self.up_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -176,16 +206,33 @@ def __init__(
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

self.gradient_checkpointing = False

def forward(self, z):
sample = z
sample = self.conv_in(sample)

# middle
sample = self.mid_block(sample)
if self.training and self.gradient_checkpointing:

# up
for up_block in self.up_blocks:
sample = up_block(sample)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)

# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# middle
sample = self.mid_block(sample)

# up
for up_block in self.up_blocks:
sample = up_block(sample)

# post-process
sample = self.conv_norm_out(sample)
Expand Down
41 changes: 41 additions & 0 deletions tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,47 @@ def test_forward_signature(self):
def test_training(self):
pass

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))

def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
Expand Down

0 comments on commit 116f70c

Please sign in to comment.