Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoencoderKLWan - support grandient_checkpointing #11071

Open
agwmon opened this issue Mar 16, 2025 · 3 comments · May be fixed by #11105
Open

AutoencoderKLWan - support grandient_checkpointing #11071

agwmon opened this issue Mar 16, 2025 · 3 comments · May be fixed by #11105

Comments

@agwmon
Copy link

agwmon commented Mar 16, 2025

Do you have a plan for supporting gradient checkpointing for AutoencoderKLWan?

Thank you for always working hard for open source 🙏🙏

@a-r-r-o-w
Copy link
Member

Hey @agwmon! We'd love if you could contribute the changes. Any of the other modeling implementations are good examples of how to apply it. I think you will just have to call the self._gradient_checkpointing_func around up/down/mid and resnet blocks.

@victolee0
Copy link
Contributor

victolee0 commented Mar 18, 2025

@a-r-r-o-w
I'm encountering an error when running test code with gradient checkpointing enabled in my PR.

When backward() is called, the following function is executed:

x = self._gradient_checkpointing_func(resnet, x, feat_cache, feat_idx)

This function calls the forward method of the ResidualBlock. The problem is that the forward method include feat_idx += 1 (feat_idx[0] += 1), which eventually causes feat_idx > len(feat_cache), resulting in an IndexError.

def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h

I'm not sure how to resolve this issue and would appreciate some guidance. Is there a recommended way to handle the feat_idx incrementing when using gradient checkpointing?

Thank you for your help!

@a-r-r-o-w
Copy link
Member

Hmm, there's not really an easy way around this from a quick look. I believe what we're doing in this code is framewise-forwards. We don't really need a cache here but it saves some amount of computation to speed up decoding a bit.

A VAE refactor might be needed here, or atleast handle feat_idx not in-place. @yiyixuxu Do you have plans to refactor this (as we discussed that we'll merge Wan PR quickly and refactor later) or should I take a stab at it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants