-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AutoencoderKLWan - support grandient_checkpointing #11071
Comments
Hey @agwmon! We'd love if you could contribute the changes. Any of the other modeling implementations are good examples of how to apply it. I think you will just have to call the |
@a-r-r-o-w When x = self._gradient_checkpointing_func(resnet, x, feat_cache, feat_idx) This function calls the forward method of the ResidualBlock. The problem is that the forward method include diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py Lines 239 to 279 in 3be6706
I'm not sure how to resolve this issue and would appreciate some guidance. Is there a recommended way to handle the feat_idx incrementing when using gradient checkpointing? Thank you for your help! |
Hmm, there's not really an easy way around this from a quick look. I believe what we're doing in this code is framewise-forwards. We don't really need a cache here but it saves some amount of computation to speed up decoding a bit. A VAE refactor might be needed here, or atleast handle feat_idx not in-place. @yiyixuxu Do you have plans to refactor this (as we discussed that we'll merge Wan PR quickly and refactor later) or should I take a stab at it? |
Do you have a plan for supporting gradient checkpointing for AutoencoderKLWan?
Thank you for always working hard for open source 🙏🙏
The text was updated successfully, but these errors were encountered: