Skip to content

[feat] stream vae#582

Merged
helloyongyang merged 4 commits intomainfrom
dev/stream_vae
Dec 9, 2025
Merged

[feat] stream vae#582
helloyongyang merged 4 commits intomainfrom
dev/stream_vae

Conversation

@Musisoul
Copy link
Collaborator

@Musisoul Musisoul commented Dec 9, 2025

No description provided.

@Musisoul Musisoul changed the title Dev/stream vae [feat] stream vae Dec 9, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Musisoul, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant enhancements to the VAE decoding process by implementing streaming capabilities. The primary goal is to improve memory efficiency and potentially enable real-time or near real-time video generation by decoding frames incrementally rather than processing an entire video sequence at once. It also lays the groundwork for distributed parallel decoding of large latent spaces.

Highlights

  • Streaming VAE Integration: The DefaultRunner now supports a streaming VAE decoding path, activated by the use_stream_vae configuration flag. This allows for iterative decoding of latent representations into video frames.
  • New run_vae_decoder_stream Method: A new method run_vae_decoder_stream has been added to DefaultRunner to manage the streaming decoding process, including concatenating frame segments and integrating profiling and lazy loading/unloading of the VAE decoder.
  • Core VAE Streaming Logic: The VAE model in vae.py now includes a decode_stream method that yields decoded frame segments, enabling a memory-efficient, frame-by-frame processing of latent inputs.
  • Distributed 2D Streaming Decoding: A decode_dist_2d_stream method has been introduced for distributed processing of 2D latent chunks, handling padding, extraction, and reconstruction of images across multiple processes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a streaming VAE decoder to reduce memory usage during video generation. The changes look good overall, adding decode_stream methods to the VAE model and a new execution path in the runner. I've identified a typo, an opportunity to reduce code duplication for better maintainability, and a significant performance concern in the distributed decoding implementation where all_gather is called per frame. Addressing these points will improve the code's quality and performance.

Comment on lines +1347 to +1392
for image in self.model.decode_stream(zs_chunk.unsqueeze(0), self.scale):
images_chunk = image.clamp_(-1, 1)
# Remove padding from decoded chunk
spatial_ratio = 8
if cur_rank_h == 0:
decoded_h_start = 0
decoded_h_end = chunk_h * spatial_ratio
elif cur_rank_h == world_size_h - 1:
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
decoded_h_end = images_chunk.shape[3]
else:
decoded_h_start = padding_size * spatial_ratio
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio

if cur_rank_w == 0:
decoded_w_start = 0
decoded_w_end = chunk_w * spatial_ratio
elif cur_rank_w == world_size_w - 1:
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
decoded_w_end = images_chunk.shape[4]
else:
decoded_w_start = padding_size * spatial_ratio
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio

images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous()

# Gather all chunks
total_processes = world_size_h * world_size_w
full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)]

dist.all_gather(full_images, images_chunk)

self.device_synchronize()

# Reconstruct the full image tensor
image_rows = []
for h_idx in range(world_size_h):
image_cols = []
for w_idx in range(world_size_w):
process_idx = h_idx * world_size_w + w_idx
image_cols.append(full_images[process_idx])
image_rows.append(torch.cat(image_cols, dim=4))

images = torch.cat(image_rows, dim=3)

yield images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dist.all_gather call is inside a for loop that iterates over individual frames yielded by self.model.decode_stream. Since dist.all_gather is a blocking collective communication operation that synchronizes all processes, performing this for every single frame will introduce significant communication overhead and likely degrade performance substantially in a distributed environment. This could make the streaming decode much slower than its non-streaming counterpart.

A possible optimization is to batch the frames before gathering. You could accumulate a small number of frames from the stream and then perform the all_gather and reconstruction logic on that batch, reducing the frequency of expensive collective communication calls.

Comment on lines +349 to +358
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()

for frame_segment in self.vae_decoder.decode_stream(latents.to(GET_DTYPE())):
yield frame_segment

if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for loading and unloading the VAE decoder is duplicated here and in the run_vae_decoder method. This duplication can lead to maintenance issues and violates the Don't Repeat Yourself (DRY) principle.

To improve maintainability, consider extracting this common logic into a context manager. For example:

from contextlib import contextmanager

@contextmanager
def _managed_vae_decoder(self):
    should_manage = self.config.get("lazy_load", False) or self.config.get("unload_modules", False)
    if should_manage:
        self.vae_decoder = self.load_vae_decoder()
    try:
        yield
    finally:
        if should_manage:
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()

You could then simplify both run_vae_decoder_stream and run_vae_decoder by using this context manager.

@helloyongyang helloyongyang merged commit 5546f75 into main Dec 9, 2025
2 checks passed
@helloyongyang helloyongyang deleted the dev/stream_vae branch January 7, 2026 10:09
helloyongyang pushed a commit that referenced this pull request Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants