/
wavefront_sync.jl
146 lines (120 loc) · 4.03 KB
/
wavefront_sync.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@inline function __check_mask(mask::UInt64)
done::Cint = Cint(0)
while wfany(__not(done))
if __not(done) == Cint(1)
chosen_mask = readfirstlane(mask)
if mask == chosen_mask
# TODO
# @rocassert(mask == ballot(true),
# "All threads specified in the mask must execute the same operation.")
done = Cint(1)
end
end
end
end
"""
ballot_sync(mask::UInt64, predicate::Bool)::UInt64
Evaluate `predicate` for all non-exited threads in `mask`
and return an integer whose Nth bit is set if and only if
`predicate` is `true` for the Nth thread of the wavefront and the Nth thread is active.
```jldoctest
julia> function ker!(x)
i = AMDGPU.Device.activelane()
if i % 2 == 0
mask = 0x0000000055555555 # Only even threads.
x[1] = AMDGPU.Device.ballot_sync(mask, true)
end
return
end
ker! (generic function with 1 method)
julia> x = ROCArray{UInt64}(undef, 1);
julia> @roc groupsize=32 ker!(x);
julia> bitstring(Array(x)[1])
"0000000000000000000000000000000001010101010101010101010101010101"
```
"""
function ballot_sync(mask::UInt64, predicate::Bool)::UInt64
__check_mask(mask)
ballot(predicate) & mask
end
"""
any_sync(mask::UInt64, predicate::Bool)::Bool
Evaluate `predicate` for all non-exited threads in `mask` and
return non-zero if and only if `predicate` evaluates to non-zero for any of them.
```jldoctest
julia> function ker!(x)
i = AMDGPU.Device.activelane()
if i % 2 == 0
mask = 0x0000000055555555 # Only even threads.
x[1] = AMDGPU.Device.any_sync(mask, i == 0)
end
return
end
ker! (generic function with 1 method)
julia> x = ROCArray{Bool}(undef, 1);
julia> @roc groupsize=32 ker!(x);
julia> x
1-element ROCArray{Bool, 1, AMDGPU.Runtime.Mem.HIPBuffer}:
1
```
"""
any_sync(mask::UInt64, predicate::Bool)::Bool = ballot_sync(mask, predicate) != 0
"""
all_sync(mask::UInt64, predicate::Bool)::Bool
Evaluate `predicate` for all non-exited threads in `mask` and
return non-zero if and only if `predicate` evaluates to non-zero for all of them.
```jldoctest
julia> function ker!(x)
i = AMDGPU.Device.activelane()
if i % 2 == 0
mask = 0x0000000055555555 # Only even threads.
x[1] = AMDGPU.Device.all_sync(mask, true)
end
return
end
ker! (generic function with 1 method)
julia> x = ROCArray{Bool}(undef, 1);
julia> @roc groupsize=32 ker!(x);
julia> x
1-element ROCArray{Bool, 1, AMDGPU.Runtime.Mem.HIPBuffer}:
1
```
"""
all_sync(mask::UInt64, predicate::Bool)::Bool = ballot_sync(mask, predicate) == mask
"""
shfl_sync(mask::UInt64, val, lane, width = wavefrontsize())
Synchronize threads according to a `mask` and
read data stored in `val` from a `lane` ID.
"""
function shfl_sync(mask::UInt64, val, lane, width = wavefrontsize())
__check_mask(mask)
shfl(val, lane, width)
end
"""
shfl_up_sync(mask::UInt64, val, δ, width = wavefrontsize())
Synchronize threads according to a `mask` and
read data stored in `val` from a `lane` with lower ID relative to the caller.
"""
function shfl_up_sync(mask::UInt64, val, δ, width = wavefrontsize())
__check_mask(mask)
shfl_up(val, δ, width)
end
"""
shfl_down_sync(mask::UInt64, val, δ, width = wavefrontsize())
Synchronize threads according to a `mask` and
read data stored in `val` from a `lane` with higher ID relative to the caller.
"""
function shfl_down_sync(mask::UInt64, val, δ, width = wavefrontsize())
__check_mask(mask)
shfl_down(val, δ, width)
end
"""
shfl_xor_sync(mask::UInt64, val, lane_mask, width = wavefrontsize())
Synchronize threads according to a `mask` and
read data stored in `val` from a lane according to a bitwise XOR
of the caller's lane ID with the `lane_mask`.
"""
function shfl_xor_sync(mask::UInt64, val, lane_mask, width = wavefrontsize())
__check_mask(mask)
shfl_xor(val, lane_mask, width)
end