Skip to content

Commit 38db7cf

Browse files
committed
issue/591 - fix rearrange context mismatch
1 parent a311e9c commit 38db7cf

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

include/infinicore/context/context.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ size_t getDeviceCount(Device::Type type);
1717

1818
infinirtStream_t getStream();
1919
infiniopHandle_t getInfiniopHandle();
20+
infiniopHandle_t getInfiniopHandle(Device device);
2021

2122
void syncStream();
2223
void syncDevice();

src/infinicore/context/context_impl.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ infiniopHandle_t getInfiniopHandle() {
103103
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
104104
}
105105

106+
infiniopHandle_t getInfiniopHandle(Device device) {
107+
if (device != getDevice()) {
108+
setDevice(device);
109+
}
110+
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
111+
}
112+
106113
void syncStream() {
107114
return ContextImpl::singleton().getCurrentRuntime()->syncStream();
108115
}

src/infinicore/ops/rearrange/rearrange_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) {
2727
infiniopRearrangeDescriptor_t desc = nullptr;
2828

2929
if (!desc_opt) {
30-
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc()));
30+
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
3131
cache.put(seed, desc);
3232
} else {
3333
desc = *desc_opt;

0 commit comments

Comments
 (0)