/
thread_partition.c
242 lines (190 loc) · 5.25 KB
/
thread_partition.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
/******************************************************************************
* INCLUDES
*****************************************************************************/
#include "thread_partition.h"
#include "timer.h"
/******************************************************************************
* PRIVATE FUNCTIONS
*****************************************************************************/
static idx_t nprobes = 0;
/**
* @brief Perform a linear search on an array for a value.
*
* @param weights The array to search.
* @param left The lower bound to begin at.
* @param right The upper (exclusive) bound of items.
* @param target The target value.
*
* @return The index j, where weights[j] <= target && weights[j+1] > target.
*/
static idx_t p_linear_search(
idx_t const * const weights,
idx_t const left,
idx_t const right,
idx_t const target)
{
for(idx_t x=left; x < right-1; ++x) {
if(target < weights[x+1]) {
return x+1;
}
}
return right;
}
/**
* @brief Perform a binary search on an array for a value.
*
* @param weights The array to search.
* @param left The lower bound to begin at.
* @param right The upper (exclusive) bound of items.
* @param target The target value.
*
* @return The index j, where weights[j] <= target && weights[j+1] > target.
*/
static idx_t p_binary_search(
idx_t const * const weights,
idx_t left,
idx_t right,
idx_t const target)
{
while((right - left) > 8) {
idx_t mid = left + ((right - left) / 2);
if(weights[mid] <= target && weights[mid+1] > target) {
return mid;
}
if(weights[mid] < target) {
left = mid + 1;
} else {
right = mid;
}
}
return p_linear_search(weights, left, right, target);
}
/*
* Not static because we use it in unit tests.
*/
bool lprobe(
idx_t const * const weights,
idx_t const nitems,
idx_t * const parts,
idx_t const nparts,
idx_t const bottleneck)
{
++nprobes;
idx_t const wtotal = weights[nitems-1];
/* initialize partitioning */
parts[0] = 0;
for(idx_t p=1; p <= nparts; ++p) {
parts[p] = nitems;
}
idx_t bsum = bottleneck;
idx_t step = nitems / nparts;
for(idx_t p=1; p < nparts; ++p) {
/* jump to the next bucket */
while(step < nitems && weights[step] < bsum) {
step += nitems / nparts;
}
/* find the end (exclusive) index of process p */
parts[p] = p_binary_search(weights, step - (nitems/nparts),
SS_MIN(step,nitems), bsum);
/* we ran out of stuff to do */
if(parts[p] == nitems) {
/* check for pathological case when the last weight is larger than
* bottleneck */
idx_t const size_last = weights[nitems-1] - weights[parts[p-1]-1];
return size_last < bottleneck;
}
bsum = weights[parts[p]-1] + bottleneck;
}
return bsum >= wtotal;
}
static idx_t p_eps_rb_partition_1d(
idx_t * const weights,
idx_t const nitems,
idx_t * const parts,
idx_t const nparts,
idx_t const eps)
{
idx_t const tot_weight = weights[nitems-1];
idx_t lower = tot_weight / nparts;
idx_t upper = tot_weight;
do {
idx_t mid = lower + ((upper - lower) / 2);
if(lprobe(weights, nitems, parts, nparts, mid)) {
upper = mid;
} else {
lower = mid+1;
}
} while(upper > lower + eps);
return upper;
}
/******************************************************************************
* PUBLIC FUNCTIONS
*****************************************************************************/
idx_t * partition_weighted(
idx_t * const weights,
idx_t const nitems,
idx_t const nparts,
idx_t * const bottleneck)
{
timer_start(&timers[TIMER_PART]);
prefix_sum_inc(weights, nitems);
idx_t * parts = splatt_malloc((nparts+1) * sizeof(*parts));
nprobes = 0;
idx_t bneck = 0;
/* actual partitioning */
if(nitems > nparts) {
/* use recursive bisectioning with 0 tolerance to get exact solution */
bneck = p_eps_rb_partition_1d(weights, nitems, parts, nparts, 0);
/* apply partitioning that we found */
bool success = lprobe(weights, nitems, parts, nparts, bneck);
assert(success == true);
/* Do a trivial partitioning. Silly, but this happens when tensors have
* short modes. */
} else {
for(idx_t p=0; p < nitems; ++p) {
parts[p] = p;
bneck = SS_MAX(bneck, weights[p]);
}
for(idx_t p=nitems; p <= nparts; ++p) {
parts[p] = nitems;
}
}
*bottleneck = bneck;
timer_stop(&timers[TIMER_PART]);
return parts;
}
idx_t * partition_simple(
idx_t const nitems,
idx_t const nparts)
{
timer_start(&timers[TIMER_PART]);
idx_t * parts = splatt_malloc((nparts+1) * sizeof(*parts));
parts[0] = 0;
idx_t const per_part = SS_MAX(nitems / nparts, 1);
for(idx_t p=1; p < nparts; ++p) {
parts[p] = SS_MAX(SS_MIN(per_part * p, nitems), 1);
}
parts[nparts] = nitems;
timer_stop(&timers[TIMER_PART]);
return parts;
}
void prefix_sum_inc(
idx_t * const weights,
idx_t const nitems)
{
for(idx_t x=1; x < nitems; ++x) {
weights[x] += weights[x-1];
}
}
void prefix_sum_exc(
idx_t * const weights,
idx_t const nitems)
{
idx_t saved = weights[0];
weights[0] = 0;
for(idx_t x=1; x < nitems; ++x) {
idx_t const tmp = weights[x];
weights[x] = weights[x-1] + saved;
saved = tmp;
}
}