Skip to content

Commit

Permalink
fix tensor stream error in custom op (PaddlePaddle#44500)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored and Aurelius84 committed Jul 29, 2022
1 parent 9ced147 commit 64422b7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
6 changes: 3 additions & 3 deletions paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ if(WITH_GPU)
nv_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce)
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
elseif(WITH_ROCM)
hip_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce)
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
else()
cc_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce)
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
endif()

set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py)
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ limitations under the License. */

#include "glog/logging.h"

#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
Expand All @@ -33,8 +35,6 @@ limitations under the License. */
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"

#include "paddle/fluid/platform/stream/cuda_stream.h"
// clang-format off

namespace paddle {
Expand Down Expand Up @@ -311,7 +311,10 @@ void Tensor::set_impl(std::shared_ptr<phi::TensorBase> &&impl) {

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t Tensor::stream() const {
return platform::stream::get_current_stream(-1)->raw_stream();
int device_id = phi::backends::gpu::GetCurrentDeviceId();
auto* gpu_context = DeviceContextPool::Instance()
.Get<AllocationType::GPU>(GPUPlace(device_id));
return gpu_context->stream();
}
#endif

Expand Down

0 comments on commit 64422b7

Please sign in to comment.