Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ucm/store/connector/nfsstore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Dict, List, Tuple

import torch
import ucmnfsstore
from connector import ucmnfsstore
from connector.ucmstore import Task, UcmKVStoreBase


Expand Down
32 changes: 16 additions & 16 deletions ucm/store/device/ascend/ascend_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
namespace UC {

template <typename Api, typename... Args>
Status AclrtApi(const char* caller, const char* file, const size_t line, const char* name,
Api&& api, Args&&... args)
Status AscendApi(const char* caller, const char* file, const size_t line, const char* name,
Api&& api, Args&&... args)
{
auto ret = api(args...);
if (ret != ACL_SUCCESS) {
Expand All @@ -43,7 +43,7 @@ Status AclrtApi(const char* caller, const char* file, const size_t line, const c
}
return Status::OK();
}
#define ACLRT_API(api, ...) AclrtApi(__FUNCTION__, __FILE__, __LINE__, #api, api, __VA_ARGS__)
#define ASCEND_API(api, ...) AscendApi(__FUNCTION__, __FILE__, __LINE__, #api, api, __VA_ARGS__)

class AscendDevice : public IBufferedDevice {
struct Closure {
Expand All @@ -58,11 +58,11 @@ class AscendDevice : public IBufferedDevice {
}

public:
AclrtDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
AscendDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
: IBufferedDevice{deviceId, bufferSize, bufferNumber}, stop_{false}, stream_{nullptr}
{
}
~AclrtDevice() override
~AscendDevice() override
{
if (this->cbThread_.joinable()) {
auto tid = this->cbThread_.native_handle();
Expand All @@ -74,32 +74,32 @@ class AscendDevice : public IBufferedDevice {
(void)aclrtDestroyStream(this->stream_);
this->stream_ = nullptr;
}
(void)aclrtResetDevice(this->deviceId_);
(void)aclrtResetDevice(this->deviceId);
}
Status Setup() override
{
auto status = Status::OK();
if ((status = ACLRT_API(aclrtSetDevice, this->deviceId_)).Failure()) { return status; }
if ((status = ASCEND_API(aclrtSetDevice, this->deviceId)).Failure()) { return status; }
if ((status = IBufferedDevice::Setup()).Failure()) { return status; }
if ((status = ACLRT_API(aclrtCreateStream, &this->stream_)).Failure()) { return status; }
if ((status = ASCEND_API(aclrtCreateStream, &this->stream_)).Failure()) { return status; }
this->cbThread_ = std::thread([this] {
while (!this->stop_) { (void)aclrtProcessReport(10); }
});
auto tid = this->cbThread_.native_handle();
if ((status = ACLRT_API(aclrtSubscribeReport, tid, this->stream_)).Failure()) {
if ((status = ASCEND_API(aclrtSubscribeReport, tid, this->stream_)).Failure()) {
return status;
}
return Status::OK();
}
Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override
{
return ACLRT_API(aclrtMemcpyAsync, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE,
this->stream_);
return ASCEND_API(aclrtMemcpyAsync, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE,
this->stream_);
}
Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) override
{
return ACLRT_API(aclrtMemcpyAsync, dst, count, src, count, ACL_MEMCPY_DEVICE_TO_HOST,
this->stream_);
return ASCEND_API(aclrtMemcpyAsync, dst, count, src, count, ACL_MEMCPY_DEVICE_TO_HOST,
this->stream_);
}
Status AppendCallback(std::function<void(bool)> cb) override
{
Expand All @@ -108,15 +108,15 @@ class AscendDevice : public IBufferedDevice {
UC_ERROR("Failed to make closure for append cb.");
return Status::OutOfMemory();
}
return ACLRT_API(aclrtLaunchCallback, Trampoline, (void*)c, ACL_CALLBACK_NO_BLOCK,
this->stream_);
return ASCEND_API(aclrtLaunchCallback, Trampoline, (void*)c, ACL_CALLBACK_NO_BLOCK,
this->stream_);
}

protected:
std::shared_ptr<std::byte> MakeBuffer(const size_t size) override
{
std::byte* host = nullptr;
auto status = ACLRT_API(aclrtMallocHost, (void**)&host, size);
auto status = ASCEND_API(aclrtMallocHost, (void**)&host, size);
if (status.Success()) { return std::shared_ptr<std::byte>(host, aclrtFreeHost); }
return nullptr;
}
Expand Down