Skip to content

Commit

Permalink
[AutoParallel] CastPyArg2Tensor check convert disttensor (#59141)
Browse files Browse the repository at this point in the history
* CastPyArg2Tensor check convert disttensor
  • Loading branch information
wanghuancoder committed Nov 22, 2023
1 parent 6747af6 commit eff0367
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,12 @@ static PyObject* tensor_method_copy_(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
paddle::Tensor& src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, src_tensor, self->tensor)) {
ConvertAllInputsToDistTensor(mesh, src_tensor, self->tensor);
}

bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
<< self->tensor.name();
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/eager_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ int tensor_properties_set_grad(TensorObject* self,
PyObject* value,
void* closure) {
EAGER_TRY
auto src = CastPyArg2Tensor(value, 0);
auto& src = CastPyArg2Tensor(value, 0);
PADDLE_ENFORCE(
egr::EagerUtils::IsLeafTensor(self->tensor),
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
Expand All @@ -311,6 +311,10 @@ int tensor_properties_set_grad(TensorObject* self,
"Detected NULL grad"
"Please check if you have manually cleared"
"the grad inside autograd_meta"));
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, src, self->tensor, *grad)) {
ConvertAllInputsToDistTensor(mesh, src, self->tensor, *grad);
}
grad->copy_(src, self->tensor.place(), true);
return 0;
EAGER_CATCH_AND_THROW_RETURN_NEG
Expand Down

0 comments on commit eff0367

Please sign in to comment.