diff --git a/src/details/ArborX_DetailsDistributor.hpp b/src/details/ArborX_DetailsDistributor.hpp index 185003d9f..c43cfc035 100644 --- a/src/details/ArborX_DetailsDistributor.hpp +++ b/src/details/ArborX_DetailsDistributor.hpp @@ -181,17 +181,38 @@ class Distributor requests.reserve(outdegrees + indegrees); for (int i = 0; i < indegrees; ++i) { - requests.emplace_back(); - MPI_Irecv(src_buffer.data() + _src_offsets[i] * num_packets, - _src_counts[i] * num_packets * sizeof(ValueType), MPI_BYTE, - _sources[i], MPI_ANY_TAG, _comm, &requests.back()); + if (_sources[i] != comm_rank) + { + auto const message_size = + _src_counts[i] * num_packets * sizeof(ValueType); + auto const receive_buffer_ptr = + src_buffer.data() + _src_offsets[i] * num_packets; + requests.emplace_back(); + MPI_Irecv(receive_buffer_ptr, message_size, MPI_BYTE, _sources[i], 123, + _comm, &requests.back()); + } } for (int i = 0; i < outdegrees; ++i) { - requests.emplace_back(); - MPI_Isend(dest_buffer.data() + _dest_offsets[i] * num_packets, - _dest_counts[i] * num_packets * sizeof(ValueType), MPI_BYTE, - _destinations[i], 123, _comm, &requests.back()); + auto const message_size = + _dest_counts[i] * num_packets * sizeof(ValueType); + auto const send_buffer_ptr = + dest_buffer.data() + _dest_offsets[i] * num_packets; + if (_destinations[i] == comm_rank) + { + auto const it = std::find(_sources.begin(), _sources.end(), comm_rank); + ARBORX_ASSERT(it != _sources.end()); + auto const position = it - _sources.begin(); + auto const receive_buffer_ptr = + src_buffer.data() + _src_offsets[position] * num_packets; + std::memcpy(receive_buffer_ptr, send_buffer_ptr, message_size); + } + else + { + requests.emplace_back(); + MPI_Isend(send_buffer_ptr, message_size, MPI_BYTE, _destinations[i], + 123, _comm, &requests.back()); + } } if (!requests.empty()) MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE);