Skip to content

optimization prefill dp banlance, support multimodal dp balance.#1271

Merged
hiworldwzj merged 10 commits intomainfrom
wzj_fix
Apr 16, 2026
Merged

optimization prefill dp banlance, support multimodal dp balance.#1271
hiworldwzj merged 10 commits intomainfrom
wzj_fix

Conversation

@hiworldwzj
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Data Parallel (DP) prefill balance logic by centralizing state management within the InferStateInfo class, simplifying transformer layer implementations for DeepSeek models. The changes introduce automated switching between balanced and unbalanced states for input IDs and position embeddings. However, a critical initialization bug was found in infer_struct.py where state-swapping methods access attributes before they are defined, potentially causing AttributeErrors. The review feedback recommends using safer attribute checks and adopting standard Python naming conventions for internal methods to resolve these issues.

Comment on lines +253 to +274
def __change_to_unbalance(self):
self.input_ids = self._unbalance_input_ids
if hasattr(self, "position_ids"):
self.position_ids = self._unbalance_position_ids
if hasattr(self, "position_cos"):
self.position_cos = self._unbalance_position_cos
if hasattr(self, "position_sin"):
self.position_sin = self._unbalance_position_sin
return

def __change_to_balance(self):
self.input_ids = self._balance_input_ids
if hasattr(self, "position_ids"):
self.position_ids = self._balance_position_ids
if hasattr(self, "position_cos"):
self.position_cos = self._balance_position_cos
if hasattr(self, "position_sin"):
self.position_sin = self._balance_position_sin
return

def _all_to_all_balance_get(self, data: torch.Tensor):
self.__change_to_balance()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of state swapping will cause a critical crash during the initialization phase.

In prepare_prefill_dp_balance, the code calls _all_to_all_balance_get (e.g., at line 236) before self._balance_input_ids is defined (at line 249). Since _all_to_all_balance_get immediately calls self.__change_to_balance(), which attempts to access self._balance_input_ids, an AttributeError will be raised.

Additionally, the hasattr checks should target the backing attributes (e.g., _balance_position_ids) rather than the public ones. Since InferStateInfo initializes public attributes like position_ids to None in its constructor, hasattr will return True, but the backing attributes will be missing if the feature wasn't used or is still being initialized, leading to further crashes.

Renaming the methods to use a single underscore is also recommended to follow standard Python conventions for internal methods.

    def _change_to_unbalance(self):
        if hasattr(self, "_unbalance_input_ids"):
            self.input_ids = self._unbalance_input_ids
        if hasattr(self, "_unbalance_position_ids"):
            self.position_ids = self._unbalance_position_ids
        if hasattr(self, "_unbalance_position_cos"):
            self.position_cos = self._unbalance_position_cos
        if hasattr(self, "_unbalance_position_sin"):
            self.position_sin = self._unbalance_position_sin
        return

    def _change_to_balance(self):
        if hasattr(self, "_balance_input_ids"):
            self.input_ids = self._balance_input_ids
        if hasattr(self, "_balance_position_ids"):
            self.position_ids = self._balance_position_ids
        if hasattr(self, "_balance_position_cos"):
            self.position_cos = self._balance_position_cos
        if hasattr(self, "_balance_position_sin"):
            self.position_sin = self._balance_position_sin
        return

    def _all_to_all_balance_get(self, data: torch.Tensor):
        self._change_to_balance()

return dest_data.view(-1, *old_shape[1:])

def _all_to_all_unbalance_get(self, data: torch.Tensor):
self.__change_to_unbalance()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the call to use the renamed internal method to maintain consistency with the suggested fix for the initialization crash.

Suggested change
self.__change_to_unbalance()
self._change_to_unbalance()

@hiworldwzj hiworldwzj merged commit 529f9ca into main Apr 16, 2026
1 check passed
@hiworldwzj hiworldwzj deleted the wzj_fix branch April 16, 2026 07:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant