forked from tshort/StaticCompiler.jl
/
ccalls.jl
90 lines (82 loc) · 3.71 KB
/
ccalls.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
find_ccalls(f, tt)
Returns a `Dict` mapping function addresses to symbol names for all `ccall`s and
`cglobal`s called from the method. This descends into other invocations
within the method.
"""
find_ccalls(@nospecialize(f), @nospecialize(tt)) = find_ccalls(reflect(f, tt))
function find_ccalls(ref::Reflection)
result = Dict{Ptr{Nothing}, Symbol}()
idx = VERSION > v"1.2" ? 5 : 4
foreigncalls = filter((c) -> lookthrough((c) -> c.head === :foreigncall && !(c.args[idx] isa QuoteNode && c.args[idx].value == :llvmcall), c), ref.CI.code)
# foreigncalls = filter((c) -> lookthrough((c) -> c.head === :foreigncall, c), ref.CI.code)
for fc in foreigncalls
sym = getsym(fc[2].args[1])
address = eval(:(cglobal($(sym))))
result[address] = Symbol(sym isa Tuple ? sym[1] : sym.value)
end
cglobals = filter((c) -> lookthrough(c -> c.head === :call && iscglobal(c.args[1]), c), ref.CI.code)
for fc in cglobals
sym = getsym(fc[2].args[2])
address = eval(:(cglobal($(sym))))
result[address] = Symbol(sym isa Tuple ? sym[1] : sym.value)
end
invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code)
invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes)
for fi in invokes
canreflect(fi) || continue
merge!(result, find_ccalls(reflect(fi)))
end
return result
end
getsym(x) = x
getsym(x::String) = QuoteNode(Symbol(x))
getsym(x::QuoteNode) = x
getsym(x::Expr) = eval.((x.args[2], x.args[3]))
iscglobal(x) = x == cglobal || x isa GlobalRef && x.name == :cglobal
"""
fix_ccalls!(mod::LLVM.Module, d)
Replace function addresses with symbol names in `mod`. The symbol names are
meant to be linked to `libjulia` or other libraries.
`d` is a `Dict` mapping a function address to symbol name for `ccall`s.
"""
function fix_ccalls!(mod::LLVM.Module, d)
for fun in functions(mod), blk in blocks(fun), instr in instructions(blk)
if instr isa LLVM.CallInst
dest = called_value(instr)
if dest isa ConstantExpr && occursin("inttoptr", string(dest))
# @show instr
# @show dest
argtypes = [llvmtype(op) for op in operands(instr)]
nargs = length(parameters(eltype(argtypes[end])))
# num_extra_args = 1 + length(collect(eachmatch(r"jl_roots", string(instr))))
ptr = Ptr{Cvoid}(convert(Int, first(operands(dest))))
if haskey(d, ptr)
sym = d[ptr]
newdest = LLVM.Function(mod, string(sym), LLVM.FunctionType(llvmtype(instr), argtypes[1:nargs]))
LLVM.linkage!(newdest, LLVM.API.LLVMExternalLinkage)
replace_uses!(dest, newdest)
end
end
elseif instr isa LLVM.LoadInst && occursin("inttoptr", string(instr))
# dest = called_value(instr)
for op in operands(instr)
lastop = op
if occursin("inttoptr", string(op))
# @show instr
if occursin("addrspacecast", string(op)) || occursin("getelementptr", string(op))
op = first(operands(op))
end
first(operands(op)) isa LLVM.ConstantInt || continue
ptr = Ptr{Cvoid}(convert(Int, first(operands(op))))
if haskey(d, ptr)
obj = d[ptr]
newdest = GlobalVariable(mod, llvmtype(instr), string(d[ptr]))
LLVM.linkage!(newdest, LLVM.API.LLVMExternalLinkage)
replace_uses!(op, newdest)
end
end
end
end
end
end