diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index a98165f72c4..ccba8e7bf88 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1718,7 +1718,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) if tensor_parallel_output is None: - tensor_parallel_output = self.config.tensor_parallel_output + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: logits = self.xpu_parallel_matmul( diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 788964ddfb9..b4f4dcb3339 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -795,7 +795,7 @@ def __init__(self, config: QWenConfig): def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: - tensor_parallel_output = self.config.tensor_parallel_output + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) return logits