-
Notifications
You must be signed in to change notification settings - Fork 344
/
mask.h
374 lines (289 loc) · 10.6 KB
/
mask.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
// Copyright Contributors to the Open Shading Language project.
// SPDX-License-Identifier: BSD-3-Clause
// https://github.com/AcademySoftwareFoundation/OpenShadingLanguage
#pragma once
#include <type_traits>
#include <OSL/oslconfig.h>
OSL_NAMESPACE_ENTER
// clang-format off
// Define popcount and countr_zero
#if OSL_CPLUSPLUS_VERSION >= 20
// For C++20 and beyond, these are in the standard library
using std::popcount;
using std::countr_zero;
#elif OSL_INTEL_CLASSIC_COMPILER_VERSION
#include <immintrin.h>
OSL_FORCEINLINE int popcount(uint32_t x) noexcept { return _mm_popcnt_u32(x);}
OSL_FORCEINLINE int popcount(uint64_t x) noexcept { return _mm_popcnt_u64(x); }
OSL_FORCEINLINE int countr_zero(uint32_t x) noexcept { return _bit_scan_forward(x); }
OSL_FORCEINLINE int countr_zero(uint64_t x) noexcept {
unsigned __int32 index;
_BitScanForward64(&index, x);
return static_cast<int>(index);
}
#elif defined(__GNUC__) || defined(__clang__)
OSL_FORCEINLINE int popcount(uint32_t x) noexcept { return __builtin_popcount(x); }
OSL_FORCEINLINE int popcount(uint64_t x) noexcept { return __builtin_popcountll(x); }
OSL_FORCEINLINE int countr_zero(uint32_t x) noexcept { return __builtin_ctz(x); }
OSL_FORCEINLINE int countr_zero(uint64_t x) noexcept { return __builtin_ctzll(x); }
#elif defined(_MSC_VER)
OSL_FORCEINLINE int popcount(uint32_t x) noexcept { return static_cast<int>(__popcnt(x)); }
OSL_FORCEINLINE int popcount(uint64_t x) noexcept { return static_cast<int>(__popcnt64(x)); }
OSL_FORCEINLINE int countr_zero(uint32_t x) noexcept {
unsigned long index;
_BitScanForward(&index, x);
return static_cast<int>(index);
}
OSL_FORCEINLINE int countr_zero(uint64_t x) noexcept {
unsigned long index;
_BitScanForward64(&index, x);
return static_cast<int>(index);
}
#else
# error "popcount and coutr_zero implementations needed for this compiler"
#endif
// clang-format on
// Simple wrapper to identify a single lane index vs. a mask_value
class Lane {
const int m_index;
public:
explicit OSL_FORCEINLINE Lane(int index) : m_index(index) {}
Lane() = delete;
OSL_FORCEINLINE Lane(const Lane& other) : m_index(other.m_index) {}
OSL_FORCEINLINE int value() const { return m_index; }
OSL_FORCEINLINE
operator int() const { return m_index; }
};
// Simple wrapper to identify an active lane
// Active lanes will bypass mask testing during assignments to Masked::LaneProxy's
// But be careful if you ever have two Masked::LaneProxy's with
// different masks
class ActiveLane : public Lane {
public:
explicit OSL_FORCEINLINE ActiveLane(int index) : Lane(index) {}
ActiveLane() = delete;
OSL_FORCEINLINE ActiveLane(const ActiveLane& other) : Lane(other) {}
};
template<int WidthT> class Mask {
typedef unsigned short Value16Type;
static_assert(sizeof(Value16Type) == 2, "unexpected platform");
typedef unsigned int Value32Type;
static_assert(sizeof(Value32Type) == 4, "unexpected platform");
typedef unsigned long long Value64Type;
static_assert(sizeof(Value64Type) == 8, "unexpected platform");
typedef
typename std::conditional<WidthT <= 32, Value32Type, Value64Type>::type
Value32or64Type;
public:
#if 0 // Enable 16bit integer storage of masks, vs 32bit.
typedef typename std::conditional<WidthT <= 16,
Value16Type,
Value32or64Type>::type ValueType;
#else
typedef Value32or64Type ValueType;
#endif
static constexpr int width = WidthT;
protected:
static constexpr int value_width = sizeof(ValueType) * 8;
static_assert(value_width >= WidthT, "unsupported WidthT");
static constexpr ValueType valid_bits
= static_cast<ValueType>(0xFFFFFFFFFFFFFFFF) >> (value_width - WidthT);
public:
OSL_FORCEINLINE Mask() {}
explicit OSL_FORCEINLINE Mask(Lane lane) : m_value(1 << lane.value()) {}
explicit OSL_FORCEINLINE Mask(bool all_on_or_off)
: m_value((all_on_or_off) ? valid_bits : 0)
{
}
explicit constexpr OSL_FORCEINLINE Mask(std::false_type) : m_value(0) {}
explicit constexpr OSL_FORCEINLINE Mask(std::true_type)
: m_value(valid_bits)
{
}
explicit OSL_FORCEINLINE Mask(Value16Type value_)
: m_value(static_cast<ValueType>(value_))
{
}
explicit OSL_FORCEINLINE Mask(Value32Type value_)
: m_value(static_cast<ValueType>(value_))
{
}
explicit OSL_FORCEINLINE Mask(Value64Type value_)
: m_value(static_cast<ValueType>(value_))
{
}
explicit OSL_FORCEINLINE Mask(int value_)
: m_value(static_cast<ValueType>(value_))
{
}
OSL_FORCEINLINE Mask(const Mask& other) : m_value(other.m_value) {}
template<int OtherWidthT,
typename = pvt::enable_if_type<(OtherWidthT < WidthT)>>
explicit OSL_FORCEINLINE Mask(const Mask<OtherWidthT>& other)
: m_value(static_cast<ValueType>(other.value()))
{
}
OSL_FORCEINLINE ValueType value() const { return m_value; }
// count number of active bits
OSL_FORCEINLINE int count() const { return OSL::popcount(m_value); }
// NOTE: undefined result if no bits are on
OSL_FORCEINLINE int first_on() const { return OSL::countr_zero(m_value); }
OSL_FORCEINLINE Mask invert() const
{
return Mask((~m_value) & valid_bits);
}
// Test only, don't allow assignment to force
// more verbose set_on or set_off to be used
// NOTE: As the actual lane value is embedded
// inside an integral type, we would have to
// return a proxy which could complicate
// codegen, so keeping it simple(r)
OSL_FORCEINLINE bool operator[](int lane) const
{
// __assume(lane >= 0 && lane < width);
//return (m_value & (1<<lane))==(1<<lane);
//return (m_value >>lane) & 1;
// From testing code generation this is the preferred form
return (m_value & (1 << lane));
}
OSL_FORCEINLINE bool is_on(int lane) const
{
// From testing code generation this is the preferred form
//return (m_value & (1<<lane))==(1<<lane);
return (m_value & (1 << lane));
}
OSL_FORCEINLINE bool is_off(int lane) const
{
// From testing code generation this is the preferred form
return (m_value & (1 << lane)) == 0;
}
OSL_FORCEINLINE bool all_on() const
{
// TODO: is this more expensive than == ?
return (m_value >= valid_bits);
}
OSL_FORCEINLINE bool all_off() const
{
return (m_value == static_cast<ValueType>(0));
}
OSL_FORCEINLINE bool any_on() const
{
return (m_value != static_cast<ValueType>(0));
}
OSL_FORCEINLINE bool any_off() const { return (m_value < valid_bits); }
OSL_FORCEINLINE bool any_off(const Mask& mask) const
{
return m_value != (m_value & mask.m_value);
}
// Setters
// For SIMD loops, set_on and set_off work better
// than a generic set(lane,flag).
// And in most all cases, the starting state
// for the mask was all on or all off,
// So really only set_on or set_off is required
// Choose to not provide a generic set(int lane, bool flag)
OSL_FORCEINLINE void set_on(int lane) { m_value |= (1 << lane); }
OSL_FORCEINLINE void set_on_if(int lane, bool cond)
{
m_value |= (cond << lane);
}
OSL_FORCEINLINE void set_all_on() { m_value = valid_bits; }
OSL_FORCEINLINE void set_count_on(int count)
{
m_value = valid_bits >> (width - count);
}
OSL_FORCEINLINE void set_off(int lane) { m_value &= (~(1 << lane)); }
OSL_FORCEINLINE void set_off_if(int lane, bool cond)
{
m_value &= (~(cond << lane));
}
OSL_FORCEINLINE void set_all_off() { m_value = static_cast<ValueType>(0); }
OSL_FORCEINLINE bool operator==(const Mask& other) const
{
return m_value == other.m_value;
}
OSL_FORCEINLINE bool operator!=(const Mask& other) const
{
return m_value != other.m_value;
}
OSL_FORCEINLINE Mask& operator&=(const Mask& other)
{
m_value = m_value & other.m_value;
return *this;
}
OSL_FORCEINLINE Mask& operator|=(const Mask& other)
{
m_value = m_value | other.m_value;
return *this;
}
OSL_FORCEINLINE Mask operator&(const Mask& other) const
{
return Mask(m_value & other.m_value);
}
OSL_FORCEINLINE Mask operator|(const Mask& other) const
{
return Mask(m_value | other.m_value);
}
OSL_FORCEINLINE Mask operator~() const { return invert(); }
template<int MinOccupancyT, int MaxOccupancyT = width, typename FunctorT>
OSL_FORCEINLINE void foreach (FunctorT f) const
{
// Expect compile time dead code elimination to skip this when possible
if (MaxOccupancyT == 0)
return;
// Expect compile time dead code elimination to skip this when possible
if (MinOccupancyT == 0) {
if (all_off())
return;
}
OSL_DASSERT(any_on());
// Expect compile time dead code elimination to emit
// one branch or the other
if (MaxOccupancyT == 1) {
ActiveLane active_lane(first_on());
f(active_lane);
} else {
Mask m(m_value);
do {
ActiveLane active_lane(m.first_on());
f(active_lane);
m.set_off(active_lane);
} while (m.any_on());
}
}
// Serially apply functor f to each
// lane active in the Mask
template<typename FunctorT> OSL_FORCEINLINE void foreach (FunctorT f) const
{
foreach
<0, width, FunctorT>(f);
}
// non-inlined version to isolate inherently serial codegen
// of functor f from other call site whose code gen might
// be SIMD in nature. Not inlining prevents optimizer from
// mixing serial code gen with call site which could inhibit
// ability to generate SIMD code.
template<int MinOccupancyT, int MaxOccupancyT = width, typename FunctorT>
OSL_NOINLINE void invoke_foreach(FunctorT f) const;
template<typename FunctorT>
OSL_NOINLINE void invoke_foreach(FunctorT f) const;
// Treat m_value as private, but access is needed for #pragma's to reference it in reduction clauses
ValueType m_value;
};
template<int WidthT>
template<typename FunctorT>
void
Mask<WidthT>::invoke_foreach(FunctorT f) const
{
foreach
<0, width, FunctorT>(f);
}
template<int WidthT>
template<int MinOccupancyT, int MaxOccupancyT, typename FunctorT>
void
Mask<WidthT>::invoke_foreach(FunctorT f) const
{
foreach
<MinOccupancyT, MaxOccupancyT, FunctorT>(f);
}
OSL_NAMESPACE_EXIT