Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/infinicore/context/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ std::shared_ptr<Memory> allocateMemory(size_t size);
std::shared_ptr<Memory> allocateHostMemory(size_t size);
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size);

void memcpyH2D(void *dst, const void *src, size_t size);
void memcpyH2D(void *dst, const void *src, size_t size, bool async = true);
void memcpyD2H(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size, bool async = true);
void memcpyH2H(void *dst, const void *src, size_t size);

} // namespace context
Expand Down
8 changes: 4 additions & 4 deletions src/infinicore/context/context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->allocatePinnedHostMemory(size);
}

void memcpyH2D(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyH2D(dst, src, size);
void memcpyH2D(void *dst, const void *src, size_t size, bool async) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyH2D(dst, src, size, async);
}

void memcpyD2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2H(dst, src, size);
}

void memcpyD2D(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size);
void memcpyD2D(void *dst, const void *src, size_t size, bool async) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size, async);
}

void memcpyH2H(void *dst, const void *src, size_t size) {
Expand Down
16 changes: 12 additions & 4 deletions src/infinicore/context/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,24 @@ std::shared_ptr<Memory> Runtime::allocatePinnedHostMemory(size_t size) {
true);
}

void Runtime::memcpyH2D(void *dst, const void *src, size_t size) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_));
void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) {
if (async) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_));
} else {
INFINICORE_CHECK_ERROR(infinirtMemcpy(dst, src, size, INFINIRT_MEMCPY_H2D));
}
}

void Runtime::memcpyD2H(void *dst, const void *src, size_t size) {
INFINICORE_CHECK_ERROR(infinirtMemcpy(dst, src, size, INFINIRT_MEMCPY_D2H));
}

void Runtime::memcpyD2D(void *dst, const void *src, size_t size) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_));
void Runtime::memcpyD2D(void *dst, const void *src, size_t size, bool async) {
if (async) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_));
} else {
INFINICORE_CHECK_ERROR(infinirtMemcpy(dst, src, size, INFINIRT_MEMCPY_D2D));
}
}

std::string Runtime::toString() const {
Expand Down
4 changes: 2 additions & 2 deletions src/infinicore/context/runtime/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class Runtime {
std::shared_ptr<Memory> allocateMemory(size_t size);
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size);

void memcpyH2D(void *dst, const void *src, size_t size);
void memcpyH2D(void *dst, const void *src, size_t size, bool async = true);
void memcpyD2H(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size, bool async = true);

std::string toString() const;

Expand Down