Conversation
Summary of ChangesHello @fuheaven, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for the DCU (AMD GPU) platform by adding a DcuDevice class that utilizes ROCm's CUDA compatibility and a DCU-specific implementation for Flash Attention. The changes are a good step towards broader hardware support. However, my review has identified a few critical issues that need to be addressed. There's an import error due to a missing module for mm.dcu. More importantly, the new FlashAttnDcu implementation has critical bugs in both its main logic and its fallback path concerning the handling of variable-length sequences, which will lead to incorrect results or runtime errors. I have also included some medium-severity suggestions to improve code clarity, maintainability, and memory efficiency.
| # Check if running on DCU platform | ||
| if os.getenv("PLATFORM") == "dcu": | ||
| from .attn.dcu import * | ||
| from .mm.dcu import * |
There was a problem hiding this comment.
| output = output.unflatten(0, (b, lq)) | ||
| return output.to(out_dtype) |
There was a problem hiding this comment.
The output from flash_attn_varlen_func is a packed tensor of shape (total_tokens, ...) when variable sequence lengths are used. The current code uses output.unflatten(0, (b, lq)), which assumes the total number of tokens is b * lq. This is only correct for fixed-length sequences and will cause a runtime error or produce incorrect results for variable-length sequences.
You need to correctly unpack the output into a padded tensor when q_lens is provided.
if q_lens is None:
output = output.unflatten(0, (b, lq))
else:
output_padded = torch.zeros(b, lq, q.shape[2], v.shape[3], dtype=output.dtype, device=output.device)
current_pos = 0
for i in range(b):
seq_len = q_lens[i]
output_padded[i, :seq_len] = output[current_pos : current_pos + seq_len]
current_pos += seq_len
output = output_padded
return output.to(out_dtype)| def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0): | ||
| """ | ||
| Fallback to PyTorch Scaled Dot Product Attention. | ||
|
|
||
| Args: | ||
| q: [B, Lq, Nq, C] Query tensor | ||
| k: [B, Lk, Nk, C] Key tensor | ||
| v: [B, Lk, Nk, C] Value tensor | ||
| causal: Whether to apply causal mask | ||
| dropout_p: Dropout probability | ||
|
|
||
| Returns: | ||
| Output tensor: [B, Lq, Nq, C] | ||
| """ | ||
| # Transpose to [B, Nq, Lq, C] for SDPA | ||
| q = q.transpose(1, 2) | ||
| k = k.transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
|
|
||
| out = torch.nn.functional.scaled_dot_product_attention( | ||
| q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p | ||
| ) | ||
|
|
||
| # Transpose back to [B, Lq, Nq, C] | ||
| return out.transpose(1, 2).contiguous() |
There was a problem hiding this comment.
The _sdpa_fallback method does not handle variable sequence lengths passed via q_lens and k_lens. The main apply method handles them, but if it falls back to SDPA, this information is lost, leading to incorrect attention calculations for batches with padded sequences.
To fix this, the fallback method needs to accept q_lens and k_lens, construct an appropriate attn_mask for scaled_dot_product_attention, and handle the output padding. Note that scaled_dot_product_attention does not allow using both is_causal=True and attn_mask simultaneously, so you may need to create a combined causal and padding mask if both are required.
| from lightx2v_platform.base.nvidia import CudaDevice | ||
|
|
||
| __all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice"] | ||
| __all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "DcuDevice"] |
There was a problem hiding this comment.
For better readability and maintainability, it's good practice to keep the items in __all__ sorted alphabetically. This makes it easier to find items in the list as it grows.
| __all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "DcuDevice"] | |
| __all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "DcuDevice", "MetaxDevice", "MluDevice"] |
| try: | ||
| return torch.cuda.is_available() | ||
| except ImportError: | ||
| return False |
There was a problem hiding this comment.
The try...except ImportError block is unnecessary here. torch is imported at the top of the file, so an ImportError for torch would have already been raised. The torch.cuda.is_available() call itself does not raise an ImportError; it safely returns False if CUDA is not available. Removing this block simplifies the code.
return torch.cuda.is_available()| q_flat = half(q.flatten(0, 1)) | ||
| q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device) | ||
| else: | ||
| q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) |
There was a problem hiding this comment.
When preparing q_flat for variable-length sequences, you are concatenating tensors first and then converting the large concatenated tensor to half precision with half(). This can be memory-inefficient if the input tensor q is in full precision, as it creates a large intermediate tensor. For better memory efficiency, consider applying the half() conversion to each slice inside the list comprehension before concatenation. The same applies to k_flat and v_flat.
| q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) | |
| q_flat = torch.cat([half(u[:v]) for u, v in zip(q, q_lens)]) |
| k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) | ||
| v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) |
There was a problem hiding this comment.
Similar to the creation of q_flat, k_flat and v_flat are created by first concatenating full-precision tensors and then converting. This is memory-inefficient. Applying half() inside the list comprehension is recommended for better memory usage.
| k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) | |
| v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) | |
| k_flat = torch.cat([half(u[:v]) for u, v in zip(k, k_lens)]) | |
| v_flat = torch.cat([half(u[:v]) for u, v in zip(v, k_lens)]) |
No description provided.