/
indexing.jl
140 lines (108 loc) · 3.97 KB
/
indexing.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Indexing and dimensions (B.4)
export
threadIdx, blockDim, blockIdx, gridDim,
laneid, lanemask, warpsize, active_mask, FULL_MASK
@generated function _index(::Val{name}, ::Val{range}) where {name, range}
@dispose ctx=Context() begin
T_int32 = LLVM.Int32Type()
# create function
llvm_f, _ = create_function(T_int32)
mod = LLVM.parent(llvm_f)
# generate IR
@dispose builder=IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)
# call the indexing intrinsic
intr_typ = LLVM.FunctionType(T_int32)
intr = LLVM.Function(mod, "llvm.nvvm.read.ptx.sreg.$name", intr_typ)
idx = call!(builder, intr_typ, intr)
# attach range metadata
range_metadata = MDNode([ConstantInt(Int32(range.start)),
ConstantInt(Int32(range.stop))])
metadata(idx)[LLVM.MD_range] = range_metadata
ret!(builder, idx)
end
call_function(llvm_f, Int32)
end
end
# XXX: these depend on the compute capability
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
const max_block_size = (x=1024, y=1024, z=64)
const max_grid_size = (x=2^31-1, y=65535, z=65535)
for dim in (:x, :y, :z)
# Thread index
fn = Symbol("threadIdx_$dim")
intr = Symbol("tid.$dim")
@eval @inline $fn() = _index($(Val(intr)), $(Val(0:max_block_size[dim]-1))) + 1i32
# Block size (#threads per block)
fn = Symbol("blockDim_$dim")
intr = Symbol("ntid.$dim")
@eval @inline $fn() = _index($(Val(intr)), $(Val(1:max_block_size[dim])))
# Block index
fn = Symbol("blockIdx_$dim")
intr = Symbol("ctaid.$dim")
@eval @inline $fn() = _index($(Val(intr)), $(Val(0:max_grid_size[dim]-1))) + 1i32
# Grid size (#blocks per grid)
fn = Symbol("gridDim_$dim")
intr = Symbol("nctaid.$dim")
@eval @inline $fn() = _index($(Val(intr)), $(Val(1:max_grid_size[dim])))
end
@device_functions begin
"""
gridDim()::NamedTuple
Returns the dimensions of the grid.
"""
@inline gridDim() = (x=gridDim_x(), y=gridDim_y(), z=gridDim_z())
"""
blockIdx()::NamedTuple
Returns the block index within the grid.
"""
@inline blockIdx() = (x=blockIdx_x(), y=blockIdx_y(), z=blockIdx_z())
"""
blockDim()::NamedTuple
Returns the dimensions of the block.
"""
@inline blockDim() = (x=blockDim_x(), y=blockDim_y(), z=blockDim_z())
"""
threadIdx()::NamedTuple
Returns the thread index within the block.
"""
@inline threadIdx() = (x=threadIdx_x(), y=threadIdx_y(), z=threadIdx_z())
"""
warpsize()::Int32
Returns the warp size (in threads).
"""
@inline warpsize() = ccall("llvm.nvvm.read.ptx.sreg.warpsize", llvmcall, Int32, ())
"""
laneid()::Int32
Returns the thread's lane within the warp.
"""
@inline laneid() = ccall("llvm.nvvm.read.ptx.sreg.laneid", llvmcall, Int32, ()) + 1i32
"""
lanemask(pred)::UInt32
Returns a 32-bit mask indicating which threads in a warp satisfy the given predicate.
Supported predicates are `==`, `<`, `<=`, `>=`, and `>`.
"""
@inline function lanemask(pred::F) where F
if pred === Base.:(==)
ccall("llvm.nvvm.read.ptx.sreg.lanemask.eq", llvmcall, UInt32, ())
elseif pred === Base.:(<)
ccall("llvm.nvvm.read.ptx.sreg.lanemask.lt", llvmcall, UInt32, ())
elseif pred === Base.:(<=)
ccall("llvm.nvvm.read.ptx.sreg.lanemask.le", llvmcall, UInt32, ())
elseif pred === Base.:(>=)
ccall("llvm.nvvm.read.ptx.sreg.lanemask.ge", llvmcall, UInt32, ())
elseif pred === Base.:(>)
ccall("llvm.nvvm.read.ptx.sreg.lanemask.gt", llvmcall, UInt32, ())
else
throw(ArgumentError("invalid lanemask function"))
end
end
"""
active_mask()
Returns a 32-bit mask indicating which threads in a warp are active with the current
executing thread.
"""
@inline active_mask() = @asmcall("activemask.b32 \$0;", "=r", false, UInt32, Tuple{})
end
const FULL_MASK = 0xffffffff