diff --git a/include/infinicore/context/context.hpp b/include/infinicore/context/context.hpp index 093004565..aa1ade434 100644 --- a/include/infinicore/context/context.hpp +++ b/include/infinicore/context/context.hpp @@ -25,9 +25,9 @@ std::shared_ptr allocateMemory(size_t size); std::shared_ptr allocateHostMemory(size_t size); std::shared_ptr allocatePinnedHostMemory(size_t size); -void memcpyH2D(void *dst, const void *src, size_t size); +void memcpyH2D(void *dst, const void *src, size_t size, bool async = true); void memcpyD2H(void *dst, const void *src, size_t size); -void memcpyD2D(void *dst, const void *src, size_t size); +void memcpyD2D(void *dst, const void *src, size_t size, bool async = true); void memcpyH2H(void *dst, const void *src, size_t size); } // namespace context diff --git a/src/infinicore/context/context_impl.cc b/src/infinicore/context/context_impl.cc index c7a96d163..b1e7f50ce 100644 --- a/src/infinicore/context/context_impl.cc +++ b/src/infinicore/context/context_impl.cc @@ -129,16 +129,16 @@ std::shared_ptr allocatePinnedHostMemory(size_t size) { return ContextImpl::singleton().getCurrentRuntime()->allocatePinnedHostMemory(size); } -void memcpyH2D(void *dst, const void *src, size_t size) { - return ContextImpl::singleton().getCurrentRuntime()->memcpyH2D(dst, src, size); +void memcpyH2D(void *dst, const void *src, size_t size, bool async) { + return ContextImpl::singleton().getCurrentRuntime()->memcpyH2D(dst, src, size, async); } void memcpyD2H(void *dst, const void *src, size_t size) { return ContextImpl::singleton().getCurrentRuntime()->memcpyD2H(dst, src, size); } -void memcpyD2D(void *dst, const void *src, size_t size) { - return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size); +void memcpyD2D(void *dst, const void *src, size_t size, bool async) { + return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size, async); } void memcpyH2H(void *dst, const void *src, size_t size) { diff --git a/src/infinicore/context/runtime/runtime.cc b/src/infinicore/context/runtime/runtime.cc index 1f192011d..005bd98cf 100644 --- a/src/infinicore/context/runtime/runtime.cc +++ b/src/infinicore/context/runtime/runtime.cc @@ -76,16 +76,24 @@ std::shared_ptr Runtime::allocatePinnedHostMemory(size_t size) { true); } -void Runtime::memcpyH2D(void *dst, const void *src, size_t size) { - INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_)); +void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) { + if (async) { + INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_)); + } else { + INFINICORE_CHECK_ERROR(infinirtMemcpy(dst, src, size, INFINIRT_MEMCPY_H2D)); + } } void Runtime::memcpyD2H(void *dst, const void *src, size_t size) { INFINICORE_CHECK_ERROR(infinirtMemcpy(dst, src, size, INFINIRT_MEMCPY_D2H)); } -void Runtime::memcpyD2D(void *dst, const void *src, size_t size) { - INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_)); +void Runtime::memcpyD2D(void *dst, const void *src, size_t size, bool async) { + if (async) { + INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_)); + } else { + INFINICORE_CHECK_ERROR(infinirtMemcpy(dst, src, size, INFINIRT_MEMCPY_D2D)); + } } std::string Runtime::toString() const { diff --git a/src/infinicore/context/runtime/runtime.hpp b/src/infinicore/context/runtime/runtime.hpp index 4e0ba7abc..2c82da403 100644 --- a/src/infinicore/context/runtime/runtime.hpp +++ b/src/infinicore/context/runtime/runtime.hpp @@ -34,9 +34,9 @@ class Runtime { std::shared_ptr allocateMemory(size_t size); std::shared_ptr allocatePinnedHostMemory(size_t size); - void memcpyH2D(void *dst, const void *src, size_t size); + void memcpyH2D(void *dst, const void *src, size_t size, bool async = true); void memcpyD2H(void *dst, const void *src, size_t size); - void memcpyD2D(void *dst, const void *src, size_t size); + void memcpyD2D(void *dst, const void *src, size_t size, bool async = true); std::string toString() const;