Skip to content

Commit

Permalink
Merge pull request #480 from nR3D/sycl
Browse files Browse the repository at this point in the history
[SYCL] Improved SYCL particle sorting
  • Loading branch information
Xiangyu-Hu committed Nov 29, 2023
2 parents c1c3f47 + 8ea1507 commit b64440c
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 165 deletions.
2 changes: 1 addition & 1 deletion src/shared/particle_neighborhood/neighborhood.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Neighborhood& Neighborhood::operator=(const NeighborhoodDevice &device) {
}

NeighborhoodDevice::NeighborhoodDevice() : current_size_(allocateDeviceData<size_t>(1)),
allocated_size_(Dimensions == 2 ? 28 : 68),
allocated_size_(Dimensions == 2 ? 28 : 82),
j_(allocateDeviceData<size_t>(allocated_size_)),
W_ij_(allocateDeviceData<DeviceReal>(allocated_size_)),
dW_ijV_j_(allocateDeviceData<DeviceReal>(allocated_size_)),
Expand Down
187 changes: 182 additions & 5 deletions src/shared/particles/particle_sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ void ParticleSorting::sortingParticleData(size_t *begin, size_t size, execution:
if (!index_sorting_device_variables_)
index_sorting_device_variables_ = allocateDeviceData<size_t>(size);

sort_by_key(begin, index_sorting_device_variables_, size, execution::executionQueue.getQueue(), 256, 4, [](size_t *data, size_t idx)
{ return idx; }).wait();
device_radix_sorting.sort_by_key(begin, index_sorting_device_variables_, size, execution::executionQueue.getQueue(), 512, 4).wait();

move_sortable_particle_device_data_(index_sorting_device_variables_, size);

updateSortedDeviceId();
}
//=================================================================================================//
size_t split_count(bool bit, sycl::nd_item<1> &item)
template <class ValueType>
SYCL_EXTERNAL size_t DeviceRadixSort<ValueType>::split_count(bool bit, sycl::nd_item<1> &item)
{
const auto group_range = item.get_local_range().size();
const size_t id = item.get_local_id();
Expand All @@ -117,14 +117,191 @@ size_t split_count(bool bit, sycl::nd_item<1> &item)
return bit ? true_before - 1 + false_totals : id - true_before;
}
//=================================================================================================//
size_t get_digit(size_t key, size_t d, size_t radix_bits)
template <class ValueType>
size_t DeviceRadixSort<ValueType>::get_digit(size_t key, size_t d, size_t radix_bits)
{
return (key >> d * radix_bits) & ((1ul << radix_bits) - 1);
}
//=================================================================================================//
size_t get_bit(size_t key, size_t b)
template <class ValueType>
size_t DeviceRadixSort<ValueType>::get_bit(size_t key, size_t b)
{
return (key >> b) & 1;
}
//=================================================================================================//
template <class ValueType>
size_t DeviceRadixSort<ValueType>::find_max_element(const size_t *data, size_t size, size_t identity)
{
size_t result = identity;
auto &sycl_queue = execution::executionQueue.getQueue();
{
sycl::buffer<size_t> buffer_result(&result, 1);
sycl_queue.submit([&](sycl::handler &cgh)
{
auto reduction_operator = sycl::reduction(buffer_result, cgh, sycl::maximum<>());
cgh.parallel_for(execution::executionQueue.getUniformNdRange(size), reduction_operator,
[=](sycl::nd_item<1> item, auto& reduction) {
if(item.get_global_id() < size)
reduction.combine(data[item.get_global_linear_id()]);
}); })
.wait_and_throw();
}
return result;
}
//=================================================================================================//
template <class ValueType>
void DeviceRadixSort<ValueType>::resize(size_t data_size, size_t radix_bits, size_t workgroup_size)
{
data_size_ = data_size;
radix_bits_ = radix_bits;
workgroup_size_ = workgroup_size;
uniform_case_masking_ = data_size % workgroup_size;
uniform_global_size_ = uniform_case_masking_ ? (data_size / workgroup_size + 1) * workgroup_size : data_size;
kernel_range_ = {uniform_global_size_, workgroup_size};
workgroups_ = kernel_range_.get_group_range().size();

radix_ = 1ul << radix_bits; // radix = 2^b

sycl::range<2> buckets_column_major_range = {radix_, workgroups_}, buckets_row_major_range = {workgroups_, radix_};
// Each entry contains global number of digits with the same value
// Column-major, so buckets offsets can be computed by just applying a scan over it
global_buckets_ = std::make_unique<sycl::buffer<size_t, 2>>(buckets_column_major_range);
// Each entry contains global number of digits with the same and lower values
global_buckets_offsets_ = std::make_unique<sycl::buffer<size_t, 2>>(buckets_column_major_range);
local_buckets_offsets_buffer_ = std::make_unique<sycl::buffer<size_t, 2>>(buckets_row_major_range); // save state of local accessor
data_swap_buffer_ = std::make_unique<sycl::buffer<SortablePair>>(uniform_global_size_); // temporary memory for swapping
// Keep extra values to be swapped when kernel range has been made uniform
uniform_extra_swap_buffer_ = std::make_unique<sycl::buffer<SortablePair>>(uniform_global_size_ - data_size);
}
//=================================================================================================//
template <class ValueType>
sycl::event DeviceRadixSort<ValueType>::sort_by_key(size_t *keys, ValueType *data, size_t data_size, sycl::queue &queue, size_t workgroup_size, size_t radix_bits)
{
if(data_size_ != data_size || radix_bits_ != radix_bits || workgroup_size_ != workgroup_size)
resize(data_size, radix_bits, workgroup_size);

// Largest key, increased by 1 if the workgroup is not homogeneous with the data vector,
// the new maximum will be used for those work-items out of data range, that will then be excluded once sorted
const size_t max_key = find_max_element(keys, data_size, 0ul) + (uniform_case_masking_ ? 1 : 0);
const size_t bits_max_key = std::floor(std::log2(max_key)) + 1.0; // bits needed to represent max_key
const size_t length = max_key ? bits_max_key / radix_bits + (bits_max_key % radix_bits ? 1 : 0) : 1; // max number of radix digits

sycl::event sort_event{};
for (int digit = 0; digit < length; ++digit)
{

auto buckets_event = queue.submit([&](sycl::handler &cgh)
{
cgh.depends_on(sort_event);
auto data_swap_acc = data_swap_buffer_->get_access(cgh, sycl::write_only, sycl::no_init);
auto local_buckets = sycl::local_accessor<size_t>(radix_, cgh);
auto local_output = sycl::local_accessor<SortablePair>(kernel_range_.get_local_range(), cgh);
auto global_buckets_accessor = global_buckets_->get_access(cgh, sycl::read_write, sycl::no_init);
auto local_buckets_offsets_accessor = local_buckets_offsets_buffer_->get_access(cgh, sycl::write_only,
sycl::no_init);

cgh.parallel_for(kernel_range_, [=, radix=radix_](sycl::nd_item<1> item) {
const size_t workgroup = item.get_group_linear_id(),
global_id = item.get_global_id();

SortablePair number;
// Initialize key-data pair, with masking in case of non-homogeneous data_size/workgroup_size
if(global_id < data_size)
number = {keys[global_id],
// Give possibility to initialize data here to avoid calling
// another kernel before sort_by_key in order to initialize it
digit ? data[global_id] : global_id};
else // masking extra indexes
// Initialize exceeding values to the largest key considered
number.first = (1 << bits_max_key) - 1; // max key for given number of bits


// Locally sort digit with split primitive
auto radix_digit = get_digit(number.first, digit, radix_bits);
auto rank = split_count(get_bit(radix_digit, 0), item); // sorting first bit
local_output[rank] = number;
for (size_t b = 1; b < radix_bits; ++b) { // sorting remaining bits
item.barrier(sycl::access::fence_space::local_space);
number = local_output[item.get_local_id()];
radix_digit = get_digit(number.first, digit, radix_bits);

rank = split_count(get_bit(radix_digit, b), item);
local_output[rank] = number;
}

// Initialize local buckets to zero, since they are uninitialized by default
for (size_t r = 0; r < radix; ++r)
local_buckets[r] = 0;

item.barrier(sycl::access::fence_space::local_space);
{
sycl::atomic_ref<size_t, sycl::memory_order_relaxed, sycl::memory_scope_work_group,
sycl::access::address_space::local_space> bucket_r{local_buckets[radix_digit]};
++bucket_r;
item.barrier(sycl::access::fence_space::local_space);
}

// Save local buckets to global memory, with one row per work-group (in column-major order)
for (size_t r = 0; r < radix; ++r)
global_buckets_accessor[r][workgroup] = local_buckets[r];

if(global_id < data_size)
data_swap_acc[workgroup_size * workgroup + rank] = number; // save local sorting back to data

// Compute local buckets offsets
size_t *begin = local_buckets.get_pointer(), *end = begin + radix,
*outBegin = local_buckets_offsets_accessor.get_pointer().get() + workgroup * radix;
sycl::joint_exclusive_scan(item.get_group(), begin, end, outBegin, sycl::plus<size_t>{});
}); });

// Global synchronization to make sure that all locally computed buckets have been copied to global memory

sycl::event scan_event = queue.submit([&](sycl::handler &cgh) {
cgh.depends_on(buckets_event);
auto global_buckets_accessor = global_buckets_->get_access(cgh, sycl::read_only);
auto global_buckets_offsets_accessor = global_buckets_offsets_->get_access(cgh, sycl::write_only);
cgh.parallel_for(kernel_range_, [=](sycl::nd_item<1> item) {
// Compute global buckets offsets
if(item.get_group_linear_id() == 0) {
size_t *begin = global_buckets_accessor.get_pointer(), *end = begin + global_buckets_accessor.size();
sycl::joint_exclusive_scan(item.get_group(), begin, end,
global_buckets_offsets_accessor.get_pointer(), sycl::plus<size_t>{});
}
});
});

sort_event = queue.submit([&](sycl::handler &cgh)
{
cgh.depends_on(scan_event);
auto data_swap_acc = data_swap_buffer_->get_access(cgh, sycl::read_only);
auto global_buckets_accessor = global_buckets_->get_access(cgh, sycl::read_only);
auto global_buckets_offsets_accessor = global_buckets_offsets_->get_access(cgh, sycl::read_write);
auto local_buckets_offsets_accessor = local_buckets_offsets_buffer_->get_access(cgh, sycl::read_only);
cgh.parallel_for(kernel_range_, [=](sycl::nd_item<1> item) {
// Compute global buckets offsets
size_t *begin = global_buckets_accessor.get_pointer(), *end = begin + global_buckets_accessor.size();
sycl::joint_exclusive_scan(item.get_group(), begin, end,
global_buckets_offsets_accessor.get_pointer(), sycl::plus<size_t>{});

// Mask only relevant indexes. All max_keys added to homogenize the computations
// should be owned by work-items with global_id >= data_size
if(item.get_global_id() < data_size) {
// Retrieve position and sorted data from swap memory
const size_t rank = item.get_local_id(), workgroup = item.get_group_linear_id();
const SortablePair number = data_swap_acc[workgroup_size * workgroup + rank];
const size_t radix_digit = get_digit(number.first, digit, radix_bits);

// Compute sorted position based on global and local buckets
const size_t data_offset = global_buckets_offsets_accessor[radix_digit][workgroup] + rank -
local_buckets_offsets_accessor[workgroup][radix_digit];

// Copy to original data pointers
keys[data_offset] = number.first;
data[data_offset] = number.second;
}
}); });
}
return sort_event;
}
//=================================================================================================//
} // namespace SPH

0 comments on commit b64440c

Please sign in to comment.