-
Notifications
You must be signed in to change notification settings - Fork 29
/
wrap_cuda.jl
70 lines (62 loc) · 2.23 KB
/
wrap_cuda.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
using Clang
# The following two likely need to be modified for the host system
includes = ["/usr/include",
"/usr/lib/gcc/x86_64-redhat-linux/4.4.4/include",
"/usr/lib/gcc/x86_64-linux-gnu/4.8/include",
"/usr/lib/gcc/x86_64-linux-gnu/4.8/include-fixed"]
cudapath = "/usr/local/cuda-6.5/include"
headers = ["cuda_runtime_api.h", "driver_types.h", "vector_types.h"]
headers = [joinpath(cudapath,h) for h in headers]
# Customize how functions, constants, and structs are written
const skip_expr = [:(const CUDART_DEVICE = __device__)]
const skip_error_check = [:cudaStreamQuery,:cudaGetLastError,:cudaPeekAtLastError]
function rewriter(ex::Expr)
if in(ex, skip_expr)
return :()
end
# Empty types get converted to Void
if ex.head == :type
a3 = ex.args[3]
if isempty(a3.args)
objname = ex.args[2]
return :(const $objname = Void)
end
end
ex.head == :function || return ex
decl, body = ex.args[1], ex.args[2]
# omit types from function prototypes
for i = 2:length(decl.args)
a = decl.args[i]
if a.head == :(::)
decl.args[i] = a.args[1]
end
end
# Error-check functions that return a cudaError_t (with some omissions)
ccallexpr = body.args[1]
if ccallexpr.head != :ccall
error("Unexpected body expression: ", body)
end
rettype = ccallexpr.args[2]
if rettype == :cudaError_t
fname = decl.args[1]
if !in(fname, skip_error_check)
body.args[1] = Expr(:call, :checkerror, deepcopy(ccallexpr))
end
end
ex
end
rewriter(A::Array) = [rewriter(a) for a in A]
rewriter(s::Symbol) = string(s)
rewriter(arg) = arg
context=wrap_c.init(output_file="gen_libcudart.jl",
common_file="gen_libcudart_h.jl",
header_library=x->"libcudart",
headers = headers,
clang_includes=includes,
clang_diagnostics=true,
# header_wrapped=(x,y)->contains(x,"cuda"),
header_wrapped=(x,y)->true,
rewriter=rewriter)
context.options = wrap_c.InternalOptions(true,true) # wrap structs, too
# Execute the wrap
run(context)