optimization prefill dp banlance, support multimodal dp balance.#1271
optimization prefill dp banlance, support multimodal dp balance.#1271hiworldwzj merged 10 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Data Parallel (DP) prefill balance logic by centralizing state management within the InferStateInfo class, simplifying transformer layer implementations for DeepSeek models. The changes introduce automated switching between balanced and unbalanced states for input IDs and position embeddings. However, a critical initialization bug was found in infer_struct.py where state-swapping methods access attributes before they are defined, potentially causing AttributeErrors. The review feedback recommends using safer attribute checks and adopting standard Python naming conventions for internal methods to resolve these issues.
| def __change_to_unbalance(self): | ||
| self.input_ids = self._unbalance_input_ids | ||
| if hasattr(self, "position_ids"): | ||
| self.position_ids = self._unbalance_position_ids | ||
| if hasattr(self, "position_cos"): | ||
| self.position_cos = self._unbalance_position_cos | ||
| if hasattr(self, "position_sin"): | ||
| self.position_sin = self._unbalance_position_sin | ||
| return | ||
|
|
||
| def __change_to_balance(self): | ||
| self.input_ids = self._balance_input_ids | ||
| if hasattr(self, "position_ids"): | ||
| self.position_ids = self._balance_position_ids | ||
| if hasattr(self, "position_cos"): | ||
| self.position_cos = self._balance_position_cos | ||
| if hasattr(self, "position_sin"): | ||
| self.position_sin = self._balance_position_sin | ||
| return | ||
|
|
||
| def _all_to_all_balance_get(self, data: torch.Tensor): | ||
| self.__change_to_balance() |
There was a problem hiding this comment.
The current implementation of state swapping will cause a critical crash during the initialization phase.
In prepare_prefill_dp_balance, the code calls _all_to_all_balance_get (e.g., at line 236) before self._balance_input_ids is defined (at line 249). Since _all_to_all_balance_get immediately calls self.__change_to_balance(), which attempts to access self._balance_input_ids, an AttributeError will be raised.
Additionally, the hasattr checks should target the backing attributes (e.g., _balance_position_ids) rather than the public ones. Since InferStateInfo initializes public attributes like position_ids to None in its constructor, hasattr will return True, but the backing attributes will be missing if the feature wasn't used or is still being initialized, leading to further crashes.
Renaming the methods to use a single underscore is also recommended to follow standard Python conventions for internal methods.
def _change_to_unbalance(self):
if hasattr(self, "_unbalance_input_ids"):
self.input_ids = self._unbalance_input_ids
if hasattr(self, "_unbalance_position_ids"):
self.position_ids = self._unbalance_position_ids
if hasattr(self, "_unbalance_position_cos"):
self.position_cos = self._unbalance_position_cos
if hasattr(self, "_unbalance_position_sin"):
self.position_sin = self._unbalance_position_sin
return
def _change_to_balance(self):
if hasattr(self, "_balance_input_ids"):
self.input_ids = self._balance_input_ids
if hasattr(self, "_balance_position_ids"):
self.position_ids = self._balance_position_ids
if hasattr(self, "_balance_position_cos"):
self.position_cos = self._balance_position_cos
if hasattr(self, "_balance_position_sin"):
self.position_sin = self._balance_position_sin
return
def _all_to_all_balance_get(self, data: torch.Tensor):
self._change_to_balance()| return dest_data.view(-1, *old_shape[1:]) | ||
|
|
||
| def _all_to_all_unbalance_get(self, data: torch.Tensor): | ||
| self.__change_to_unbalance() |
No description provided.