Skip to content

Call setDevice on each thread at entry point. #3269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/api/c/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ using detail::getActiveDeviceId;
using detail::getBackend;
using detail::getDeviceCount;
using detail::getDeviceInfo;
using detail::init;
using detail::intl;
using detail::isDoubleSupported;
using detail::isHalfSupported;
Expand Down Expand Up @@ -107,7 +108,7 @@ af_err af_init() {
try {
thread_local std::once_flag flag;
std::call_once(flag, []() {
getDeviceInfo();
init();
#if defined(USE_MKL) && !defined(USE_STATIC_MKL)
int errCode = -1;
// Have used the AF_MKL_INTERFACE_SIZE as regular if's so that
Expand Down
5 changes: 5 additions & 0 deletions src/backend/cpu/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ int& getMaxJitSize() {

int getDeviceCount() { return DeviceManager::NUM_DEVICES; }

void init() {
thread_local const auto& instance = DeviceManager::getInstance();
UNUSED(instance);
}

// Get the currently active device id
unsigned getActiveDeviceId() { return DeviceManager::ACTIVE_DEVICE_ID; }

Expand Down
2 changes: 2 additions & 0 deletions src/backend/cpu/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ int& getMaxJitSize();

int getDeviceCount();

void init();

unsigned getActiveDeviceId();

size_t getDeviceMemorySize(int device);
Expand Down
6 changes: 6 additions & 0 deletions src/backend/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ int getDeviceCount() {
}
}

void init() {
thread_local auto err =
cudaSetDevice(getDeviceNativeId(getActiveDeviceId()));
UNUSED(err);
}

unsigned getActiveDeviceId() { return tlocalActiveDeviceId(); }

int getDeviceNativeId(int device) {
Expand Down
2 changes: 2 additions & 0 deletions src/backend/cuda/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ int& getMaxJitSize();

int getDeviceCount();

void init();

unsigned getActiveDeviceId();

int getDeviceNativeId(int device);
Expand Down
5 changes: 5 additions & 0 deletions src/backend/opencl/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ int getDeviceCount() noexcept try {
return 0;
}

void init() {
thread_local const DeviceManager& devMngr = DeviceManager::getInstance();
UNUSED(devMngr);
}

unsigned getActiveDeviceId() {
// Second element is the queue id, which is
// what we mean by active device id in opencl backend
Expand Down
2 changes: 2 additions & 0 deletions src/backend/opencl/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ std::string getDeviceInfo() noexcept;

int getDeviceCount() noexcept;

void init();

unsigned getActiveDeviceId();

int& getMaxJitSize();
Expand Down