diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index cb7f439690619..845af87fd5e38 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -4,17 +4,17 @@ if(WITH_GPU) nv_library( phi_tensor_raw SRCS tensor.cc - DEPS tensor_base dense_tensor phi_api_utils phi_enforce) + DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool) elseif(WITH_ROCM) hip_library( phi_tensor_raw SRCS tensor.cc - DEPS tensor_base dense_tensor phi_api_utils phi_enforce) + DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool) else() cc_library( phi_tensor_raw SRCS tensor.cc - DEPS tensor_base dense_tensor phi_api_utils phi_enforce) + DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool) endif() set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py) diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 4e4e7e31cfbc7..cf528beb800ba 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -21,7 +21,9 @@ limitations under the License. */ #include "glog/logging.h" +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/dense_tensor.h" @@ -33,8 +35,6 @@ limitations under the License. */ #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_utils.h" - -#include "paddle/fluid/platform/stream/cuda_stream.h" // clang-format off namespace paddle { @@ -311,7 +311,10 @@ void Tensor::set_impl(std::shared_ptr &&impl) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) gpuStream_t Tensor::stream() const { - return platform::stream::get_current_stream(-1)->raw_stream(); + int device_id = phi::backends::gpu::GetCurrentDeviceId(); + auto* gpu_context = DeviceContextPool::Instance() + .Get(GPUPlace(device_id)); + return gpu_context->stream(); } #endif