Conversation
Summary of ChangesHello @Musisoul, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant enhancements to the VAE decoding process by implementing streaming capabilities. The primary goal is to improve memory efficiency and potentially enable real-time or near real-time video generation by decoding frames incrementally rather than processing an entire video sequence at once. It also lays the groundwork for distributed parallel decoding of large latent spaces. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a streaming VAE decoder to reduce memory usage during video generation. The changes look good overall, adding decode_stream methods to the VAE model and a new execution path in the runner. I've identified a typo, an opportunity to reduce code duplication for better maintainability, and a significant performance concern in the distributed decoding implementation where all_gather is called per frame. Addressing these points will improve the code's quality and performance.
| for image in self.model.decode_stream(zs_chunk.unsqueeze(0), self.scale): | ||
| images_chunk = image.clamp_(-1, 1) | ||
| # Remove padding from decoded chunk | ||
| spatial_ratio = 8 | ||
| if cur_rank_h == 0: | ||
| decoded_h_start = 0 | ||
| decoded_h_end = chunk_h * spatial_ratio | ||
| elif cur_rank_h == world_size_h - 1: | ||
| decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio | ||
| decoded_h_end = images_chunk.shape[3] | ||
| else: | ||
| decoded_h_start = padding_size * spatial_ratio | ||
| decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio | ||
|
|
||
| if cur_rank_w == 0: | ||
| decoded_w_start = 0 | ||
| decoded_w_end = chunk_w * spatial_ratio | ||
| elif cur_rank_w == world_size_w - 1: | ||
| decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio | ||
| decoded_w_end = images_chunk.shape[4] | ||
| else: | ||
| decoded_w_start = padding_size * spatial_ratio | ||
| decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio | ||
|
|
||
| images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous() | ||
|
|
||
| # Gather all chunks | ||
| total_processes = world_size_h * world_size_w | ||
| full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)] | ||
|
|
||
| dist.all_gather(full_images, images_chunk) | ||
|
|
||
| self.device_synchronize() | ||
|
|
||
| # Reconstruct the full image tensor | ||
| image_rows = [] | ||
| for h_idx in range(world_size_h): | ||
| image_cols = [] | ||
| for w_idx in range(world_size_w): | ||
| process_idx = h_idx * world_size_w + w_idx | ||
| image_cols.append(full_images[process_idx]) | ||
| image_rows.append(torch.cat(image_cols, dim=4)) | ||
|
|
||
| images = torch.cat(image_rows, dim=3) | ||
|
|
||
| yield images |
There was a problem hiding this comment.
The dist.all_gather call is inside a for loop that iterates over individual frames yielded by self.model.decode_stream. Since dist.all_gather is a blocking collective communication operation that synchronizes all processes, performing this for every single frame will introduce significant communication overhead and likely degrade performance substantially in a distributed environment. This could make the streaming decode much slower than its non-streaming counterpart.
A possible optimization is to batch the frames before gathering. You could accumulate a small number of frames from the stream and then perform the all_gather and reconstruction logic on that batch, reducing the frequency of expensive collective communication calls.
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||
| self.vae_decoder = self.load_vae_decoder() | ||
|
|
||
| for frame_segment in self.vae_decoder.decode_stream(latents.to(GET_DTYPE())): | ||
| yield frame_segment | ||
|
|
||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||
| del self.vae_decoder | ||
| torch.cuda.empty_cache() | ||
| gc.collect() |
There was a problem hiding this comment.
The logic for loading and unloading the VAE decoder is duplicated here and in the run_vae_decoder method. This duplication can lead to maintenance issues and violates the Don't Repeat Yourself (DRY) principle.
To improve maintainability, consider extracting this common logic into a context manager. For example:
from contextlib import contextmanager
@contextmanager
def _managed_vae_decoder(self):
should_manage = self.config.get("lazy_load", False) or self.config.get("unload_modules", False)
if should_manage:
self.vae_decoder = self.load_vae_decoder()
try:
yield
finally:
if should_manage:
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()You could then simplify both run_vae_decoder_stream and run_vae_decoder by using this context manager.
No description provided.