Skip to content

[Feature] iluvatar platforms support#1045

Merged
helloyongyang merged 1 commit intoModelTC:mainfrom
Deep-Spark:hfj_dev
Apr 30, 2026
Merged

[Feature] iluvatar platforms support#1045
helloyongyang merged 1 commit intoModelTC:mainfrom
Deep-Spark:hfj_dev

Conversation

@hufangjian2017
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Iluvatar CUDA platform, adding new configuration files, device initialization logic, and optimized implementations for Flash Attention, INT8 quantized linear layers, RMS Norm, and RoPE. The review feedback identifies several critical issues, including instances where NotImplementedError is instantiated but not raised, potential NameError and logic bugs in the Flash Attention implementation, and uninitialized attributes in the RoPE class. Additionally, there are inconsistencies in API usage for quantization functions and minor documentation typos that require correction.

Comment thread lightx2v/models/input_encoders/hf/wan/t5/model.py Outdated
Comment thread lightx2v/models/input_encoders/hf/wan/t5/model.py Outdated
Comment thread lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py Outdated
Comment thread lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py Outdated
device = q.device

def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues here:

  1. dtype is undefined if len(q.shape) == 3 because it is only assigned inside the elif len(q.shape) == 4 block (line 30). This will cause a NameError if half() is ever called for 3D inputs.
  2. The logic x if x.dtype in half_dtypes else x.to(dtype) does not actually ensure half precision if the input q is float32 (as dtype would be float32). Flash attention kernels typically require float16 or bfloat16.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修正

return_softmax_lse=False,
causal=False,
)
return x.reshape(bs * max_seqlen_q, -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reshape operation will fail for variable-length 4D inputs. When cu_seqlens_q is provided, the output x from flash_attn_varlen_func is a packed tensor with shape [sum(seqlens), nheads, head_dim]. Reshaping it to [bs * max_seqlen_q, -1] (the padded size) will raise a RuntimeError because the total number of elements won't match if any padding was present in the original 4D input.

device = x.device
input_tensor_quant = torch.empty(x.shape, dtype=torch.int8, device=device)
input_tensor_scale = torch.empty(x.shape[:-1], dtype=torch.float32, device=device)
ixf.dynamic_scaled_int8_quant(output=input_tensor_quant, input=x, scale=input_tensor_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The API usage of ixf.dynamic_scaled_int8_quant is inconsistent with its usage in lightx2v_platform/ops/mm/iluvatar_cuda/q_linear.py. Here it is used as an in-place function with keyword arguments (output=, input=, scale=), while in q_linear.py it is used as a function returning two values with a single positional argument. Please verify the correct ixformer API and ensure consistency.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iluvatar output使用先分配tensor

if xq.dim() == 4:
xq = xq.squeeze(0)
xk = xk.squeeze(0)
return xq.to(self.infer_dtype), xk.to(self.infer_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.infer_dtype is used but not initialized in IluvatarWanRope.__init__. Since this class is instantiated without arguments in transformer_infer.py, it will raise an AttributeError during the forward pass.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class IluvatarWanRope(RopeTemplate),RopeTemplate
RopeTemplate 会初始化 infer_dtype

Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo in docstring: 'Kernel: mlu' should be 'Kernel: iluvatar' or 'Kernel: cuda'.

Suggested change
Kernel: mlu
Kernel: iluvatar

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修正


# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

x.device.index can be None if the device is the default 'cuda' device. It is safer to pass the device object directly to torch.cuda.device.

Suggested change
with torch.cuda.device(x.device.index):
with torch.cuda.device(x.device):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修正

Update lightx2v/models/input_encoders/hf/wan/t5/model.py
@helloyongyang helloyongyang merged commit 5194f30 into ModelTC:main Apr 30, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants