Skip to content

Commit

Permalink
[ATLAS] Fix TNN Atlas Bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
doxutx committed May 21, 2024
1 parent dc16e53 commit baff1f6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 17 deletions.
11 changes: 6 additions & 5 deletions source/tnn/device/atlas/atlas_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ AtlasContext::~AtlasContext() {
//}
}

Status AtlasContext::Setup(int device_id) {
this->device_id_ = device_id;
return TNN_OK;
}

Status AtlasContext::LoadLibrary(std::vector<std::string> path) {
return TNN_OK;
}
Expand Down Expand Up @@ -105,6 +100,12 @@ void AtlasContext::SetModelType(ModelType model_type) {
this->model_type_ = model_type;
}

void AtlasContext::SetDeviceId(int device_id) {
this->device_id_ = device_id;
}

int AtlasContext::GetDeviceId() {
return this->device_id_;
}

} // namespace TNN_NS
8 changes: 5 additions & 3 deletions source/tnn/device/atlas/atlas_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class AtlasContext : public Context {
// @brief deconstructor
~AtlasContext();

// @brief setup with specified device id
Status Setup(int device_id);

// @brief load library
virtual Status LoadLibrary(std::vector<std::string> path) override;

Expand Down Expand Up @@ -66,6 +63,11 @@ class AtlasContext : public Context {

// @brief set ModelType
void SetModelType(ModelType model_type);

// @brief set specific device id
void SetDeviceId(int device_id);

int GetDeviceId();

private:
ModelType model_type_;
Expand Down
9 changes: 1 addition & 8 deletions source/tnn/device/atlas/atlas_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,7 @@ AbstractLayerAcc* AtlasDevice::CreateLayerAcc(LayerType type) {

Context* AtlasDevice::CreateContext(int device_id) {
auto context = new AtlasContext();

Status ret = context->Setup(device_id);
if (ret != TNN_OK) {
LOGE("Cuda context setup failed.");
delete context;
return NULL;
}

context->SetDeviceId(device_id);
return context;
}

Expand Down
12 changes: 11 additions & 1 deletion source/tnn/device/atlas/atlas_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ AtlasNetwork::~AtlasNetwork() {
if (acl_ret != ACL_ERROR_NONE) {
LOGE("unload model failed, modelId is %u\n", this->om_model_info_->model_id);
}
this->om_model_info_->model_id = INT_MAX;
}

if (nullptr != this->om_model_info_->model_desc) {
(void)aclmdlDestroyDesc(this->om_model_info_->model_desc);
this->om_model_info_->model_desc = nullptr;
Expand Down Expand Up @@ -115,6 +116,15 @@ AtlasNetwork::~AtlasNetwork() {
this->om_model_weight_ptr_ = nullptr;
this->om_model_info_->weight_size = 0;
}

// Destroy aclrt Device()
if (tnn_atlas_context->GetDeviceId() != INT_MAX) {
LOGD("Reset aclrt Device.\n");
acl_ret = aclrtResetDevice(tnn_atlas_context->GetDeviceId());
if (acl_ret != ACL_ERROR_NONE) {
LOGE("TNN ATLAS Network: aclrtResetDevice() failed\n");
}
}
}

// Call DeInit() of DefaultNetwork
Expand Down

0 comments on commit baff1f6

Please sign in to comment.