Skip to content

add dcu platform#584

Merged
helloyongyang merged 1 commit intoModelTC:mainfrom
fuheaven:dcu
Dec 9, 2025
Merged

add dcu platform#584
helloyongyang merged 1 commit intoModelTC:mainfrom
fuheaven:dcu

Conversation

@fuheaven
Copy link
Contributor

@fuheaven fuheaven commented Dec 9, 2025

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @fuheaven, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the lightx2v_platform's hardware compatibility by integrating support for AMD's DCU platform. It establishes the necessary infrastructure to detect and configure DCU devices, ensuring seamless interaction with PyTorch through ROCm's CUDA compatibility. Furthermore, it introduces performance-critical optimizations, such as a dedicated Flash Attention implementation for DCU, enhancing the framework's efficiency on AMD hardware.

Highlights

  • DCU Platform Support: Introduced a new DcuDevice class to enable detection and management of AMD DCU (Data Center GPU) devices within the lightx2v_platform.
  • ROCm/HIP Compatibility: Leverages ROCm's HIP compatibility layer, allowing DCU devices to be recognized and utilized through existing CUDA APIs in PyTorch, simplifying integration.
  • DCU-Optimized Flash Attention: Added a DCU-specific Flash Attention implementation (FlashAttnDcu) that attempts to use the ROCm-compatible Flash Attention 2.6.1 for performance, with a fallback to PyTorch's Scaled Dot Product Attention (SDPA) if Flash Attention is not available.
  • Dynamic Operation Loading: Modified the operations initialization to dynamically load DCU-specific attention and matrix multiplication operations when the platform is identified as 'dcu' and the AI device is 'cuda'.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the DCU (AMD GPU) platform by adding a DcuDevice class that utilizes ROCm's CUDA compatibility and a DCU-specific implementation for Flash Attention. The changes are a good step towards broader hardware support. However, my review has identified a few critical issues that need to be addressed. There's an import error due to a missing module for mm.dcu. More importantly, the new FlashAttnDcu implementation has critical bugs in both its main logic and its fallback path concerning the handling of variable-length sequences, which will lead to incorrect results or runtime errors. I have also included some medium-severity suggestions to improve code clarity, maintainability, and memory efficiency.

# Check if running on DCU platform
if os.getenv("PLATFORM") == "dcu":
from .attn.dcu import *
from .mm.dcu import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line attempts to import from .mm.dcu, but the corresponding module and files (lightx2v_platform/ops/mm/dcu/) are not included in this pull request. This will cause an ImportError at runtime. Please either add the missing module files or remove this import statement if it's not yet needed.

Comment on lines +126 to +127
output = output.unflatten(0, (b, lq))
return output.to(out_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The output from flash_attn_varlen_func is a packed tensor of shape (total_tokens, ...) when variable sequence lengths are used. The current code uses output.unflatten(0, (b, lq)), which assumes the total number of tokens is b * lq. This is only correct for fixed-length sequences and will cause a runtime error or produce incorrect results for variable-length sequences.

You need to correctly unpack the output into a padded tensor when q_lens is provided.

        if q_lens is None:
            output = output.unflatten(0, (b, lq))
        else:
            output_padded = torch.zeros(b, lq, q.shape[2], v.shape[3], dtype=output.dtype, device=output.device)
            current_pos = 0
            for i in range(b):
                seq_len = q_lens[i]
                output_padded[i, :seq_len] = output[current_pos : current_pos + seq_len]
                current_pos += seq_len
            output = output_padded
        return output.to(out_dtype)

Comment on lines +129 to +153
def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
"""
Fallback to PyTorch Scaled Dot Product Attention.

Args:
q: [B, Lq, Nq, C] Query tensor
k: [B, Lk, Nk, C] Key tensor
v: [B, Lk, Nk, C] Value tensor
causal: Whether to apply causal mask
dropout_p: Dropout probability

Returns:
Output tensor: [B, Lq, Nq, C]
"""
# Transpose to [B, Nq, Lq, C] for SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p
)

# Transpose back to [B, Lq, Nq, C]
return out.transpose(1, 2).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _sdpa_fallback method does not handle variable sequence lengths passed via q_lens and k_lens. The main apply method handles them, but if it falls back to SDPA, this information is lost, leading to incorrect attention calculations for batches with padded sequences.

To fix this, the fallback method needs to accept q_lens and k_lens, construct an appropriate attn_mask for scaled_dot_product_attention, and handle the output padding. Note that scaled_dot_product_attention does not allow using both is_causal=True and attn_mask simultaneously, so you may need to create a combined causal and padding mask if both are required.

from lightx2v_platform.base.nvidia import CudaDevice

__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice"]
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "DcuDevice"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and maintainability, it's good practice to keep the items in __all__ sorted alphabetically. This makes it easier to find items in the list as it grows.

Suggested change
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "DcuDevice"]
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "DcuDevice", "MetaxDevice", "MluDevice"]

Comment on lines +26 to +29
try:
return torch.cuda.is_available()
except ImportError:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except ImportError block is unnecessary here. torch is imported at the top of the file, so an ImportError for torch would have already been raised. The torch.cuda.is_available() call itself does not raise an ImportError; it safely returns False if CUDA is not available. Removing this block simplifies the code.

        return torch.cuda.is_available()

q_flat = half(q.flatten(0, 1))
q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
else:
q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When preparing q_flat for variable-length sequences, you are concatenating tensors first and then converting the large concatenated tensor to half precision with half(). This can be memory-inefficient if the input tensor q is in full precision, as it creates a large intermediate tensor. For better memory efficiency, consider applying the half() conversion to each slice inside the list comprehension before concatenation. The same applies to k_flat and v_flat.

Suggested change
q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
q_flat = torch.cat([half(u[:v]) for u, v in zip(q, q_lens)])

Comment on lines +102 to +103
k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the creation of q_flat, k_flat and v_flat are created by first concatenating full-precision tensors and then converting. This is memory-inefficient. Applying half() inside the list comprehension is recommended for better memory usage.

Suggested change
k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
k_flat = torch.cat([half(u[:v]) for u, v in zip(k, k_lens)])
v_flat = torch.cat([half(u[:v]) for u, v in zip(v, k_lens)])

@helloyongyang helloyongyang merged commit 1f7bad5 into ModelTC:main Dec 9, 2025
helloyongyang pushed a commit that referenced this pull request Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants