/
utils.jl
86 lines (69 loc) · 2.16 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# helpers for writing device functionality
# helper type for writing Int32 literals
# TODO: upstream this
struct Literal{T} end
Base.:(*)(x::Number, ::Type{Literal{T}}) where {T} = T(x)
const i32 = Literal{Int32}
# local method table for device functions
@static if isdefined(Base.Experimental, Symbol("@overlay"))
Base.Experimental.@MethodTable(method_table)
else
const method_table = nothing
end
macro device_override(ex)
ex = macroexpand(__module__, ex)
esc(quote
Base.Experimental.@overlay(CUDA.method_table, $ex)
end)
end
macro device_function(ex)
ex = macroexpand(__module__, ex)
def = splitdef(ex)
# generate a function that errors
def[:body] = quote
error("This function is not intended for use on the CPU")
end
esc(quote
$(combinedef(def))
@device_override $ex
end)
end
macro device_functions(ex)
ex = macroexpand(__module__, ex)
# recursively prepend `@device_function` to all function definitions
function rewrite(block)
out = Expr(:block)
for arg in block.args
if Meta.isexpr(arg, :block)
# descend in blocks
push!(out.args, rewrite(arg))
elseif Meta.isexpr(arg, [:function, :(=)])
# rewrite function definitions
push!(out.args, :(@device_function $arg))
else
# preserve all the rest
push!(out.args, arg)
end
end
out
end
esc(rewrite(ex))
end
## alignment API
# we don't expose this as Aligned{N}, because we want to have the T typevar first
# to facilitate use in function signatures as ::Aligned{<:T}
struct Aligned{T, N}
data::T
end
alignment(::Aligned{<:Any, N}) where {N} = N
Base.getindex(x::Aligned) = x.data
"""
CUDA.align{N}(obj)
Construct an aligned object, providing alignment information to APIs that require it.
"""
struct align{N} end
(::Type{align{N}})(data::T) where {T,N} = Aligned{T,N}(data)
# default alignment for common types
Aligned(x::Aligned) = x
Aligned(x::Ptr{T}) where T = align{Base.datatype_alignment(T)}(x)
Aligned(x::LLVMPtr{T}) where T = align{Base.datatype_alignment(T)}(x)