/
main.cpp
65 lines (55 loc) · 2.3 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
//==============================================================
// Copyright © 2020 Intel Corporation
//
// SPDX-License-Identifier: MIT
// =============================================================
// oneDPL headers should be included before standard headers
#include <oneapi/dpl/algorithm>
#include <oneapi/dpl/execution>
#include <oneapi/dpl/iterator>
#include <sycl/sycl.hpp>
#include <iostream>
int main() {
const int n = 1000000;
sycl::buffer<int> keys_buf{n}; // buffer with keys
sycl::buffer<int> vals_buf{n}; // buffer with values
// create objects to iterate over buffers
auto keys_begin = oneapi::dpl::begin(keys_buf);
auto vals_begin = oneapi::dpl::begin(vals_buf);
auto counting_begin = oneapi::dpl::counting_iterator<int>{0};
// use default policy for algorithms execution
auto policy = oneapi::dpl::execution::dpcpp_default;
// 1. Initialization of buffers
// let keys_buf contain {n, n, n-2, n-2, ..., 4, 4, 2, 2}
std::transform(policy, counting_begin, counting_begin + n, keys_begin,
[n](int i) { return n - (i / 2) * 2; });
// fill vals_buf with the analogue of std::iota using counting_iterator
std::copy(policy, counting_begin, counting_begin + n, vals_begin);
// 2. Sorting
auto zipped_begin = oneapi::dpl::make_zip_iterator(keys_begin, vals_begin);
// stable sort by keys
std::stable_sort(
policy, zipped_begin, zipped_begin + n,
// Generic lambda is needed because type of lhs and rhs is unspecified.
[](auto lhs, auto rhs) { return std::get<0>(lhs) < std::get<0>(rhs); });
// 3.Checking results
sycl::host_accessor host_keys(keys_buf, sycl::read_only);
sycl::host_accessor host_vals(vals_buf, sycl::read_only);
// expected output:
// keys: {2, 2, 4, 4, ..., n - 2, n - 2, n, n}
// vals: {n - 2, n - 1, n - 4, n - 3, ..., 2, 3, 0, 1}
for (int i = 0; i < n; ++i) {
if (host_keys[i] != (i / 2) * 2 &&
host_vals[i] != n - (i / 2) * 2 - (i % 2 == 0 ? 2 : 1)) {
std::cout << "fail: i = " << i << ", host_keys[i] = " << host_keys[i]
<< ", host_vals[i] = " << host_vals[i] << "\n";
return 1;
}
}
std::cout << "success\nRun on "
<< policy.queue()
.get_device()
.template get_info<sycl::info::device::name>()
<< "\n";
return 0;
}