diff --git a/examples/jit.jl b/examples/jit.jl index 8a70a543..77062e36 100644 --- a/examples/jit.jl +++ b/examples/jit.jl @@ -19,7 +19,7 @@ module TestRuntime end struct TestCompilerParams <: AbstractCompilerParams end -GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntime +GPUCompiler.runtime_module(::CompilerJob{<:Any, TestCompilerParams}) = TestRuntime ## JIT integration @@ -58,8 +58,8 @@ const jit = Ref{CompilerInstance}() function get_trampoline(job) compiler = jit[] lljit = compiler.jit - lctm = compiler.lctm - ism = compiler.ism + lctm = compiler.lctm + ism = compiler.ism # We could also use one dylib per job jd = JITDylib(lljit) @@ -68,11 +68,14 @@ function get_trampoline(job) target_sym = String(gensym(:target)) flags = LLVM.API.LLVMJITSymbolFlags( LLVM.API.LLVMJITSymbolGenericFlagsCallable | - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + LLVM.API.LLVMJITSymbolGenericFlagsExported, 0 + ) entry = LLVM.API.LLVMOrcCSymbolAliasMapPair( mangle(lljit, entry_sym), LLVM.API.LLVMOrcCSymbolAliasMapEntry( - mangle(lljit, target_sym), flags)) + mangle(lljit, target_sym), flags + ) + ) mu = LLVM.reexports(lctm, ism, jd, Ref(entry)) LLVM.define(jd, mu) @@ -85,7 +88,7 @@ function get_trampoline(job) function materialize(mr) buf = JuliaContext() do ctx - ir, meta = GPUCompiler.compile(:llvm, job; validate=false) + ir, meta = GPUCompiler.compile(:llvm, job; validate = false) # Rename entry to match target_sym LLVM.name!(meta.entry, target_sym) @@ -117,14 +120,14 @@ function get_trampoline(job) end import GPUCompiler: deferred_codegen_jobs -@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} +@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F, tt, world} # manual version of native_job because we have a function type source = methodinstance(F, Base.to_tuple_type(tt), world) - target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) + target = NativeCompilerTarget(; jlruntime = true, llvm_always_inline = true) # XXX: do we actually require the Julia runtime? # with jlruntime=false, we reach an unreachable. params = TestCompilerParams() - config = CompilerConfig(target, params; kernel=false) + config = CompilerConfig(target, params; kernel = false) job = CompilerJob(source, config, world) # XXX: invoking GPUCompiler from a generated function is not allowed! # for things to work, we need to forward the correct world, at least. @@ -135,7 +138,7 @@ import GPUCompiler: deferred_codegen_jobs deferred_codegen_jobs[id] = job - quote + return quote ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) assume(ptr != C_NULL) return ptr @@ -143,8 +146,8 @@ import GPUCompiler: deferred_codegen_jobs end @generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N} - argtt = tt.parameters[1] - rettype = rt.parameters[1] + argtt = tt.parameters[1] + rettype = rt.parameters[1] argtypes = DataType[argtt.parameters...] argexprs = Union{Expr, Symbol}[] @@ -199,7 +202,7 @@ end if GPUCompiler.isghosttype(rettype) || Core.Compiler.isconstType(rettype) # Do nothing... # In theory we could set `rettype` to `T_void`, but ccall will do that for us - # elseif jl_is_uniontype? + # elseif jl_is_uniontype? elseif !GPUCompiler.deserves_retbox(rettype) rt = convert(LLVMType, rettype) if !isa(rt, LLVM.VoidType) && GPUCompiler.deserves_sret(rettype, rt) @@ -214,26 +217,26 @@ end end end - quote + return quote $before ret = ccall(f, $rettype, ($(ccall_types...),), $(argexprs...)) $after end end -@inline function call_delayed(f::F, args...) where F +@inline function call_delayed(f::F, args...) where {F} tt = Tuple{map(Core.Typeof, args)...} rt = Core.Compiler.return_type(f, tt) world = GPUCompiler.tls_world_age() ptr = deferred_codegen(f, Val(tt), Val(world)) - abi_call(ptr, rt, tt, f, args...) + return abi_call(ptr, rt, tt, f, args...) end optlevel = LLVM.API.LLVMCodeGenLevelDefault -tm = GPUCompiler.JITTargetMachine(optlevel=optlevel) +tm = GPUCompiler.JITTargetMachine(optlevel = optlevel) LLVM.asm_verbosity!(tm, true) -lljit = LLJIT(;tm) +lljit = LLJIT(; tm) jd_main = JITDylib(lljit) @@ -267,36 +270,36 @@ using Test f(A) = (A[] += 42; nothing) global flag = [0] function caller() - call_delayed(f, flag::Vector{Int}) + return call_delayed(f, flag::Vector{Int}) end @test caller() === nothing @test flag[] == 42 # test that we can call a function with a return value -add(x, y) = x+y +add(x, y) = x + y function call_add(x, y) - call_delayed(add, x, y) + return call_delayed(add, x, y) end @test call_add(1, 3) == 4 incr(r) = r[] += 1 function call_incr(r) - call_delayed(incr, r) + return call_delayed(incr, r) end r = Ref{Int}(0) @test call_incr(r) == 1 @test r[] == 1 function call_real(c) - call_delayed(real, c) + return call_delayed(real, c) end -@test call_real(1.0+im) == 1.0 +@test call_real(1.0 + im) == 1.0 # tests struct return if Sys.ARCH != :aarch64 - @test call_delayed(complex, 1.0, 2.0) == 1.0+2.0im + @test call_delayed(complex, 1.0, 2.0) == 1.0 + 2.0im else - @test_broken call_delayed(complex, 1.0, 2.0) == 1.0+2.0im + @test_broken call_delayed(complex, 1.0, 2.0) == 1.0 + 2.0im end throws(arr, i) = arr[i] @@ -306,11 +309,11 @@ throws(arr, i) = arr[i] struct Closure x::Int64 end -(c::Closure)(b) = c.x+b +(c::Closure)(b) = c.x + b @test call_delayed(Closure(3), 5) == 8 struct Closure2 x::Integer end -(c::Closure2)(b) = c.x+b +(c::Closure2)(b) = c.x + b @test call_delayed(Closure2(3), 5) == 8 diff --git a/examples/kernel.jl b/examples/kernel.jl index e54982a9..be740e67 100644 --- a/examples/kernel.jl +++ b/examples/kernel.jl @@ -11,7 +11,7 @@ module TestRuntime end struct TestCompilerParams <: AbstractCompilerParams end -GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntime +GPUCompiler.runtime_module(::CompilerJob{<:Any, TestCompilerParams}) = TestRuntime kernel() = nothing @@ -26,7 +26,7 @@ function main() GPUCompiler.compile(:asm, job) end - println(output[1]) + return println(output[1]) end isinteractive() || main() diff --git a/src/GPUCompiler.jl b/src/GPUCompiler.jl index ec33e9a7..8f2cf75e 100644 --- a/src/GPUCompiler.jl +++ b/src/GPUCompiler.jl @@ -66,7 +66,7 @@ function __init__() global compile_cache = dir Tracy.@register_tracepoints() - register_deferred_codegen() + return register_deferred_codegen() end end # module diff --git a/src/bpf.jl b/src/bpf.jl index 80070d6d..84976f84 100644 --- a/src/bpf.jl +++ b/src/bpf.jl @@ -5,7 +5,7 @@ export BPFCompilerTarget Base.@kwdef struct BPFCompilerTarget <: AbstractCompilerTarget - function_pointers::UnitRange{Int}=1:1000 # set of valid function "pointers" + function_pointers::UnitRange{Int} = 1:1000 # set of valid function "pointers" end llvm_triple(::BPFCompilerTarget) = "bpf-bpf-bpf" @@ -13,7 +13,7 @@ llvm_datalayout(::BPFCompilerTarget) = "e-m:e-p:64:64-i64:64-n32:64-S128" function llvm_machine(target::BPFCompilerTarget) triple = llvm_triple(target) - t = Target(;triple=triple) + t = Target(; triple = triple) cpu = "" feat = "" diff --git a/src/debug.jl b/src/debug.jl index e47cf594..087fd387 100644 --- a/src/debug.jl +++ b/src/debug.jl @@ -19,7 +19,7 @@ function backtrace(inst::LLVM.Instruction, bt = StackTraces.StackFrame[]) while loc !== nothing scope = LLVM.scope(loc) if scope !== nothing - name = replace(LLVM.name(scope), r";$"=>"") + name = replace(LLVM.name(scope), r";$" => "") file = LLVM.file(scope) path = joinpath(LLVM.directory(file), LLVM.filename(file)) line = LLVM.line(loc) diff --git a/src/driver.jl b/src/driver.jl index 950ea272..ce12b9f3 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -18,9 +18,9 @@ export JuliaContext # JuliaContext helper below, which returns a local context on Julia 1.9, and the global # unique context on all other versions. Once we only support Julia 1.9, we'll deprecate # this helper to a regular `Context()` call. -function JuliaContext(; opaque_pointers=nothing) +function JuliaContext(; opaque_pointers = nothing) # XXX: remove - ThreadSafeContext(; opaque_pointers) + return ThreadSafeContext(; opaque_pointers) end function JuliaContext(f; kwargs...) ts_ctx = JuliaContext(; kwargs...) @@ -30,7 +30,7 @@ function JuliaContext(f; kwargs...) # rework this once we depend on Julia 1.9 or later. ctx = context(ts_ctx) activate(ctx) - try + return try f(ctx) finally deactivate(ctx) @@ -44,7 +44,7 @@ end export compile # (::CompilerJob) -const compile_hook = Ref{Union{Nothing,Function}}(nothing) +const compile_hook = Ref{Union{Nothing, Function}}(nothing) """ compile(target::Symbol, job::CompilerJob) @@ -74,11 +74,11 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); kwargs...) config = CompilerConfig(job.config; kwargs...) job = CompilerJob(job.source, config) end - compile_unhooked(output, job) + return compile_unhooked(output, job) end function compile_unhooked(output::Symbol, @nospecialize(job::CompilerJob); kwargs...) - if context(; throw_error=false) === nothing + if context(; throw_error = false) === nothing error("No active LLVM context. Use `JuliaContext()` do-block syntax to create one.") end @@ -131,10 +131,10 @@ const deferred_codegen_jobs = Dict{Int, Any}() # lazy compilation from, while also enabling recursive compilation. # see `register_deferred_codegen` function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid} - ptr + return ptr end -@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} +@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft, tt} id = length(deferred_codegen_jobs) + 1 deferred_codegen_jobs[id] = (; ft, tt) # don't bother looking up the method instance, as we'll do so again during codegen @@ -144,7 +144,7 @@ end # generated functions so use the current world counter, which may be too new # for the world we're compiling for. - quote + return quote # TODO: add an edge to this method instance to support method redefinitions ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) end @@ -154,13 +154,15 @@ end # Called from __init__ # On 1.11+ this is needed due to a Julia bug that drops the pointer when code-coverage is enabled. function register_deferred_codegen() - @dispose jljit=JuliaOJIT() begin + @dispose jljit = JuliaOJIT() begin jd = JITDylib(jljit) address = LLVM.API.LLVMOrcJITTargetAddress( - reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},)))) + reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))) + ) flags = LLVM.API.LLVMJITSymbolFlags( - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + LLVM.API.LLVMJITSymbolGenericFlagsExported, 0 + ) name = mangle(jljit, "deferred_codegen") symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) map = if LLVM.version() >= v"15" @@ -212,7 +214,7 @@ const __llvm_initialized = Ref(false) # deferred code generation has_deferred_jobs = job.config.toplevel && !job.config.only_entry && - haskey(functions(ir), "deferred_codegen") + haskey(functions(ir), "deferred_codegen") jobs = Dict{CompilerJob, String}(job => entry_fn) if has_deferred_jobs dyn_marker = functions(ir)["deferred_codegen"] @@ -252,7 +254,7 @@ const __llvm_initialized = Ref(false) dyn_entry_fn = get!(jobs, dyn_job) do target = nest_target(dyn_job.config.target, job.config.target) params = nest_params(dyn_job.config.params, job.config.params) - config = CompilerConfig(dyn_job.config; toplevel=false, target, params) + config = CompilerConfig(dyn_job.config; toplevel = false, target, params) dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config)) dyn_entry_fn = LLVM.name(dyn_meta.entry) merge!(compiled, dyn_meta.compiled) @@ -266,7 +268,7 @@ const __llvm_initialized = Ref(false) # insert a pointer to the function everywhere the entry is used T_ptr = convert(LLVMType, Ptr{Cvoid}) for call in worklist[dyn_job] - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin position!(builder, call) fptr = if LLVM.version() >= v"17" T_ptr = LLVM.PointerType() @@ -284,7 +286,7 @@ const __llvm_initialized = Ref(false) end # minimal optimization to convert the inttoptr/call into a direct call - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, InstCombinePass()) end @@ -311,9 +313,11 @@ const __llvm_initialized = Ref(false) @tracepoint "target libraries" link_libraries!(job, ir, undefined_fns) # GPU run-time library - if !uses_julia_runtime(job) && any(fn -> fn in runtime_fns || - fn in runtime_intrinsics, - undefined_fns) + if !uses_julia_runtime(job) && any( + fn -> fn in runtime_fns || + fn in runtime_intrinsics, + undefined_fns + ) @tracepoint "runtime library" link_library!(ir, runtime) end end @@ -336,10 +340,12 @@ const __llvm_initialized = Ref(false) end end if LLVM.version() >= v"17" - run!(InternalizePass(; preserved_gvs), ir, - llvm_machine(job.config.target)) + run!( + InternalizePass(; preserved_gvs), ir, + llvm_machine(job.config.target) + ) else - @dispose pm=ModulePassManager() begin + @dispose pm = ModulePassManager() begin internalize!(pm, preserved_gvs) run!(pm, ir) end @@ -355,7 +361,7 @@ const __llvm_initialized = Ref(false) # which also need to happen _after_ regular optimization. # XXX: make these part of the optimizer pipeline? if has_deferred_jobs - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, InstCombinePass()) end @@ -373,7 +379,7 @@ const __llvm_initialized = Ref(false) if job.config.cleanup @tracepoint "clean-up" begin - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, RecomputeGlobalsAAPass()) add!(pb, GlobalOptPass()) add!(pb, GlobalDCEPass()) @@ -425,8 +431,10 @@ const __llvm_initialized = Ref(false) return ir, (; entry, compiled) end -@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module, - format::LLVM.API.LLVMCodeGenFileType) +@locked function emit_asm( + @nospecialize(job::CompilerJob), ir::LLVM.Module, + format::LLVM.API.LLVMCodeGenFileType + ) # NOTE: strip after validation to get better errors if job.config.strip @tracepoint "Debug info removal" strip_debuginfo!(ir) diff --git a/src/error.jl b/src/error.jl index 0138b7b8..2b061e7f 100644 --- a/src/error.jl +++ b/src/error.jl @@ -5,11 +5,13 @@ export KernelError, InternalCompilerError struct KernelError <: Exception job::CompilerJob message::String - help::Union{Nothing,String} + help::Union{Nothing, String} bt::StackTraces.StackTrace - KernelError(@nospecialize(job::CompilerJob), message::String, help=nothing; - bt=StackTraces.StackTrace()) = + KernelError( + @nospecialize(job::CompilerJob), message::String, help = nothing; + bt = StackTraces.StackTrace() + ) = new(job, message, help, bt) end @@ -18,7 +20,7 @@ function Base.showerror(io::IO, err::KernelError) println(io, "KernelError: $(err.message)") println(io) println(io, something(err.help, "Try inspecting the generated code with any of the @device_code_... macros.")) - Base.show_backtrace(io, err.bt) + return Base.show_backtrace(io, err.bt) end @@ -30,9 +32,11 @@ struct InternalCompilerError <: Exception end function Base.showerror(io::IO, err::InternalCompilerError) - println(io, """GPUCompiler.jl encountered an unexpected internal error. - Please file an issue attaching the following information, including the backtrace, - as well as a reproducible example (if possible).""") + println( + io, """GPUCompiler.jl encountered an unexpected internal error. + Please file an issue attaching the following information, including the backtrace, + as well as a reproducible example (if possible).""" + ) println(io, "\nInternalCompilerError: $(err.message)") @@ -40,7 +44,7 @@ function Base.showerror(io::IO, err::InternalCompilerError) if !isempty(err.meta) println(io, "\nAdditional information:") - for (key,val) in err.meta + for (key, val) in err.meta println(io, " - $key = $(repr(val))") end end @@ -54,15 +58,20 @@ function Base.showerror(io::IO, err::InternalCompilerError) println(io) - let InteractiveUtils = Base.require(Base.PkgId(Base.UUID("b77e0a4c-d291-57a0-90e8-8db25a27a240"), "InteractiveUtils")) + return let InteractiveUtils = Base.require(Base.PkgId(Base.UUID("b77e0a4c-d291-57a0-90e8-8db25a27a240"), "InteractiveUtils")) InteractiveUtils.versioninfo(io) end end macro compiler_assert(ex, job, kwargs...) msg = "$ex, at $(__source__.file):$(__source__.line)" - return :($(esc(ex)) ? $(nothing) - : throw(InternalCompilerError($(esc(job)), $msg; - $(map(esc, kwargs)...))) + return :( + $(esc(ex)) ? $(nothing) + : throw( + InternalCompilerError( + $(esc(job)), $msg; + $(map(esc, kwargs)...) + ) ) + ) end diff --git a/src/execution.jl b/src/execution.jl index 9b4940a7..22dd592e 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -10,13 +10,13 @@ export split_kwargs, assign_args! # intended for use in macros; the resulting groups can be used in expressions. # can be used at run time, but not in performance critical code. function split_kwargs(kwargs, kw_groups...) - kwarg_groups = ntuple(_->[], length(kw_groups) + 1) + kwarg_groups = ntuple(_ -> [], length(kw_groups) + 1) for kwarg in kwargs # decode if Meta.isexpr(kwarg, :(=)) # use in macros key, val = kwarg.args - elseif kwarg isa Pair{Symbol,<:Any} + elseif kwarg isa Pair{Symbol, <:Any} # use in functions key, val = kwarg else @@ -111,12 +111,12 @@ You will need to restart your Julia environment for it to take effect. !!! note The cache functionality requires Julia 1.11 """ -function enable_disk_cache!(state::Bool=true) - @set_preferences!("disk_cache"=>string(state)) +function enable_disk_cache!(state::Bool = true) + return @set_preferences!("disk_cache" => string(state)) end disk_cache_path() = @get_scratch!("disk_cache") -clear_disk_cache!() = rm(disk_cache_path(); recursive=true, force=true) +clear_disk_cache!() = rm(disk_cache_path(); recursive = true, force = true) const cache_lock = ReentrantLock() @@ -133,9 +133,11 @@ and return data that can be cached across sessions (e.g., LLVM IR). This data is forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create session-dependent objects (e.g., a `CuModule`). """ -function cached_compilation(cache::AbstractDict{<:Any,V}, - src::MethodInstance, cfg::CompilerConfig, - compiler::Function, linker::Function) where {V} +function cached_compilation( + cache::AbstractDict{<:Any, V}, + src::MethodInstance, cfg::CompilerConfig, + compiler::Function, linker::Function + ) where {V} # NOTE: we index the cach both using (mi, world, cfg) keys, for the fast look-up, # and using CodeInfo keys for the slow look-up. we need to cache both for # performance, but cannot use a separate private cache for the ci->obj lookup @@ -186,7 +188,8 @@ end disk_cache_path(), # bifurcate the cache by build id of GPUCompiler string(gpucompiler_buildid), - string(h, ".jls")) + string(h, ".jls") + ) end struct DiskCacheEntry @@ -195,13 +198,15 @@ struct DiskCacheEntry asm end -@noinline function actual_compilation(cache::AbstractDict, src::MethodInstance, world::UInt, - cfg::CompilerConfig, compiler::Function, linker::Function) +@noinline function actual_compilation( + cache::AbstractDict, src::MethodInstance, world::UInt, + cfg::CompilerConfig, compiler::Function, linker::Function + ) job = CompilerJob(src, cfg, world) obj = nothing # fast path: find an applicable CodeInstance and see if we have compiled it before - ci = ci_cache_lookup(ci_cache(job), src, world, world)::Union{Nothing,CodeInstance} + ci = ci_cache_lookup(ci_cache(job), src, world, world)::Union{Nothing, CodeInstance} if ci !== nothing key = (ci, cfg) obj = get(cache, key, nothing) @@ -234,7 +239,7 @@ end @warn "Cache missmatch" src.specTypes cfg entry.src entry.cfg end catch ex - @warn "Failed to load compiled kernel" job path exception=(ex, catch_backtrace()) + @warn "Failed to load compiled kernel" job path exception = (ex, catch_backtrace()) end end end @@ -256,13 +261,13 @@ end entry = DiskCacheEntry(src.specTypes, cfg, asm) # atomic write to disk - tmppath, io = mktemp(dirname(path); cleanup=false) + tmppath, io = mktemp(dirname(path); cleanup = false) serialize(io, entry) close(io) @static if VERSION >= v"1.12.0-DEV.1023" - mv(tmppath, path; force=true) + mv(tmppath, path; force = true) else - Base.rename(tmppath, path, force=true) + Base.rename(tmppath, path, force = true) end end end @@ -272,13 +277,15 @@ end if ci === nothing ci = ci_cache_lookup(ci_cache(job), src, world, world) if ci === nothing - error("""Did not find CodeInstance for $job. + error( + """Did not find CodeInstance for $job. - Pleaase make sure that the `compiler` function passed to `cached_compilation` - invokes GPUCompiler with exactly the same configuration as passed to the API. + Pleaase make sure that the `compiler` function passed to `cached_compilation` + invokes GPUCompiler with exactly the same configuration as passed to the API. - Note that you should do this by calling `GPUCompiler.compile`, and not by - using reflection functions (which alter the compiler configuration).""") + Note that you should do this by calling `GPUCompiler.compile`, and not by + using reflection functions (which alter the compiler configuration).""" + ) end key = (ci, cfg) end diff --git a/src/gcn.jl b/src/gcn.jl index 146d9a33..aa2bee8b 100644 --- a/src/gcn.jl +++ b/src/gcn.jl @@ -6,9 +6,9 @@ export GCNCompilerTarget Base.@kwdef struct GCNCompilerTarget <: AbstractCompilerTarget dev_isa::String - features::String="" + features::String = "" end -GCNCompilerTarget(dev_isa; features="") = GCNCompilerTarget(dev_isa, features) +GCNCompilerTarget(dev_isa; features = "") = GCNCompilerTarget(dev_isa, features) llvm_triple(::GCNCompilerTarget) = "amdgcn-amd-amdhsa" @@ -19,7 +19,7 @@ function llvm_machine(target::GCNCompilerTarget) return nothing end triple = llvm_triple(target) - t = Target(triple=triple) + t = Target(triple = triple) cpu = target.dev_isa feat = target.features @@ -40,8 +40,10 @@ runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev const gcn_intrinsics = () # TODO: ("vprintf", "__assertfail", "malloc", "free") isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsics) -function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}), - mod::LLVM.Module, entry::LLVM.Function) +function finish_module!( + @nospecialize(job::CompilerJob{GCNCompilerTarget}), + mod::LLVM.Module, entry::LLVM.Function + ) lower_throw_extra!(mod) if job.config.kernel @@ -63,53 +65,53 @@ function lower_throw_extra!(mod::LLVM.Module) changed = false @tracepoint "lower throw (extra)" begin - throw_functions = [ - r"julia_bounds_error.*", - r"julia_throw_boundserror.*", - r"julia_error_if_canonical_getindex.*", - r"julia_error_if_canonical_setindex.*", - r"julia___subarray_throw_boundserror.*", - ] - - for f in functions(mod) - f_name = LLVM.name(f) - for fn in throw_functions - if occursin(fn, f_name) - for use in uses(f) - call = user(use)::LLVM.CallInst - - # replace the throw with a trap - @dispose builder=IRBuilder() begin - position!(builder, call) - emit_exception!(builder, f_name, call) - end - - # remove the call - nargs = length(parameters(f)) - call_args = arguments(call) - erase!(call) - - # HACK: kill the exceptions' unused arguments - for arg in call_args - # peek through casts - if isa(arg, LLVM.AddrSpaceCastInst) - cast = arg - arg = first(operands(cast)) - isempty(uses(cast)) && erase!(cast) + throw_functions = [ + r"julia_bounds_error.*", + r"julia_throw_boundserror.*", + r"julia_error_if_canonical_getindex.*", + r"julia_error_if_canonical_setindex.*", + r"julia___subarray_throw_boundserror.*", + ] + + for f in functions(mod) + f_name = LLVM.name(f) + for fn in throw_functions + if occursin(fn, f_name) + for use in uses(f) + call = user(use)::LLVM.CallInst + + # replace the throw with a trap + @dispose builder = IRBuilder() begin + position!(builder, call) + emit_exception!(builder, f_name, call) end - if isa(arg, LLVM.Instruction) && isempty(uses(arg)) - erase!(arg) + # remove the call + nargs = length(parameters(f)) + call_args = arguments(call) + erase!(call) + + # HACK: kill the exceptions' unused arguments + for arg in call_args + # peek through casts + if isa(arg, LLVM.AddrSpaceCastInst) + cast = arg + arg = first(operands(cast)) + isempty(uses(cast)) && erase!(cast) + end + + if isa(arg, LLVM.Instruction) && isempty(uses(arg)) + erase!(arg) + end end + + changed = true end - changed = true + @compiler_assert isempty(uses(f)) job end - - @compiler_assert isempty(uses(f)) job end end - end end return changed @@ -122,7 +124,7 @@ function emit_trap!(job::CompilerJob{GCNCompilerTarget}, builder, mod, inst) else LLVM.Function(mod, "llvm.trap", trap_ft) end - call!(builder, trap_ft, trap) + return call!(builder, trap_ft, trap) end can_vectorize(job::CompilerJob{GCNCompilerTarget}) = true diff --git a/src/interface.jl b/src/interface.jl index 21ddcf57..f7ef50a6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -27,7 +27,7 @@ llvm_triple(@nospecialize(target::AbstractCompilerTarget)) = error("Not implemen function llvm_machine(@nospecialize(target::AbstractCompilerTarget)) triple = llvm_triple(target) - t = Target(triple=triple) + t = Target(triple = triple) tm = TargetMachine(t, triple) asm_verbosity!(tm, true) @@ -41,7 +41,7 @@ llvm_datalayout(target::AbstractCompilerTarget) = DataLayout(llvm_machine(target function julia_datalayout(@nospecialize(target::AbstractCompilerTarget)) dl = llvm_datalayout(target) dl === nothing && return nothing - DataLayout(string(dl) * "-ni:10:11:12:13") + return DataLayout(string(dl) * "-ni:10:11:12:13") end have_fma(@nospecialize(target::AbstractCompilerTarget), T::Type) = false @@ -68,8 +68,10 @@ export CompilerConfig # the configuration of the compiler -const CONFIG_KWARGS = [:kernel, :name, :entry_abi, :always_inline, :opt_level, - :libraries, :optimize, :cleanup, :validate, :strip] +const CONFIG_KWARGS = [ + :kernel, :name, :entry_abi, :always_inline, :opt_level, + :libraries, :optimize, :cleanup, :validate, :strip, +] """ CompilerConfig(target, params; kernel=true, entry_abi=:specfunc, name=nothing, @@ -102,12 +104,12 @@ Several keyword arguments can be used to customize the compilation process: - `validate`: enable optional validation of input and outputs (default: true) - `strip`: strip non-functional metadata and debug information (default: false) """ -struct CompilerConfig{T,P} +struct CompilerConfig{T, P} target::T params::P kernel::Bool - name::Union{Nothing,String} + name::Union{Nothing, String} entry_abi::Symbol always_inline::Bool opt_level::Int @@ -121,27 +123,33 @@ struct CompilerConfig{T,P} toplevel::Bool only_entry::Bool - function CompilerConfig(target::AbstractCompilerTarget, params::AbstractCompilerParams; - kernel=true, name=nothing, entry_abi=:specfunc, toplevel=true, - always_inline=false, opt_level=2, optimize=toplevel, - libraries=toplevel, cleanup=toplevel, validate=toplevel, - strip=false, only_entry=false) + function CompilerConfig( + target::AbstractCompilerTarget, params::AbstractCompilerParams; + kernel = true, name = nothing, entry_abi = :specfunc, toplevel = true, + always_inline = false, opt_level = 2, optimize = toplevel, + libraries = toplevel, cleanup = toplevel, validate = toplevel, + strip = false, only_entry = false + ) if entry_abi ∉ (:specfunc, :func) error("Unknown entry_abi=$entry_abi") end - new{typeof(target), typeof(params)}(target, params, kernel, name, entry_abi, - always_inline, opt_level, libraries, optimize, - cleanup, validate, strip, toplevel, only_entry) + return new{typeof(target), typeof(params)}( + target, params, kernel, name, entry_abi, + always_inline, opt_level, libraries, optimize, + cleanup, validate, strip, toplevel, only_entry + ) end end # copy constructor -function CompilerConfig(cfg::CompilerConfig; target=cfg.target, params=cfg.params, - kernel=cfg.kernel, name=cfg.name, entry_abi=cfg.entry_abi, - always_inline=cfg.always_inline, opt_level=cfg.opt_level, - libraries=cfg.libraries, optimize=cfg.optimize, cleanup=cfg.cleanup, - validate=cfg.validate, strip=cfg.strip, toplevel=cfg.toplevel, - only_entry=cfg.only_entry) +function CompilerConfig( + cfg::CompilerConfig; target = cfg.target, params = cfg.params, + kernel = cfg.kernel, name = cfg.name, entry_abi = cfg.entry_abi, + always_inline = cfg.always_inline, opt_level = cfg.opt_level, + libraries = cfg.libraries, optimize = cfg.optimize, cleanup = cfg.cleanup, + validate = cfg.validate, strip = cfg.strip, toplevel = cfg.toplevel, + only_entry = cfg.only_entry + ) # deriving a non-toplevel job disables certain features # XXX: should we keep track if any of these were set explicitly in the first place? # see how PkgEval does that. @@ -151,12 +159,14 @@ function CompilerConfig(cfg::CompilerConfig; target=cfg.target, params=cfg.param cleanup = false validate = false end - CompilerConfig(target, params; kernel, entry_abi, name, always_inline, opt_level, - libraries, optimize, cleanup, validate, strip, toplevel, only_entry) + return CompilerConfig( + target, params; kernel, entry_abi, name, always_inline, opt_level, + libraries, optimize, cleanup, validate, strip, toplevel, only_entry + ) end function Base.show(io::IO, @nospecialize(cfg::CompilerConfig{T})) where {T} - print(io, "CompilerConfig for ", T) + return print(io, "CompilerConfig for ", T) end function Base.hash(cfg::CompilerConfig, h::UInt) @@ -194,18 +204,20 @@ using Core: MethodInstance Construct a `CompilerJob` that will be used to drive compilation for the given `source` and `config` in a given `world`. """ -struct CompilerJob{T,P} +struct CompilerJob{T, P} source::MethodInstance - config::CompilerConfig{T,P} + config::CompilerConfig{T, P} world::UInt - CompilerJob(source::MethodInstance, config::CompilerConfig{T,P}, - world=tls_world_age()) where {T,P} = - new{T,P}(source, config, world) + CompilerJob( + source::MethodInstance, config::CompilerConfig{T, P}, + world = tls_world_age() + ) where {T, P} = + new{T, P}(source, config, world) end # copy constructor -CompilerJob(job::CompilerJob; source=job.source, config=job.config, world=job.world) = +CompilerJob(job::CompilerJob; source = job.source, config = job.config, world = job.world) = CompilerJob(source, config, world) function Base.hash(job::CompilerJob, h::UInt) @@ -222,6 +234,10 @@ end # Has the runtime available and does not require special handling uses_julia_runtime(@nospecialize(job::CompilerJob)) = false +# Should we emit code in imaging mode (i.e. without embedding concrete runtime addresses)? +imaging_mode(@nospecialize(job::CompilerJob)) = imaging_mode(job.config.target) +imaging_mode(@nospecialize(target::AbstractCompilerTarget)) = false + # Is it legal to run vectorization passes on this target can_vectorize(@nospecialize(job::CompilerJob)) = false @@ -238,15 +254,19 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false # provide a specific interpreter to use. if VERSION >= v"1.11.0-DEV.1552" -get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)), - token=ci_cache_token(job), inf_params=inference_params(job), - opt_params=optimization_params(job)) + get_interpreter(@nospecialize(job::CompilerJob)) = + GPUInterpreter( + job.world; method_table_view = maybe_cached(method_table_view(job)), + token = ci_cache_token(job), inf_params = inference_params(job), + opt_params = optimization_params(job) + ) else -get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)), - code_cache=ci_cache(job), inf_params=inference_params(job), - opt_params=optimization_params(job)) + get_interpreter(@nospecialize(job::CompilerJob)) = + GPUInterpreter( + job.world; method_table_view = maybe_cached(method_table_view(job)), + code_cache = ci_cache(job), inf_params = inference_params(job), + opt_params = optimization_params(job) + ) end # does this target support throwing Julia exceptions with jl_throw? @@ -295,15 +315,15 @@ if VERSION >= v"1.11.0-DEV.1552" # Soft deprecated user should use `CC.code_cache(get_interpreter(job))` ci_cache(@nospecialize(job::CompilerJob)) = CC.code_cache(get_interpreter(job)) else -function ci_cache(@nospecialize(job::CompilerJob)) - lock(GLOBAL_CI_CACHES_LOCK) do - cache = get!(GLOBAL_CI_CACHES, job.config) do - CodeCache() + function ci_cache(@nospecialize(job::CompilerJob)) + return lock(GLOBAL_CI_CACHES_LOCK) do + cache = get!(GLOBAL_CI_CACHES, job.config) do + CodeCache() + end + return cache end - return cache end end -end # the method table to use # deprecate method_table on next-breaking release @@ -312,10 +332,10 @@ method_table_view(@nospecialize(job::CompilerJob)) = get_method_table_view(job.w # the inference parameters to use when constructing the GPUInterpreter function inference_params(@nospecialize(job::CompilerJob)) - if VERSION >= v"1.12.0-DEV.1017" + return if VERSION >= v"1.12.0-DEV.1017" CC.InferenceParams() else - CC.InferenceParams(; unoptimize_throw_blocks=false) + CC.InferenceParams(; unoptimize_throw_blocks = false) end end @@ -324,15 +344,15 @@ function optimization_params(@nospecialize(job::CompilerJob)) kwargs = NamedTuple() if job.config.always_inline - kwargs = (kwargs..., inline_cost_threshold=Int(CC.MAX_INLINE_COST)) + kwargs = (kwargs..., inline_cost_threshold = Int(CC.MAX_INLINE_COST)) end - return CC.OptimizationParams(;kwargs...) + return CC.OptimizationParams(; kwargs...) end # how much debuginfo to emit function llvm_debug_info(@nospecialize(job::CompilerJob)) - if Base.JLOptions().debug_level == 0 + return if Base.JLOptions().debug_level == 0 LLVM.API.LLVMDebugEmissionKindNoDebug elseif Base.JLOptions().debug_level == 1 LLVM.API.LLVMDebugEmissionKindLineTablesOnly @@ -350,8 +370,10 @@ prepare_job!(@nospecialize(job::CompilerJob)) = return # early extension point used to link-in external bitcode files. # this is typically used by downstream packages to link vendor libraries. -link_libraries!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - undefined_fns::Vector{String}) = return +link_libraries!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + undefined_fns::Vector{String} +) = return # finalization of the module, before deferred codegen and optimization finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function) = diff --git a/src/irgen.jl b/src/irgen.jl index a7c36a60..86437700 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -15,12 +15,16 @@ function irgen(@nospecialize(job::CompilerJob)) for llvmf in functions(mod) if Base.isdebugbuild() # only occurs in debug builds - delete!(function_attributes(llvmf), - EnumAttribute("sspstrong", 0)) + delete!( + function_attributes(llvmf), + EnumAttribute("sspstrong", 0) + ) end - delete!(function_attributes(llvmf), - StringAttribute("probe-stack", "inline-asm")) + delete!( + function_attributes(llvmf), + StringAttribute("probe-stack", "inline-asm") + ) if Sys.iswindows() personality!(llvmf, nothing) @@ -102,13 +106,13 @@ function irgen(@nospecialize(job::CompilerJob)) push!(preserved_gvs, LLVM.name(gvar)) end if LLVM.version() >= v"17" - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, InternalizePass(; preserved_gvs)) add!(pb, AlwaysInlinerPass()) run!(pb, mod, llvm_machine(job.config.target)) end else - @dispose pm=ModulePassManager() begin + @dispose pm = ModulePassManager() begin internalize!(pm, preserved_gvs) always_inliner!(pm) run!(pm, mod) @@ -138,64 +142,64 @@ function lower_throw!(mod::LLVM.Module) changed = false @tracepoint "lower throw" begin - throw_functions = [ - # unsupported runtime functions that are used to throw specific exceptions - "jl_throw" => "exception", - "jl_error" => "error", - "jl_too_few_args" => "too few arguments exception", - "jl_too_many_args" => "too many arguments exception", - "jl_type_error" => "type error", - "jl_type_error_rt" => "type error", - "jl_undefined_var_error" => "undefined variable error", - "jl_bounds_error" => "bounds error", - "jl_bounds_error_v" => "bounds error", - "jl_bounds_error_int" => "bounds error", - "jl_bounds_error_tuple_int" => "bounds error", - "jl_bounds_error_unboxed_int" => "bounds error", - "jl_bounds_error_ints" => "bounds error", - "jl_eof_error" => "EOF error", - ] - - for f in functions(mod) - fn = LLVM.name(f) - for (throw_fn, name) in throw_functions - occursin(throw_fn, fn) || continue - - for use in uses(f) - call = user(use)::LLVM.CallInst + throw_functions = [ + # unsupported runtime functions that are used to throw specific exceptions + "jl_throw" => "exception", + "jl_error" => "error", + "jl_too_few_args" => "too few arguments exception", + "jl_too_many_args" => "too many arguments exception", + "jl_type_error" => "type error", + "jl_type_error_rt" => "type error", + "jl_undefined_var_error" => "undefined variable error", + "jl_bounds_error" => "bounds error", + "jl_bounds_error_v" => "bounds error", + "jl_bounds_error_int" => "bounds error", + "jl_bounds_error_tuple_int" => "bounds error", + "jl_bounds_error_unboxed_int" => "bounds error", + "jl_bounds_error_ints" => "bounds error", + "jl_eof_error" => "EOF error", + ] + + for f in functions(mod) + fn = LLVM.name(f) + for (throw_fn, name) in throw_functions + occursin(throw_fn, fn) || continue + + for use in uses(f) + call = user(use)::LLVM.CallInst + + # replace the throw with a PTX-compatible exception + @dispose builder = IRBuilder() begin + position!(builder, call) + emit_exception!(builder, name, call) + end - # replace the throw with a PTX-compatible exception - @dispose builder=IRBuilder() begin - position!(builder, call) - emit_exception!(builder, name, call) - end + # remove the call + call_args = arguments(call) + erase!(call) + + # HACK: kill the exceptions' unused arguments + # this is needed for throwing objects with @nospecialize constructors. + for arg in call_args + # peek through casts + if isa(arg, LLVM.AddrSpaceCastInst) + cast = arg + arg = first(operands(cast)) + isempty(uses(cast)) && erase!(cast) + end - # remove the call - call_args = arguments(call) - erase!(call) - - # HACK: kill the exceptions' unused arguments - # this is needed for throwing objects with @nospecialize constructors. - for arg in call_args - # peek through casts - if isa(arg, LLVM.AddrSpaceCastInst) - cast = arg - arg = first(operands(cast)) - isempty(uses(cast)) && erase!(cast) + if isa(arg, LLVM.Instruction) && isempty(uses(arg)) + erase!(arg) + end end - if isa(arg, LLVM.Instruction) && isempty(uses(arg)) - erase!(arg) - end + changed = true end - changed = true + @compiler_assert isempty(uses(f)) job + break end - - @compiler_assert isempty(uses(f)) job - break - end - end + end end return changed @@ -227,7 +231,7 @@ function emit_exception!(builder, name, inst) rt = Runtime.get(:report_exception_frame) ft = convert(LLVM.FunctionType, rt) bt = backtrace(inst) - for (i,frame) in enumerate(bt) + for (i, frame) in enumerate(bt) idx = ConstantInt(parameters(ft)[1], i) func = globalstring_ptr!(builder, String(frame.func), "di_func") file = globalstring_ptr!(builder, String(frame.file), "di_file") @@ -239,7 +243,7 @@ function emit_exception!(builder, name, inst) # signal the exception call!(builder, Runtime.get(:signal_exception)) - emit_trap!(job, builder, mod, inst) + return emit_trap!(job, builder, mod, inst) end function emit_trap!(@nospecialize(job::CompilerJob), builder, mod, inst) @@ -249,7 +253,7 @@ function emit_trap!(@nospecialize(job::CompilerJob), builder, mod, inst) else LLVM.Function(mod, "llvm.trap", trap_ft) end - call!(builder, trap_ft, trap) + return call!(builder, trap_ft, trap) end @@ -271,8 +275,10 @@ end # - `name`: the name of the argument # - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument # is not passed at the LLVM level. -function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType; - post_optimization::Bool=false) +function classify_arguments( + @nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType; + post_optimization::Bool = false + ) source_sig = job.source.specTypes source_types = [source_sig.parameters...] @@ -286,7 +292,7 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu if post_optimization && kernel_state_type(job) !== Nothing args = [] - push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1)) + push!(args, (cc = KERNEL_STATE, typ = kernel_state_type(job), name = :kernel_state, idx = 1)) codegen_i = 2 else args = [] @@ -294,31 +300,31 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu end for (source_typ, source_name) in zip(source_types, source_argnames) if isghosttype(source_typ) || Core.Compiler.isconstType(source_typ) - push!(args, (cc=GHOST, typ=source_typ, name=source_name, idx=nothing)) + push!(args, (cc = GHOST, typ = source_typ, name = source_name, idx = nothing)) continue end codegen_typ = codegen_types[codegen_i] if codegen_typ isa LLVM.PointerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) # pointers are used for multiple kinds of arguments # - literal pointer values if source_typ <: Ptr || source_typ <: Core.LLVMPtr @assert llvm_source_typ == codegen_typ - push!(args, (cc=BITS_VALUE, typ=source_typ, name=source_name, idx=codegen_i)) - # - boxed values - # XXX: use `deserves_retbox` instead? + push!(args, (cc = BITS_VALUE, typ = source_typ, name = source_name, idx = codegen_i)) + # - boxed values + # XXX: use `deserves_retbox` instead? elseif llvm_source_typ isa LLVM.PointerType @assert llvm_source_typ == codegen_typ - push!(args, (cc=MUT_REF, typ=source_typ, name=source_name, idx=codegen_i)) - # - references to aggregates + push!(args, (cc = MUT_REF, typ = source_typ, name = source_name, idx = codegen_i)) + # - references to aggregates else @assert llvm_source_typ != codegen_typ - push!(args, (cc=BITS_REF, typ=source_typ, name=source_name, idx=codegen_i)) + push!(args, (cc = BITS_REF, typ = source_typ, name = source_name, idx = codegen_i)) end else - push!(args, (cc=BITS_VALUE, typ=source_typ, name=source_name, idx=codegen_i)) + push!(args, (cc = BITS_VALUE, typ = source_typ, name = source_name, idx = codegen_i)) end codegen_i += 1 @@ -328,7 +334,7 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu end function is_immutable_datatype(T::Type) - isa(T,DataType) && !Base.ismutabletype(T) + return isa(T, DataType) && !Base.ismutabletype(T) end function is_inlinealloc(T::Type) @@ -336,7 +342,7 @@ function is_inlinealloc(T::Type) # FIXME: To simple if mayinlinealloc if !Base.datatype_pointerfree(T) - t_name(dt::DataType)=dt.name + t_name(dt::DataType) = dt.name if t_name(T).n_uninitialized != 0 return false end @@ -347,7 +353,7 @@ function is_inlinealloc(T::Type) end function is_concrete_immutable(T::Type) - is_immutable_datatype(T) && T.layout !== C_NULL + return is_immutable_datatype(T) && T.layout !== C_NULL end function is_pointerfree(T::Type) @@ -367,8 +373,8 @@ end deserves_argbox(T) = !deserves_stack(T) deserves_retbox(T) = deserves_argbox(T) function deserves_sret(T, llvmT) - @assert isa(T,DataType) - sizeof(T) > sizeof(Ptr{Cvoid}) && !isa(llvmT, LLVM.FloatingPointType) && !isa(llvmT, LLVM.VectorType) + @assert isa(T, DataType) + return sizeof(T) > sizeof(Ptr{Cvoid}) && !isa(llvmT, LLVM.FloatingPointType) && !isa(llvmT, LLVM.VectorType) end @@ -380,113 +386,115 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM. ft = function_type(f) @tracepoint "lower byval" begin - # find the byval parameters - byval = BitVector(undef, length(parameters(ft))) - types = Vector{LLVMType}(undef, length(parameters(ft))) - for i in 1:length(byval) - byval[i] = false - for attr in collect(parameter_attributes(f, i)) - if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType())) - byval[i] = true - types[i] = value(attr) + # find the byval parameters + byval = BitVector(undef, length(parameters(ft))) + types = Vector{LLVMType}(undef, length(parameters(ft))) + for i in 1:length(byval) + byval[i] = false + for attr in collect(parameter_attributes(f, i)) + if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType())) + byval[i] = true + types[i] = value(attr) + end end end - end - # fixup metadata - # - # Julia emits invariant.load and const TBAA metadata on loads from pointer args, - # which is invalid now that we have materialized the byval. - for (i, param) in enumerate(parameters(f)) - if byval[i] - # collect all uses of the argument - worklist = Vector{LLVM.Instruction}(user.(collect(uses(param)))) - while !isempty(worklist) - value = popfirst!(worklist) - - # remove the invariant.load attribute - md = metadata(value) - if haskey(md, LLVM.MD_invariant_load) - delete!(md, LLVM.MD_invariant_load) - end - if haskey(md, LLVM.MD_tbaa) - delete!(md, LLVM.MD_tbaa) - end + # fixup metadata + # + # Julia emits invariant.load and const TBAA metadata on loads from pointer args, + # which is invalid now that we have materialized the byval. + for (i, param) in enumerate(parameters(f)) + if byval[i] + # collect all uses of the argument + worklist = Vector{LLVM.Instruction}(user.(collect(uses(param)))) + while !isempty(worklist) + value = popfirst!(worklist) + + # remove the invariant.load attribute + md = metadata(value) + if haskey(md, LLVM.MD_invariant_load) + delete!(md, LLVM.MD_invariant_load) + end + if haskey(md, LLVM.MD_tbaa) + delete!(md, LLVM.MD_tbaa) + end - # recurse on the output of some instructions - if isa(value, LLVM.BitCastInst) || - isa(value, LLVM.GetElementPtrInst) || - isa(value, LLVM.AddrSpaceCastInst) - append!(worklist, user.(collect(uses(value)))) + # recurse on the output of some instructions + if isa(value, LLVM.BitCastInst) || + isa(value, LLVM.GetElementPtrInst) || + isa(value, LLVM.AddrSpaceCastInst) + append!(worklist, user.(collect(uses(value)))) + end end end end - end - - # generate the new function type & definition - new_types = LLVM.LLVMType[] - for (i, param) in enumerate(parameters(ft)) - if byval[i] - llvm_typ = convert(LLVMType, types[i]) - push!(new_types, llvm_typ) - else - push!(new_types, param) - end - end - new_ft = LLVM.FunctionType(return_type(ft), new_types) - new_f = LLVM.Function(mod, "", new_ft) - linkage!(new_f, linkage(f)) - for (arg, new_arg) in zip(parameters(f), parameters(new_f)) - LLVM.name!(new_arg, LLVM.name(arg)) - end - # emit IR performing the "conversions" - new_args = LLVM.Value[] - @dispose builder=IRBuilder() begin - entry = BasicBlock(new_f, "conversion") - position!(builder, entry) - - # perform argument conversions + # generate the new function type & definition + new_types = LLVM.LLVMType[] for (i, param) in enumerate(parameters(ft)) if byval[i] - # copy the argument value to a stack slot, and reference it. llvm_typ = convert(LLVMType, types[i]) - ptr = alloca!(builder, llvm_typ) - if LLVM.addrspace(param) != 0 - ptr = addrspacecast!(builder, ptr, param) - end - store!(builder, parameters(new_f)[i], ptr) - push!(new_args, ptr) + push!(new_types, llvm_typ) else - push!(new_args, parameters(new_f)[i]) - for attr in collect(parameter_attributes(f, i)) - push!(parameter_attributes(new_f, i), attr) - end + push!(new_types, param) end end + new_ft = LLVM.FunctionType(return_type(ft), new_types) + new_f = LLVM.Function(mod, "", new_ft) + linkage!(new_f, linkage(f)) + for (arg, new_arg) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_arg, LLVM.name(arg)) + end - # map the arguments - value_map = Dict{LLVM.Value, LLVM.Value}( - param => new_args[i] for (i,param) in enumerate(parameters(f)) - ) + # emit IR performing the "conversions" + new_args = LLVM.Value[] + @dispose builder = IRBuilder() begin + entry = BasicBlock(new_f, "conversion") + position!(builder, entry) - value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + # perform argument conversions + for (i, param) in enumerate(parameters(ft)) + if byval[i] + # copy the argument value to a stack slot, and reference it. + llvm_typ = convert(LLVMType, types[i]) + ptr = alloca!(builder, llvm_typ) + if LLVM.addrspace(param) != 0 + ptr = addrspacecast!(builder, ptr, param) + end + store!(builder, parameters(new_f)[i], ptr) + push!(new_args, ptr) + else + push!(new_args, parameters(new_f)[i]) + for attr in collect(parameter_attributes(f, i)) + push!(parameter_attributes(new_f, i), attr) + end + end + end - # fall through - br!(builder, blocks(new_f)[2]) - end + # map the arguments + value_map = Dict{LLVM.Value, LLVM.Value}( + param => new_args[i] for (i, param) in enumerate(parameters(f)) + ) - # remove the old function - # NOTE: if we ever have legitimate uses of the old function, create a shim instead - fn = LLVM.name(f) - @assert isempty(uses(f)) - replace_metadata_uses!(f, new_f) - erase!(f) - LLVM.name!(new_f, fn) + value_map[f] = new_f + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges + ) + + # fall through + br!(builder, blocks(new_f)[2]) + end + + # remove the old function + # NOTE: if we ever have legitimate uses of the old function, create a shim instead + fn = LLVM.name(f) + @assert isempty(uses(f)) + replace_metadata_uses!(f, new_f) + erase!(f) + LLVM.name!(new_f, fn) - return new_f + return new_f end end @@ -538,7 +546,7 @@ function add_kernel_state!(mod::LLVM.Module) worklist_length = length(worklist) additions = LLVM.Function[] function check_user(val) - if val isa Instruction + return if val isa Instruction bb = LLVM.parent(val) new_f = LLVM.parent(bb) in(new_f, worklist) || push!(additions, new_f) @@ -624,8 +632,10 @@ function add_kernel_state!(mod::LLVM.Module) # rewrite references to the old function merge!(value_map, workmap) - clone_into!(new_f, f; value_map, materializer, - changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + clone_into!( + new_f, f; value_map, materializer, + changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges + ) # remove the function IR so that we won't have any uses left after this pass. empty!(f) @@ -642,7 +652,7 @@ function add_kernel_state!(mod::LLVM.Module) # update uses of the new function, modifying call sites to include the kernel state function rewrite_uses!(f, ft) # update uses - @dispose builder=IRBuilder() begin + return @dispose builder = IRBuilder() begin for use in uses(f) val = user(use) if val isa LLVM.CallBase && called_operand(val) == f @@ -678,16 +688,22 @@ function add_kernel_state!(mod::LLVM.Module) # XXX: we won't have to do this with opaque pointers. position!(builder, val) target_ft = called_type(val) - new_args = map(zip(parameters(target_ft), - arguments(val))) do (param_typ, arg) + new_args = map( + zip( + parameters(target_ft), + arguments(val) + ) + ) do (param_typ, arg) if value_type(arg) != param_typ const_bitcast(arg, param_typ) else arg end end - new_val = call!(builder, called_type(val), called_operand(val), new_args, - operand_bundles(val)) + new_val = call!( + builder, called_type(val), called_operand(val), new_args, + operand_bundles(val) + ) callconv!(new_val, callconv(val)) replace_uses!(val, new_val) @@ -731,7 +747,7 @@ function lower_kernel_state!(fun::LLVM.Function) state_intr = functions(mod)["julia.gpu.state_getter"] state_arg = nothing # only look-up when needed - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin for use in uses(state_intr) inst = user(use) @assert inst isa LLVM.CallInst @@ -795,7 +811,7 @@ end # run-time equivalent function kernel_state_value(state) - @dispose ctx=Context() begin + return @dispose ctx = Context() begin T_state = convert(LLVMType, state) # create function @@ -807,7 +823,7 @@ function kernel_state_value(state) state_intr_ft = function_type(state_intr) # generate IR - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry = BasicBlock(llvm_f, "entry") position!(builder, entry) @@ -825,8 +841,10 @@ end # the kernel state argument is always passed by value to avoid codegen issues with byval. # some back-ends however do not support passing kernel arguments by value, so this pass # serves to convert that argument (and is conceptually the inverse of `lower_byval`). -function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - f::LLVM.Function) +function kernel_state_to_reference!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + f::LLVM.Function + ) ft = function_type(f) # check if we even need a kernel state argument @@ -864,7 +882,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M # emit IR performing the "conversions" new_args = LLVM.Value[] - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry = BasicBlock(new_f, "conversion") position!(builder, entry) @@ -879,12 +897,14 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M # map the arguments value_map = Dict{LLVM.Value, LLVM.Value}( - param => new_args[i] for (i,param) in enumerate(parameters(f)) + param => new_args[i] for (i, param) in enumerate(parameters(f)) ) value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges + ) # fall through br!(builder, blocks(new_f)[2]) @@ -907,7 +927,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M LLVM.name!(new_f, fn) # minimal optimization - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, SimplifyCFGPass()) run!(pb, new_f, llvm_machine(job.config.target)) end @@ -916,8 +936,10 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M end end -function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - entry::LLVM.Function, kernel_intrinsics::Dict) +function add_input_arguments!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + entry::LLVM.Function, kernel_intrinsics::Dict + ) entry_fn = LLVM.name(entry) # figure out which intrinsics are used and need to be added as arguments @@ -949,6 +971,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, error("Don't know how to check uses of $candidate. Please file an issue.") end end + return end for f in worklist scan_uses(f) @@ -981,7 +1004,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, for (arg, new_arg) in zip(parameters(f), parameters(new_f)) LLVM.name!(new_arg, LLVM.name(arg)) end - for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end]) + for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[(end - nargs + 1):end]) LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name) end @@ -999,8 +1022,10 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, end value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly) + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly + ) # we can't remove this function yet, as we might still need to rewrite any called, # but remove the IR already @@ -1016,7 +1041,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # update other uses of the old function, modifying call sites to pass the arguments function rewrite_uses!(f, new_f) # update uses - @dispose builder=IRBuilder() begin + return @dispose builder = IRBuilder() begin for use in uses(f) val = user(use) if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst @@ -1024,9 +1049,11 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # forward the arguments position!(builder, val) new_val = if val isa LLVM.CallInst - call!(builder, function_type(new_f), new_f, - [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], - operand_bundles(val)) + call!( + builder, function_type(new_f), new_f, + [arguments(val)..., parameters(callee_f)[(end - nargs + 1):end]...], + operand_bundles(val) + ) else # TODO: invoke and callbr error("Rewrite of $(typeof(val))-based calls is not implemented: $val") @@ -1070,7 +1097,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, val = user(use) callee_f = LLVM.parent(LLVM.parent(val)) if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst - replace_uses!(val, parameters(callee_f)[end-nargs+i]) + replace_uses!(val, parameters(callee_f)[end - nargs + i]) else error("Cannot rewrite unknown use of function: $val") end diff --git a/src/jlgen.jl b/src/jlgen.jl index 330cf7f8..f37e2ebd 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -24,7 +24,7 @@ end # create a MethodError from a function type # TODO: fix upstream function unsafe_function_from_type(ft::Type) - if isdefined(ft, :instance) + return if isdefined(ft, :instance) ft.instance else # HACK: dealing with a closure or something... let's do somthing really invalid, @@ -33,14 +33,14 @@ function unsafe_function_from_type(ft::Type) end end global MethodError -function MethodError(ft::Type{<:Function}, tt::Type, world::Integer=typemax(UInt)) - Base.MethodError(unsafe_function_from_type(ft), tt, world) +function MethodError(ft::Type{<:Function}, tt::Type, world::Integer = typemax(UInt)) + return Base.MethodError(unsafe_function_from_type(ft), tt, world) end -MethodError(ft, tt, world=typemax(UInt)) = Base.MethodError(ft, tt, world) +MethodError(ft, tt, world = typemax(UInt)) = Base.MethodError(ft, tt, world) # generate a LineInfoNode for the current source code location macro LineInfoNode(method) - Core.LineInfoNode(__module__, method, __source__.file, Int32(__source__.line), Int32(0)) + return Core.LineInfoNode(__module__, method, __source__.file, Int32(__source__.line), Int32(0)) end """ @@ -60,8 +60,10 @@ pass at run time. For non-concrete signatures, use `generic_methodinstance` inst """ methodinstance -function generic_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), - world::Integer=tls_world_age()) +function generic_methodinstance( + @nospecialize(ft::Type), @nospecialize(tt::Type), + world::Integer = tls_world_age() + ) sig = signature_type_by_tt(ft, tt) match, _ = CC._findsup(sig, nothing, world) @@ -76,85 +78,93 @@ end # Julia's cached method lookup to simply look up method instances at run time. @static if VERSION >= v"1.11.0-DEV.1552" -# XXX: version of Base.method_instance that uses a function type -@inline function methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), - world::Integer=tls_world_age()) - sig = signature_type_by_tt(ft, tt) - @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233 - - mi = ccall(:jl_method_lookup_by_tt, Any, - (Any, Csize_t, Any), - sig, world, #=method_table=# nothing) - mi === nothing && throw(MethodError(ft, tt, world)) - mi = mi::MethodInstance + # XXX: version of Base.method_instance that uses a function type + @inline function methodinstance( + @nospecialize(ft::Type), @nospecialize(tt::Type), + world::Integer = tls_world_age() + ) + sig = signature_type_by_tt(ft, tt) + @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233 + + mi = ccall( + :jl_method_lookup_by_tt, Any, + (Any, Csize_t, Any), + sig, world, #=method_table=# nothing + ) + mi === nothing && throw(MethodError(ft, tt, world)) + mi = mi::MethodInstance + + # `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi + if !Base.isdispatchtuple(mi.specTypes) + mi = CC.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance + end - # `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi - if !Base.isdispatchtuple(mi.specTypes) - mi = CC.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance + return mi end - return mi -end - -# on older versions of Julia, we always need to use the generic lookup + # on older versions of Julia, we always need to use the generic lookup else -const methodinstance = generic_methodinstance - -function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type) - @nospecialize - @assert CC.isType(ft) && CC.isType(tt) - ft = ft.parameters[1] - tt = tt.parameters[1] - - stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) - - # look up the method match - method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft, tt.parameters...} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, #=mt=# nothing, world, min_world, max_world) - match === nothing && return stub(world, source, method_error) - - # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) - ci = CC.retrieve_code_info(mi, world) - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - empty!(new_ci.codelocs) - empty!(new_ci.linetable) - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - - # propagate edge metadata - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - new_ci.edges = MethodInstance[mi] - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:3] - - # return the method instance - push!(new_ci.code, CC.ReturnNode(mi)) - push!(new_ci.ssaflags, 0x00) - push!(new_ci.linetable, @LineInfoNode(methodinstance)) - push!(new_ci.codelocs, 1) - new_ci.ssavaluetypes += 1 - - return new_ci -end + const methodinstance = generic_methodinstance + + function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type) + @nospecialize + @assert CC.isType(ft) && CC.isType(tt) + ft = ft.parameters[1] + tt = tt.parameters[1] + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) + + # look up the method match + method_error = :(throw(MethodError(ft, tt, $world))) + sig = Tuple{ft, tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + match = ccall( + :jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, #=mt=# nothing, world, min_world, max_world + ) + match === nothing && return stub(world, source, method_error) + + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, Ref{MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams + ) + ci = CC.retrieve_code_info(mi, world) + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + empty!(new_ci.codelocs) + empty!(new_ci.linetable) + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + + # propagate edge metadata + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + new_ci.edges = MethodInstance[mi] + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i in 1:3] + + # return the method instance + push!(new_ci.code, CC.ReturnNode(mi)) + push!(new_ci.ssaflags, 0x00) + push!(new_ci.linetable, @LineInfoNode(methodinstance)) + push!(new_ci.codelocs, 1) + new_ci.ssavaluetypes += 1 + + return new_ci + end -@eval function methodinstance(ft, tt) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, :generated, methodinstance_generator)) -end + @eval function methodinstance(ft, tt) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, methodinstance_generator)) + end end @@ -163,133 +173,135 @@ end const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552" if !HAS_INTEGRATED_CACHE -struct CodeCache - dict::IdDict{MethodInstance,Vector{CodeInstance}} + struct CodeCache + dict::IdDict{MethodInstance, Vector{CodeInstance}} - CodeCache() = new(IdDict{MethodInstance,Vector{CodeInstance}}()) -end + CodeCache() = new(IdDict{MethodInstance, Vector{CodeInstance}}()) + end -function Base.show(io::IO, ::MIME"text/plain", cc::CodeCache) - print(io, "CodeCache with $(mapreduce(length, +, values(cc.dict); init=0)) entries") - if !isempty(cc.dict) - print(io, ": ") - for (mi, cis) in cc.dict - println(io) - print(io, " ") - show(io, mi) - - function worldstr(min_world, max_world) - if min_world == typemax(UInt) - "empty world range" - elseif max_world == typemax(UInt) - "worlds $(Int(min_world))+" - else - "worlds $(Int(min_world)) to $(Int(max_world))" + function Base.show(io::IO, ::MIME"text/plain", cc::CodeCache) + print(io, "CodeCache with $(mapreduce(length, +, values(cc.dict); init = 0)) entries") + return if !isempty(cc.dict) + print(io, ": ") + for (mi, cis) in cc.dict + println(io) + print(io, " ") + show(io, mi) + + function worldstr(min_world, max_world) + return if min_world == typemax(UInt) + "empty world range" + elseif max_world == typemax(UInt) + "worlds $(Int(min_world))+" + else + "worlds $(Int(min_world)) to $(Int(max_world))" + end end - end - for (i,ci) in enumerate(cis) - println(io) - print(io, " CodeInstance for ", worldstr(ci.min_world, ci.max_world)) + for (i, ci) in enumerate(cis) + println(io) + print(io, " CodeInstance for ", worldstr(ci.min_world, ci.max_world)) + end end end end -end -Base.empty!(cc::CodeCache) = empty!(cc.dict) + Base.empty!(cc::CodeCache) = empty!(cc.dict) -const GLOBAL_CI_CACHES = Dict{CompilerConfig, CodeCache}() -const GLOBAL_CI_CACHES_LOCK = ReentrantLock() + const GLOBAL_CI_CACHES = Dict{CompilerConfig, CodeCache}() + const GLOBAL_CI_CACHES_LOCK = ReentrantLock() -## method invalidations + ## method invalidations -function CC.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance) - # make sure the invalidation callback is attached to the method instance - add_codecache_callback!(cache, mi) - cis = get!(cache.dict, mi, CodeInstance[]) - push!(cis, ci) -end + function CC.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance) + # make sure the invalidation callback is attached to the method instance + add_codecache_callback!(cache, mi) + cis = get!(cache.dict, mi, CodeInstance[]) + return push!(cis, ci) + end -# invalidation (like invalidate_method_instance, but for our cache) -struct CodeCacheCallback - cache::CodeCache -end + # invalidation (like invalidate_method_instance, but for our cache) + struct CodeCacheCallback + cache::CodeCache + end -@static if VERSION ≥ v"1.11.0-DEV.798" + @static if VERSION ≥ v"1.11.0-DEV.798" -function add_codecache_callback!(cache::CodeCache, mi::MethodInstance) - callback = CodeCacheCallback(cache) - CC.add_invalidation_callback!(callback, mi) -end -function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32) - cis = get(callback.cache.dict, replaced, nothing) - if cis === nothing - return - end - for ci in cis - if ci.max_world == ~0 % Csize_t - @assert ci.min_world - 1 <= max_world "attempting to set illogical constraints" -@static if VERSION >= v"1.11.0-DEV.1390" - @atomic ci.max_world = max_world -else - ci.max_world = max_world -end + function add_codecache_callback!(cache::CodeCache, mi::MethodInstance) + callback = CodeCacheCallback(cache) + CC.add_invalidation_callback!(callback, mi) + end + function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32) + cis = get(callback.cache.dict, replaced, nothing) + if cis === nothing + return + end + for ci in cis + if ci.max_world == ~0 % Csize_t + @assert ci.min_world - 1 <= max_world "attempting to set illogical constraints" + @static if VERSION >= v"1.11.0-DEV.1390" + @atomic ci.max_world = max_world + else + ci.max_world = max_world + end + end + @assert ci.max_world <= max_world + end end - @assert ci.max_world <= max_world - end -end -else + else -function add_codecache_callback!(cache::CodeCache, mi::MethodInstance) - callback = CodeCacheCallback(cache) - if !isdefined(mi, :callbacks) - mi.callbacks = Any[callback] - elseif !in(callback, mi.callbacks) - push!(mi.callbacks, callback) - end -end -function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32, - seen::Set{MethodInstance}=Set{MethodInstance}()) - push!(seen, replaced) - - cis = get(callback.cache.dict, replaced, nothing) - if cis === nothing - return - end - for ci in cis - if ci.max_world == ~0 % Csize_t - @assert ci.min_world - 1 <= max_world "attempting to set illogical constraints" - ci.max_world = max_world - end - @assert ci.max_world <= max_world - end - - # recurse to all backedges to update their valid range also - if isdefined(replaced, :backedges) - backedges = filter(replaced.backedges) do @nospecialize(mi) - if mi isa MethodInstance - mi ∉ seen - elseif mi isa Type - # an `invoke` call, which is a `(sig, MethodInstance)` pair. - # let's ignore the `sig` and process the `MethodInstance` next. - false - else - error("invalid backedge") + function add_codecache_callback!(cache::CodeCache, mi::MethodInstance) + callback = CodeCacheCallback(cache) + if !isdefined(mi, :callbacks) + mi.callbacks = Any[callback] + elseif !in(callback, mi.callbacks) + push!(mi.callbacks, callback) end end + function (callback::CodeCacheCallback)( + replaced::MethodInstance, max_world::UInt32, + seen::Set{MethodInstance} = Set{MethodInstance}() + ) + push!(seen, replaced) - # Don't touch/empty backedges `invalidate_method_instance` in C will do that later - # replaced.backedges = Any[] + cis = get(callback.cache.dict, replaced, nothing) + if cis === nothing + return + end + for ci in cis + if ci.max_world == ~0 % Csize_t + @assert ci.min_world - 1 <= max_world "attempting to set illogical constraints" + ci.max_world = max_world + end + @assert ci.max_world <= max_world + end - for mi in backedges - callback(mi::MethodInstance, max_world, seen) + # recurse to all backedges to update their valid range also + if isdefined(replaced, :backedges) + backedges = filter(replaced.backedges) do @nospecialize(mi) + if mi isa MethodInstance + mi ∉ seen + elseif mi isa Type + # an `invoke` call, which is a `(sig, MethodInstance)` pair. + # let's ignore the `sig` and process the `MethodInstance` next. + false + else + error("invalid backedge") + end + end + + # Don't touch/empty backedges `invalidate_method_instance` in C will do that later + # replaced.backedges = Any[] + + for mi in backedges + callback(mi::MethodInstance, max_world, seen) + end + end end - end -end -end + end end # !HAS_INTEGRATED_CACHE @@ -300,7 +312,7 @@ Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE) # Implements a priority lookup for method tables, where the first match in the stack get's returned. # An alternative to this would be to use a "Union" where we would query the parent method table and # do a most-specific match. -struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView +struct StackedMethodTable{MTV <: CC.MethodTableView} <: CC.MethodTableView world::UInt mt::Core.MethodTable parent::MTV @@ -313,7 +325,7 @@ CC.isoverlayed(::StackedMethodTable) = true @static if VERSION >= v"1.11.0-DEV.363" # https://github.com/JuliaLang/julia/pull/51078 # same API as before but without returning isoverlayed flag - function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) + function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int = -1) result = CC._findall(sig, table.mt, table.world, limit) result === nothing && return nothing # to many matches nr = CC.length(result) @@ -330,8 +342,10 @@ CC.isoverlayed(::StackedMethodTable) = true CC.vcat(result.matches, parent_result.matches), CC.WorldRange( CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world), - CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)), - result.ambig | parent_result.ambig) + CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world) + ), + result.ambig | parent_result.ambig + ) end function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) @@ -342,11 +356,12 @@ CC.isoverlayed(::StackedMethodTable) = true parent_match, CC.WorldRange( max(valid_worlds.min_world, parent_valid_worlds.min_world), - min(valid_worlds.max_world, parent_valid_worlds.max_world)) - ) + min(valid_worlds.max_world, parent_valid_worlds.max_world) + ), + ) end else - function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) + function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int = -1) result = CC._findall(sig, table.mt, table.world, limit) result === nothing && return nothing # to many matches nr = CC.length(result) @@ -363,13 +378,16 @@ else # merge the parent match results with the internal method table return CC.MethodMatchResult( - CC.MethodLookupResult( - CC.vcat(result.matches, parent_result.matches), - CC.WorldRange( - CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world), - CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)), - result.ambig | parent_result.ambig), - overlayed) + CC.MethodLookupResult( + CC.vcat(result.matches, parent_result.matches), + CC.WorldRange( + CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world), + CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world) + ), + result.ambig | parent_result.ambig + ), + overlayed + ) end function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) @@ -380,8 +398,10 @@ else parent_match, CC.WorldRange( max(valid_worlds.min_world, parent_valid_worlds.min_world), - min(valid_worlds.max_world, parent_valid_worlds.max_world)), - overlayed) + min(valid_worlds.max_world, parent_valid_worlds.max_world) + ), + overlayed, + ) end end @@ -404,15 +424,15 @@ end get_method_table_view(world::UInt, mt::CC.MethodTable) = CC.OverlayMethodTable(world, mt) -struct GPUInterpreter{MTV<:CC.MethodTableView} <: CC.AbstractInterpreter +struct GPUInterpreter{MTV <: CC.MethodTableView} <: CC.AbstractInterpreter world::UInt method_table_view::MTV -@static if HAS_INTEGRATED_CACHE - token::Any -else - code_cache::CodeCache -end + @static if HAS_INTEGRATED_CACHE + token::Any + else + code_cache::CodeCache + end inf_cache::Vector{CC.InferenceResult} inf_params::CC.InferenceParams @@ -420,59 +440,75 @@ end end @static if HAS_INTEGRATED_CACHE -function GPUInterpreter(world::UInt=Base.get_world_counter(); - method_table_view::CC.MethodTableView, - token::Any, - inf_params::CC.InferenceParams, - opt_params::CC.OptimizationParams) - @assert world <= Base.get_world_counter() - - inf_cache = Vector{CC.InferenceResult}() - - return GPUInterpreter(world, method_table_view, - token, inf_cache, - inf_params, opt_params) -end + function GPUInterpreter( + world::UInt = Base.get_world_counter(); + method_table_view::CC.MethodTableView, + token::Any, + inf_params::CC.InferenceParams, + opt_params::CC.OptimizationParams + ) + @assert world <= Base.get_world_counter() + + inf_cache = Vector{CC.InferenceResult}() + + return GPUInterpreter( + world, method_table_view, + token, inf_cache, + inf_params, opt_params + ) + end -function GPUInterpreter(interp::GPUInterpreter; - world::UInt=interp.world, - method_table_view::CC.MethodTableView=interp.method_table_view, - token::Any=interp.token, - inf_cache::Vector{CC.InferenceResult}=interp.inf_cache, - inf_params::CC.InferenceParams=interp.inf_params, - opt_params::CC.OptimizationParams=interp.opt_params) - return GPUInterpreter(world, method_table_view, - token, inf_cache, - inf_params, opt_params) -end + function GPUInterpreter( + interp::GPUInterpreter; + world::UInt = interp.world, + method_table_view::CC.MethodTableView = interp.method_table_view, + token::Any = interp.token, + inf_cache::Vector{CC.InferenceResult} = interp.inf_cache, + inf_params::CC.InferenceParams = interp.inf_params, + opt_params::CC.OptimizationParams = interp.opt_params + ) + return GPUInterpreter( + world, method_table_view, + token, inf_cache, + inf_params, opt_params + ) + end else -function GPUInterpreter(world::UInt=Base.get_world_counter(); - method_table_view::CC.MethodTableView, - code_cache::CodeCache, - inf_params::CC.InferenceParams, - opt_params::CC.OptimizationParams) - @assert world <= Base.get_world_counter() - - inf_cache = Vector{CC.InferenceResult}() - - return GPUInterpreter(world, method_table_view, - code_cache, inf_cache, - inf_params, opt_params) -end + function GPUInterpreter( + world::UInt = Base.get_world_counter(); + method_table_view::CC.MethodTableView, + code_cache::CodeCache, + inf_params::CC.InferenceParams, + opt_params::CC.OptimizationParams + ) + @assert world <= Base.get_world_counter() + + inf_cache = Vector{CC.InferenceResult}() + + return GPUInterpreter( + world, method_table_view, + code_cache, inf_cache, + inf_params, opt_params + ) + end -function GPUInterpreter(interp::GPUInterpreter; - world::UInt=interp.world, - method_table_view::CC.MethodTableView=interp.method_table_view, - code_cache::CodeCache=interp.code_cache, - inf_cache::Vector{CC.InferenceResult}=interp.inf_cache, - inf_params::CC.InferenceParams=interp.inf_params, - opt_params::CC.OptimizationParams=interp.opt_params) - return GPUInterpreter(world, method_table_view, - code_cache, inf_cache, - inf_params, opt_params) -end + function GPUInterpreter( + interp::GPUInterpreter; + world::UInt = interp.world, + method_table_view::CC.MethodTableView = interp.method_table_view, + code_cache::CodeCache = interp.code_cache, + inf_cache::Vector{CC.InferenceResult} = interp.inf_cache, + inf_params::CC.InferenceParams = interp.inf_params, + opt_params::CC.OptimizationParams = interp.opt_params + ) + return GPUInterpreter( + world, method_table_view, + code_cache, inf_cache, + inf_params, opt_params + ) + end end # HAS_INTEGRATED_CACHE CC.InferenceParams(interp::GPUInterpreter) = interp.inf_params @@ -490,33 +526,41 @@ CC.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing CC.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing function CC.add_remark!(interp::GPUInterpreter, sv::CC.InferenceState, msg) - @safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" + return @safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" end CC.may_optimize(interp::GPUInterpreter) = true CC.may_compress(interp::GPUInterpreter) = true CC.may_discard_trees(interp::GPUInterpreter) = true @static if VERSION <= v"1.12.0-DEV.1531" -CC.verbose_stmt_info(interp::GPUInterpreter) = false + CC.verbose_stmt_info(interp::GPUInterpreter) = false end CC.method_table(interp::GPUInterpreter) = interp.method_table_view # semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349) -function CC.concrete_eval_eligible(interp::GPUInterpreter, - @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) +function CC.concrete_eval_eligible( + interp::GPUInterpreter, + @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState + ) # NOTE it's fine to skip overloading with `sv::IRInterpretationState` since we disables # semi-concrete interpretation anyway. - ret = @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, - f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) + ret = @invoke CC.concrete_eval_eligible( + interp::CC.AbstractInterpreter, + f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState + ) if ret === :semi_concrete_eval return :none end return ret end -function CC.concrete_eval_eligible(interp::GPUInterpreter, - @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo) - ret = @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, - f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo) +function CC.concrete_eval_eligible( + interp::GPUInterpreter, + @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo + ) + ret = @invoke CC.concrete_eval_eligible( + interp::CC.AbstractInterpreter, + f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo + ) ret === false && return nothing return ret end @@ -527,37 +571,39 @@ using Core.Compiler: WorldView if !HAS_INTEGRATED_CACHE -function CC.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance) - CC.get(wvc, mi, nothing) !== nothing -end + function CC.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance) + return CC.get(wvc, mi, nothing) !== nothing + end -function CC.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default) - # check the cache - for ci in get!(wvc.cache.dict, mi, CodeInstance[]) - if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world - # TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code))) - src = if ci.inferred isa Vector{UInt8} - ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), - mi.def, C_NULL, ci.inferred) - else - ci.inferred + function CC.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default) + # check the cache + for ci in get!(wvc.cache.dict, mi, CodeInstance[]) + if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world + # TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code))) + src = if ci.inferred isa Vector{UInt8} + ccall( + :jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), + mi.def, C_NULL, ci.inferred + ) + else + ci.inferred + end + return ci end - return ci end - end - return default -end + return default + end -function CC.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance) - r = CC.get(wvc, mi, nothing) - r === nothing && throw(KeyError(mi)) - return r::CodeInstance -end + function CC.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance) + r = CC.get(wvc, mi, nothing) + r === nothing && throw(KeyError(mi)) + return r::CodeInstance + end -function CC.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance) - CC.setindex!(wvc.cache, ci, mi) -end + function CC.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance) + return CC.setindex!(wvc.cache, ci, mi) + end end # HAS_INTEGRATED_CACHE @@ -664,7 +710,7 @@ const _method_instances = Ref{Any}() const _cache = Ref{Any}() function _lookup_fun(mi, min_world, max_world) push!(_method_instances[], mi) - ci_cache_lookup(_cache[], mi, min_world, max_world) + return ci_cache_lookup(_cache[], mi, min_world, max_world) end @enum CompilationPolicy::Cint begin @@ -706,7 +752,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) if Sys.ARCH == :x86 || Sys.ARCH == :x86_64 function lookup_fun(mi, min_world, max_world) push!(method_instances, mi) - ci_cache_lookup(cache, mi, min_world, max_world) + return ci_cache_lookup(cache, mi, min_world, max_world) end lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt)) else @@ -717,19 +763,25 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # set-up the compiler interface debug_info_kind = llvm_debug_info(job) + imaging = imaging_mode(job) + cgparams = (; - track_allocations = false, - code_coverage = false, - prefer_specsig = true, - gnu_pubnames = false, - debug_info_kind = Cint(debug_info_kind), + track_allocations = false, + code_coverage = false, + prefer_specsig = true, + gnu_pubnames = false, + debug_info_kind = Cint(debug_info_kind), safepoint_on_entry = can_safepoint(job), - gcstack_arg = false) + gcstack_arg = false, + ) + if :use_jlplt in fieldnames(Base.CodegenParams) + cgparams = (; cgparams..., use_jlplt = imaging) + end if VERSION < v"1.12.0-DEV.1667" - cgparams = (; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb), cgparams... ) + cgparams = (; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb), cgparams...) end if v"1.12.0-DEV.2126" <= VERSION < v"1.13-" || VERSION >= v"1.13.0-DEV.285" - cgparams = (; force_emit_all = true , cgparams...) + cgparams = (; force_emit_all = true, cgparams...) end params = Base.CodegenParams(; cgparams...) @@ -748,6 +800,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) Metadata(ConstantInt(DEBUG_METADATA_VERSION())) end + imaging_flag = imaging ? 1 : 0 + native_code = if VERSION >= v"1.12.0-DEV.1823" codeinfos = Any[] for (ci, src) in populated @@ -758,19 +812,25 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) end @ccall jl_emit_native(codeinfos::Vector{Any}, ts_mod::LLVM.API.LLVMOrcThreadSafeModuleRef, Ref(params)::Ptr{Base.CodegenParams}, #=extern linkage=# false::Cint)::Ptr{Cvoid} elseif VERSION >= v"1.12.0-DEV.1667" - ccall(:jl_create_native, Ptr{Cvoid}, + ccall( + :jl_create_native, Ptr{Cvoid}, (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint, Cint, Cint, Csize_t, Ptr{Cvoid}), - [job.source], ts_mod, Ref(params), CompilationPolicyExtern, #=imaging mode=# 0, #=external linkage=# 0, job.world, Base.unsafe_convert(Ptr{Nothing}, lookup_cb)) + [job.source], ts_mod, Ref(params), CompilationPolicyExtern, imaging_flag, #=external linkage=# 0, job.world, Base.unsafe_convert(Ptr{Nothing}, lookup_cb) + ) else - ccall(:jl_create_native, Ptr{Cvoid}, + ccall( + :jl_create_native, Ptr{Cvoid}, (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint, Cint, Cint, Csize_t), - [job.source], ts_mod, Ref(params), CompilationPolicyExtern, #=imaging mode=# 0, #=external linkage=# 0, job.world) + [job.source], ts_mod, Ref(params), CompilationPolicyExtern, imaging_flag, #=external linkage=# 0, job.world + ) end @assert native_code != C_NULL llvm_mod_ref = - ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef, - (Ptr{Cvoid},), native_code) + ccall( + :jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef, + (Ptr{Cvoid},), native_code + ) @assert llvm_mod_ref != C_NULL # XXX: this is wrong; we can't expose the underlying LLVM module, but should @@ -790,14 +850,20 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # Since Julia 1.13, the caller is responsible for initializing global variables that # point to global values or bindings with their address in memory. num_gvars = Ref{Csize_t}(0) - @ccall jl_get_llvm_gvs(native_code::Ptr{Cvoid}, num_gvars::Ptr{Csize_t}, - C_NULL::Ptr{Cvoid})::Nothing + @ccall jl_get_llvm_gvs( + native_code::Ptr{Cvoid}, num_gvars::Ptr{Csize_t}, + C_NULL::Ptr{Cvoid} + )::Nothing gvs = Vector{Ptr{LLVM.API.LLVMOpaqueValue}}(undef, num_gvars[]) - @ccall jl_get_llvm_gvs(native_code::Ptr{Cvoid}, num_gvars::Ptr{Csize_t}, - gvs::Ptr{LLVM.API.LLVMOpaqueValue})::Nothing + @ccall jl_get_llvm_gvs( + native_code::Ptr{Cvoid}, num_gvars::Ptr{Csize_t}, + gvs::Ptr{LLVM.API.LLVMOpaqueValue} + )::Nothing inits = Vector{Ptr{Cvoid}}(undef, num_gvars[]) - @ccall jl_get_llvm_gv_inits(native_code::Ptr{Cvoid}, num_gvars::Ptr{Csize_t}, - inits::Ptr{Cvoid})::Nothing + @ccall jl_get_llvm_gv_inits( + native_code::Ptr{Cvoid}, num_gvars::Ptr{Csize_t}, + inits::Ptr{Cvoid} + )::Nothing for (gv_ref, init) in zip(gvs, inits) gv = GlobalVariable(gv_ref) @@ -812,11 +878,15 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # lookup function (used to populate method_instances) isn't always called then. num_cis = Ref{Csize_t}(0) - @ccall jl_get_llvm_cis(native_code::Ptr{Cvoid}, num_cis::Ptr{Csize_t}, - C_NULL::Ptr{Cvoid})::Nothing + @ccall jl_get_llvm_cis( + native_code::Ptr{Cvoid}, num_cis::Ptr{Csize_t}, + C_NULL::Ptr{Cvoid} + )::Nothing resize!(method_instances, num_cis[]) - @ccall jl_get_llvm_cis(native_code::Ptr{Cvoid}, num_cis::Ptr{Csize_t}, - method_instances::Ptr{Cvoid})::Nothing + @ccall jl_get_llvm_cis( + native_code::Ptr{Cvoid}, num_cis::Ptr{Csize_t}, + method_instances::Ptr{Cvoid} + )::Nothing for (i, ci) in enumerate(method_instances) method_instances[i] = ci.def::MethodInstance @@ -826,11 +896,15 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # slightly older versions of Julia used MIs directly num_mis = Ref{Csize_t}(0) - @ccall jl_get_llvm_mis(native_code::Ptr{Cvoid}, num_mis::Ptr{Csize_t}, - C_NULL::Ptr{Cvoid})::Nothing + @ccall jl_get_llvm_mis( + native_code::Ptr{Cvoid}, num_mis::Ptr{Csize_t}, + C_NULL::Ptr{Cvoid} + )::Nothing resize!(method_instances, num_mis[]) - @ccall jl_get_llvm_mis(native_code::Ptr{Cvoid}, num_mis::Ptr{Csize_t}, - method_instances::Ptr{Cvoid})::Nothing + @ccall jl_get_llvm_mis( + native_code::Ptr{Cvoid}, num_mis::Ptr{Csize_t}, + method_instances::Ptr{Cvoid} + )::Nothing end # process all compiled method instances @@ -842,15 +916,19 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # get the function index llvm_func_idx = Ref{Int32}(-1) llvm_specfunc_idx = Ref{Int32}(-1) - ccall(:jl_get_function_id, Nothing, - (Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}), - native_code, ci, llvm_func_idx, llvm_specfunc_idx) + ccall( + :jl_get_function_id, Nothing, + (Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}), + native_code, ci, llvm_func_idx, llvm_specfunc_idx + ) @assert llvm_func_idx[] != -1 || llvm_specfunc_idx[] != -1 "Static compilation failed" # get the function llvm_func = if llvm_func_idx[] >= 1 - llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef, - (Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1) + llvm_func_ref = ccall( + :jl_get_llvm_function, LLVM.API.LLVMValueRef, + (Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[] - 1 + ) @assert llvm_func_ref != C_NULL LLVM.name(LLVM.Function(llvm_func_ref)) else @@ -858,8 +936,10 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) end llvm_specfunc = if llvm_specfunc_idx[] >= 1 - llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef, - (Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1) + llvm_specfunc_ref = ccall( + :jl_get_llvm_function, LLVM.API.LLVMValueRef, + (Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[] - 1 + ) @assert llvm_specfunc_ref != C_NULL LLVM.name(LLVM.Function(llvm_specfunc_ref)) else @@ -868,7 +948,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # NOTE: it's not safe to store raw LLVM functions here, since those may get # removed or renamed during optimization, so we store their name instead. - compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc) + compiled[mi] = (; ci, func = llvm_func, specfunc = llvm_specfunc) end # ensure that the requested method instance was compiled @@ -879,15 +959,15 @@ end # partially revert JuliaLangjulia#49391 @static if v"1.11.0-DEV.1603" <= VERSION < v"1.12.0-DEV.347" && # reverted on master - !(v"1.11-beta2" <= VERSION < v"1.12") # reverted on 1.11-beta2 -function CC.typeinf(interp::GPUInterpreter, frame::CC.InferenceState) - if CC.__measure_typeinf__[] - CC.Timings.enter_new_timer(frame) - v = CC._typeinf(interp, frame) - CC.Timings.exit_current_timer(frame) - return v - else - return CC._typeinf(interp, frame) + !(v"1.11-beta2" <= VERSION < v"1.12") # reverted on 1.11-beta2 + function CC.typeinf(interp::GPUInterpreter, frame::CC.InferenceState) + if CC.__measure_typeinf__[] + CC.Timings.enter_new_timer(frame) + v = CC._typeinf(interp, frame) + CC.Timings.exit_current_timer(frame) + return v + else + return CC._typeinf(interp, frame) + end end end -end diff --git a/src/mangling.jl b/src/mangling.jl index a461c354..fc1bd9bd 100644 --- a/src/mangling.jl +++ b/src/mangling.jl @@ -5,7 +5,7 @@ # LLVM doesn't like names with special characters, so we need to sanitize them. # note that we are stricter than LLVM, because of `ptxas`. -safe_name(fn::String) = replace(fn, r"[^A-Za-z0-9]"=>"_") +safe_name(fn::String) = replace(fn, r"[^A-Za-z0-9]" => "_") safe_name(t::DataType) = safe_name(String(nameof(t))) function safe_name(t::Type{<:Function}) @@ -21,7 +21,7 @@ function safe_name(t::Type{<:Function}) mt.name end end - safe_name(string(fn)) + return safe_name(string(fn)) end safe_name(::Type{Union{}}) = "Bottom" @@ -43,7 +43,7 @@ function mangle_param(t, substitutions = Any[], top = false) elseif sub == 1 "S_" else - seq_id = uppercase(string(sub-2; base=36)) + seq_id = uppercase(string(sub - 2; base = 36)) "S$(seq_id)_" end return res @@ -55,7 +55,7 @@ function mangle_param(t, substitutions = Any[], top = false) return str end - if isa(t, DataType) && t <: Ptr + return if isa(t, DataType) && t <: Ptr tn = mangle_param(eltype(t), substitutions) push!(substitutions, t) "P$tn" @@ -129,34 +129,34 @@ function mangle_param(t, substitutions = Any[], top = false) elseif isa(t, Char) mangle_param(UInt32(t), substitutions) elseif isa(t, Union{Bool, Cchar, Cuchar, Cshort, Cushort, Cint, Cuint, Clong, Culong, Clonglong, Culonglong, Int128, UInt128}) - ts = t isa Bool ? 'b' : # bool - t isa Cchar ? 'a' : # signed char - t isa Cuchar ? 'h' : # unsigned char - t isa Cshort ? 's' : # short - t isa Cushort ? 't' : # unsigned short - t isa Cint ? 'i' : # int - t isa Cuint ? 'j' : # unsigned int - t isa Clong ? 'l' : # long - t isa Culong ? 'm' : # unsigned long - t isa Clonglong ? 'x' : # long long, __int64 - t isa Culonglong ? 'y' : # unsigned long long, __int64 - t isa Int128 ? 'n' : # __int128 - t isa UInt128 ? 'o' : # unsigned __int128 - error("Invalid type") - tn = string(abs(t), base=10) + ts = t isa Bool ? 'b' : # bool + t isa Cchar ? 'a' : # signed char + t isa Cuchar ? 'h' : # unsigned char + t isa Cshort ? 's' : # short + t isa Cushort ? 't' : # unsigned short + t isa Cint ? 'i' : # int + t isa Cuint ? 'j' : # unsigned int + t isa Clong ? 'l' : # long + t isa Culong ? 'm' : # unsigned long + t isa Clonglong ? 'x' : # long long, __int64 + t isa Culonglong ? 'y' : # unsigned long long, __int64 + t isa Int128 ? 'n' : # __int128 + t isa UInt128 ? 'o' : # unsigned __int128 + error("Invalid type") + tn = string(abs(t), base = 10) # for legibility, encode Julia-native integers as C-native integers, if possible if t isa Int && typemin(Cint) <= t <= typemax(Cint) ts = 'i' end if t < 0 - tn = 'n'*tn + tn = 'n' * tn end "L$(ts)$(tn)E" elseif t isa Float32 - bits = string(reinterpret(UInt32, t); base=16) + bits = string(reinterpret(UInt32, t); base = 16) "Lf$(bits)E" elseif t isa Float64 - bits = string(reinterpret(UInt64, t); base=16) + bits = string(reinterpret(UInt64, t); base = 16) "Ld$(bits)E" else tn = safe_name(t) # TODO: actually does support digits... diff --git a/src/mcgen.jl b/src/mcgen.jl index 77a40d85..9a0f01b8 100644 --- a/src/mcgen.jl +++ b/src/mcgen.jl @@ -6,7 +6,7 @@ function prepare_execution!(@nospecialize(job::CompilerJob), mod::LLVM.Module) global current_job current_job = job - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin register!(pb, ResolveCPUReferencesPass()) add!(pb, RecomputeGlobalsAAPass()) @@ -56,7 +56,7 @@ function resolve_cpu_references!(mod::LLVM.Module) changed = true end end - changed + return changed end changed |= replace_bindings!(f) @@ -69,7 +69,7 @@ ResolveCPUReferencesPass() = NewPMModulePass("ResolveCPUReferences", resolve_cpu_references!) -function mcgen(@nospecialize(job::CompilerJob), mod::LLVM.Module, format=LLVM.API.LLVMAssemblyFile) +function mcgen(@nospecialize(job::CompilerJob), mod::LLVM.Module, format = LLVM.API.LLVMAssemblyFile) tm = llvm_machine(job.config.target) return String(emit(tm, mod, format)) diff --git a/src/metal.jl b/src/metal.jl index d3a83d61..d04d5cf8 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -13,12 +13,12 @@ end # for backwards compatibility MetalCompilerTarget(macos::VersionNumber) = - MetalCompilerTarget(; macos, air=v"2.4", metal=v"2.4") + MetalCompilerTarget(; macos, air = v"2.4", metal = v"2.4") function Base.hash(target::MetalCompilerTarget, h::UInt) h = hash(target.macos, h) h = hash(target.air, h) - h = hash(target.metal, h) + return h = hash(target.metal, h) end source_code(target::MetalCompilerTarget) = "text" @@ -29,10 +29,10 @@ llvm_machine(::MetalCompilerTarget) = nothing llvm_triple(target::MetalCompilerTarget) = "air64-apple-macosx$(target.macos)" llvm_datalayout(target::MetalCompilerTarget) = - "e-p:64:64:64"* - "-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64"* - "-f32:32:32-f64:64:64"* - "-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024"* + "e-p:64:64:64" * + "-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64" * + "-f32:32:32-f64:64:64" * + "-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024" * "-n8:16:32" pass_by_value(job::CompilerJob{MetalCompilerTarget}) = false @@ -58,10 +58,12 @@ function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarge # possible to 'query' these in device code, relying on LLVM to optimize the checks away # and generate static code. note that we only do so if there's actual uses of these # variables; unconditionally creating a gvar would result in duplicate declarations. - for (name, value) in ["air_major" => job.config.target.air.major, - "air_minor" => job.config.target.air.minor, - "metal_major" => job.config.target.metal.major, - "metal_minor" => job.config.target.metal.minor] + for (name, value) in [ + "air_major" => job.config.target.air.major, + "air_minor" => job.config.target.air.minor, + "metal_major" => job.config.target.metal.major, + "metal_minor" => job.config.target.metal.minor, + ] if haskey(globals(mod), name) gv = globals(mod)[name] initializer!(gv, ConstantInt(LLVM.Int32Type(), value)) @@ -75,7 +77,7 @@ function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarge # we emit properties (of the air and metal version) as private global constants, # so run the optimizer so that they are inlined before the rest of the optimizer runs. - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, RecomputeGlobalsAAPass()) add!(pb, GlobalOptPass()) run!(pb, mod) @@ -115,7 +117,7 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) # Metal never supports 128-bit integers append!(errors, check_ir_values(mod, LLVM.IntType(128))) - errors + return errors end # hide `noreturn` function attributes, which cause issues with the back-end compiler, @@ -140,7 +142,7 @@ function hide_noreturn!(mod::LLVM.Module) end any_noreturn || return false - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, AlwaysInlinerPass()) add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, SimplifyCFGPass()) @@ -152,8 +154,10 @@ function hide_noreturn!(mod::LLVM.Module) return true end -function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, - entry::LLVM.Function) +function finish_ir!( + @nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, + entry::LLVM.Function + ) entry_fn = LLVM.name(entry) # convert the kernel state argument to a reference @@ -190,7 +194,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L end if changed # lowering may have introduced additional functions marked `alwaysinline` - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, AlwaysInlinerPass()) add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, SimplifyCFGPass()) @@ -203,7 +207,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L # perform codegen passes that would normally run during machine code emission if LLVM.has_oldpm() # XXX: codegen passes don't seem available in the new pass manager yet - @dispose pm=ModulePassManager() begin + @dispose pm = ModulePassManager() begin expand_reductions!(pm) run!(pm, mod) end @@ -212,8 +216,10 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L return functions(mod)[entry_fn] end -@unlocked function mcgen(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module, - format=LLVM.API.LLVMObjectFile) +@unlocked function mcgen( + job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module, + format = LLVM.API.LLVMObjectFile + ) # our LLVM version does not support emitting Metal libraries return nothing end @@ -230,13 +236,15 @@ end # NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to # be executed after optimization (where Julia's address spaces are stripped). If we ever # want to execute it earlier, adapt remapType to rewrite all pointer types. -function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - f::LLVM.Function) +function add_parameter_address_spaces!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + f::LLVM.Function + ) ft = function_type(f) # find the byref parameters byref = BitVector(undef, length(parameters(ft))) - args = classify_arguments(job, ft; post_optimization=job.config.optimize) + args = classify_arguments(job, ft; post_optimization = job.config.optimize) filter!(args) do arg arg.cc != GHOST end @@ -283,7 +291,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV # so instead, we load the arguments in stack slots and dereference them so that we can # keep on using the original IR that assumed pointers without address spaces new_args = LLVM.Value[] - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry = BasicBlock(new_f, "conversion") position!(builder, entry) @@ -306,12 +314,14 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV # map the arguments value_map = Dict{LLVM.Value, LLVM.Value}( - param => new_args[i] for (i,param) in enumerate(parameters(f)) + param => new_args[i] for (i, param) in enumerate(parameters(f)) ) value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges + ) # fall through br!(builder, blocks(new_f)[2]) @@ -326,7 +336,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV LLVM.name!(new_f, fn) # clean-up after this pass (which runs after optimization) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, SimplifyCFGPass()) add!(pb, SROAPass()) add!(pb, EarlyCSEPass()) @@ -342,8 +352,10 @@ end # # global constant objects need to reside in address space 2, so we clone each function # that uses global objects and rewrite the globals used by it -function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - entry::LLVM.Function) +function add_global_address_spaces!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + entry::LLVM.Function + ) # determine global variables we need to update global_map = Dict{LLVM.Value, LLVM.Value}() for gv in globals(mod) @@ -375,7 +387,7 @@ function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.M # determine which functions we need to update function_worklist = Set{LLVM.Function}() function check_user(val) - if val isa LLVM.Instruction + return if val isa LLVM.Instruction bb = LLVM.parent(val) f = LLVM.parent(bb) @@ -396,7 +408,7 @@ function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.M for fun in function_worklist fn = LLVM.name(fun) - new_fun = clone(fun; value_map=global_map) + new_fun = clone(fun; value_map = global_map) replace_uses!(fun, new_fun) replace_metadata_uses!(fun, new_fun) erase!(fun) @@ -447,7 +459,7 @@ function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f # emit IR performing the "conversions" new_args = LLVM.Value[] - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry = BasicBlock(new_f, "entry") position!(builder, entry) @@ -466,12 +478,14 @@ function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f # map the arguments value_map = Dict{LLVM.Value, LLVM.Value}( - param => new_args[i] for (i,param) in enumerate(parameters(f)) + param => new_args[i] for (i, param) in enumerate(parameters(f)) ) value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly) + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly + ) # fall through br!(builder, blocks(new_f)[2]) @@ -510,41 +524,43 @@ end const kernel_intrinsics = Dict() for intr in [ - "dispatch_quadgroups_per_threadgroup", "dispatch_simdgroups_per_threadgroup", - "quadgroup_index_in_threadgroup", "quadgroups_per_threadgroup", - "simdgroup_index_in_threadgroup", "simdgroups_per_threadgroup", - "thread_index_in_quadgroup", "thread_index_in_simdgroup", - "thread_index_in_threadgroup", "thread_execution_width", "threads_per_simdgroup"], - (llvm_typ, julia_typ) in [ - ("i32", UInt32), - ("i16", UInt16), - ] - push!(kernel_intrinsics, "julia.air.$intr.$llvm_typ" => (name=intr, typ=julia_typ)) + "dispatch_quadgroups_per_threadgroup", "dispatch_simdgroups_per_threadgroup", + "quadgroup_index_in_threadgroup", "quadgroups_per_threadgroup", + "simdgroup_index_in_threadgroup", "simdgroups_per_threadgroup", + "thread_index_in_quadgroup", "thread_index_in_simdgroup", + "thread_index_in_threadgroup", "thread_execution_width", "threads_per_simdgroup", + ], + (llvm_typ, julia_typ) in [ + ("i32", UInt32), + ("i16", UInt16), + ] + push!(kernel_intrinsics, "julia.air.$intr.$llvm_typ" => (name = intr, typ = julia_typ)) end for intr in [ - "dispatch_threads_per_threadgroup", - "grid_origin", "grid_size", - "thread_position_in_grid", "thread_position_in_threadgroup", - "threadgroup_position_in_grid", "threadgroups_per_grid", - "threads_per_grid", "threads_per_threadgroup"], - (llvm_typ, julia_typ) in [ - ("i32", UInt32), - ("v2i32", NTuple{2, VecElement{UInt32}}), - ("v3i32", NTuple{3, VecElement{UInt32}}), - ("i16", UInt16), - ("v2i16", NTuple{2, VecElement{UInt16}}), - ("v3i16", NTuple{3, VecElement{UInt16}}), - ] - push!(kernel_intrinsics, "julia.air.$intr.$llvm_typ" => (name=intr, typ=julia_typ)) + "dispatch_threads_per_threadgroup", + "grid_origin", "grid_size", + "thread_position_in_grid", "thread_position_in_threadgroup", + "threadgroup_position_in_grid", "threadgroups_per_grid", + "threads_per_grid", "threads_per_threadgroup", + ], + (llvm_typ, julia_typ) in [ + ("i32", UInt32), + ("v2i32", NTuple{2, VecElement{UInt32}}), + ("v3i32", NTuple{3, VecElement{UInt32}}), + ("i16", UInt16), + ("v2i16", NTuple{2, VecElement{UInt16}}), + ("v3i16", NTuple{3, VecElement{UInt16}}), + ] + push!(kernel_intrinsics, "julia.air.$intr.$llvm_typ" => (name = intr, typ = julia_typ)) end function argument_type_name(typ) - if typ isa LLVM.IntegerType && width(typ) == 16 + return if typ isa LLVM.IntegerType && width(typ) == 16 "ushort" elseif typ isa LLVM.IntegerType && width(typ) == 32 "uint" elseif typ isa LLVM.VectorType - argument_type_name(eltype(typ)) * string(Int(length(typ))) + argument_type_name(eltype(typ)) * string(Int(length(typ))) else error("Cannot encode unknown type `$typ`") end @@ -554,18 +570,20 @@ end # # module metadata is used to identify buffers that are passed as kernel arguments. -function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - entry::LLVM.Function) +function add_argument_metadata!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + entry::LLVM.Function + ) entry_ft = function_type(entry) ## argument info arg_infos = Metadata[] # Iterate through arguments and create metadata for them - args = classify_arguments(job, entry_ft; post_optimization=job.config.optimize) + args = classify_arguments(job, entry_ft; post_optimization = job.config.optimize) i = 1 for arg in args - arg.idx === nothing && continue + arg.idx === nothing && continue if job.config.optimize @assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType else @@ -582,12 +600,12 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul # argument index @assert arg.idx == i - push!(md, Metadata(ConstantInt(Int32(i-1)))) + push!(md, Metadata(ConstantInt(Int32(i - 1)))) push!(md, MDString("air.buffer")) push!(md, MDString("air.location_index")) - push!(md, Metadata(ConstantInt(Int32(i-1)))) + push!(md, Metadata(ConstantInt(Int32(i - 1)))) # XXX: unknown push!(md, Metadata(ConstantInt(Int32(1)))) @@ -626,10 +644,10 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul arg_info = Metadata[] - push!(arg_info, Metadata(ConstantInt(Int32(i-1)))) - push!(arg_info, MDString("air.$intr_fn" )) + push!(arg_info, Metadata(ConstantInt(Int32(i - 1)))) + push!(arg_info, MDString("air.$intr_fn")) - push!(arg_info, MDString("air.arg_type_name" )) + push!(arg_info, MDString("air.arg_type_name")) push!(arg_info, MDString(argument_type_name(value_type(intr_arg)))) arg_info = MDNode(arg_info) @@ -766,12 +784,14 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct intr = LLVM.Intrinsic(call_fun) # unsupported, but safe to remove - unsupported_intrinsics = LLVM.Intrinsic.([ - "llvm.experimental.noalias.scope.decl", - "llvm.lifetime.start", - "llvm.lifetime.end", - "llvm.assume" - ]) + unsupported_intrinsics = LLVM.Intrinsic.( + [ + "llvm.experimental.noalias.scope.decl", + "llvm.lifetime.start", + "llvm.lifetime.end", + "llvm.assume", + ] + ) if intr in unsupported_intrinsics erase!(call) changed = true @@ -780,15 +800,15 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct # intrinsics that map straight to AIR mappable_intrinsics = Dict( # one argument - LLVM.Intrinsic("llvm.abs") => ("air.abs", true), - LLVM.Intrinsic("llvm.fabs") => ("air.fabs", missing), + LLVM.Intrinsic("llvm.abs") => ("air.abs", true), + LLVM.Intrinsic("llvm.fabs") => ("air.fabs", missing), # two arguments - LLVM.Intrinsic("llvm.umin") => ("air.min", false), - LLVM.Intrinsic("llvm.smin") => ("air.min", true), - LLVM.Intrinsic("llvm.umax") => ("air.max", false), - LLVM.Intrinsic("llvm.smax") => ("air.max", true), - LLVM.Intrinsic("llvm.minnum") => ("air.fmin", missing), - LLVM.Intrinsic("llvm.maxnum") => ("air.fmax", missing), + LLVM.Intrinsic("llvm.umin") => ("air.min", false), + LLVM.Intrinsic("llvm.smin") => ("air.min", true), + LLVM.Intrinsic("llvm.umax") => ("air.max", false), + LLVM.Intrinsic("llvm.smax") => ("air.max", true), + LLVM.Intrinsic("llvm.minnum") => ("air.fmin", missing), + LLVM.Intrinsic("llvm.maxnum") => ("air.fmax", missing), ) if haskey(mappable_intrinsics, intr) @@ -798,7 +818,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct typ = value_type(call) function type_suffix(typ) # XXX: can't we use LLVM to do this kind of mangling? - if typ isa LLVM.IntegerType + return if typ isa LLVM.IntegerType "i$(width(typ))" elseif typ == LLVM.HalfType() "f16" @@ -824,7 +844,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct else LLVM.Function(mod, fn, call_ft) end - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin position!(builder, call) debuglocation!(builder, call) @@ -852,12 +872,12 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct error("Unsupported copysign type: $typ") end - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin position!(builder, call) debuglocation!(builder, call) # get bits - typ′ = LLVM.IntType(8*sizeof(jltyp)) + typ′ = LLVM.IntType(8 * sizeof(jltyp)) arg0′ = bitcast!(builder, arg0, typ′) arg1′ = bitcast!(builder, arg1, typ′) @@ -892,9 +912,9 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct # create a function that performs the IEEE-compliant operation. # normally we'd do this inline, but LLVM.jl doesn't have BB split functionality. new_intr_fn = if is_minimum - "air.minimum.f$(8*sizeof(jltyp))" + "air.minimum.f$(8 * sizeof(jltyp))" else - "air.maximum.f$(8*sizeof(jltyp))" + "air.maximum.f$(8 * sizeof(jltyp))" end if haskey(functions(mod), new_intr_fn) @@ -914,7 +934,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct bb_compare_zero = BasicBlock(new_intr, "compare_zero") bb_fallback = BasicBlock(new_intr, "fallback") - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin # first, check if either argument is NaN, and return it if so position!(builder, bb_check_arg0) @@ -936,14 +956,18 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct position!(builder, bb_check_zero) - typ′ = LLVM.IntType(8*sizeof(jltyp)) + typ′ = LLVM.IntType(8 * sizeof(jltyp)) arg0′ = bitcast!(builder, arg0, typ′) arg1′ = bitcast!(builder, arg1, typ′) - arg0_zero = fcmp!(builder, LLVM.API.LLVMRealUEQ, arg0, - LLVM.ConstantFP(typ, zero(jltyp))) - arg1_zero = fcmp!(builder, LLVM.API.LLVMRealUEQ, arg1, - LLVM.ConstantFP(typ, zero(jltyp))) + arg0_zero = fcmp!( + builder, LLVM.API.LLVMRealUEQ, arg0, + LLVM.ConstantFP(typ, zero(jltyp)) + ) + arg1_zero = fcmp!( + builder, LLVM.API.LLVMRealUEQ, arg1, + LLVM.ConstantFP(typ, zero(jltyp)) + ) args_zero = and!(builder, arg0_zero, arg1_zero) arg0_sign = and!(builder, arg0′, LLVM.ConstantInt(typ′, Base.sign_mask(jltyp))) arg1_sign = and!(builder, arg1′, LLVM.ConstantInt(typ′, Base.sign_mask(jltyp))) @@ -952,8 +976,10 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct br!(builder, relevant_zero, bb_compare_zero, bb_fallback) position!(builder, bb_compare_zero) - arg0_negative = icmp!(builder, LLVM.API.LLVMIntNE, arg0_sign, - LLVM.ConstantInt(typ′, 0)) + arg0_negative = icmp!( + builder, LLVM.API.LLVMIntNE, arg0_sign, + LLVM.ConstantInt(typ′, 0) + ) val = if is_minimum select!(builder, arg0_negative, arg0, arg1) else @@ -965,9 +991,9 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct position!(builder, bb_fallback) fallback_intr_fn = if is_minimum - "air.fmin.f$(8*sizeof(jltyp))" + "air.fmin.f$(8 * sizeof(jltyp))" else - "air.fmax.f$(8*sizeof(jltyp))" + "air.fmax.f$(8 * sizeof(jltyp))" end fallback_intr = if haskey(functions(mod), fallback_intr_fn) functions(mod)[fallback_intr_fn] @@ -979,7 +1005,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct end end - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin position!(builder, call) debuglocation!(builder, call) @@ -1005,22 +1031,24 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod attrs = function_attributes(f) function add_attributes(names...) for name in names - if LLVM.version() >= v"16" && name in ["argmemonly", "inaccessiblememonly", - "inaccessiblemem_or_argmemonly", - "readnone", "readonly", "writeonly"] + if LLVM.version() >= v"16" && name in [ + "argmemonly", "inaccessiblememonly", + "inaccessiblemem_or_argmemonly", + "readnone", "readonly", "writeonly", + ] # XXX: workaround for changes from https://reviews.llvm.org/D135780 continue end push!(attrs, EnumAttribute(name, 0)) end - changed = true + return changed = true end # synchronization if fn == "air.wg.barrier" || fn == "air.simdgroup.barrier" add_attributes("nounwind", "convergent") - # atomics + # atomics elseif match(r"air.atomic.(local|global).load", fn) !== nothing # TODO: "memory(argmem: read)" on LLVM 16+ add_attributes("argmemonly", "readonly", "nounwind") @@ -1066,7 +1094,7 @@ function replace_unreachable!(@nospecialize(job::CompilerJob), f::LLVM.Function) # would probably keep the problematic control flow just as it is. isempty(exit_blocks) && return false - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin # if we have multiple exit blocks, take the last one, which is hopefully the least # divergent (assuming divergent control flow is the root of the problem here). exit_block = last(exit_blocks) diff --git a/src/native.jl b/src/native.jl index fdd880ec..5861e94f 100644 --- a/src/native.jl +++ b/src/native.jl @@ -5,17 +5,17 @@ export NativeCompilerTarget Base.@kwdef struct NativeCompilerTarget <: AbstractCompilerTarget - cpu::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName()) - features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures()) - llvm_always_inline::Bool=false # will mark the job function as always inline - jlruntime::Bool=false # Use Julia runtime for throwing errors, instead of the GPUCompiler support + cpu::String = (LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName()) + features::String = (LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures()) + llvm_always_inline::Bool = false # will mark the job function as always inline + jlruntime::Bool = false # Use Julia runtime for throwing errors, instead of the GPUCompiler support end llvm_triple(::NativeCompilerTarget) = Sys.MACHINE function llvm_machine(target::NativeCompilerTarget) triple = llvm_triple(target) - t = Target(triple=triple) + t = Target(triple = triple) tm = TargetMachine(t, triple, target.cpu, target.features) asm_verbosity!(tm, true) diff --git a/src/optim.jl b/src/optim.jl index 282127f2..2b0f4e74 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -1,12 +1,12 @@ # LLVM IR optimization -function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=2) +function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level = 2) tm = llvm_machine(job.config.target) global current_job current_job = job - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin register!(pb, GPULowerCPUFeaturesPass()) register!(pb, GPULowerPTLSPass()) register!(pb, GPULowerGCFramePass()) @@ -19,6 +19,14 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level= end run!(pb, mod, tm) end + + # Make sure any lingering TLS getters are rewritten even if upstream LLVM passes + # transformed them before the GPULowerPTLSPass had a chance to run. + if occursin("StaticCompilerTarget", string(typeof(job.config.target))) && + uses_julia_runtime(job) + lower_ptls!(mod) + end + optimize_module!(job, mod) run!(DeadArgumentEliminationPass(), mod, tm) return @@ -39,23 +47,25 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) end end buildIntrinsicLoweringPipeline(mpm, job, opt_level) - buildCleanupPipeline(mpm, job, opt_level) + return buildCleanupPipeline(mpm, job, opt_level) end const BasicSimplifyCFGOptions = - (; switch_range_to_icmp=true, - switch_to_lookup=true, - forward_switch_cond=true, - ) + (; + switch_range_to_icmp = true, + switch_to_lookup = true, + forward_switch_cond = true, +) const AggressiveSimplifyCFGOptions = - (; switch_range_to_icmp=true, - switch_to_lookup=true, - forward_switch_cond=true, - # These mess with loop rotation, so only do them after that - hoist_common_insts=true, - # Causes an SRET assertion error in late-gc-lowering - #sink_common_insts=true - ) + (; + switch_range_to_icmp = true, + switch_to_lookup = true, + forward_switch_cond = true, + # These mess with loop rotation, so only do them after that + hoist_common_insts = true, + # Causes an SRET assertion error in late-gc-lowering + #sink_common_insts=true +) function buildEarlySimplificationPipeline(mpm, @nospecialize(job::CompilerJob), opt_level) if should_verify() @@ -68,7 +78,7 @@ function buildEarlySimplificationPipeline(mpm, @nospecialize(job::CompilerJob), # TODO invokePipelineStartCallbacks add!(mpm, Annotation2MetadataPass()) add!(mpm, ConstantMergePass()) - add!(mpm, NewPMFunctionPassManager()) do fpm + return add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, LowerExpectIntrinsicPass()) if opt_level >= 2 add!(fpm, PropagateJuliaAddrspacesPass()) @@ -92,7 +102,7 @@ function buildEarlyOptimizerPipeline(mpm, @nospecialize(job::CompilerJob), opt_l end end add!(mpm, GPULowerCPUFeaturesPass()) - if opt_level >= 1 + return if opt_level >= 1 add!(mpm, NewPMFunctionPassManager()) do fpm if opt_level >= 2 add!(fpm, SROAPass()) @@ -120,10 +130,10 @@ function buildLoopOptimizerPipeline(fpm, @nospecialize(job::CompilerJob), opt_le # TODO invokeLateLoopOptimizationCallbacks end if opt_level >= 2 - add!(fpm, NewPMLoopPassManager(; use_memory_ssa=true)) do lpm + add!(fpm, NewPMLoopPassManager(; use_memory_ssa = true)) do lpm add!(lpm, LICMPass()) add!(lpm, JuliaLICMPass()) - add!(lpm, SimpleLoopUnswitchPass(nontrivial=true, trivial=true)) + add!(lpm, SimpleLoopUnswitchPass(nontrivial = true, trivial = true)) add!(lpm, LICMPass()) add!(lpm, JuliaLICMPass()) end @@ -131,7 +141,7 @@ function buildLoopOptimizerPipeline(fpm, @nospecialize(job::CompilerJob), opt_le if opt_level >= 2 add!(fpm, IRCEPass()) end - add!(fpm, NewPMLoopPassManager()) do lpm + return add!(fpm, NewPMLoopPassManager()) do lpm if opt_level >= 2 add!(lpm, LoopInstSimplifyPass()) add!(lpm, LoopIdiomRecognizePass()) @@ -160,7 +170,7 @@ function buildScalarOptimizerPipeline(fpm, @nospecialize(job::CompilerJob), opt_ if opt_level >= 3 add!(fpm, GVNPass()) end - if opt_level >= 2 + return if opt_level >= 2 add!(fpm, DSEPass()) # TODO invokePeepholeCallbacks add!(fpm, SimplifyCFGPass(; AggressiveSimplifyCFGOptions...)) @@ -184,7 +194,7 @@ function buildVectorPipeline(fpm, @nospecialize(job::CompilerJob), opt_level) add!(fpm, VectorCombinePass()) # TODO invokeVectorizerCallbacks add!(fpm, ADCEPass()) - add!(fpm, LoopUnrollPass(; opt_level)) + return add!(fpm, LoopUnrollPass(; opt_level)) end function buildIntrinsicLoweringPipeline(mpm, @nospecialize(job::CompilerJob), opt_level) @@ -257,7 +267,7 @@ function buildIntrinsicLoweringPipeline(mpm, @nospecialize(job::CompilerJob), op # Julia's operand bundles confuse the inliner, so repeat here now they are gone. # FIXME: we should fix the inliner so that inlined code gets optimized early-on - add!(mpm, AlwaysInlinerPass()) + return add!(mpm, AlwaysInlinerPass()) end function buildCleanupPipeline(mpm, @nospecialize(job::CompilerJob), opt_level) @@ -273,7 +283,7 @@ function buildCleanupPipeline(mpm, @nospecialize(job::CompilerJob), opt_level) add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, AnnotationRemarksPass()) end - add!(mpm, NewPMFunctionPassManager()) do fpm + return add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, DemoteFloat16Pass()) if opt_level >= 1 add!(fpm, GVNPass()) @@ -282,7 +292,6 @@ function buildCleanupPipeline(mpm, @nospecialize(job::CompilerJob), opt_level) end - ## custom passes # lowering intrinsics @@ -351,7 +360,7 @@ function lower_gc_frame!(fun::LLVM.Function) alloc_obj = functions(mod)["julia.gc_alloc_obj"] alloc_obj_ft = function_type(alloc_obj) T_prjlvalue = return_type(alloc_obj_ft) - T_pjlvalue = convert(LLVMType, Any; allow_boxed=true) + T_pjlvalue = convert(LLVMType, Any; allow_boxed = true) for use in uses(alloc_obj) call = user(use)::LLVM.CallInst @@ -361,7 +370,7 @@ function lower_gc_frame!(fun::LLVM.Function) sz = ops[2] # replace with PTX alloc_obj - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin position!(builder, call) ptr = call!(builder, Runtime.get(:gc_pool_alloc), [sz]) replace_uses!(call, ptr) @@ -405,7 +414,31 @@ function lower_ptls!(mod::LLVM.Module) intrinsic = "julia.get_pgcstack" - if haskey(functions(mod), intrinsic) + # On host-style static targets we want a relocatable call into libjulia instead of + # embedding the pointer to the TLS getter. Replace the intrinsic with a declared + # libjulia call to avoid baking absolute addresses that crash in standalone binaries. + if haskey(functions(mod), intrinsic) && + occursin("StaticCompilerTarget", string(typeof(job.config.target))) && + uses_julia_runtime(job) + + pgc_fn = functions(mod)[intrinsic] + jl_decl = if haskey(functions(mod), "jl_get_pgcstack") + functions(mod)["jl_get_pgcstack"] + else + LLVM.Function(mod, "jl_get_pgcstack", LLVM.FunctionType(LLVM.PointerType())) + end + + for use in uses(pgc_fn) + call = user(use)::LLVM.CallInst + @dispose builder = IRBuilder() begin + position!(builder, call) + repl = call!(builder, function_type(jl_decl), jl_decl, LLVM.Value[]) + replace_uses!(call, repl) + end + erase!(call) + changed = true + end + elseif haskey(functions(mod), intrinsic) ptls_getter = functions(mod)[intrinsic] for use in uses(ptls_getter) @@ -417,7 +450,35 @@ function lower_ptls!(mod::LLVM.Module) # the validator will detect this end end - end + end + + # Newer Julia versions sometimes lower the TLS getter to an inttoptr call that bakes + # the address of `jl_get_pgcstack_static` into the IR. Rewrite those calls as well to + # make sure we always end up with a relocatable reference into libjulia when the + # runtime is linked. + if uses_julia_runtime(job) && occursin("StaticCompilerTarget", string(typeof(job.config.target))) + jl_decl = if haskey(functions(mod), "jl_get_pgcstack") + functions(mod)["jl_get_pgcstack"] + else + LLVM.Function(mod, "jl_get_pgcstack", LLVM.FunctionType(LLVM.PointerType())) + end + + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + inst isa LLVM.CallInst || continue + + callee = LLVM.called_operand(inst) + if callee isa LLVM.ConstantExpr && occursin("inttoptr", string(callee)) && + occursin("pgcstack", string(inst)) + @dispose builder = IRBuilder() begin + position!(builder, inst) + repl = call!(builder, function_type(jl_decl), jl_decl, LLVM.Value[]) + replace_uses!(inst, repl) + end + erase!(inst) + changed = true + end + end + end return changed end diff --git a/src/precompile.jl b/src/precompile.jl index 8f62451b..85fe86b6 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -2,22 +2,22 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin precompile_module = @eval module $(gensym()) - using ..GPUCompiler - - module DummyRuntime - # dummy methods - signal_exception() = return - malloc(sz) = C_NULL - report_oom(sz) = return - report_exception(ex) = return - report_exception_name(ex) = return - report_exception_frame(idx, func, file, line) = return - end + using ..GPUCompiler + + module DummyRuntime + # dummy methods + signal_exception() = return + malloc(sz) = C_NULL + report_oom(sz) = return + report_exception(ex) = return + report_exception_name(ex) = return + report_exception_frame(idx, func, file, line) = return + end - struct DummyCompilerParams <: AbstractCompilerParams end - const DummyCompilerJob = CompilerJob{NativeCompilerTarget, DummyCompilerParams} + struct DummyCompilerParams <: AbstractCompilerParams end + const DummyCompilerJob = CompilerJob{NativeCompilerTarget, DummyCompilerParams} - GPUCompiler.runtime_module(::DummyCompilerJob) = DummyRuntime + GPUCompiler.runtime_module(::DummyCompilerJob) = DummyRuntime end kernel() = nothing @@ -28,7 +28,7 @@ using PrecompileTools: @setup_workload, @compile_workload params = precompile_module.DummyCompilerParams() # XXX: on Windows, compiling the GPU runtime leaks GPU code in the native cache, # so prevent building the runtime library (see JuliaGPU/GPUCompiler.jl#601) - config = CompilerConfig(target, params; libraries=false) + config = CompilerConfig(target, params; libraries = false) job = CompilerJob(source, config) JuliaContext() do ctx diff --git a/src/ptx.jl b/src/ptx.jl index 2ee1b700..0a340b94 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -13,16 +13,16 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget debuginfo::Bool = false # optional properties - minthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing - maxthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing - blocks_per_sm::Union{Nothing,Int} = nothing - maxregs::Union{Nothing,Int} = nothing + minthreads::Union{Nothing, Int, NTuple{<:Any, Int}} = nothing + maxthreads::Union{Nothing, Int, NTuple{<:Any, Int}} = nothing + blocks_per_sm::Union{Nothing, Int} = nothing + maxregs::Union{Nothing, Int} = nothing fastmath::Bool = Base.JLOptions().fast_math == 1 # deprecated; remove with next major version - exitable::Union{Nothing,Bool} = nothing - unreachable::Union{Nothing,Bool} = nothing + exitable::Union{Nothing, Bool} = nothing + unreachable::Union{Nothing, Bool} = nothing end function Base.hash(target::PTXCompilerTarget, h::UInt) @@ -37,22 +37,24 @@ function Base.hash(target::PTXCompilerTarget, h::UInt) h = hash(target.maxregs, h) h = hash(target.fastmath, h) - h + return h end source_code(target::PTXCompilerTarget) = "ptx" -llvm_triple(target::PTXCompilerTarget) = Int===Int64 ? "nvptx64-nvidia-cuda" : "nvptx-nvidia-cuda" +llvm_triple(target::PTXCompilerTarget) = Int === Int64 ? "nvptx64-nvidia-cuda" : "nvptx-nvidia-cuda" function llvm_machine(target::PTXCompilerTarget) @static if :NVPTX ∉ LLVM.backends() return nothing end triple = llvm_triple(target) - t = Target(triple=triple) + t = Target(triple = triple) - tm = TargetMachine(t, triple, "sm_$(target.cap.major)$(target.cap.minor)", - "+ptx$(target.ptx.major)$(target.ptx.minor)") + tm = TargetMachine( + t, triple, "sm_$(target.cap.major)$(target.cap.minor)", + "+ptx$(target.ptx.major)$(target.ptx.minor)" + ) asm_verbosity!(tm, true) return tm @@ -64,7 +66,7 @@ llvm_datalayout(target::PTXCompilerTarget) = "e-" * # on 32-bit systems, use 32-bit pointers. # on 64-bit systems, use 64-bit pointers. - (Int === Int64 ? "p:64:64:64-" : "p:32:32:32-") * + (Int === Int64 ? "p:64:64:64-" : "p:32:32:32-") * # alignment of integer types "i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-" * # alignment of floating point types @@ -90,7 +92,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget})) job.config.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.config.target.maxthreads)") job.config.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.config.target.blocks_per_sm)") job.config.target.maxregs !== nothing && print(io, ", maxregs=$(job.config.target.maxregs)") - job.config.target.fastmath && print(io, ", fast math enabled") + return job.config.target.fastmath && print(io, ", fast math enabled") end const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free") @@ -103,16 +105,20 @@ runtime_slug(@nospecialize(job::CompilerJob{PTXCompilerTarget})) = "-sm_$(job.config.target.cap.major)$(job.config.target.cap.minor)" * "-debuginfo=$(Int(llvm_debug_info(job)))" -function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), - mod::LLVM.Module, entry::LLVM.Function) +function finish_module!( + @nospecialize(job::CompilerJob{PTXCompilerTarget}), + mod::LLVM.Module, entry::LLVM.Function + ) # emit the device capability and ptx isa version as constants in the module. this makes # it possible to 'query' these in device code, relying on LLVM to optimize the checks # away and generate static code. note that we only do so if there's actual uses of these # variables; unconditionally creating a gvar would result in duplicate declarations. - for (name, value) in ["sm_major" => job.config.target.cap.major, - "sm_minor" => job.config.target.cap.minor, - "ptx_major" => job.config.target.ptx.major, - "ptx_minor" => job.config.target.ptx.minor] + for (name, value) in [ + "sm_major" => job.config.target.cap.major, + "sm_minor" => job.config.target.cap.minor, + "ptx_major" => job.config.target.ptx.major, + "ptx_minor" => job.config.target.ptx.minor, + ] if haskey(globals(mod), name) gv = globals(mod)[name] initializer!(gv, ConstantInt(LLVM.Int32Type(), value)) @@ -139,7 +145,7 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # we emit properties (of the device and ptx isa) as private global constants, # so run the optimizer so that they are inlined before the rest of the optimizer runs. - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, RecomputeGlobalsAAPass()) add!(pb, GlobalOptPass()) run!(pb, mod, llvm_machine(job.config.target)) @@ -148,11 +154,13 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), return entry end -function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), - mod::LLVM.Module) +function optimize_module!( + @nospecialize(job::CompilerJob{PTXCompilerTarget}), + mod::LLVM.Module + ) tm = llvm_machine(job.config.target) # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) - @dispose pb=NewPMPassBuilder() begin + return @dispose pb = NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) add!(pb, NewPMFunctionPassManager()) do fpm @@ -166,9 +174,9 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # but Julia's pass sequence only invokes the simple unroller. add!(fpm, LoopUnrollPass(; job.config.opt_level)) add!(fpm, InstCombinePass()) # clean-up redundancy - add!(fpm, NewPMLoopPassManager(; use_memory_ssa=true)) do lpm + add!(fpm, NewPMLoopPassManager(; use_memory_ssa = true)) do lpm add!(lpm, LICMPass()) # the inner runtime check might be - # outer loop invariant + # outer loop invariant end # the above loop unroll pass might have unrolled regular, non-runtime nested loops. @@ -191,8 +199,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), end end -function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), - mod::LLVM.Module, entry::LLVM.Function) +function finish_ir!( + @nospecialize(job::CompilerJob{PTXCompilerTarget}), + mod::LLVM.Module, entry::LLVM.Function + ) if LLVM.version() < v"17" for f in functions(mod) lower_unreachable!(f) @@ -206,33 +216,53 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), annotations = Metadata[entry] ## kernel metadata - append!(annotations, [MDString("kernel"), - ConstantInt(Int32(1))]) + append!( + annotations, [ + MDString("kernel"), + ConstantInt(Int32(1)), + ] + ) ## expected CTA sizes if job.config.target.minthreads !== nothing for (dim, name) in enumerate([:x, :y, :z]) bound = dim <= length(job.config.target.minthreads) ? job.config.target.minthreads[dim] : 1 - append!(annotations, [MDString("reqntid$name"), - ConstantInt(Int32(bound))]) + append!( + annotations, [ + MDString("reqntid$name"), + ConstantInt(Int32(bound)), + ] + ) end end if job.config.target.maxthreads !== nothing for (dim, name) in enumerate([:x, :y, :z]) bound = dim <= length(job.config.target.maxthreads) ? job.config.target.maxthreads[dim] : 1 - append!(annotations, [MDString("maxntid$name"), - ConstantInt(Int32(bound))]) + append!( + annotations, [ + MDString("maxntid$name"), + ConstantInt(Int32(bound)), + ] + ) end end if job.config.target.blocks_per_sm !== nothing - append!(annotations, [MDString("minctasm"), - ConstantInt(Int32(job.config.target.blocks_per_sm))]) + append!( + annotations, [ + MDString("minctasm"), + ConstantInt(Int32(job.config.target.blocks_per_sm)), + ] + ) end if job.config.target.maxregs !== nothing - append!(annotations, [MDString("maxnreg"), - ConstantInt(Int32(job.config.target.maxregs))]) + append!( + annotations, [ + MDString("maxnreg"), + ConstantInt(Int32(job.config.target.maxregs)), + ] + ) end push!(metadata(mod)["nvvm.annotations"], MDNode(annotations)) @@ -243,7 +273,7 @@ end function llvm_debug_info(@nospecialize(job::CompilerJob{PTXCompilerTarget})) # allow overriding the debug info from CUDA.jl - if job.config.target.debuginfo + return if job.config.target.debuginfo invoke(llvm_debug_info, Tuple{CompilerJob}, job) else LLVM.API.LLVMDebugEmissionKindNoDebug @@ -364,7 +394,7 @@ function lower_unreachable!(f::LLVM.Function) end # rewrite the unreachable terminators - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry_block = first(blocks(f)) for block in unreachable_blocks inst = terminator(block) @@ -391,98 +421,98 @@ function nvvm_reflect!(fun::LLVM.Function) changed = false @tracepoint "nvvmreflect" begin - # find and sanity check the nnvm-reflect function - # TODO: also handle the llvm.nvvm.reflect intrinsic - haskey(LLVM.functions(mod), NVVM_REFLECT_FUNCTION) || return false - reflect_function = functions(mod)[NVVM_REFLECT_FUNCTION] - isdeclaration(reflect_function) || error("_reflect function should not have a body") - reflect_typ = return_type(function_type(reflect_function)) - isa(reflect_typ, LLVM.IntegerType) || error("_reflect's return type should be integer") - - to_remove = [] - for use in uses(reflect_function) - call = user(use) - isa(call, LLVM.CallInst) || continue - if length(operands(call)) != 2 - @error """Unrecognized format of __nvvm_reflect call: - $(string(call)) - Wrong number of operands: expected 2, got $(length(operands(call))).""" - continue - end + # find and sanity check the nnvm-reflect function + # TODO: also handle the llvm.nvvm.reflect intrinsic + haskey(LLVM.functions(mod), NVVM_REFLECT_FUNCTION) || return false + reflect_function = functions(mod)[NVVM_REFLECT_FUNCTION] + isdeclaration(reflect_function) || error("_reflect function should not have a body") + reflect_typ = return_type(function_type(reflect_function)) + isa(reflect_typ, LLVM.IntegerType) || error("_reflect's return type should be integer") + + to_remove = [] + for use in uses(reflect_function) + call = user(use) + isa(call, LLVM.CallInst) || continue + if length(operands(call)) != 2 + @error """Unrecognized format of __nvvm_reflect call: + $(string(call)) + Wrong number of operands: expected 2, got $(length(operands(call))).""" + continue + end - # decode the string argument - if LLVM.version() >= v"17" - sym = operands(call)[1] - else - str = operands(call)[1] - if !isa(str, LLVM.ConstantExpr) || opcode(str) != LLVM.API.LLVMGetElementPtr + # decode the string argument + if LLVM.version() >= v"17" + sym = operands(call)[1] + else + str = operands(call)[1] + if !isa(str, LLVM.ConstantExpr) || opcode(str) != LLVM.API.LLVMGetElementPtr + @safe_error """Unrecognized format of __nvvm_reflect call: + $(string(call)) + Operand should be a GEP instruction, got a $(typeof(str)). Please file an issue.""" + continue + end + sym = operands(str)[1] + if isa(sym, LLVM.ConstantExpr) && opcode(sym) == LLVM.API.LLVMGetElementPtr + # CUDA 11.0 or below + sym = operands(sym)[1] + end + end + if !isa(sym, LLVM.GlobalVariable) @safe_error """Unrecognized format of __nvvm_reflect call: - $(string(call)) - Operand should be a GEP instruction, got a $(typeof(str)). Please file an issue.""" + $(string(call)) + Operand should be a global variable, got a $(typeof(sym)). Please file an issue.""" continue end - sym = operands(str)[1] - if isa(sym, LLVM.ConstantExpr) && opcode(sym) == LLVM.API.LLVMGetElementPtr - # CUDA 11.0 or below - sym = operands(sym)[1] + sym_op = operands(sym)[1] + if !isa(sym_op, LLVM.ConstantArray) && !isa(sym_op, LLVM.ConstantDataArray) + @safe_error """Unrecognized format of __nvvm_reflect call: + $(string(call)) + Operand should be a constant array, got a $(typeof(sym_op)). Please file an issue.""" + end + chars = convert.(Ref(UInt8), collect(sym_op)) + reflect_arg = String(chars[1:(end - 1)]) + + # handle possible cases + # XXX: put some of these property in the compiler job? + # and/or first set the "nvvm-reflect-*" module flag like Clang does? + fast_math = current_job.config.target.fastmath + # NOTE: we follow nvcc's --use_fast_math + reflect_val = if reflect_arg == "__CUDA_FTZ" + # single-precision denormals support + ConstantInt(reflect_typ, fast_math ? 1 : 0) + elseif reflect_arg == "__CUDA_PREC_DIV" + # single-precision floating-point division and reciprocals. + ConstantInt(reflect_typ, fast_math ? 0 : 1) + elseif reflect_arg == "__CUDA_PREC_SQRT" + # single-precision floating point square roots. + ConstantInt(reflect_typ, fast_math ? 0 : 1) + elseif reflect_arg == "__CUDA_FMAD" + # contraction of floating-point multiplies and adds/subtracts into + # floating-point multiply-add operations (FMAD, FFMA, or DFMA) + ConstantInt(reflect_typ, fast_math ? 1 : 0) + elseif reflect_arg == "__CUDA_ARCH" + ConstantInt(reflect_typ, job.config.target.cap.major * 100 + job.config.target.cap.minor * 10) + else + @safe_error """Unrecognized format of __nvvm_reflect call: + $(string(call)) + Unknown argument $reflect_arg. Please file an issue.""" + continue end - end - if !isa(sym, LLVM.GlobalVariable) - @safe_error """Unrecognized format of __nvvm_reflect call: - $(string(call)) - Operand should be a global variable, got a $(typeof(sym)). Please file an issue.""" - continue - end - sym_op = operands(sym)[1] - if !isa(sym_op, LLVM.ConstantArray) && !isa(sym_op, LLVM.ConstantDataArray) - @safe_error """Unrecognized format of __nvvm_reflect call: - $(string(call)) - Operand should be a constant array, got a $(typeof(sym_op)). Please file an issue.""" - end - chars = convert.(Ref(UInt8), collect(sym_op)) - reflect_arg = String(chars[1:end-1]) - - # handle possible cases - # XXX: put some of these property in the compiler job? - # and/or first set the "nvvm-reflect-*" module flag like Clang does? - fast_math = current_job.config.target.fastmath - # NOTE: we follow nvcc's --use_fast_math - reflect_val = if reflect_arg == "__CUDA_FTZ" - # single-precision denormals support - ConstantInt(reflect_typ, fast_math ? 1 : 0) - elseif reflect_arg == "__CUDA_PREC_DIV" - # single-precision floating-point division and reciprocals. - ConstantInt(reflect_typ, fast_math ? 0 : 1) - elseif reflect_arg == "__CUDA_PREC_SQRT" - # single-precision floating point square roots. - ConstantInt(reflect_typ, fast_math ? 0 : 1) - elseif reflect_arg == "__CUDA_FMAD" - # contraction of floating-point multiplies and adds/subtracts into - # floating-point multiply-add operations (FMAD, FFMA, or DFMA) - ConstantInt(reflect_typ, fast_math ? 1 : 0) - elseif reflect_arg == "__CUDA_ARCH" - ConstantInt(reflect_typ, job.config.target.cap.major*100 + job.config.target.cap.minor*10) - else - @safe_error """Unrecognized format of __nvvm_reflect call: - $(string(call)) - Unknown argument $reflect_arg. Please file an issue.""" - continue - end - replace_uses!(call, reflect_val) - push!(to_remove, call) - end + replace_uses!(call, reflect_val) + push!(to_remove, call) + end - # remove the calls to the function - for val in to_remove - @assert isempty(uses(val)) - erase!(val) - end + # remove the calls to the function + for val in to_remove + @assert isempty(uses(val)) + erase!(val) + end - # maybe also delete the function - if isempty(uses(reflect_function)) - erase!(reflect_function) - end + # maybe also delete the function + if isempty(uses(reflect_function)) + erase!(reflect_function) + end end return changed diff --git a/src/reflection.jl b/src/reflection.jl index 4cae6f01..10156be0 100644 --- a/src/reflection.jl +++ b/src/reflection.jl @@ -8,7 +8,7 @@ const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulh # syntax highlighting # -const _pygmentize = Ref{Union{String,Nothing}}() +const _pygmentize = Ref{Union{String, Nothing}}() function pygmentize() if !isassigned(_pygmentize) _pygmentize[] = Sys.which("pygmentize") @@ -120,15 +120,15 @@ include("reflection_compat.jl") function code_lowered(@nospecialize(job::CompilerJob); kwargs...) sig = job.source.specTypes # XXX: can we just use the method instance? - code_lowered_by_type(sig; kwargs...) + return code_lowered_by_type(sig; kwargs...) end -function code_typed(@nospecialize(job::CompilerJob); interactive::Bool=false, kwargs...) +function code_typed(@nospecialize(job::CompilerJob); interactive::Bool = false, kwargs...) sig = job.source.specTypes # XXX: can we just use the method instance? - if interactive + return if interactive # call Cthulhu without introducing a dependency on Cthulhu mod = get(Base.loaded_modules, Cthulhu, nothing) - mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") + mod === nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") interp = get_interpreter(job) descend_code_typed = getfield(mod, :descend_code_typed) descend_code_typed(sig; interp, kwargs...) @@ -138,13 +138,13 @@ function code_typed(@nospecialize(job::CompilerJob); interactive::Bool=false, kw end end -function code_warntype(io::IO, @nospecialize(job::CompilerJob); interactive::Bool=false, kwargs...) +function code_warntype(io::IO, @nospecialize(job::CompilerJob); interactive::Bool = false, kwargs...) sig = job.source.specTypes # XXX: can we just use the method instance? - if interactive + return if interactive @assert io == stdout # call Cthulhu without introducing a dependency on Cthulhu mod = get(Base.loaded_modules, Cthulhu, nothing) - mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") + mod === nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") interp = get_interpreter(job) descend_code_warntype = getfield(mod, :descend_code_warntype) @@ -183,22 +183,26 @@ The following keyword arguments are supported: See also: [`@device_code_llvm`](@ref), `InteractiveUtils.code_llvm` """ -function code_llvm(io::IO, @nospecialize(job::CompilerJob); optimize::Bool=true, raw::Bool=false, - debuginfo::Symbol=:default, dump_module::Bool=false, kwargs...) +function code_llvm( + io::IO, @nospecialize(job::CompilerJob); optimize::Bool = true, raw::Bool = false, + debuginfo::Symbol = :default, dump_module::Bool = false, kwargs... + ) # NOTE: jl_dump_function_ir supports stripping metadata, so don't do it in the driver - config = CompilerConfig(job.config; validate=false, strip=false, optimize) + config = CompilerConfig(job.config; validate = false, strip = false, optimize) str = JuliaContext() do ctx ir, meta = compile(:llvm, CompilerJob(job; config)) ts_mod = ThreadSafeModule(ir) entry_fn = meta.entry GC.@preserve ts_mod entry_fn begin value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value, !raw, dump_module, debuginfo) + ccall( + :jl_dump_function_ir, Ref{String}, + (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), + value, !raw, dump_module, debuginfo + ) end end - highlight(io, str, "llvm") + return highlight(io, str, "llvm") end code_llvm(@nospecialize(job::CompilerJob); kwargs...) = code_llvm(stdout, job; kwargs...) @@ -215,13 +219,15 @@ The following keyword arguments are supported: See also: [`@device_code_native`](@ref), `InteractiveUtils.code_llvm` """ -function code_native(io::IO, @nospecialize(job::CompilerJob); - raw::Bool=false, dump_module::Bool=false) - config = CompilerConfig(job.config; strip=!raw, only_entry=!dump_module, validate=false) +function code_native( + io::IO, @nospecialize(job::CompilerJob); + raw::Bool = false, dump_module::Bool = false + ) + config = CompilerConfig(job.config; strip = !raw, only_entry = !dump_module, validate = false) asm, meta = JuliaContext() do ctx compile(:asm, CompilerJob(job; config)) end - highlight(io, asm, source_code(job.config.target)) + return highlight(io, asm, source_code(job.config.target)) end code_native(@nospecialize(job::CompilerJob); kwargs...) = code_native(stdout, job; kwargs...) @@ -233,12 +239,12 @@ code_native(@nospecialize(job::CompilerJob); kwargs...) = function emit_hooked_compilation(inner_hook, ex...) user_code = ex[end] - user_kwargs = ex[1:end-1] - quote + user_kwargs = ex[1:(end - 1)] + return quote # we only want to invoke the hook once for every compilation job jobs = Set() function outer_hook(job) - if !in(job, jobs) + return if !in(job, jobs) # the user hook might invoke the compiler again, so disable the hook old_hook = $compile_hook[] try @@ -276,10 +282,10 @@ Evaluates the expression `ex` and returns the result of See also: `InteractiveUtils.@code_lowered` """ macro device_code_lowered(ex...) - quote + return quote buf = Any[] function hook(job::CompilerJob) - append!(buf, code_lowered(job)) + return append!(buf, code_lowered(job)) end $(emit_hooked_compilation(:hook, ex...)) buf @@ -295,10 +301,10 @@ Evaluates the expression `ex` and returns the result of See also: `InteractiveUtils.@code_typed` """ macro device_code_typed(ex...) - quote - output = Dict{CompilerJob,Any}() + return quote + output = Dict{CompilerJob, Any}() function hook(job::CompilerJob; kwargs...) - output[job] = code_typed(job; kwargs...) + return output[job] = code_typed(job; kwargs...) end $(emit_hooked_compilation(:hook, ex...)) output @@ -314,12 +320,12 @@ Evaluates the expression `ex` and prints the result of See also: `InteractiveUtils.@code_warntype` """ macro device_code_warntype(ex...) - function hook(job::CompilerJob; io::IO=stdout, kwargs...) + function hook(job::CompilerJob; io::IO = stdout, kwargs...) println(io, "$job") println(io) - code_warntype(io, job; kwargs...) + return code_warntype(io, job; kwargs...) end - emit_hooked_compilation(hook, ex...) + return emit_hooked_compilation(hook, ex...) end """ @@ -332,11 +338,11 @@ to `io` for every compiled GPU kernel. For other supported keywords, see See also: InteractiveUtils.@code_llvm """ macro device_code_llvm(ex...) - function hook(job::CompilerJob; io::IO=stdout, kwargs...) + function hook(job::CompilerJob; io::IO = stdout, kwargs...) println(io, "; $job") - code_llvm(io, job; kwargs...) + return code_llvm(io, job; kwargs...) end - emit_hooked_compilation(hook, ex...) + return emit_hooked_compilation(hook, ex...) end """ @@ -347,12 +353,12 @@ for every compiled GPU kernel. For other supported keywords, see [`GPUCompiler.code_native`](@ref). """ macro device_code_native(ex...) - function hook(job::CompilerJob; io::IO=stdout, kwargs...) + function hook(job::CompilerJob; io::IO = stdout, kwargs...) println(io, "// $job") println(io) - code_native(io, job; kwargs...) + return code_native(io, job; kwargs...) end - emit_hooked_compilation(hook, ex...) + return emit_hooked_compilation(hook, ex...) end """ @@ -374,23 +380,23 @@ macro device_code(ex...) end open(joinpath(dir, "$fn.typed.jl"), "w") do io - code = only(code_typed(job; debuginfo=:source)) + code = only(code_typed(job; debuginfo = :source)) println(io, code) end open(joinpath(dir, "$fn.unopt.ll"), "w") do io - code_llvm(io, job; dump_module=true, raw=true, optimize=false) + code_llvm(io, job; dump_module = true, raw = true, optimize = false) end open(joinpath(dir, "$fn.opt.ll"), "w") do io - code_llvm(io, job; dump_module=true, raw=true) + code_llvm(io, job; dump_module = true, raw = true) end open(joinpath(dir, "$fn.asm"), "w") do io - code_native(io, job; dump_module=true, raw=true) + code_native(io, job; dump_module = true, raw = true) end - localUnique += 1 + return localUnique += 1 end - emit_hooked_compilation(hook, ex...) + return emit_hooked_compilation(hook, ex...) end diff --git a/src/reflection_compat.jl b/src/reflection_compat.jl index 92467a3d..73a3bbdf 100644 --- a/src/reflection_compat.jl +++ b/src/reflection_compat.jl @@ -3,11 +3,11 @@ using InteractiveUtils: highlighting using Base: hasgenerator -function method_instances(@nospecialize(tt::Type), world::UInt=Base.get_world_counter()) +function method_instances(@nospecialize(tt::Type), world::UInt = Base.get_world_counter()) return map(Core.Compiler.specialize_method, method_matches(tt; world)) end -function code_lowered_by_type(@nospecialize(tt); generated::Bool=true, debuginfo::Symbol=:default) +function code_lowered_by_type(@nospecialize(tt); generated::Bool = true, debuginfo::Symbol = :default) debuginfo = Base.IRShow.debuginfo(debuginfo) if debuginfo !== :source && debuginfo !== :none @@ -18,9 +18,11 @@ function code_lowered_by_type(@nospecialize(tt); generated::Bool=true, debuginfo if Base.may_invoke_generator(m) return ccall(:jl_code_for_staged, Any, (Any,), m)::CodeInfo else - error("Could not expand generator for `@generated` method ", m, ". ", - "This can happen if the provided argument types (", t, ") are ", - "not leaf types, but the `generated` argument is `true`.") + error( + "Could not expand generator for `@generated` method ", m, ". ", + "This can happen if the provided argument types (", t, ") are ", + "not leaf types, but the `generated` argument is `true`." + ) end end code = Base.uncompressed_ir(m.def::Method) @@ -29,8 +31,10 @@ function code_lowered_by_type(@nospecialize(tt); generated::Bool=true, debuginfo end end -function code_warntype_by_type(io::IO, @nospecialize(tt); - debuginfo::Symbol=:default, optimize::Bool=false, kwargs...) +function code_warntype_by_type( + io::IO, @nospecialize(tt); + debuginfo::Symbol = :default, optimize::Bool = false, kwargs... + ) debuginfo = Base.IRShow.debuginfo(debuginfo) lineprinter = Base.IRShow.__debuginfo[debuginfo] for (src, rettype) in Base.code_typed_by_type(tt; optimize, kwargs...) @@ -51,16 +55,16 @@ function code_warntype_by_type(io::IO, @nospecialize(tt); println(io, "Static Parameters") sig = p.def.sig warn_color = Base.warn_color() # more mild user notification - for i = 1:length(p.sparam_vals) + for i in 1:length(p.sparam_vals) sig = sig::UnionAll name = sig.var.name val = p.sparam_vals[i] print_highlighted(io::IO, v::String, color::Symbol) = - if highlighting[:warntype] - Base.printstyled(io, v; color) - else - Base.print(io, v) - end + if highlighting[:warntype] + Base.printstyled(io, v; color) + else + Base.print(io, v) + end if val isa TypeVar if val.lb === Union{} print(io, " ", name, " <: ") @@ -91,24 +95,24 @@ function code_warntype_by_type(io::IO, @nospecialize(tt); lambda_io = IOContext(lambda_io, :SOURCE_SLOTNAMES => slotnames) slottypes = src.slottypes nargs > 0 && println(io, "Arguments") - for i = 1:length(slotnames) + for i in 1:length(slotnames) if i == nargs + 1 println(io, "Locals") end print(io, " ", slotnames[i]) if isa(slottypes, Vector{Any}) - InteractiveUtils.warntype_type_printer(io; type=slottypes[i], used=true) + InteractiveUtils.warntype_type_printer(io; type = slottypes[i], used = true) end println(io) end end print(io, "Body") - InteractiveUtils.warntype_type_printer(io; type=rettype, used=true) + InteractiveUtils.warntype_type_printer(io; type = rettype, used = true) println(io) irshow_config = Base.IRShow.IRShowConfig(lineprinter(src), InteractiveUtils.warntype_type_printer) Base.IRShow.show_ir(lambda_io, src, irshow_config) println(io) end - nothing + return nothing end diff --git a/src/rtlib.jl b/src/rtlib.jl index 91b4c71c..f004a778 100644 --- a/src/rtlib.jl +++ b/src/rtlib.jl @@ -8,6 +8,7 @@ function link_library!(mod::LLVM.Module, libs::Vector{LLVM.Module}) for lib in libs link!(mod, lib) end + return end @@ -18,7 +19,7 @@ end ## higher-level functionality to work with runtime functions -function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[]) +function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args = LLVM.Value[]) bb = position(builder) f = LLVM.parent(bb) mod = LLVM.parent(f) @@ -36,26 +37,30 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[ # is linked, as part of the lower_gc_frame! optimization pass. # XXX: report_exception can also be used after the runtime is linked during # CUDA/Enzyme nested compilation - error("Calling an intrinsic function that clashes with an existing definition: ", - string(ft), " ", rt.name) + error( + "Calling an intrinsic function that clashes with an existing definition: ", + string(ft), " ", rt.name + ) end # runtime functions are written in Julia, while we're calling from LLVM, # this often results in argument type mismatches. try to fix some here. args = LLVM.Value[args...] if length(args) != length(parameters(ft)) - error("Incorrect number of arguments for runtime function: ", - "passing ", length(args), " argument(s) to '", string(ft), " ", rt.name, "'") + error( + "Incorrect number of arguments for runtime function: ", + "passing ", length(args), " argument(s) to '", string(ft), " ", rt.name, "'" + ) end - for (i,arg) in enumerate(args) + for (i, arg) in enumerate(args) if value_type(arg) != parameters(ft)[i] args[i] = if (value_type(arg) isa LLVM.PointerType) && - (parameters(ft)[i] isa LLVM.IntegerType) + (parameters(ft)[i] isa LLVM.IntegerType) # pointers are passed as integers on Julia 1.11 and earlier ptrtoint!(builder, args[i], parameters(ft)[i]) elseif value_type(arg) isa LLVM.PointerType && - parameters(ft)[i] isa LLVM.PointerType && - addrspace(value_type(arg)) != addrspace(parameters(ft)[i]) + parameters(ft)[i] isa LLVM.PointerType && + addrspace(value_type(arg)) != addrspace(parameters(ft)[i]) # runtime functions are always in the default address space, # while arguments may come from globals in other address spaces. addrspacecast!(builder, args[i], parameters(ft)[i]) @@ -65,7 +70,7 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[ end end - call!(builder, ft, f, args) + return call!(builder, ft, f, args) end @@ -97,7 +102,7 @@ function emit_function!(mod, config::CompilerConfig, f, method) replace_uses!(decl, entry) erase!(decl) end - LLVM.name!(entry, name) + return LLVM.name!(entry, name) end function build_runtime(@nospecialize(job::CompilerJob)) @@ -105,7 +110,7 @@ function build_runtime(@nospecialize(job::CompilerJob)) # the compiler job passed into here is identifies the job that requires the runtime. # derive a job that represents the runtime itself (notably with kernel=false). - config = CompilerConfig(job.config; kernel=false, toplevel=false, only_entry=false, strip=false) + config = CompilerConfig(job.config; kernel = false, toplevel = false, only_entry = false, strip = false) for method in values(Runtime.methods) def = if isa(method.def, Symbol) @@ -122,7 +127,7 @@ function build_runtime(@nospecialize(job::CompilerJob)) # removes Julia address spaces, which would then lead to type mismatches when using # functions from the runtime library from IR that has not been stripped of AS info. - mod + return mod end const runtime_lock = ReentrantLock() @@ -159,13 +164,13 @@ const runtime_cache = Dict{String, Vector{UInt8}}() lib = build_runtime(job) # atomic write to disk - temp_path, io = mktemp(dirname(path); cleanup=false) + temp_path, io = mktemp(dirname(path); cleanup = false) write(io, lib) close(io) @static if VERSION >= v"1.12.0-DEV.1023" - mv(temp_path, path; force=true) + mv(temp_path, path; force = true) else - Base.rename(temp_path, path, force=true) + Base.rename(temp_path, path, force = true) end end @@ -177,7 +182,7 @@ end # NOTE: call this function from global scope, so any change triggers recompilation. function reset_runtime() lock(runtime_lock) do - rm(compile_cache; recursive=true, force=true) + rm(compile_cache; recursive = true, force = true) end return diff --git a/src/runtime.jl b/src/runtime.jl index 2b11d915..06a3ff54 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -18,7 +18,7 @@ using LLVM.Interop struct RuntimeMethodInstance # either a function defined here, or a symbol to fetch a target-specific definition - def::Union{Function,Symbol} + def::Union{Function, Symbol} return_type::Type types::Tuple @@ -35,23 +35,23 @@ end function Base.convert(::Type{LLVM.FunctionType}, rt::RuntimeMethodInstance) types = if rt.llvm_types === nothing - LLVMType[convert(LLVMType, typ; allow_boxed=true) for typ in rt.types] + LLVMType[convert(LLVMType, typ; allow_boxed = true) for typ in rt.types] else rt.llvm_types() end return_type = if rt.llvm_return_type === nothing - convert(LLVMType, rt.return_type; allow_boxed=true) + convert(LLVMType, rt.return_type; allow_boxed = true) else rt.llvm_return_type() end - LLVM.FunctionType(return_type, types) + return LLVM.FunctionType(return_type, types) end -const methods = Dict{Symbol,RuntimeMethodInstance}() +const methods = Dict{Symbol, RuntimeMethodInstance}() function get(name::Symbol) - methods[name] + return methods[name] end # Register a Julia function `def` as a runtime library function identified by `name`. The @@ -66,11 +66,15 @@ end # different values for `name`. The LLVM function name will be deduced from that name, but # you can always specify `llvm_name` to influence that. Never use an LLVM name that starts # with `julia_` or the function might clash with other compiled functions. -function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=nothing; - name=isa(def,Symbol) ? def : nameof(def), llvm_name="gpu_$name") - meth = RuntimeMethodInstance(def, - return_type, types, name, - llvm_return_type, llvm_types, llvm_name) +function compile( + def, return_type, types, llvm_return_type = nothing, llvm_types = nothing; + name = isa(def, Symbol) ? def : nameof(def), llvm_name = "gpu_$name" + ) + meth = RuntimeMethodInstance( + def, + return_type, types, name, + llvm_return_type, llvm_types, llvm_name + ) if haskey(methods, name) error("Runtime function $name has already been registered!") end @@ -110,7 +114,7 @@ compile(:report_exception_name, Nothing, (Ptr{Cchar},)) ## GC # FIXME: get rid of this and allow boxed types -T_prjlvalue() = convert(LLVMType, Any; allow_boxed=true) +T_prjlvalue() = convert(LLVMType, Any; allow_boxed = true) function gc_pool_alloc(sz::Csize_t) ptr = malloc(sz) @@ -132,26 +136,28 @@ compile(:malloc, Ptr{Nothing}, (Csize_t,)) const tag_type = UInt const tag_size = sizeof(tag_type) -const gc_bits = 0x3 # FIXME +const gc_bits = 0x03 # FIXME # get the type tag of a type at run-time -@generated function type_tag(::Val{type_name}) where type_name - @dispose ctx=Context() begin +@generated function type_tag(::Val{type_name}) where {type_name} + return @dispose ctx = Context() begin T_tag = convert(LLVMType, tag_type) T_ptag = LLVM.PointerType(T_tag) - T_pjlvalue = convert(LLVMType, Any; allow_boxed=true) + T_pjlvalue = convert(LLVMType, Any; allow_boxed = true) # create function llvm_f, _ = create_function(T_tag) mod = LLVM.parent(llvm_f) # this isn't really a function, but we abuse it to get the JIT to resolve the address - typ = LLVM.Function(mod, "jl_" * String(type_name) * "_type", - LLVM.FunctionType(T_pjlvalue)) + typ = LLVM.Function( + mod, "jl_" * String(type_name) * "_type", + LLVM.FunctionType(T_pjlvalue) + ) # generate IR - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry = BasicBlock(llvm_f, "entry") position!(builder, entry) @@ -168,15 +174,15 @@ end # we use `jl_value_ptr`, a Julia pseudo-intrinsic that can be used to box and unbox values -@inline @generated function box(val, ::Val{type_name}) where type_name +@inline @generated function box(val, ::Val{type_name}) where {type_name} sz = sizeof(val) allocsz = sz + tag_size # type-tags are ephemeral, so look them up at run time #tag = unsafe_load(convert(Ptr{tag_type}, type_name)) - tag = :( type_tag(Val(type_name)) ) + tag = :(type_tag(Val(type_name))) - quote + return quote ptr = malloc($(Csize_t(allocsz))) # store the type tag @@ -184,33 +190,35 @@ end Core.Intrinsics.pointerset(ptr, $tag | $gc_bits, #=index=# 1, #=align=# $tag_size) # store the value - ptr = convert(Ptr{$val}, ptr+tag_size) + ptr = convert(Ptr{$val}, ptr + tag_size) Core.Intrinsics.pointerset(ptr, val, #=index=# 1, #=align=# $sz) unsafe_pointer_to_objref(ptr) end end -@inline function unbox(obj, ::Type{T}) where T +@inline function unbox(obj, ::Type{T}) where {T} ptr = ccall(:jl_value_ptr, Ptr{Cvoid}, (Any,), obj) # load the value ptr = convert(Ptr{T}, ptr) - Core.Intrinsics.pointerref(ptr, #=index=# 1, #=align=# sizeof(T)) + return Core.Intrinsics.pointerref(ptr, #=index=# 1, #=align=# sizeof(T)) end # generate functions functions that exist in the Julia runtime (see julia/src/datatype.c) -for (T, t) in [Int8 => :int8, Int16 => :int16, Int32 => :int32, Int64 => :int64, - UInt8 => :uint8, UInt16 => :uint16, UInt32 => :uint32, UInt64 => :uint64, - Bool => :bool, Float32 => :float32, Float64 => :float64] - box_fn = Symbol("box_$t") +for (T, t) in [ + Int8 => :int8, Int16 => :int16, Int32 => :int32, Int64 => :int64, + UInt8 => :uint8, UInt16 => :uint16, UInt32 => :uint32, UInt64 => :uint64, + Bool => :bool, Float32 => :float32, Float64 => :float64, + ] + box_fn = Symbol("box_$t") unbox_fn = Symbol("unbox_$t") @eval begin - $box_fn(val) = box($T(val), Val($(QuoteNode(t)))) + $box_fn(val) = box($T(val), Val($(QuoteNode(t)))) $unbox_fn(obj) = unbox(obj, $T) - compile($box_fn, Any, ($T,), T_prjlvalue; llvm_name=$"ijl_$box_fn") - compile($unbox_fn, $T, (Any,); llvm_name=$"ijl_$unbox_fn") + compile($box_fn, Any, ($T,), T_prjlvalue; llvm_name = $"ijl_$box_fn") + compile($unbox_fn, $T, (Any,); llvm_name = $"ijl_$unbox_fn") end end diff --git a/src/spirv.jl b/src/spirv.jl index ea552e70..8209dbea 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -5,17 +5,25 @@ # https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/master/docs/SPIRVRepresentationInLLVM.rst const SPIRV_LLVM_Backend_jll = - LazyModule("SPIRV_LLVM_Backend_jll", - UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c")) + LazyModule( + "SPIRV_LLVM_Backend_jll", + UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c") +) const SPIRV_LLVM_Translator_unified_jll = - LazyModule("SPIRV_LLVM_Translator_unified_jll", - UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361")) + LazyModule( + "SPIRV_LLVM_Translator_unified_jll", + UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361") +) const SPIRV_LLVM_Translator_jll = - LazyModule("SPIRV_LLVM_Translator_jll", - UUID("4a5d46fc-d8cf-5151-a261-86b458210efb")) + LazyModule( + "SPIRV_LLVM_Translator_jll", + UUID("4a5d46fc-d8cf-5151-a261-86b458210efb") +) const SPIRV_Tools_jll = - LazyModule("SPIRV_Tools_jll", - UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4")) + LazyModule( + "SPIRV_Tools_jll", + UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4") +) ## target @@ -23,7 +31,7 @@ const SPIRV_Tools_jll = export SPIRVCompilerTarget Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget - version::Union{Nothing,VersionNumber} = nothing + version::Union{Nothing, VersionNumber} = nothing extensions::Vector{String} = [] supports_fp16::Bool = true supports_fp64::Bool = true @@ -36,21 +44,21 @@ end function llvm_triple(target::SPIRVCompilerTarget) if target.backend == :llvm - architecture = Int===Int64 ? "spirv64" : "spirv32" # could also be "spirv" for logical addressing + architecture = Int === Int64 ? "spirv64" : "spirv32" # could also be "spirv" for logical addressing subarchitecture = target.version === nothing ? "" : "v$(target.version.major).$(target.version.minor)" vendor = "unknown" # could also be AMD os = "unknown" environment = "unknown" return "$architecture$subarchitecture-$vendor-$os-$environment" elseif target.backend == :khronos - return Int===Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown" + return Int === Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown" end end # SPIRV is not supported by our LLVM builds, so we can't get a target machine llvm_machine(::SPIRVCompilerTarget) = nothing -llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ? +llvm_datalayout(::SPIRVCompilerTarget) = Int === Int64 ? "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1" : "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1" @@ -62,8 +70,10 @@ llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ? runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) = "spirv-" * String(job.config.target.backend) -function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, - entry::LLVM.Function) +function finish_module!( + job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, + entry::LLVM.Function + ) # update calling convention for f in functions(mod) # JuliaGPU/GPUCompiler.jl#97 @@ -90,8 +100,10 @@ function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module) return errors end -function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, - entry::LLVM.Function) +function finish_ir!( + job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, + entry::LLVM.Function + ) # convert the kernel state argument to a byval reference if job.config.kernel state = kernel_state_type(job) @@ -115,19 +127,33 @@ function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, # add module metadata ## OpenCL 2.0 - push!(metadata(mod)["opencl.ocl.version"], - MDNode([ConstantInt(Int32(2)), - ConstantInt(Int32(0))])) + push!( + metadata(mod)["opencl.ocl.version"], + MDNode( + [ + ConstantInt(Int32(2)), + ConstantInt(Int32(0)), + ] + ) + ) ## SPIR-V 1.5 - push!(metadata(mod)["opencl.spirv.version"], - MDNode([ConstantInt(Int32(1)), - ConstantInt(Int32(5))])) + push!( + metadata(mod)["opencl.spirv.version"], + MDNode( + [ + ConstantInt(Int32(1)), + ConstantInt(Int32(5)), + ] + ) + ) return entry end -@unlocked function mcgen(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, - format=LLVM.API.LLVMAssemblyFile) +@unlocked function mcgen( + job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, + format = LLVM.API.LLVMAssemblyFile + ) # The SPIRV Tools don't handle Julia's debug info, rejecting DW_LANG_Julia... strip_debuginfo!(mod) @@ -140,14 +166,14 @@ end rm_freeze!(mod) # translate to SPIR-V - input = tempname(cleanup=false) * ".bc" - translated = tempname(cleanup=false) * ".spv" + input = tempname(cleanup = false) * ".bc" + translated = tempname(cleanup = false) * ".spv" write(input, mod) if job.config.target.backend === :llvm cmd = `$(SPIRV_LLVM_Backend_jll.llc()) $input -filetype=obj -o $translated` if !isempty(job.config.target.extensions) - str = join(map(ext->"+$ext", job.config.target.extensions), ",") + str = join(map(ext -> "+$ext", job.config.target.extensions), ",") cmd = `$(cmd) -spirv-ext=$str` end elseif job.config.target.backend === :khronos @@ -161,7 +187,7 @@ end cmd = `$translator -o $translated $input --spirv-debug-info-version=ocl-100` if !isempty(job.config.target.extensions) - str = join(map(ext->"+$ext", job.config.target.extensions), ",") + str = join(map(ext -> "+$ext", job.config.target.extensions), ",") cmd = `$(cmd) --spirv-ext=$str` end @@ -172,8 +198,10 @@ end try run(cmd) catch e - error("""Failed to translate LLVM code to SPIR-V. - If you think this is a bug, please file an issue and attach $(input).""") + error( + """Failed to translate LLVM code to SPIR-V. + If you think this is a bug, please file an issue and attach $(input).""" + ) end # validate @@ -181,20 +209,26 @@ end try run(`$(SPIRV_Tools_jll.spirv_val()) $translated`) catch e - error("""Failed to validate generated SPIR-V. - If you think this is a bug, please file an issue and attach $(input) and $(translated).""") + error( + """Failed to validate generated SPIR-V. + If you think this is a bug, please file an issue and attach $(input) and $(translated).""" + ) end end # optimize - optimized = tempname(cleanup=false) * ".spv" + optimized = tempname(cleanup = false) * ".spv" if job.config.target.optimize try - run(```$(SPIRV_Tools_jll.spirv_opt()) -O --skip-validation - $translated -o $optimized```) + run( + ```$(SPIRV_Tools_jll.spirv_opt()) -O --skip-validation + $translated -o $optimized``` + ) catch - error("""Failed to optimize generated SPIR-V. - If you think this is a bug, please file an issue and attach $(input) and $(translated).""") + error( + """Failed to optimize generated SPIR-V. + If you think this is a bug, please file an issue and attach $(input) and $(translated).""" + ) end else cp(translated, optimized) @@ -215,12 +249,12 @@ end end # reimplementation that uses `spirv-dis`, giving much more pleasant output -function code_native(io::IO, job::CompilerJob{SPIRVCompilerTarget}; raw::Bool=false, dump_module::Bool=false) - config = CompilerConfig(job.config; strip=!raw, only_entry=!dump_module, validate=false) +function code_native(io::IO, job::CompilerJob{SPIRVCompilerTarget}; raw::Bool = false, dump_module::Bool = false) + config = CompilerConfig(job.config; strip = !raw, only_entry = !dump_module, validate = false) obj, _ = JuliaContext() do ctx compile(:obj, CompilerJob(job; config)) end - mktemp() do input_path, input_io + return mktemp() do input_path, input_io write(input_io, obj) flush(input_io) @@ -246,20 +280,20 @@ function rm_trap!(mod::LLVM.Module) changed = false @tracepoint "remove trap" begin - if haskey(functions(mod), "llvm.trap") - trap = functions(mod)["llvm.trap"] + if haskey(functions(mod), "llvm.trap") + trap = functions(mod)["llvm.trap"] - for use in uses(trap) - val = user(use) - if isa(val, LLVM.CallInst) - erase!(val) - changed = true + for use in uses(trap) + val = user(use) + if isa(val, LLVM.CallInst) + erase!(val) + changed = true + end end - end - @compiler_assert isempty(uses(trap)) job - erase!(trap) - end + @compiler_assert isempty(uses(trap)) job + erase!(trap) + end end return changed @@ -272,15 +306,15 @@ function rm_freeze!(mod::LLVM.Module) changed = false @tracepoint "remove freeze" begin - for f in functions(mod), bb in blocks(f), inst in instructions(bb) - if inst isa LLVM.FreezeInst - orig = first(operands(inst)) - replace_uses!(inst, orig) - @compiler_assert isempty(uses(inst)) job - erase!(inst) - changed = true + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + if inst isa LLVM.FreezeInst + orig = first(operands(inst)) + replace_uses!(inst, orig) + @compiler_assert isempty(uses(inst)) job + erase!(inst) + changed = true + end end - end end return changed @@ -293,50 +327,50 @@ function convert_i128_allocas!(mod::LLVM.Module) changed = false @tracepoint "convert i128 allocas" begin - for f in functions(mod), bb in blocks(f) - for inst in instructions(bb) - if inst isa LLVM.AllocaInst - alloca_type = LLVMType(LLVM.API.LLVMGetAllocatedType(inst)) - - # Check if this is an i128 or an array of i128 - if alloca_type isa LLVM.ArrayType - T = eltype(alloca_type) - else - T = alloca_type - end - if T isa LLVM.IntegerType && width(T) == 128 - # replace i128 with <2 x i64> - vec_type = LLVM.VectorType(LLVM.Int64Type(), 2) + for f in functions(mod), bb in blocks(f) + for inst in instructions(bb) + if inst isa LLVM.AllocaInst + alloca_type = LLVMType(LLVM.API.LLVMGetAllocatedType(inst)) + # Check if this is an i128 or an array of i128 if alloca_type isa LLVM.ArrayType - array_size = length(alloca_type) - new_alloca_type = LLVM.ArrayType(vec_type, array_size) + T = eltype(alloca_type) else - new_alloca_type = vec_type + T = alloca_type end - align_val = alignment(inst) - - # Create new alloca with vector type - @dispose builder=IRBuilder() begin - position!(builder, inst) - new_alloca = alloca!(builder, new_alloca_type) - alignment!(new_alloca, align_val) - - # Bitcast the new alloca back to the original pointer type - # XXX: The issue only seems to manifest itself on LLVM >= 18 - # where we use opaque pointers anyways, so not sure this - # is needed - old_ptr_type = LLVMType(LLVM.API.LLVMTypeOf(inst.ref)) - bitcast_ptr = bitcast!(builder, new_alloca, old_ptr_type) - - replace_uses!(inst, bitcast_ptr) - erase!(inst) - changed = true + if T isa LLVM.IntegerType && width(T) == 128 + # replace i128 with <2 x i64> + vec_type = LLVM.VectorType(LLVM.Int64Type(), 2) + + if alloca_type isa LLVM.ArrayType + array_size = length(alloca_type) + new_alloca_type = LLVM.ArrayType(vec_type, array_size) + else + new_alloca_type = vec_type + end + align_val = alignment(inst) + + # Create new alloca with vector type + @dispose builder = IRBuilder() begin + position!(builder, inst) + new_alloca = alloca!(builder, new_alloca_type) + alignment!(new_alloca, align_val) + + # Bitcast the new alloca back to the original pointer type + # XXX: The issue only seems to manifest itself on LLVM >= 18 + # where we use opaque pointers anyways, so not sure this + # is needed + old_ptr_type = LLVMType(LLVM.API.LLVMTypeOf(inst.ref)) + bitcast_ptr = bitcast!(builder, new_alloca, old_ptr_type) + + replace_uses!(inst, bitcast_ptr) + erase!(inst) + changed = true + end end end end end - end end return changed @@ -380,7 +414,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F # emit IR performing the "conversions" new_args = Vector{LLVM.Value}() - @dispose builder=IRBuilder() begin + @dispose builder = IRBuilder() begin entry = BasicBlock(new_f, "conversion") position!(builder, entry) @@ -397,12 +431,14 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F # map the arguments value_map = Dict{LLVM.Value, LLVM.Value}( - param => new_args[i] for (i,param) in enumerate(parameters(f)) + param => new_args[i] for (i, param) in enumerate(parameters(f)) ) value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges + ) # apply byval attributes again (`clone_into!` didn't due to the type mismatch) for i in 1:length(byval) diff --git a/src/utils.jl b/src/utils.jl index 095f22dc..bbdb9ee6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,14 +1,16 @@ -defs(mod::LLVM.Module) = filter(f -> !isdeclaration(f), collect(functions(mod))) -decls(mod::LLVM.Module) = filter(f -> isdeclaration(f) && !LLVM.isintrinsic(f), - collect(functions(mod))) +defs(mod::LLVM.Module) = filter(f -> !isdeclaration(f), collect(functions(mod))) +decls(mod::LLVM.Module) = filter( + f -> isdeclaration(f) && !LLVM.isintrinsic(f), + collect(functions(mod)) +) ## debug verification should_verify() = ccall(:jl_is_debugbuild, Cint, ()) == 1 || - Base.JLOptions().debug_level >= 2 || - something(tryparse(Bool, get(ENV, "CI", "false")), true) + Base.JLOptions().debug_level >= 2 || + something(tryparse(Bool, get(ENV, "CI", "false")), true) -isdebug(group, mod=GPUCompiler) = +isdebug(group, mod = GPUCompiler) = Base.CoreLogging.current_logger_for_env(Base.CoreLogging.Debug, group, mod) !== nothing @@ -29,7 +31,7 @@ function Base.getproperty(lazy_mod::LazyModule, sym::Symbol) if mod === nothing error("This functionality requires the $(pkg.name) package, which should be installed and loaded first.") end - getfield(mod, sym) + return getfield(mod, sym) end @@ -49,11 +51,11 @@ end for level in [:debug, :info, :warn, :error] @eval begin macro $(Symbol("safe_$level"))(ex...) - macrocall = :(@placeholder $(ex...) _file=$(String(__source__.file)) _line=$(__source__.line)) + macrocall = :(@placeholder $(ex...) _file = $(String(__source__.file)) _line = $(__source__.line)) # NOTE: `@placeholder` in order to avoid hard-coding @__LINE__ etc macrocall.args[1] = Symbol($"@$level") - quote - io = IOContext(Core.stderr, :color=>STDERR_HAS_COLOR[]) + return quote + io = IOContext(Core.stderr, :color => STDERR_HAS_COLOR[]) # ideally we call Logging.shouldlog() here, but that is likely to yield, # so instead we rely on the min_enabled_level of the logger. # in the case of custom loggers that may be an issue, because, @@ -86,16 +88,25 @@ end macro safe_show(exs...) blk = Expr(:block) for ex in exs - push!(blk.args, - :(println(Core.stdout, $(sprint(Base.show_unquoted,ex)*" = "), - repr(begin local value = $(esc(ex)) end)))) + push!( + blk.args, + :( + println( + Core.stdout, $(sprint(Base.show_unquoted, ex) * " = "), + repr( + begin + local value = $(esc(ex)) + end + ) + ) + ) + ) end isempty(exs) || push!(blk.args, :value) return blk end - ## codegen locking # lock codegen to prevent races on the LLVM context. @@ -118,7 +129,7 @@ macro locked(ex) ccall(:jl_typeinf_lock_end, Cvoid, ()) end end - esc(combinedef(def)) + return esc(combinedef(def)) end # HACK: temporarily unlock again to perform a task switch @@ -136,7 +147,7 @@ macro unlocked(ex) ccall(:jl_typeinf_lock_begin, Cvoid, ()) end end - esc(combinedef(def)) + return esc(combinedef(def)) end @@ -154,6 +165,7 @@ function prune_constexpr_uses!(root::LLVM.Value) isempty(uses(val)) && LLVM.unsafe_destroy!(val) end end + return end diff --git a/src/validation.jl b/src/validation.jl index 0190d1c9..2348e28c 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -14,7 +14,7 @@ function method_matches(@nospecialize(tt::Type{<:Tuple}); world::Integer) end function typeinf_type(mi::MethodInstance; interp::CC.AbstractInterpreter) - @static if VERSION < v"1.11.0" + return @static if VERSION < v"1.11.0" code = Core.Compiler.get(Core.Compiler.code_cache(interp), mi, nothing) if code isa Core.Compiler.CodeInstance return code.rettype @@ -40,11 +40,15 @@ function check_method(@nospecialize(job::CompilerJob)) # kernels can't return values if job.config.kernel - rt = typeinf_type(job.source; interp=get_interpreter(job)) + rt = typeinf_type(job.source; interp = get_interpreter(job)) if rt != Nothing && rt != Union{} - throw(KernelError(job, "kernel returns a value of type `$rt`", - """Make sure your kernel function ends in `return`, `return nothing` or `nothing`.""")) + throw( + KernelError( + job, "kernel returns a value of type `$rt`", + """Make sure your kernel function ends in `return`, `return nothing` or `nothing`.""" + ) + ) end end @@ -62,15 +66,15 @@ function hasfieldcount(@nospecialize(dt)) return true end -function explain_nonisbits(@nospecialize(dt), depth=1; maxdepth=10) - dt===Module && return "" # work around JuliaLang/julia#33347 +function explain_nonisbits(@nospecialize(dt), depth = 1; maxdepth = 10) + dt === Module && return "" # work around JuliaLang/julia#33347 depth > maxdepth && return "" hasfieldcount(dt) || return "" msg = "" for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt)) if !isbitstype(ft) msg *= " "^depth * ".$fn is of type $ft which is not isbits.\n" - msg *= explain_nonisbits(ft, depth+1) + msg *= explain_nonisbits(ft, depth + 1) end end return msg @@ -86,16 +90,20 @@ function check_invocation(@nospecialize(job::CompilerJob)) # make sure any non-isbits arguments are unused real_arg_i = 0 - for (arg_i,dt) in enumerate(sig.parameters) + for (arg_i, dt) in enumerate(sig.parameters) isghosttype(dt) && continue Core.Compiler.isconstType(dt) && continue real_arg_i += 1 # XXX: can we support these for CPU targets? if dt <: Core.OpaqueClosure - throw(KernelError(job, "passing an opaque closure", - """Argument $arg_i to your kernel function is an opaque closure. - This is a CPU-only object not supported by GPUCompiler.""")) + throw( + KernelError( + job, "passing an opaque closure", + """Argument $arg_i to your kernel function is an opaque closure. + This is a CPU-only object not supported by GPUCompiler.""" + ) + ) end # If an object doesn't have fields, it can only be used by identity, so we can allow @@ -105,13 +113,17 @@ function check_invocation(@nospecialize(job::CompilerJob)) end if !isbitstype(dt) - throw(KernelError(job, "passing non-bitstype argument", - """Argument $arg_i to your kernel function is of type $dt, which is not a bitstype: - $(explain_nonisbits(dt)) - - Only bitstypes, which are "plain data" types that are immutable - and contain no references to other values, can be used in GPU kernels. - For more information, see the `Base.isbitstype` function.""")) + throw( + KernelError( + job, "passing non-bitstype argument", + """Argument $arg_i to your kernel function is of type $dt, which is not a bitstype: + $(explain_nonisbits(dt)) + + Only bitstypes, which are "plain data" types that are immutable + and contain no references to other values, can be used in GPU kernels. + For more information, see the `Base.isbitstype` function.""" + ) + ) end end @@ -131,20 +143,20 @@ end const RUNTIME_FUNCTION = "call to the Julia runtime" const UNKNOWN_FUNCTION = "call to an unknown function" const POINTER_FUNCTION = "call through a literal pointer" -const CCALL_FUNCTION = "call to an external C function" -const LAZY_FUNCTION = "call to a lazy-initialized function" -const DELAYED_BINDING = "use of an undefined name" -const DYNAMIC_CALL = "dynamic function invocation" +const CCALL_FUNCTION = "call to an external C function" +const LAZY_FUNCTION = "call to a lazy-initialized function" +const DELAYED_BINDING = "use of an undefined name" +const DYNAMIC_CALL = "dynamic function invocation" function Base.showerror(io::IO, err::InvalidIRError) print(io, "InvalidIRError: compiling ", err.job.source, " resulted in invalid LLVM IR") for (kind, bt, meta) in err.errors - printstyled(io, "\nReason: unsupported $kind"; color=:red) + printstyled(io, "\nReason: unsupported $kind"; color = :red) if meta !== nothing if kind == RUNTIME_FUNCTION || kind == UNKNOWN_FUNCTION || kind == POINTER_FUNCTION || kind == DYNAMIC_CALL || kind == CCALL_FUNCTION || kind == LAZY_FUNCTION - printstyled(io, " (call to ", meta, ")"; color=:red) + printstyled(io, " (call to ", meta, ")"; color = :red) elseif kind == DELAYED_BINDING - printstyled(io, " (use of '", meta, "')"; color=:red) + printstyled(io, " (use of '", meta, "')"; color = :red) end end Base.show_backtrace(io, bt) @@ -211,7 +223,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.LoadInst) name = match(rx, name).captures[1] push!(errors, (LAZY_FUNCTION, bt, name)) catch e - @safe_debug "Decoding name of PLT entry failed" inst bb=LLVM.parent(inst) + @safe_debug "Decoding name of PLT entry failed" inst bb = LLVM.parent(inst) push!(errors, (LAZY_FUNCTION, bt, nothing)) end end @@ -235,11 +247,11 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) sym = Base.unsafe_pointer_to_objref(sym) push!(errors, (DELAYED_BINDING, bt, sym)) catch e - @safe_debug "Decoding arguments to jl_get_binding_or_error failed" inst bb=LLVM.parent(inst) + @safe_debug "Decoding arguments to jl_get_binding_or_error failed" inst bb = LLVM.parent(inst) push!(errors, (DELAYED_BINDING, bt, nothing)) end elseif fn == "jl_reresolve_binding_value_seqcst" || fn == "ijl_reresolve_binding_value_seqcst" || - fn == "jl_get_binding_value_seqcst" || fn == "ijl_get_binding_value_seqcst" + fn == "jl_get_binding_value_seqcst" || fn == "ijl_get_binding_value_seqcst" try # pry the binding from the IR expr = arguments(inst)[1]::ConstantExpr @@ -248,7 +260,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) obj = Base.unsafe_pointer_to_objref(ptr) push!(errors, (DELAYED_BINDING, bt, obj.globalref)) catch e - @safe_debug "Decoding arguments to jl_reresolve_binding_value_seqcst failed" inst bb=LLVM.parent(inst) + @safe_debug "Decoding arguments to jl_reresolve_binding_value_seqcst failed" inst bb = LLVM.parent(inst) push!(errors, (DELAYED_BINDING, bt, nothing)) end elseif startswith(fn, "tojlinvoke") @@ -281,7 +293,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) meth = Base.unsafe_pointer_to_objref(meth)::Core.MethodInstance push!(errors, (DYNAMIC_CALL, bt, meth.def)) catch e - @safe_debug "Decoding arguments to jl_invoke failed" inst bb=LLVM.parent(inst) + @safe_debug "Decoding arguments to jl_invoke failed" inst bb = LLVM.parent(inst) push!(errors, (DYNAMIC_CALL, bt, nothing)) end elseif fn == "jl_apply_generic" || fn == "ijl_apply_generic" @@ -293,7 +305,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) f = Base.unsafe_pointer_to_objref(f) push!(errors, (DYNAMIC_CALL, bt, f)) catch e - @safe_debug "Decoding arguments to jl_apply_generic failed" inst bb=LLVM.parent(inst) + @safe_debug "Decoding arguments to jl_apply_generic failed" inst bb = LLVM.parent(inst) push!(errors, (DYNAMIC_CALL, bt, nothing)) end @@ -305,14 +317,14 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) name_value = map(collect(name_init)) do char convert(UInt8, char) end |> String - name_value = name_value[1:end-1] # remove trailing \0 + name_value = name_value[1:(end - 1)] # remove trailing \0 push!(errors, (CCALL_FUNCTION, bt, name_value)) catch e - @safe_debug "Decoding arguments to jl_load_and_lookup failed" inst bb=LLVM.parent(inst) + @safe_debug "Decoding arguments to jl_load_and_lookup failed" inst bb = LLVM.parent(inst) push!(errors, (CCALL_FUNCTION, bt, nothing)) end - # detect calls to undefined functions + # detect calls to undefined functions elseif isdeclaration(dest) && !LLVM.isintrinsic(dest) && !isintrinsic(job, fn) # figure out if the function lives in the Julia runtime library if libjulia[] == C_NULL @@ -344,7 +356,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) if !valid_function_pointer(job, ptr) # look it up in the Julia JIT cache - frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0) + frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) # XXX: what if multiple frames are returned? rare, but happens if length(frames) == 1 fn, file, line, linfo, fromC, inlined = last(frames) @@ -360,7 +372,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) end # helper function to check for illegal values in an LLVM module -function check_ir_values(mod::LLVM.Module, predicate, msg="value") +function check_ir_values(mod::LLVM.Module, predicate, msg = "value") errors = IRError[] for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) if predicate(inst) || any(predicate, operands(inst)) @@ -372,5 +384,5 @@ function check_ir_values(mod::LLVM.Module, predicate, msg="value") end ## shorthand to check for illegal value types function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) - check_ir_values(mod, val -> value_type(val) == T_bad, "use of $(string(T_bad)) value") + return check_ir_values(mod, val -> value_type(val) == T_bad, "use of $(string(T_bad)) value") end diff --git a/test/bpf.jl b/test/bpf.jl index a4b10f38..58d26a07 100644 --- a/test/bpf.jl +++ b/test/bpf.jl @@ -1,6 +1,6 @@ @testset "No-op" begin mod = @eval module $(gensym()) - kernel() = 0 + kernel() = 0 end @test @filecheck begin @@ -12,7 +12,7 @@ end @testset "Return argument" begin mod = @eval module $(gensym()) - kernel(x) = x + kernel(x) = x end @test @filecheck begin @@ -24,7 +24,7 @@ end end @testset "Addition" begin mod = @eval module $(gensym()) - kernel(x) = x+1 + kernel(x) = x + 1 end @test @filecheck begin @@ -37,7 +37,7 @@ end end @testset "Errors" begin mod = @eval module $(gensym()) - kernel(x) = fakefunc(x) + kernel(x) = fakefunc(x) end @test_throws GPUCompiler.InvalidIRError BPF.code_execution(mod.kernel, (UInt64,)) @@ -45,8 +45,8 @@ end @testset "Function Pointers" begin @testset "valid" begin mod = @eval module $(gensym()) - goodcall(x) = Base.llvmcall("%2 = call i64 inttoptr (i64 3 to i64 (i64)*)(i64 %0)\nret i64 %2", Int, Tuple{Int}, x) - kernel(x) = goodcall(x) + goodcall(x) = Base.llvmcall("%2 = call i64 inttoptr (i64 3 to i64 (i64)*)(i64 %0)\nret i64 %2", Int, Tuple{Int}, x) + kernel(x) = goodcall(x) end @test @filecheck begin @@ -59,8 +59,8 @@ end @testset "invalid" begin mod = @eval module $(gensym()) - badcall(x) = Base.llvmcall("%2 = call i64 inttoptr (i64 3000 to i64 (i64)*)(i64 %0)\nret i64 %2", Int, Tuple{Int}, x) - kernel(x) = badcall(x) + badcall(x) = Base.llvmcall("%2 = call i64 inttoptr (i64 3000 to i64 (i64)*)(i64 %0)\nret i64 %2", Int, Tuple{Int}, x) + kernel(x) = badcall(x) end @test_throws GPUCompiler.InvalidIRError BPF.code_execution(mod.kernel, (Int,)) diff --git a/test/examples.jl b/test/examples.jl index aef6d4cc..82014197 100644 --- a/test/examples.jl +++ b/test/examples.jl @@ -1,4 +1,4 @@ -function find_sources(path::String, sources=String[]) +function find_sources(path::String, sources = String[]) if isdir(path) for entry in readdir(path) find_sources(joinpath(path, entry), sources) @@ -6,7 +6,7 @@ function find_sources(path::String, sources=String[]) elseif endswith(path, ".jl") push!(sources, path) end - sources + return sources end dir = joinpath(@__DIR__, "..", "examples") @@ -18,6 +18,6 @@ cd(dir) do examples = relpath.(files, Ref(dir)) @testset for example in examples cmd = `$(Base.julia_cmd()) --project=$(Base.active_project())` - @test success(pipeline(`$cmd $example`, stderr=stderr)) + @test success(pipeline(`$cmd $example`, stderr = stderr)) end end diff --git a/test/gcn.jl b/test/gcn.jl index 95641a44..92a5836b 100644 --- a/test/gcn.jl +++ b/test/gcn.jl @@ -1,286 +1,286 @@ if :AMDGPU in LLVM.backends() -# XXX: generic `sink` generates an instruction selection error -sink_gcn(i) = sink(i, Val(5)) + # XXX: generic `sink` generates an instruction selection error + sink_gcn(i) = sink(i, Val(5)) -@testset "IR" begin + @testset "IR" begin -@testset "kernel calling convention" begin - mod = @eval module $(gensym()) - kernel() = return - end + @testset "kernel calling convention" begin + mod = @eval module $(gensym()) + kernel() = return + end - @test @filecheck begin - check"CHECK-NOT: amdgpu_kernel" - GCN.code_llvm(mod.kernel, Tuple{}; dump_module=true) - end + @test @filecheck begin + check"CHECK-NOT: amdgpu_kernel" + GCN.code_llvm(mod.kernel, Tuple{}; dump_module = true) + end - @test @filecheck begin - check"CHECK: amdgpu_kernel" - GCN.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true) - end -end + @test @filecheck begin + check"CHECK: amdgpu_kernel" + GCN.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true) + end + end + + @testset "bounds errors" begin + mod = @eval module $(gensym()) + function kernel() + Base.throw_boundserror(1, 2) + return + end + end + + @test @filecheck begin + check"CHECK-NOT: {{julia_throw_boundserror_[0-9]+}}" + check"CHECK: @gpu_report_exception" + check"CHECK: @gpu_signal_exception" + GCN.code_llvm(mod.kernel, Tuple{}) + end + end -@testset "bounds errors" begin - mod = @eval module $(gensym()) - function kernel() - Base.throw_boundserror(1, 2) - return + @testset "https://github.com/JuliaGPU/AMDGPU.jl/issues/846" begin + ir, rt = GCN.code_typed((Tuple{Tuple{Val{4}}, Tuple{Float32}},); always_inline = true) do t + t[1] + end |> only + @test rt == Tuple{Val{4}} end - end - @test @filecheck begin - check"CHECK-NOT: {{julia_throw_boundserror_[0-9]+}}" - check"CHECK: @gpu_report_exception" - check"CHECK: @gpu_signal_exception" - GCN.code_llvm(mod.kernel, Tuple{}) end -end -@testset "https://github.com/JuliaGPU/AMDGPU.jl/issues/846" begin - ir, rt = GCN.code_typed((Tuple{Tuple{Val{4}}, Tuple{Float32}},); always_inline=true) do t - t[1] - end |> only - @test rt == Tuple{Val{4}} -end + ############################################################################################ + @testset "assembly" begin -end + @testset "skip scalar trap" begin + mod = @eval module $(gensym()) + workitem_idx_x() = ccall("llvm.amdgcn.workitem.id.x", llvmcall, Int32, ()) + trap() = ccall("llvm.trap", llvmcall, Nothing, ()) -############################################################################################ -@testset "assembly" begin - -@testset "skip scalar trap" begin - mod = @eval module $(gensym()) - workitem_idx_x() = ccall("llvm.amdgcn.workitem.id.x", llvmcall, Int32, ()) - trap() = ccall("llvm.trap", llvmcall, Nothing, ()) + function kernel() + if workitem_idx_x() > 1 + trap() + end + return + end + end - function kernel() - if workitem_idx_x() > 1 - trap() + @test @filecheck begin + check"CHECK-LABEL: {{(julia|j)_kernel_[0-9]+}}:" + check"CHECK: s_cbranch_exec" + check"CHECK: s_trap 2" + GCN.code_native(mod.kernel, Tuple{}) end - return end - end - @test @filecheck begin - check"CHECK-LABEL: {{(julia|j)_kernel_[0-9]+}}:" - check"CHECK: s_cbranch_exec" - check"CHECK: s_trap 2" - GCN.code_native(mod.kernel, Tuple{}) - end -end - -@testset "child functions" begin - # we often test using @noinline child functions, so test whether these survive - # (despite not having side-effects) - mod = @eval module $(gensym()) - import ..sink_gcn - @noinline child(i) = sink_gcn(i) - function parent(i) - child(i) - return - end - end + @testset "child functions" begin + # we often test using @noinline child functions, so test whether these survive + # (despite not having side-effects) + mod = @eval module $(gensym()) + import ..sink_gcn + @noinline child(i) = sink_gcn(i) + function parent(i) + child(i) + return + end + end - @test @filecheck begin - check"CHECK-LABEL: {{(julia|j)_parent_[0-9]+}}:" - check"CHECK: s_add_u32 {{.+}} {{(julia|j)_child_[0-9]+}}@rel32@" - check"CHECK: s_addc_u32 {{.+}} {{(julia|j)_child_[0-9]+}}@rel32@" - GCN.code_native(mod.parent, Tuple{Int64}; dump_module=true) - end -end - -@testset "kernel functions" begin - mod = @eval module $(gensym()) - import ..sink_gcn - @noinline nonentry(i) = sink_gcn(i) - function entry(i) - nonentry(i) - return + @test @filecheck begin + check"CHECK-LABEL: {{(julia|j)_parent_[0-9]+}}:" + check"CHECK: s_add_u32 {{.+}} {{(julia|j)_child_[0-9]+}}@rel32@" + check"CHECK: s_addc_u32 {{.+}} {{(julia|j)_child_[0-9]+}}@rel32@" + GCN.code_native(mod.parent, Tuple{Int64}; dump_module = true) + end end - end - @test @filecheck begin - check"CHECK-NOT: .amdhsa_kernel {{(julia|j)_nonentry_[0-9]+}}" - check"CHECK: .type {{(julia|j)_nonentry_[0-9]+}},@function" - check"CHECK: .amdhsa_kernel _Z5entry5Int64" - GCN.code_native(mod.entry, Tuple{Int64}; dump_module=true, kernel=true) - end -end - -@testset "child function reuse" begin - # bug: depending on a child function from multiple parents resulted in - # the child only being present once - - mod = @eval module $(gensym()) - import ..sink_gcn - @noinline child(i) = sink_gcn(i) - function parent1(i) - child(i) - return - end - function parent2(i) - child(i+1) - return + @testset "kernel functions" begin + mod = @eval module $(gensym()) + import ..sink_gcn + @noinline nonentry(i) = sink_gcn(i) + function entry(i) + nonentry(i) + return + end + end + + @test @filecheck begin + check"CHECK-NOT: .amdhsa_kernel {{(julia|j)_nonentry_[0-9]+}}" + check"CHECK: .type {{(julia|j)_nonentry_[0-9]+}},@function" + check"CHECK: .amdhsa_kernel _Z5entry5Int64" + GCN.code_native(mod.entry, Tuple{Int64}; dump_module = true, kernel = true) + end end - end - @test @filecheck begin - check"CHECK: .type {{(julia|j)_child_[0-9]+}},@function" - GCN.code_native(mod.parent1, Tuple{Int}; dump_module=true) - end + @testset "child function reuse" begin + # bug: depending on a child function from multiple parents resulted in + # the child only being present once - @test @filecheck begin - check"CHECK: .type {{(julia|j)_child_[0-9]+}},@function" - GCN.code_native(mod.parent2, Tuple{Int}; dump_module=true) - end -end - -@testset "child function reuse bis" begin - # bug: similar, but slightly different issue as above - # in the case of two child functions - - mod = @eval module $(gensym()) - import ..sink_gcn - @noinline child1(i) = sink_gcn(i) - @noinline child2(i) = sink_gcn(i+1) - function parent1(i) - child1(i) + child2(i) - return - end - function parent2(i) - child1(i+1) + child2(i+1) - return - end - end + mod = @eval module $(gensym()) + import ..sink_gcn + @noinline child(i) = sink_gcn(i) + function parent1(i) + child(i) + return + end + function parent2(i) + child(i + 1) + return + end + end - @test @filecheck begin - check"CHECK-DAG: .type {{(julia|j)_child1_[0-9]+}},@function" - check"CHECK-DAG: .type {{(julia|j)_child2_[0-9]+}},@function" - GCN.code_native(mod.parent1, Tuple{Int}; dump_module=true) - end + @test @filecheck begin + check"CHECK: .type {{(julia|j)_child_[0-9]+}},@function" + GCN.code_native(mod.parent1, Tuple{Int}; dump_module = true) + end - @test @filecheck begin - check"CHECK-DAG: .type {{(julia|j)_child1_[0-9]+}},@function" - check"CHECK-DAG: .type {{(julia|j)_child2_[0-9]+}},@function" - GCN.code_native(mod.parent2, Tuple{Int}; dump_module=true) - end -end + @test @filecheck begin + check"CHECK: .type {{(julia|j)_child_[0-9]+}},@function" + GCN.code_native(mod.parent2, Tuple{Int}; dump_module = true) + end + end -@testset "indirect sysimg function use" begin - # issue #9: re-using sysimg functions should force recompilation - # (host fldmod1->mod1 throws, so the GCN code shouldn't contain a throw) + @testset "child function reuse bis" begin + # bug: similar, but slightly different issue as above + # in the case of two child functions + + mod = @eval module $(gensym()) + import ..sink_gcn + @noinline child1(i) = sink_gcn(i) + @noinline child2(i) = sink_gcn(i + 1) + function parent1(i) + child1(i) + child2(i) + return + end + function parent2(i) + child1(i + 1) + child2(i + 1) + return + end + end - # NOTE: Int32 to test for #49 + @test @filecheck begin + check"CHECK-DAG: .type {{(julia|j)_child1_[0-9]+}},@function" + check"CHECK-DAG: .type {{(julia|j)_child2_[0-9]+}},@function" + GCN.code_native(mod.parent1, Tuple{Int}; dump_module = true) + end - mod = @eval module $(gensym()) - function kernel(out) - wid, lane = fldmod1(unsafe_load(out), Int32(32)) - unsafe_store!(out, wid) - return + @test @filecheck begin + check"CHECK-DAG: .type {{(julia|j)_child1_[0-9]+}},@function" + check"CHECK-DAG: .type {{(julia|j)_child2_[0-9]+}},@function" + GCN.code_native(mod.parent2, Tuple{Int}; dump_module = true) + end end - end - @test @filecheck begin - check"CHECK-LABEL: {{(julia|j)_kernel_[0-9]+}}:" - check"CHECK-NOT: jl_throw" - check"CHECK-NOT: jl_invoke" - GCN.code_native(mod.kernel, Tuple{Ptr{Int32}}) - end -end - -@testset "LLVM intrinsics" begin - # issue #13 (a): cannot select trunc - mod = @eval module $(gensym()) - function kernel(x) - unsafe_trunc(Int, x) - return - end - end - GCN.code_native(devnull, mod.kernel, Tuple{Float64}) - @test "We did not crash!" != "" -end - -# FIXME: _ZNK4llvm14TargetLowering20scalarizeVectorStoreEPNS_11StoreSDNodeERNS_12SelectionDAGE -false && @testset "exception arguments" begin - mod = @eval module $(gensym()) - function kernel(a) - unsafe_store!(a, trunc(Int, unsafe_load(a))) - return - end - end + @testset "indirect sysimg function use" begin + # issue #9: re-using sysimg functions should force recompilation + # (host fldmod1->mod1 throws, so the GCN code shouldn't contain a throw) + + # NOTE: Int32 to test for #49 - GCN.code_native(devnull, mod.kernel, Tuple{Ptr{Float64}}) -end + mod = @eval module $(gensym()) + function kernel(out) + wid, lane = fldmod1(unsafe_load(out), Int32(32)) + unsafe_store!(out, wid) + return + end + end -# FIXME: in function julia_inner_18528 void (%jl_value_t addrspace(10)*): invalid addrspacecast -false && @testset "GC and TLS lowering" begin - mod = @eval module $(gensym()) - import ..sink_gcn - mutable struct PleaseAllocate - y::Csize_t + @test @filecheck begin + check"CHECK-LABEL: {{(julia|j)_kernel_[0-9]+}}:" + check"CHECK-NOT: jl_throw" + check"CHECK-NOT: jl_invoke" + GCN.code_native(mod.kernel, Tuple{Ptr{Int32}}) + end end - # common pattern in Julia 0.7: outlined throw to avoid a GC frame in the calling code - @noinline function inner(x) - sink_gcn(x.y) - nothing + @testset "LLVM intrinsics" begin + # issue #13 (a): cannot select trunc + mod = @eval module $(gensym()) + function kernel(x) + unsafe_trunc(Int, x) + return + end + end + GCN.code_native(devnull, mod.kernel, Tuple{Float64}) + @test "We did not crash!" != "" end - function kernel(i) - inner(PleaseAllocate(Csize_t(42))) - nothing + # FIXME: _ZNK4llvm14TargetLowering20scalarizeVectorStoreEPNS_11StoreSDNodeERNS_12SelectionDAGE + false && @testset "exception arguments" begin + mod = @eval module $(gensym()) + function kernel(a) + unsafe_store!(a, trunc(Int, unsafe_load(a))) + return + end + end + + GCN.code_native(devnull, mod.kernel, Tuple{Ptr{Float64}}) end - end - @test @filecheck begin - check"CHECK-NOT: jl_push_gc_frame" - check"CHECK-NOT: jl_pop_gc_frame" - check"CHECK-NOT: jl_get_gc_frame_slot" - check"CHECK-NOT: jl_new_gc_frame" - check"CHECK: gpu_gc_pool_alloc" - GCN.code_native(mod.kernel, Tuple{Int}) - end + # FIXME: in function julia_inner_18528 void (%jl_value_t addrspace(10)*): invalid addrspacecast + false && @testset "GC and TLS lowering" begin + mod = @eval module $(gensym()) + import ..sink_gcn + mutable struct PleaseAllocate + y::Csize_t + end - # make sure that we can still ellide allocations - function ref_kernel(ptr, i) - data = Ref{Int64}() - data[] = 0 - if i > 1 - data[] = 1 - else - data[] = 2 + # common pattern in Julia 0.7: outlined throw to avoid a GC frame in the calling code + @noinline function inner(x) + sink_gcn(x.y) + nothing + end + + function kernel(i) + inner(PleaseAllocate(Csize_t(42))) + nothing + end + end + + @test @filecheck begin + check"CHECK-NOT: jl_push_gc_frame" + check"CHECK-NOT: jl_pop_gc_frame" + check"CHECK-NOT: jl_get_gc_frame_slot" + check"CHECK-NOT: jl_new_gc_frame" + check"CHECK: gpu_gc_pool_alloc" + GCN.code_native(mod.kernel, Tuple{Int}) + end + + # make sure that we can still ellide allocations + function ref_kernel(ptr, i) + data = Ref{Int64}() + data[] = 0 + if i > 1 + data[] = 1 + else + data[] = 2 + end + unsafe_store!(ptr, data[], i) + return nothing + end + + @test @filecheck begin + check"CHECK-NOT: gpu_gc_pool_alloc" + GCN.code_native(ref_kernel, Tuple{Ptr{Int64}, Int}) + end end - unsafe_store!(ptr, data[], i) - return nothing - end - @test @filecheck begin - check"CHECK-NOT: gpu_gc_pool_alloc" - GCN.code_native(ref_kernel, Tuple{Ptr{Int64}, Int}) - end -end - -@testset "float boxes" begin - mod = @eval module $(gensym()) - function kernel(a,b) - c = Int32(a) - # the conversion to Int32 may fail, in which case the input Float32 is boxed in order to - # pass it to the @nospecialize exception constructor. we should really avoid that (eg. - # by avoiding @nospecialize, or optimize the unused arguments away), but for now the box - # should just work. - unsafe_store!(b, c) - return + @testset "float boxes" begin + mod = @eval module $(gensym()) + function kernel(a, b) + c = Int32(a) + # the conversion to Int32 may fail, in which case the input Float32 is boxed in order to + # pass it to the @nospecialize exception constructor. we should really avoid that (eg. + # by avoiding @nospecialize, or optimize the unused arguments away), but for now the box + # should just work. + unsafe_store!(b, c) + return + end + end + + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK: jl_box_float32" + GCN.code_llvm(mod.kernel, Tuple{Float32, Ptr{Float32}}) + end + GCN.code_native(devnull, mod.kernel, Tuple{Float32, Ptr{Float32}}) end - end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK: jl_box_float32" - GCN.code_llvm(mod.kernel, Tuple{Float32,Ptr{Float32}}) end - GCN.code_native(devnull, mod.kernel, Tuple{Float32,Ptr{Float32}}) -end - -end end # :AMDGPU in LLVM.backends() diff --git a/test/helpers/bpf.jl b/test/helpers/bpf.jl index 15eaa121..f0eea6a5 100644 --- a/test/helpers/bpf.jl +++ b/test/helpers/bpf.jl @@ -4,25 +4,25 @@ using ..GPUCompiler import ..TestRuntime struct CompilerParams <: AbstractCompilerParams end -GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime +GPUCompiler.runtime_module(::CompilerJob{<:Any, CompilerParams}) = TestRuntime function create_job(@nospecialize(func), @nospecialize(types); kwargs...) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) target = BPFCompilerTarget() params = CompilerParams() - config = CompilerConfig(target, params; kernel=false, config_kwargs...) - CompilerJob(source, config), kwargs + config = CompilerConfig(target, params; kernel = false, config_kwargs...) + return CompilerJob(source, config), kwargs end function code_llvm(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_llvm(io, job; kwargs...) + return GPUCompiler.code_llvm(io, job; kwargs...) end function code_native(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_native(io, job; kwargs...) + return GPUCompiler.code_native(io, job; kwargs...) end # aliases without ::IO argument @@ -37,7 +37,7 @@ end # simulates codegen for a kernel function: validates by default function code_execution(@nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - JuliaContext() do ctx + return JuliaContext() do ctx GPUCompiler.compile(:asm, job; kwargs...) end end diff --git a/test/helpers/enzyme.jl b/test/helpers/enzyme.jl index 133e9a5a..55547c9a 100644 --- a/test/helpers/enzyme.jl +++ b/test/helpers/enzyme.jl @@ -2,12 +2,12 @@ module Enzyme using ..GPUCompiler -struct EnzymeTarget{Target<:AbstractCompilerTarget} <: AbstractCompilerTarget +struct EnzymeTarget{Target <: AbstractCompilerTarget} <: AbstractCompilerTarget target::Target end -function EnzymeTarget(;kwargs...) - EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...)) +function EnzymeTarget(; kwargs...) + return EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...)) end GPUCompiler.llvm_triple(target::EnzymeTarget) = GPUCompiler.llvm_triple(target.target) @@ -18,7 +18,7 @@ GPUCompiler.have_fma(target::EnzymeTarget, T::Type) = GPUCompiler.have_fma(targe GPUCompiler.dwarf_version(target::EnzymeTarget) = GPUCompiler.dwarf_version(target.target) abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end -struct EnzymeCompilerParams{Params<:AbstractCompilerParams} <: AbstractEnzymeCompilerParams +struct EnzymeCompilerParams{Params <: AbstractCompilerParams} <: AbstractEnzymeCompilerParams params::Params end struct PrimalCompilerParams <: AbstractEnzymeCompilerParams @@ -68,14 +68,18 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt:: sig = Tuple{ft, tt.parameters...} min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, #=mt=# nothing, world, min_world, max_world) + match = ccall( + :jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, #=mt=# nothing, world, min_world, max_world + ) match === nothing && return stub(world, source, method_error) # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) + mi = ccall( + :jl_specializations_get_linfo, Ref{Core.MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams + ) ci = CC.retrieve_code_info(mi, world) # prepare a new code info @@ -83,10 +87,10 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt:: new_ci = copy(ci) empty!(new_ci.code) @static if isdefined(Core, :DebugInfo) - new_ci.debuginfo = Core.DebugInfo(:none) + new_ci.debuginfo = Core.DebugInfo(:none) else - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below end empty!(new_ci.ssaflags) new_ci.ssavaluetypes = 0 @@ -99,7 +103,7 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt:: # prepare the slots new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:3] + new_ci.slotflags = UInt8[0x00 for i in 1:3] @static if isdefined(Core, :DebugInfo) new_ci.nargs = 3 end @@ -107,7 +111,7 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt:: # We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget. target = EnzymeTarget() params = EnzymeCompilerParams() - config = CompilerConfig(target, params; kernel=false) + config = CompilerConfig(target, params; kernel = false) job = CompilerJob(mi, config, world) id = length(deferred_codegen_jobs) + 1 @@ -116,9 +120,9 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt:: # return the deferred_codegen_id push!(new_ci.code, CC.ReturnNode(id)) push!(new_ci.ssaflags, 0x00) - @static if isdefined(Core, :DebugInfo) + @static if isdefined(Core, :DebugInfo) else - push!(new_ci.codelocs, 1) # see note below + push!(new_ci.codelocs, 1) # see note below end new_ci.ssavaluetypes += 1 @@ -137,7 +141,7 @@ end @inline function deferred_codegen(f::Type, tt::Type) id = deferred_codegen_id(f, tt) - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id) + return ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id) end -end \ No newline at end of file +end diff --git a/test/helpers/gcn.jl b/test/helpers/gcn.jl index c894fbd3..60aedf2d 100644 --- a/test/helpers/gcn.jl +++ b/test/helpers/gcn.jl @@ -4,35 +4,35 @@ using ..GPUCompiler import ..TestRuntime struct CompilerParams <: AbstractCompilerParams end -GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime +GPUCompiler.runtime_module(::CompilerJob{<:Any, CompilerParams}) = TestRuntime function create_job(@nospecialize(func), @nospecialize(types); kwargs...) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) - target = GCNCompilerTarget(dev_isa="gfx900") + target = GCNCompilerTarget(dev_isa = "gfx900") params = CompilerParams() - config = CompilerConfig(target, params; kernel=false, config_kwargs...) - CompilerJob(source, config), kwargs + config = CompilerConfig(target, params; kernel = false, config_kwargs...) + return CompilerJob(source, config), kwargs end function code_typed(@nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_typed(job; kwargs...) + return GPUCompiler.code_typed(job; kwargs...) end function code_warntype(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_warntype(io, job; kwargs...) + return GPUCompiler.code_warntype(io, job; kwargs...) end function code_llvm(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_llvm(io, job; kwargs...) + return GPUCompiler.code_llvm(io, job; kwargs...) end function code_native(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_native(io, job; kwargs...) + return GPUCompiler.code_native(io, job; kwargs...) end # aliases without ::IO argument @@ -46,8 +46,8 @@ end # simulates codegen for a kernel function: validates by default function code_execution(@nospecialize(func), @nospecialize(types); kwargs...) - job, kwargs = create_job(func, types; kernel=true, kwargs...) - JuliaContext() do ctx + job, kwargs = create_job(func, types; kernel = true, kwargs...) + return JuliaContext() do ctx GPUCompiler.compile(:asm, job; kwargs...) end end diff --git a/test/helpers/metal.jl b/test/helpers/metal.jl index 41eb0fbe..97767035 100644 --- a/test/helpers/metal.jl +++ b/test/helpers/metal.jl @@ -4,35 +4,35 @@ using ..GPUCompiler import ..TestRuntime struct CompilerParams <: AbstractCompilerParams end -GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime +GPUCompiler.runtime_module(::CompilerJob{<:Any, CompilerParams}) = TestRuntime function create_job(@nospecialize(func), @nospecialize(types); kwargs...) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) - target = MetalCompilerTarget(; macos=v"12.2", metal=v"3.0", air=v"3.0") + target = MetalCompilerTarget(; macos = v"12.2", metal = v"3.0", air = v"3.0") params = CompilerParams() - config = CompilerConfig(target, params; kernel=false, config_kwargs...) - CompilerJob(source, config), kwargs + config = CompilerConfig(target, params; kernel = false, config_kwargs...) + return CompilerJob(source, config), kwargs end function code_typed(@nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_typed(job; kwargs...) + return GPUCompiler.code_typed(job; kwargs...) end function code_warntype(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_warntype(io, job; kwargs...) + return GPUCompiler.code_warntype(io, job; kwargs...) end function code_llvm(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_llvm(io, job; kwargs...) + return GPUCompiler.code_llvm(io, job; kwargs...) end function code_native(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_native(io, job; kwargs...) + return GPUCompiler.code_native(io, job; kwargs...) end # aliases without ::IO argument @@ -46,8 +46,8 @@ end # simulates codegen for a kernel function: validates by default function code_execution(@nospecialize(func), @nospecialize(types); kwargs...) - job, kwargs = create_job(func, types; kernel=true, kwargs...) - JuliaContext() do ctx + job, kwargs = create_job(func, types; kernel = true, kwargs...) + return JuliaContext() do ctx GPUCompiler.compile(:asm, job; kwargs...) end end diff --git a/test/helpers/native.jl b/test/helpers/native.jl index d53ff172..2bccba0a 100644 --- a/test/helpers/native.jl +++ b/test/helpers/native.jl @@ -10,44 +10,46 @@ struct CompilerParams <: AbstractCompilerParams entry_safepoint::Bool method_table - CompilerParams(entry_safepoint::Bool=false, method_table=test_method_table) = + CompilerParams(entry_safepoint::Bool = false, method_table = test_method_table) = new(entry_safepoint, method_table) end -NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams} +NativeCompilerJob = CompilerJob{NativeCompilerTarget, CompilerParams} GPUCompiler.runtime_module(::NativeCompilerJob) = TestRuntime GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob)) = job.config.params.method_table GPUCompiler.can_safepoint(@nospecialize(job::NativeCompilerJob)) = job.config.params.entry_safepoint -function create_job(@nospecialize(func), @nospecialize(types); - entry_safepoint::Bool=false, method_table=test_method_table, kwargs...) +function create_job( + @nospecialize(func), @nospecialize(types); + entry_safepoint::Bool = false, method_table = test_method_table, kwargs... + ) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) target = NativeCompilerTarget() params = CompilerParams(entry_safepoint, method_table) - config = CompilerConfig(target, params; kernel=false, config_kwargs...) - CompilerJob(source, config), kwargs + config = CompilerConfig(target, params; kernel = false, config_kwargs...) + return CompilerJob(source, config), kwargs end function code_typed(@nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_typed(job; kwargs...) + return GPUCompiler.code_typed(job; kwargs...) end function code_warntype(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_warntype(io, job; kwargs...) + return GPUCompiler.code_warntype(io, job; kwargs...) end function code_llvm(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_llvm(io, job; kwargs...) + return GPUCompiler.code_llvm(io, job; kwargs...) end function code_native(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_native(io, job; kwargs...) + return GPUCompiler.code_native(io, job; kwargs...) end # aliases without ::IO argument @@ -61,8 +63,8 @@ end # simulates codegen for a kernel function: validates by default function code_execution(@nospecialize(func), @nospecialize(types); kwargs...) - job, kwargs = create_job(func, types; kernel=true, kwargs...) - JuliaContext() do ctx + job, kwargs = create_job(func, types; kernel = true, kwargs...) + return JuliaContext() do ctx GPUCompiler.compile(:asm, job; kwargs...) end end @@ -70,19 +72,19 @@ end const runtime_cache = Dict{Any, Any}() function compiler(job) - JuliaContext() do ctx + return JuliaContext() do ctx GPUCompiler.compile(:asm, job) end end function linker(job, asm) - asm + return asm end # simulates cached codegen function cached_execution(@nospecialize(func), @nospecialize(types); kwargs...) - job, kwargs = create_job(func, types; validate=false, kwargs...) - GPUCompiler.cached_compilation(runtime_cache, job.source, job.config, compiler, linker) + job, kwargs = create_job(func, types; validate = false, kwargs...) + return GPUCompiler.cached_compilation(runtime_cache, job.source, job.config, compiler, linker) end end diff --git a/test/helpers/precompile.jl b/test/helpers/precompile.jl index 74ab2b37..e02c8cb4 100644 --- a/test/helpers/precompile.jl +++ b/test/helpers/precompile.jl @@ -1,5 +1,5 @@ function precompile_test_harness(@nospecialize(f), testset::String) - @testset "$testset" begin + return @testset "$testset" begin precompile_test_harness(f, true) end end @@ -7,8 +7,8 @@ function precompile_test_harness(@nospecialize(f), separate::Bool) # XXX: clean-up may fail on Windows, because opened files are not deletable. # fix this by running the harness in a separate process, such that the # compilation cache files are not opened? - load_path = mktempdir(cleanup=true) - load_cache_path = separate ? mktempdir(cleanup=true) : load_path + load_path = mktempdir(cleanup = true) + load_cache_path = separate ? mktempdir(cleanup = true) : load_path try pushfirst!(LOAD_PATH, load_path) pushfirst!(DEPOT_PATH, load_cache_path) @@ -17,7 +17,7 @@ function precompile_test_harness(@nospecialize(f), separate::Bool) popfirst!(DEPOT_PATH) popfirst!(LOAD_PATH) end - nothing + return nothing end function check_presence(mi, token) @@ -47,5 +47,5 @@ function create_standalone(load_path, name::String, file) # Write out the test setup as a micro package write(joinpath(load_path, "$name.jl"), string(code)) - Base.compilecache(Base.PkgId(name)) + return Base.compilecache(Base.PkgId(name)) end diff --git a/test/helpers/ptx.jl b/test/helpers/ptx.jl index e82416bc..89c6eefe 100644 --- a/test/helpers/ptx.jl +++ b/test/helpers/ptx.jl @@ -5,7 +5,7 @@ import ..TestRuntime struct CompilerParams <: AbstractCompilerParams end -PTXCompilerJob = CompilerJob{PTXCompilerTarget,CompilerParams} +PTXCompilerJob = CompilerJob{PTXCompilerTarget, CompilerParams} struct PTXKernelState data::Int64 @@ -35,36 +35,38 @@ module PTXTestRuntime end GPUCompiler.runtime_module(::PTXCompilerJob) = PTXTestRuntime -function create_job(@nospecialize(func), @nospecialize(types); - minthreads=nothing, maxthreads=nothing, - blocks_per_sm=nothing, maxregs=nothing, - kwargs...) +function create_job( + @nospecialize(func), @nospecialize(types); + minthreads = nothing, maxthreads = nothing, + blocks_per_sm = nothing, maxregs = nothing, + kwargs... + ) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) - target = PTXCompilerTarget(; cap=v"7.0", minthreads, maxthreads, blocks_per_sm, maxregs) + target = PTXCompilerTarget(; cap = v"7.0", minthreads, maxthreads, blocks_per_sm, maxregs) params = CompilerParams() - config = CompilerConfig(target, params; kernel=false, config_kwargs...) - CompilerJob(source, config), kwargs + config = CompilerConfig(target, params; kernel = false, config_kwargs...) + return CompilerJob(source, config), kwargs end function code_typed(@nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_typed(job; kwargs...) + return GPUCompiler.code_typed(job; kwargs...) end function code_warntype(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_warntype(io, job; kwargs...) + return GPUCompiler.code_warntype(io, job; kwargs...) end function code_llvm(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_llvm(io, job; kwargs...) + return GPUCompiler.code_llvm(io, job; kwargs...) end function code_native(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_native(io, job; kwargs...) + return GPUCompiler.code_native(io, job; kwargs...) end # aliases without ::IO argument @@ -78,8 +80,8 @@ end # simulates codegen for a kernel function: validates by default function code_execution(@nospecialize(func), @nospecialize(types); kwargs...) - job, kwargs = create_job(func, types; kernel=true, kwargs...) - JuliaContext() do ctx + job, kwargs = create_job(func, types; kernel = true, kwargs...) + return JuliaContext() do ctx GPUCompiler.compile(:asm, job; kwargs...) end end diff --git a/test/helpers/runtime.jl b/test/helpers/runtime.jl index c35cbb04..13c8f35e 100644 --- a/test/helpers/runtime.jl +++ b/test/helpers/runtime.jl @@ -1,9 +1,9 @@ module TestRuntime - # dummy methods - signal_exception() = return - malloc(sz) = C_NULL - report_oom(sz) = return - report_exception(ex) = return - report_exception_name(ex) = return - report_exception_frame(idx, func, file, line) = return +# dummy methods +signal_exception() = return +malloc(sz) = C_NULL +report_oom(sz) = return +report_exception(ex) = return +report_exception_name(ex) = return +report_exception_frame(idx, func, file, line) = return end diff --git a/test/helpers/spirv.jl b/test/helpers/spirv.jl index 0144cd6a..974f4151 100644 --- a/test/helpers/spirv.jl +++ b/test/helpers/spirv.jl @@ -4,38 +4,42 @@ using ..GPUCompiler import ..TestRuntime struct CompilerParams <: AbstractCompilerParams end -GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime +GPUCompiler.runtime_module(::CompilerJob{<:Any, CompilerParams}) = TestRuntime -function create_job(@nospecialize(func), @nospecialize(types); - supports_fp16=true, supports_fp64=true, backend::Symbol, - kwargs...) +function create_job( + @nospecialize(func), @nospecialize(types); + supports_fp16 = true, supports_fp64 = true, backend::Symbol, + kwargs... + ) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) - target = SPIRVCompilerTarget(; backend, validate=true, optimize=true, - supports_fp16, supports_fp64) + target = SPIRVCompilerTarget(; + backend, validate = true, optimize = true, + supports_fp16, supports_fp64 + ) params = CompilerParams() - config = CompilerConfig(target, params; kernel=false, config_kwargs...) - CompilerJob(source, config), kwargs + config = CompilerConfig(target, params; kernel = false, config_kwargs...) + return CompilerJob(source, config), kwargs end function code_typed(@nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_typed(job; kwargs...) + return GPUCompiler.code_typed(job; kwargs...) end function code_warntype(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_warntype(io, job; kwargs...) + return GPUCompiler.code_warntype(io, job; kwargs...) end function code_llvm(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_llvm(io, job; kwargs...) + return GPUCompiler.code_llvm(io, job; kwargs...) end function code_native(io::IO, @nospecialize(func), @nospecialize(types); kwargs...) job, kwargs = create_job(func, types; kwargs...) - GPUCompiler.code_native(io, job; kwargs...) + return GPUCompiler.code_native(io, job; kwargs...) end # aliases without ::IO argument @@ -49,8 +53,8 @@ end # simulates codegen for a kernel function: validates by default function code_execution(@nospecialize(func), @nospecialize(types); kwargs...) - job, kwargs = create_job(func, types; kernel=true, kwargs...) - JuliaContext() do ctx + job, kwargs = create_job(func, types; kernel = true, kwargs...) + return JuliaContext() do ctx GPUCompiler.compile(:asm, job; kwargs...) end end diff --git a/test/helpers/test.jl b/test/helpers/test.jl index 014ddb27..8e9c996a 100644 --- a/test/helpers/test.jl +++ b/test/helpers/test.jl @@ -1,6 +1,6 @@ # @test_throw, with additional testing for the exception message macro test_throws_message(f, typ, ex...) - quote + return quote msg = "" @test_throws $(esc(typ)) try $(esc(ex...)) @@ -19,20 +19,20 @@ macro test_throws_message(f, typ, ex...) end # helper function for sinking a value to prevent the callee from getting optimized away -@inline @generated function sink(i::T, ::Val{addrspace}=Val(0)) where {T <: Union{Int32,UInt32}, addrspace} +@inline @generated function sink(i::T, ::Val{addrspace} = Val(0)) where {T <: Union{Int32, UInt32}, addrspace} as_str = addrspace > 0 ? " addrspace($addrspace)" : "" llvmcall_str = """%slot = alloca i32$(addrspace > 0 ? ", addrspace($addrspace)" : "") - store volatile i32 %0, i32$(as_str)* %slot - %value = load volatile i32, i32$(as_str)* %slot - ret i32 %value""" + store volatile i32 %0, i32$(as_str)* %slot + %value = load volatile i32, i32$(as_str)* %slot + ret i32 %value""" return :(Base.llvmcall($llvmcall_str, T, Tuple{T}, i)) end -@inline @generated function sink(i::T, ::Val{addrspace}=Val(0)) where {T <: Union{Int64,UInt64}, addrspace} +@inline @generated function sink(i::T, ::Val{addrspace} = Val(0)) where {T <: Union{Int64, UInt64}, addrspace} as_str = addrspace > 0 ? " addrspace($addrspace)" : "" llvmcall_str = """%slot = alloca i64$(addrspace > 0 ? ", addrspace($addrspace)" : "") - store volatile i64 %0, i64$(as_str)* %slot - %value = load volatile i64, i64$(as_str)* %slot - ret i64 %value""" + store volatile i64 %0, i64$(as_str)* %slot + %value = load volatile i64, i64$(as_str)* %slot + ret i64 %value""" return :(Base.llvmcall($llvmcall_str, T, Tuple{T}, i)) end @@ -47,10 +47,10 @@ module FileCheck global filecheck_path::String function __init__() - global filecheck_path = joinpath(LLVM_jll.artifact_dir, "tools", "FileCheck") + return global filecheck_path = joinpath(LLVM_jll.artifact_dir, "tools", "FileCheck") end - function filecheck_exe(; adjust_PATH::Bool=true, adjust_LIBPATH::Bool=true) + function filecheck_exe(; adjust_PATH::Bool = true, adjust_LIBPATH::Bool = true) env = Base.invokelatest( LLVM_jll.JLLWrappers.adjust_ENV!, copy(ENV), @@ -69,12 +69,12 @@ module FileCheck function filecheck(f, input) # FileCheck assumes that the input is available as a file - mktemp() do path, input_io + return mktemp() do path, input_io write(input_io, input) close(input_io) # capture the output of `f` and write it into a temporary buffer - result = IOCapture.capture(rethrow=Union{}) do + result = IOCapture.capture(rethrow = Union{}) do f(input) end output_io = IOBuffer() @@ -90,9 +90,11 @@ module FileCheck end # determine some useful prefixes for FileCheck - prefixes = ["CHECK", - "JULIA$(VERSION.major)_$(VERSION.minor)", - "LLVM$(Base.libllvm_version.major)"] + prefixes = [ + "CHECK", + "JULIA$(VERSION.major)_$(VERSION.minor)", + "LLVM$(Base.libllvm_version.major)", + ] ## whether we use typed pointers or opaque pointers if julia_typed_pointers push!(prefixes, "TYPED") @@ -110,11 +112,11 @@ module FileCheck seekstart(output_io) filecheck_io = Pipe() cmd = ```$(filecheck_exe()) - --color - --allow-unused-prefixes - --check-prefixes $(join(prefixes, ',')) - $path``` - proc = run(pipeline(ignorestatus(cmd); stdin=output_io, stdout=filecheck_io, stderr=filecheck_io); wait=false) + --color + --allow-unused-prefixes + --check-prefixes $(join(prefixes, ',')) + $path``` + proc = run(pipeline(ignorestatus(cmd); stdin = output_io, stdout = filecheck_io, stderr = filecheck_io); wait = false) close(filecheck_io.in) # collect the output of FileCheck @@ -135,7 +137,7 @@ module FileCheck const checks = String[] macro check_str(str) push!(checks, str) - nothing + return nothing end macro filecheck(ex) @@ -146,10 +148,12 @@ module FileCheck check_str = join(checks, "\n") empty!(checks) - esc(quote - filecheck($check_str) do _ - $ex + return esc( + quote + filecheck($check_str) do _ + $ex + end end - end) + ) end end diff --git a/test/metal.jl b/test/metal.jl index 951781f1..25f65f09 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -1,186 +1,200 @@ @testset "IR" begin -@testset "kernel functions" begin -@testset "byref aggregates" begin - mod = @eval module $(gensym()) - kernel(x) = return - end - - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"TYPED-SAME: ({{(\{ i64 \}|\[1 x i64\])}}*" - check"OPAQUE-SAME: (ptr" - Metal.code_llvm(mod.kernel, Tuple{Tuple{Int}}) - end - - # for kernels, every pointer argument needs to take an address space - @test @filecheck begin - check"CHECK-LABEL: define void @_Z6kernel5TupleI5Int64E" - check"TYPED-SAME: ({{(\{ i64 \}|\[1 x i64\])}} addrspace(1)*" - check"OPAQUE-SAME: (ptr addrspace(1)" - Metal.code_llvm(mod.kernel, Tuple{Tuple{Int}}; kernel=true) - end -end - -@testset "byref primitives" begin - mod = @eval module $(gensym()) - kernel(x) = return - end - - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK-SAME: (i64" - Metal.code_llvm(mod.kernel, Tuple{Int}) - end - - # for kernels, every pointer argument needs to take an address space - @test @filecheck begin - check"CHECK-LABEL: define void @_Z6kernel5Int64" - check"TYPED-SAME: (i64 addrspace(1)*" - check"OPAQUE-SAME: (ptr addrspace(1)" - Metal.code_llvm(mod.kernel, Tuple{Int}; kernel=true) - end -end - -@testset "module metadata" begin - mod = @eval module $(gensym()) - kernel() = return - end - - @test @filecheck begin - check"CHECK: air.version" - check"CHECK: air.language_version" - check"CHECK: air.max_device_buffers" - Metal.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true) - end -end - -@testset "argument metadata" begin - mod = @eval module $(gensym()) - kernel(x) = return - end - - @test @filecheck begin - check"CHECK: air.buffer" - Metal.code_llvm(mod.kernel, Tuple{Int}; dump_module=true, kernel=true) - end - - # XXX: perform more exhaustive testing of argument passing metadata here, - # or just defer to execution testing in Metal.jl? -end - -@testset "input arguments" begin - mod = @eval module $(gensym()) - function kernel(ptr) - idx = ccall("extern julia.air.thread_position_in_threadgroup.i32", - llvmcall, UInt32, ()) + 1 - unsafe_store!(ptr, 42, idx) - return + @testset "kernel functions" begin + @testset "byref aggregates" begin + mod = @eval module $(gensym()) + kernel(x) = return + end + + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"TYPED-SAME: ({{(\{ i64 \}|\[1 x i64\])}}*" + check"OPAQUE-SAME: (ptr" + Metal.code_llvm(mod.kernel, Tuple{Tuple{Int}}) + end + + # for kernels, every pointer argument needs to take an address space + @test @filecheck begin + check"CHECK-LABEL: define void @_Z6kernel5TupleI5Int64E" + check"TYPED-SAME: ({{(\{ i64 \}|\[1 x i64\])}} addrspace(1)*" + check"OPAQUE-SAME: (ptr addrspace(1)" + Metal.code_llvm(mod.kernel, Tuple{Tuple{Int}}; kernel = true) + end end - end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"TYPED-SAME: ({{.+}} addrspace(1)* %{{.+}})" - check"OPAQUE-SAME: (ptr addrspace(1) %{{.+}})" - check"CHECK: call i32 @julia.air.thread_position_in_threadgroup.i32" - Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Int,1}}) - end + @testset "byref primitives" begin + mod = @eval module $(gensym()) + kernel(x) = return + end + + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK-SAME: (i64" + Metal.code_llvm(mod.kernel, Tuple{Int}) + end + + # for kernels, every pointer argument needs to take an address space + @test @filecheck begin + check"CHECK-LABEL: define void @_Z6kernel5Int64" + check"TYPED-SAME: (i64 addrspace(1)*" + check"OPAQUE-SAME: (ptr addrspace(1)" + Metal.code_llvm(mod.kernel, Tuple{Int}; kernel = true) + end + end - @test @filecheck begin - check"CHECK-LABEL: define void @_Z6kernel7LLVMPtrI5Int64Li1EE" - check"TYPED-SAME: ({{.+}} addrspace(1)* %{{.+}}, i32 %thread_position_in_threadgroup)" - check"OPAQUE-SAME: (ptr addrspace(1) %{{.+}}, i32 %thread_position_in_threadgroup)" - check"CHECK-NOT: call i32 @julia.air.thread_position_in_threadgroup.i32" - Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Int,1}}; kernel=true) - end -end + @testset "module metadata" begin + mod = @eval module $(gensym()) + kernel() = return + end + + @test @filecheck begin + check"CHECK: air.version" + check"CHECK: air.language_version" + check"CHECK: air.max_device_buffers" + Metal.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true) + end + end -@testset "vector intrinsics" begin - mod = @eval module $(gensym()) - foo(x, y) = ccall("llvm.smax.v2i64", llvmcall, NTuple{2, VecElement{Int64}}, - (NTuple{2, VecElement{Int64}}, NTuple{2, VecElement{Int64}}), x, y) - end + @testset "argument metadata" begin + mod = @eval module $(gensym()) + kernel(x) = return + end - @test @filecheck begin - check"CHECK-LABEL: define <2 x i64> @{{(julia|j)_foo_[0-9]+}}" - check"CHECK: air.max.s.v2i64" - Metal.code_llvm(mod.foo, (NTuple{2, VecElement{Int64}}, NTuple{2, VecElement{Int64}})) - end -end + @test @filecheck begin + check"CHECK: air.buffer" + Metal.code_llvm(mod.kernel, Tuple{Int}; dump_module = true, kernel = true) + end -@testset "unsupported type detection" begin - mod = @eval module $(gensym()) - function kernel(ptr) - buf = reinterpret(Ptr{Float32}, ptr) - val = unsafe_load(buf) - dval = Cdouble(val) - # ccall("extern metal_os_log", llvmcall, Nothing, (Float64,), dval) - Base.llvmcall((""" - declare void @llvm.va_start(i8*) - declare void @llvm.va_end(i8*) - declare void @air.os_log(i8*, i64) - - define void @metal_os_log(...) { - %1 = alloca i8* - %2 = bitcast i8** %1 to i8* - call void @llvm.va_start(i8* %2) - %3 = load i8*, i8** %1 - call void @air.os_log(i8* %3, i64 8) - call void @llvm.va_end(i8* %2) - ret void - } - - define void @entry(double %val) #0 { - call void (...) @metal_os_log(double %val) - ret void - } - - attributes #0 = { alwaysinline }""", "entry"), - Nothing, Tuple{Float64}, dval) - return + # XXX: perform more exhaustive testing of argument passing metadata here, + # or just defer to execution testing in Metal.jl? end - end - - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK: @metal_os_log" - Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Float32,1}}; validate=true) - end - - function kernel2(ptr) - val = unsafe_load(ptr) - res = val * val - unsafe_store!(ptr, res) - return - end - @test_throws_message(InvalidIRError, - Metal.code_execution(kernel2, - Tuple{Core.LLVMPtr{Float64,1}})) do msg - occursin("unsupported use of double value", msg) - end -end + @testset "input arguments" begin + mod = @eval module $(gensym()) + function kernel(ptr) + idx = ccall( + "extern julia.air.thread_position_in_threadgroup.i32", + llvmcall, UInt32, () + ) + 1 + unsafe_store!(ptr, 42, idx) + return + end + end + + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"TYPED-SAME: ({{.+}} addrspace(1)* %{{.+}})" + check"OPAQUE-SAME: (ptr addrspace(1) %{{.+}})" + check"CHECK: call i32 @julia.air.thread_position_in_threadgroup.i32" + Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Int, 1}}) + end + + @test @filecheck begin + check"CHECK-LABEL: define void @_Z6kernel7LLVMPtrI5Int64Li1EE" + check"TYPED-SAME: ({{.+}} addrspace(1)* %{{.+}}, i32 %thread_position_in_threadgroup)" + check"OPAQUE-SAME: (ptr addrspace(1) %{{.+}}, i32 %thread_position_in_threadgroup)" + check"CHECK-NOT: call i32 @julia.air.thread_position_in_threadgroup.i32" + Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Int, 1}}; kernel = true) + end + end -@testset "constant globals" begin - mod = @eval module $(gensym()) - const xs = (1.0f0, 2f0) + @testset "vector intrinsics" begin + mod = @eval module $(gensym()) + foo(x, y) = ccall( + "llvm.smax.v2i64", llvmcall, NTuple{2, VecElement{Int64}}, + (NTuple{2, VecElement{Int64}}, NTuple{2, VecElement{Int64}}), x, y + ) + end + + @test @filecheck begin + check"CHECK-LABEL: define <2 x i64> @{{(julia|j)_foo_[0-9]+}}" + check"CHECK: air.max.s.v2i64" + Metal.code_llvm(mod.foo, (NTuple{2, VecElement{Int64}}, NTuple{2, VecElement{Int64}})) + end + end - function kernel(ptr, i) - unsafe_store!(ptr, xs[i]) + @testset "unsupported type detection" begin + mod = @eval module $(gensym()) + function kernel(ptr) + buf = reinterpret(Ptr{Float32}, ptr) + val = unsafe_load(buf) + dval = Cdouble(val) + # ccall("extern metal_os_log", llvmcall, Nothing, (Float64,), dval) + Base.llvmcall( + ( + """ + declare void @llvm.va_start(i8*) + declare void @llvm.va_end(i8*) + declare void @air.os_log(i8*, i64) + + define void @metal_os_log(...) { + %1 = alloca i8* + %2 = bitcast i8** %1 to i8* + call void @llvm.va_start(i8* %2) + %3 = load i8*, i8** %1 + call void @air.os_log(i8* %3, i64 8) + call void @llvm.va_end(i8* %2) + ret void + } + + define void @entry(double %val) #0 { + call void (...) @metal_os_log(double %val) + ret void + } + + attributes #0 = { alwaysinline }""", "entry", + ), + Nothing, Tuple{Float64}, dval + ) + return + end + end + + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK: @metal_os_log" + Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Float32, 1}}; validate = true) + end + + function kernel2(ptr) + val = unsafe_load(ptr) + res = val * val + unsafe_store!(ptr, res) + return + end + + @test_throws_message( + InvalidIRError, + Metal.code_execution( + kernel2, + Tuple{Core.LLVMPtr{Float64, 1}} + ) + ) do msg + occursin("unsupported use of double value", msg) + end + end - return + @testset "constant globals" begin + mod = @eval module $(gensym()) + const xs = (1.0f0, 2.0f0) + + function kernel(ptr, i) + unsafe_store!(ptr, xs[i]) + + return + end + end + + @test @filecheck begin + check"CHECK: @{{.+}} ={{.*}} addrspace(2) constant [2 x float]" + check"CHECK: define void @_Z6kernel7LLVMPtrI7Float32Li1EE5Int64" + Metal.code_llvm( + mod.kernel, Tuple{Core.LLVMPtr{Float32, 1}, Int}; + dump_module = true, kernel = true + ) + end end - end - @test @filecheck begin - check"CHECK: @{{.+}} ={{.*}} addrspace(2) constant [2 x float]" - check"CHECK: define void @_Z6kernel7LLVMPtrI7Float32Li1EE5Int64" - Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Float32,1}, Int}; - dump_module=true, kernel=true) end -end - -end end diff --git a/test/native.jl b/test/native.jl index da08764f..c6d9916d 100644 --- a/test/native.jl +++ b/test/native.jl @@ -25,20 +25,20 @@ end @testset "compilation" begin @testset "callable structs" begin mod = @eval module $(gensym()) - struct MyCallable end - (::MyCallable)(a, b) = a+b + struct MyCallable end + (::MyCallable)(a, b) = a + b end - (ci, rt) = Native.code_typed(mod.MyCallable(), (Int, Int), kernel=false)[1] + (ci, rt) = Native.code_typed(mod.MyCallable(), (Int, Int), kernel = false)[1] @test ci.slottypes[1] == Core.Compiler.Const(mod.MyCallable()) end @testset "compilation database" begin mod = @eval module $(gensym()) - @noinline inner(x) = x+1 - function outer(x) - return inner(x) - end + @noinline inner(x) = x + 1 + function outer(x) + return inner(x) + end end job, _ = Native.create_job(mod.outer, (Int,)) @@ -47,10 +47,10 @@ end meth = only(methods(mod.outer, (Int,))) - mis = filter(mi->mi.def == meth, keys(meta.compiled)) + mis = filter(mi -> mi.def == meth, keys(meta.compiled)) @test length(mis) == 1 - other_mis = filter(mi->mi.def != meth, keys(meta.compiled)) + other_mis = filter(mi -> mi.def != meth, keys(meta.compiled)) @test length(other_mis) == 1 @test only(other_mis).def in methods(mod.inner) end @@ -58,23 +58,23 @@ end @testset "advanced database" begin mod = @eval module $(gensym()) - @noinline inner(x) = x+1 - foo(x) = sum(inner, fill(x, 10, 10)) + @noinline inner(x) = x + 1 + foo(x) = sum(inner, fill(x, 10, 10)) end - job, _ = Native.create_job(mod.foo, (Float64,); validate=false) + job, _ = Native.create_job(mod.foo, (Float64,); validate = false) JuliaContext() do ctx # shouldn't segfault ir, meta = GPUCompiler.compile(:llvm, job) meth = only(methods(mod.foo, (Float64,))) - mis = filter(mi->mi.def == meth, keys(meta.compiled)) + mis = filter(mi -> mi.def == meth, keys(meta.compiled)) @test length(mis) == 1 inner_methods = filter(keys(meta.compiled)) do mi mi.def in methods(mod.inner) && - mi.specTypes == Tuple{typeof(mod.inner), Float64} + mi.specTypes == Tuple{typeof(mod.inner), Float64} end @test length(inner_methods) == 1 end @@ -82,8 +82,8 @@ end @testset "cached compilation" begin mod = @eval module $(gensym()) - @noinline child(i) = i - kernel(i) = child(i)+1 + @noinline child(i) = i + kernel(i) = child(i) + 1 end # smoke test @@ -95,7 +95,7 @@ end end # basic redefinition - @eval mod kernel(i) = child(i)+2 + @eval mod kernel(i) = child(i) + 2 job, _ = Native.create_job(mod.kernel, (Int64,)) @test @filecheck begin check"CHECK-LABEL: define i64 @{{(julia|j)_kernel_[0-9]+}}" @@ -135,7 +135,7 @@ end @test invocations[] == 1 # redefinition - @eval mod kernel(i) = child(i)+3 + @eval mod kernel(i) = child(i) + 3 source = methodinstance(ft, tt, Base.get_world_counter()) @test @filecheck begin check"CHECK-LABEL: define i64 @{{(julia|j)_kernel_[0-9]+}}" @@ -158,7 +158,7 @@ end @test invocations[] == 2 # redefining child functions - @eval mod @noinline child(i) = i+1 + @eval mod @noinline child(i) = i + 1 Base.invokelatest(GPUCompiler.cached_compilation, cache, source, job.config, compiler, linker) @test invocations[] == 3 @@ -167,7 +167,7 @@ end @test invocations[] == 3 # change in configuration - config = CompilerConfig(job.config; name="foobar") + config = CompilerConfig(job.config; name = "foobar") @test @filecheck begin check"CHECK: define i64 @foobar" Base.invokelatest(GPUCompiler.cached_compilation, cache, source, config, compiler, linker) @@ -184,7 +184,7 @@ end end t = @async Base.invokelatest(background, job) wait(c1) # make sure the task has started - @eval mod kernel(i) = child(i)+4 + @eval mod kernel(i) = child(i) + 4 source = methodinstance(ft, tt, Base.get_world_counter()) ir = Base.invokelatest(GPUCompiler.cached_compilation, cache, source, job.config, compiler, linker) @test contains(ir, r"add i64 %\d+, 4") @@ -199,7 +199,7 @@ end @testset "allowed mutable types" begin # when types have no fields, we should always allow them mod = @eval module $(gensym()) - struct Empty end + struct Empty end end Native.code_execution(Returns(nothing), (mod.Empty,)) @@ -213,140 +213,140 @@ end @testset "IR" begin -@testset "basic reflection" begin - mod = @eval module $(gensym()) + @testset "basic reflection" begin + mod = @eval module $(gensym()) valid_kernel() = return invalid_kernel() = 1 - end + end - @test @filecheck begin - # module should contain our function + a generic call wrapper - check"CHECK: @{{(julia|j)_valid_kernel_[0-9]+}}" - Native.code_llvm(mod.valid_kernel, Tuple{}; optimize=false, dump_module=true) - end + @test @filecheck begin + # module should contain our function + a generic call wrapper + check"CHECK: @{{(julia|j)_valid_kernel_[0-9]+}}" + Native.code_llvm(mod.valid_kernel, Tuple{}; optimize = false, dump_module = true) + end - @test Native.code_llvm(devnull, mod.invalid_kernel, Tuple{}) == nothing - @test_throws KernelError Native.code_llvm(devnull, mod.invalid_kernel, Tuple{}; kernel=true) == nothing -end + @test Native.code_llvm(devnull, mod.invalid_kernel, Tuple{}) == nothing + @test_throws KernelError Native.code_llvm(devnull, mod.invalid_kernel, Tuple{}; kernel = true) == nothing + end -@testset "unbound typevars" begin - mod = @eval module $(gensym()) + @testset "unbound typevars" begin + mod = @eval module $(gensym()) invalid_kernel() where {unbound} = return + end + @test_throws KernelError Native.code_llvm(devnull, mod.invalid_kernel, Tuple{}) end - @test_throws KernelError Native.code_llvm(devnull, mod.invalid_kernel, Tuple{}) -end -@testset "child functions" begin - # we often test using `@noinline sink` child functions, so test whether these survive - mod = @eval module $(gensym()) + @testset "child functions" begin + # we often test using `@noinline sink` child functions, so test whether these survive + mod = @eval module $(gensym()) import ..sink @noinline child(i) = sink(i) parent(i) = child(i) - end + end - @test @filecheck begin - check"CHECK-LABEL: define i64 @{{(julia|j)_parent_[0-9]+}}" - check"CHECK: call{{.*}} i64 @{{(julia|j)_child_[0-9]+}}" - Native.code_llvm(mod.parent, Tuple{Int}) + @test @filecheck begin + check"CHECK-LABEL: define i64 @{{(julia|j)_parent_[0-9]+}}" + check"CHECK: call{{.*}} i64 @{{(julia|j)_child_[0-9]+}}" + Native.code_llvm(mod.parent, Tuple{Int}) + end end -end -@testset "sysimg" begin - # bug: use a system image function - mod = @eval module $(gensym()) - function foobar(a,i) - Base.pointerset(a, 0, mod1(i,10), 8) + @testset "sysimg" begin + # bug: use a system image function + mod = @eval module $(gensym()) + function foobar(a, i) + Base.pointerset(a, 0, mod1(i, 10), 8) + end end - end - @test @filecheck begin - check"CHECK-NOT: jlsys_" - Native.code_llvm(mod.foobar, Tuple{Ptr{Int},Int}) + @test @filecheck begin + check"CHECK-NOT: jlsys_" + Native.code_llvm(mod.foobar, Tuple{Ptr{Int}, Int}) + end end -end -@testset "tracked pointers" begin - mod = @eval module $(gensym()) + @testset "tracked pointers" begin + mod = @eval module $(gensym()) function kernel(a) a[1] = 1 return end - end + end - # this used to throw an LLVM assertion (#223) - Native.code_llvm(devnull, mod.kernel, Tuple{Vector{Int}}; kernel=true) - @test "We did not crash!" != "" -end + # this used to throw an LLVM assertion (#223) + Native.code_llvm(devnull, mod.kernel, Tuple{Vector{Int}}; kernel = true) + @test "We did not crash!" != "" + end -@testset "CUDA.jl#278" begin - # codegen idempotency - # NOTE: this isn't fixed, but surfaces here due to bad inference of checked_sub - # NOTE: with the fix to print_to_string this doesn't error anymore, - # but still have a test to make sure it doesn't regress - Native.code_llvm(devnull, Base.checked_sub, Tuple{Int,Int}; optimize=false) - Native.code_llvm(devnull, Base.checked_sub, Tuple{Int,Int}; optimize=false) + @testset "CUDA.jl#278" begin + # codegen idempotency + # NOTE: this isn't fixed, but surfaces here due to bad inference of checked_sub + # NOTE: with the fix to print_to_string this doesn't error anymore, + # but still have a test to make sure it doesn't regress + Native.code_llvm(devnull, Base.checked_sub, Tuple{Int, Int}; optimize = false) + Native.code_llvm(devnull, Base.checked_sub, Tuple{Int, Int}; optimize = false) - # breaking recursion in print_to_string makes it possible to compile - # even in the presence of the above bug - Native.code_llvm(devnull, Base.print_to_string, Tuple{Int,Int}; optimize=false) + # breaking recursion in print_to_string makes it possible to compile + # even in the presence of the above bug + Native.code_llvm(devnull, Base.print_to_string, Tuple{Int, Int}; optimize = false) - @test "We did not crash!" != "" -end + @test "We did not crash!" != "" + end -@testset "LLVM D32593" begin - mod = @eval module $(gensym()) + @testset "LLVM D32593" begin + mod = @eval module $(gensym()) struct D32593_struct foo::Float32 bar::Float32 end D32593(ptr) = unsafe_load(ptr).foo - end - - Native.code_llvm(devnull, mod.D32593, Tuple{Ptr{mod.D32593_struct}}) - @test "We did not crash!" != "" -end + end -@testset "slow abi" begin - mod = @eval module $(gensym()) - x = 2 - f = () -> x+1 - end - @test @filecheck begin - check"CHECK: define {{.+}} @julia" - check"TYPED: define nonnull {}* @jfptr" - check"OPAQUE: define nonnull ptr @jfptr" - check"CHECK: call {{.+}} @julia" - Native.code_llvm(mod.f, Tuple{}; entry_abi=:func, dump_module=true) + Native.code_llvm(devnull, mod.D32593, Tuple{Ptr{mod.D32593_struct}}) + @test "We did not crash!" != "" end -end -@testset "function entry safepoint emission" begin - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_identity_[0-9]+}}" - check"CHECK-NOT: %safepoint" - Native.code_llvm(identity, Tuple{Nothing}; entry_safepoint=false, optimize=false, dump_module=true) + @testset "slow abi" begin + mod = @eval module $(gensym()) + x = 2 + f = () -> x + 1 + end + @test @filecheck begin + check"CHECK: define {{.+}} @julia" + check"TYPED: define nonnull {}* @jfptr" + check"OPAQUE: define nonnull ptr @jfptr" + check"CHECK: call {{.+}} @julia" + Native.code_llvm(mod.f, Tuple{}; entry_abi = :func, dump_module = true) + end end - # XXX: broken by JuliaLang/julia#57010, - # see https://github.com/JuliaLang/julia/pull/57010/files#r2079576894 - if VERSION < v"1.13.0-DEV.533" + @testset "function entry safepoint emission" begin @test @filecheck begin check"CHECK-LABEL: define void @{{(julia|j)_identity_[0-9]+}}" - check"CHECK: %safepoint" - Native.code_llvm(identity, Tuple{Nothing}; entry_safepoint=true, optimize=false, dump_module=true) + check"CHECK-NOT: %safepoint" + Native.code_llvm(identity, Tuple{Nothing}; entry_safepoint = false, optimize = false, dump_module = true) + end + + # XXX: broken by JuliaLang/julia#57010, + # see https://github.com/JuliaLang/julia/pull/57010/files#r2079576894 + if VERSION < v"1.13.0-DEV.533" + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_identity_[0-9]+}}" + check"CHECK: %safepoint" + Native.code_llvm(identity, Tuple{Nothing}; entry_safepoint = true, optimize = false, dump_module = true) + end end end -end -@testset "always_inline" begin - # XXX: broken by JuliaLang/julia#51599, see JuliaGPU/GPUCompiler.jl#527. - # yet somehow this works on 1.12? - broken = VERSION >= v"1.13-" + @testset "always_inline" begin + # XXX: broken by JuliaLang/julia#51599, see JuliaGPU/GPUCompiler.jl#527. + # yet somehow this works on 1.12? + broken = VERSION >= v"1.13-" - mod = @eval module $(gensym()) + mod = @eval module $(gensym()) import ..sink - expensive(x) = $(foldl((e, _) -> :($sink($e) + $sink(x)), 1:100; init=:x)) + expensive(x) = $(foldl((e, _) -> :($sink($e) + $sink(x)), 1:100; init = :x)) function g(x) expensive(x) return @@ -355,51 +355,59 @@ end expensive(x) return end - end + end - @test @filecheck begin - check"CHECK: @{{(julia|j)_expensive_[0-9]+}}" - Native.code_llvm(mod.g, Tuple{Int64}; dump_module=true, kernel=true) - end + @test @filecheck begin + check"CHECK: @{{(julia|j)_expensive_[0-9]+}}" + Native.code_llvm(mod.g, Tuple{Int64}; dump_module = true, kernel = true) + end + + @test @filecheck( + begin + check"CHECK-NOT: @{{(julia|j)_expensive_[0-9]+}}" + Native.code_llvm(mod.g, Tuple{Int64}; dump_module = true, kernel = true, always_inline = true) + end + ) broken = broken - @test @filecheck(begin - check"CHECK-NOT: @{{(julia|j)_expensive_[0-9]+}}" - Native.code_llvm(mod.g, Tuple{Int64}; dump_module=true, kernel=true, always_inline=true) - end) broken=broken + @test @filecheck begin + check"CHECK: @{{(julia|j)_expensive_[0-9]+}}" + Native.code_llvm(mod.h, Tuple{Int64}; dump_module = true, kernel = true) + end - @test @filecheck begin - check"CHECK: @{{(julia|j)_expensive_[0-9]+}}" - Native.code_llvm(mod.h, Tuple{Int64}; dump_module=true, kernel=true) + @test @filecheck( + begin + check"CHECK-NOT: @{{(julia|j)_expensive_[0-9]+}}" + Native.code_llvm(mod.h, Tuple{Int64}; dump_module = true, kernel = true, always_inline = true) + end + ) broken = broken end - @test @filecheck(begin - check"CHECK-NOT: @{{(julia|j)_expensive_[0-9]+}}" - Native.code_llvm(mod.h, Tuple{Int64}; dump_module=true, kernel=true, always_inline=true) - end) broken=broken -end - -@testset "function attributes" begin - mod = @eval module $(gensym()) + @testset "function attributes" begin + mod = @eval module $(gensym()) @inline function convergent_barrier() - Base.llvmcall((""" - declare void @barrier() #1 + Base.llvmcall( + ( + """ + declare void @barrier() #1 - define void @entry() #0 { - call void @barrier() - ret void - } + define void @entry() #0 { + call void @barrier() + ret void + } - attributes #0 = { alwaysinline } - attributes #1 = { convergent }""", "entry"), - Nothing, Tuple{}) + attributes #0 = { alwaysinline } + attributes #1 = { convergent }""", "entry", + ), + Nothing, Tuple{} + ) + end end - end - @test @filecheck begin - check"CHECK: attributes #{{.}} = { convergent }" - Native.code_llvm(mod.convergent_barrier, Tuple{}; dump_module=true, raw=true) + @test @filecheck begin + check"CHECK: attributes #{{.}} = { convergent }" + Native.code_llvm(mod.convergent_barrier, Tuple{}; dump_module = true, raw = true) + end end -end end @@ -407,33 +415,33 @@ end @testset "assembly" begin -@testset "basic reflection" begin - mod = @eval module $(gensym()) + @testset "basic reflection" begin + mod = @eval module $(gensym()) valid_kernel() = return invalid_kernel() = 1 - end + end - @test Native.code_native(devnull, mod.valid_kernel, Tuple{}) == nothing - @test Native.code_native(devnull, mod.invalid_kernel, Tuple{}) == nothing - @test_throws KernelError Native.code_native(devnull, mod.invalid_kernel, Tuple{}; kernel=true) -end + @test Native.code_native(devnull, mod.valid_kernel, Tuple{}) == nothing + @test Native.code_native(devnull, mod.invalid_kernel, Tuple{}) == nothing + @test_throws KernelError Native.code_native(devnull, mod.invalid_kernel, Tuple{}; kernel = true) + end -@testset "idempotency" begin - # bug: generate code twice for the same kernel (jl_to_ptx wasn't idempotent) - mod = @eval module $(gensym()) + @testset "idempotency" begin + # bug: generate code twice for the same kernel (jl_to_ptx wasn't idempotent) + mod = @eval module $(gensym()) kernel() = return - end - Native.code_native(devnull, mod.kernel, Tuple{}) - Native.code_native(devnull, mod.kernel, Tuple{}) + end + Native.code_native(devnull, mod.kernel, Tuple{}) + Native.code_native(devnull, mod.kernel, Tuple{}) - @test "We did not crash!" != "" -end + @test "We did not crash!" != "" + end -@testset "compile for host after gpu" begin - # issue #11: re-using host functions after GPU compilation - mod = @eval module $(gensym()) + @testset "compile for host after gpu" begin + # issue #11: re-using host functions after GPU compilation + mod = @eval module $(gensym()) import ..sink - @noinline child(i) = sink(i+1) + @noinline child(i) = sink(i + 1) function fromhost() child(10) @@ -443,11 +451,11 @@ end child(10) return end - end + end - Native.code_native(devnull, mod.fromptx, Tuple{}) - @test mod.fromhost() == 11 -end + Native.code_native(devnull, mod.fromptx, Tuple{}) + @test mod.fromhost() == 11 + end end @@ -456,118 +464,137 @@ end @testset "errors" begin -@testset "non-isbits arguments" begin - mod = @eval module $(gensym()) + @testset "non-isbits arguments" begin + mod = @eval module $(gensym()) import ..sink - foobar(i) = (sink(unsafe_trunc(Int,i)); return) - end + foobar(i) = (sink(unsafe_trunc(Int, i)); return) + end - @test_throws_message(KernelError, - Native.code_execution(mod.foobar, Tuple{BigInt})) do msg - occursin("passing non-bitstype argument", msg) && - occursin("BigInt", msg) - end + @test_throws_message( + KernelError, + Native.code_execution(mod.foobar, Tuple{BigInt}) + ) do msg + occursin("passing non-bitstype argument", msg) && + occursin("BigInt", msg) + end - # test that we get information about fields and reason why something is not isbits - mod = @eval module $(gensym()) + # test that we get information about fields and reason why something is not isbits + mod = @eval module $(gensym()) struct CleverType{T} x::T end Base.unsafe_trunc(::Type{Int}, x::CleverType) = unsafe_trunc(Int, x.x) - foobar(i) = (sink(unsafe_trunc(Int,i)); return) - end - @test_throws_message(KernelError, - Native.code_execution(mod.foobar, Tuple{mod.CleverType{BigInt}})) do msg - occursin("passing non-bitstype argument", msg) && - occursin("CleverType", msg) && - occursin("BigInt", msg) + foobar(i) = (sink(unsafe_trunc(Int, i)); return) + end + @test_throws_message( + KernelError, + Native.code_execution(mod.foobar, Tuple{mod.CleverType{BigInt}}) + ) do msg + occursin("passing non-bitstype argument", msg) && + occursin("CleverType", msg) && + occursin("BigInt", msg) + end end -end -@testset "invalid LLVM IR" begin - mod = @eval module $(gensym()) + @testset "invalid LLVM IR" begin + mod = @eval module $(gensym()) foobar(i) = println(i) - end + end - @test_throws_message(InvalidIRError, - Native.code_execution(mod.foobar, Tuple{Int})) do msg - occursin("invalid LLVM IR", msg) && - (occursin(GPUCompiler.RUNTIME_FUNCTION, msg) || - occursin(GPUCompiler.UNKNOWN_FUNCTION, msg) || - occursin(GPUCompiler.DYNAMIC_CALL, msg)) && - occursin("[1] println", msg) && - occursin("[2] foobar", msg) + @test_throws_message( + InvalidIRError, + Native.code_execution(mod.foobar, Tuple{Int}) + ) do msg + occursin("invalid LLVM IR", msg) && + ( + occursin(GPUCompiler.RUNTIME_FUNCTION, msg) || + occursin(GPUCompiler.UNKNOWN_FUNCTION, msg) || + occursin(GPUCompiler.DYNAMIC_CALL, msg) + ) && + occursin("[1] println", msg) && + occursin("[2] foobar", msg) + end end -end -@testset "invalid LLVM IR (ccall)" begin - mod = @eval module $(gensym()) + @testset "invalid LLVM IR (ccall)" begin + mod = @eval module $(gensym()) function foobar(p) unsafe_store!(p, ccall(:time, Cint, ())) return end - end + end - @test_throws_message(InvalidIRError, - Native.code_execution(mod.foobar, Tuple{Ptr{Int}})) do msg - if VERSION >= v"1.11-" - occursin("invalid LLVM IR", msg) && - occursin(GPUCompiler.LAZY_FUNCTION, msg) && - occursin("call to time", msg) && - occursin("[1] foobar", msg) - else - occursin("invalid LLVM IR", msg) && - occursin(GPUCompiler.POINTER_FUNCTION, msg) && - occursin("[1] foobar", msg) + @test_throws_message( + InvalidIRError, + Native.code_execution(mod.foobar, Tuple{Ptr{Int}}) + ) do msg + if VERSION >= v"1.11-" + occursin("invalid LLVM IR", msg) && + ( + occursin(GPUCompiler.LAZY_FUNCTION, msg) || + occursin(GPUCompiler.RUNTIME_FUNCTION, msg) + ) && + occursin("call to time", msg) && + occursin("[1] foobar", msg) + else + occursin("invalid LLVM IR", msg) && + occursin(GPUCompiler.POINTER_FUNCTION, msg) && + occursin("[1] foobar", msg) + end end end -end -@testset "delayed bindings" begin - mod = @eval module $(gensym()) + @testset "delayed bindings" begin + mod = @eval module $(gensym()) function kernel() undefined return end - end + end - @test_throws_message(InvalidIRError, - Native.code_execution(mod.kernel, Tuple{})) do msg - occursin("invalid LLVM IR", msg) && - occursin(GPUCompiler.DELAYED_BINDING, msg) && - occursin(r"use of '.*undefined'", msg) && - occursin("[1] kernel", msg) + @test_throws_message( + InvalidIRError, + Native.code_execution(mod.kernel, Tuple{}) + ) do msg + occursin("invalid LLVM IR", msg) && + occursin(GPUCompiler.DELAYED_BINDING, msg) && + occursin(r"use of '.*undefined'", msg) && + occursin("[1] kernel", msg) + end end -end -@testset "dynamic call (invoke)" begin - mod = @eval module $(gensym()) + @testset "dynamic call (invoke)" begin + mod = @eval module $(gensym()) @noinline nospecialize_child(@nospecialize(i)) = i kernel(a, b) = (unsafe_store!(b, nospecialize_child(a)); return) - end + end - @test_throws_message(InvalidIRError, - Native.code_execution(mod.kernel, Tuple{Int,Ptr{Int}})) do msg - occursin("invalid LLVM IR", msg) && - occursin(GPUCompiler.DYNAMIC_CALL, msg) && - occursin("call to nospecialize_child", msg) && - occursin("[1] kernel", msg) + @test_throws_message( + InvalidIRError, + Native.code_execution(mod.kernel, Tuple{Int, Ptr{Int}}) + ) do msg + occursin("invalid LLVM IR", msg) && + occursin(GPUCompiler.DYNAMIC_CALL, msg) && + occursin("call to nospecialize_child", msg) && + occursin("[1] kernel", msg) + end end -end -@testset "dynamic call (apply)" begin - mod = @eval module $(gensym()) + @testset "dynamic call (apply)" begin + mod = @eval module $(gensym()) func() = println(1) - end + end - @test_throws_message(InvalidIRError, - Native.code_execution(mod.func, Tuple{})) do msg - occursin("invalid LLVM IR", msg) && - occursin(GPUCompiler.DYNAMIC_CALL, msg) && - occursin("call to print", msg) && - occursin("[2] func", msg) + @test_throws_message( + InvalidIRError, + Native.code_execution(mod.func, Tuple{}) + ) do msg + occursin("invalid LLVM IR", msg) && + occursin(GPUCompiler.DYNAMIC_CALL, msg) && + occursin("call to print", msg) && + occursin("[2] func", msg) + end end -end end @@ -577,8 +604,8 @@ end # NOTE: method overrides do not support redefinitions, so we use different kernels mod = @eval module $(gensym()) - kernel() = child() - @inline child() = 0 + kernel() = child() + @inline child() = 0 end @test @filecheck begin @@ -588,14 +615,14 @@ end end mod = @eval module $(gensym()) - using ..GPUCompiler + using ..GPUCompiler - Base.Experimental.@MethodTable(method_table) + Base.Experimental.@MethodTable(method_table) - kernel() = child() - @inline child() = 0 + kernel() = child() + @inline child() = 0 - Base.Experimental.@overlay method_table child() = 1 + Base.Experimental.@overlay method_table child() = 1 end @test @filecheck begin @@ -608,26 +635,26 @@ end @testset "semi-concrete interpretation + overlay methods" begin # issue 366, caused dynamic deispatch mod = @eval module $(gensym()) - using ..GPUCompiler - using StaticArrays + using ..GPUCompiler + using StaticArrays - function kernel(width, height) - xy = SVector{2, Float32}(0.5f0, 0.5f0) - res = SVector{2, UInt32}(width, height) - floor.(UInt32, max.(0f0, xy) .* res) - return - end + function kernel(width, height) + xy = SVector{2, Float32}(0.5f0, 0.5f0) + res = SVector{2, UInt32}(width, height) + floor.(UInt32, max.(0.0f0, xy) .* res) + return + end - Base.Experimental.@MethodTable method_table - Base.Experimental.@overlay method_table Base.isnan(x::Float32) = - (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 + Base.Experimental.@MethodTable method_table + Base.Experimental.@overlay method_table Base.isnan(x::Float32) = + (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 end @test @filecheck begin check"CHECK-LABEL: @julia_kernel" check"CHECK-NOT: apply_generic" check"CHECK: llvm.floor" - Native.code_llvm(mod.kernel, Tuple{Int, Int}; debuginfo=:none, mod.method_table) + Native.code_llvm(mod.kernel, Tuple{Int, Int}; debuginfo = :none, mod.method_table) end end @@ -636,14 +663,14 @@ end # broken again by JuliaLang/julia#51092, see JuliaGPU/GPUCompiler.jl#506 mod = @eval module $(gensym()) - child(; kwargs...) = return - function parent() - child(; a=1f0, b=1.0) - return - end + child(; kwargs...) = return + function parent() + child(; a = 1.0f0, b = 1.0) + return + end - Base.Experimental.@MethodTable method_table - Base.Experimental.@overlay method_table @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = return + Base.Experimental.@MethodTable method_table + Base.Experimental.@overlay method_table @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = return end @test @filecheck begin @@ -653,7 +680,7 @@ end check"CHECK-NOT: inttoptr" check"CHECK-NOT: apply_type" check"CHECK: ret void" - Native.code_llvm(mod.parent, Tuple{}; debuginfo=:none, mod.method_table) + Native.code_llvm(mod.parent, Tuple{}; debuginfo = :none, mod.method_table) end end @@ -669,7 +696,7 @@ end return end - ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo=:none)) + ir = sprint(io -> Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo = :none)) @test !occursin("deferred_codegen", ir) @test occursin("call void @julia_kernel", ir) end diff --git a/test/native/precompile.jl b/test/native/precompile.jl index 6fe981a5..2d35e112 100644 --- a/test/native/precompile.jl +++ b/test/native/precompile.jl @@ -1,32 +1,33 @@ - - precompile_test_harness("Inference caching") do load_path # Write out the Native test setup as a micro package create_standalone(load_path, "NativeCompiler", "native.jl") - write(joinpath(load_path, "NativeBackend.jl"), :( - module NativeBackend - import NativeCompiler - using PrecompileTools - - function kernel(A, x) - A[1] = x - return - end + write( + joinpath(load_path, "NativeBackend.jl"), :( + module NativeBackend + import NativeCompiler + using PrecompileTools - let - job, _ = NativeCompiler.Native.create_job(kernel, (Vector{Int}, Int)) - precompile(job) - end + function kernel(A, x) + A[1] = x + return + end - # identity is foreign - @setup_workload begin - job, _ = NativeCompiler.Native.create_job(identity, (Int,)) - @compile_workload begin + let + job, _ = NativeCompiler.Native.create_job(kernel, (Vector{Int}, Int)) precompile(job) end - end - end) |> string) + + # identity is foreign + @setup_workload begin + job, _ = NativeCompiler.Native.create_job(identity, (Int,)) + @compile_workload begin + precompile(job) + end + end + end + ) |> string + ) Base.compilecache(Base.PkgId("NativeBackend")) @eval let @@ -48,7 +49,7 @@ precompile_test_harness("Inference caching") do load_path @test check_presence(kernel_mi, token) # check that identity survived - @test check_presence(identity_mi, token) broken=VERSION>=v"1.12.0-DEV.1268" + @test check_presence(identity_mi, token) broken = VERSION >= v"1.12.0-DEV.1268" GPUCompiler.clear_disk_cache!() @test GPUCompiler.disk_cache_enabled() == false @@ -56,7 +57,7 @@ precompile_test_harness("Inference caching") do load_path GPUCompiler.enable_disk_cache!() @test GPUCompiler.disk_cache_enabled() == true - job, _ = NativeCompiler.Native.create_job(NativeBackend.kernel, (Vector{Int}, Int); validate=false) + job, _ = NativeCompiler.Native.create_job(NativeBackend.kernel, (Vector{Int}, Int); validate = false) @assert job.source == kernel_mi ci = GPUCompiler.ci_cache_lookup(GPUCompiler.ci_cache(job), job.source, job.world, job.world) @assert ci !== nothing diff --git a/test/ptx.jl b/test/ptx.jl index 9e56ee51..4d6f2873 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -1,423 +1,423 @@ @testset "IR" begin -@testset "exceptions" begin - mod = @eval module $(gensym()) + @testset "exceptions" begin + mod = @eval module $(gensym()) foobar() = throw(DivideError()) + end + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_foobar_[0-9]+}}" + # plain exceptions should get lowered to a call to the GPU run-time + # not a jl_throw referencing a jl_value_t representing the exception + check"CHECK-NOT: jl_throw" + check"CHECK: gpu_report_exception" + + PTX.code_llvm(mod.foobar, Tuple{}; dump_module = true) + end end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_foobar_[0-9]+}}" - # plain exceptions should get lowered to a call to the GPU run-time - # not a jl_throw referencing a jl_value_t representing the exception - check"CHECK-NOT: jl_throw" - check"CHECK: gpu_report_exception" - - PTX.code_llvm(mod.foobar, Tuple{}; dump_module=true) - end -end -@testset "kernel functions" begin -@testset "kernel argument attributes" begin - mod = @eval module $(gensym()) - kernel(x) = return + @testset "kernel functions" begin + @testset "kernel argument attributes" begin + mod = @eval module $(gensym()) + kernel(x) = return + + struct Aggregate + x::Int + end + end + + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"TYPED-SAME: ({{({ i64 }|\[1 x i64\])}}*" + check"OPAQUE-SAME: (ptr" + PTX.code_llvm(mod.kernel, Tuple{mod.Aggregate}) + end - struct Aggregate - x::Int + @test @filecheck begin + check"CHECK-LABEL: define ptx_kernel void @_Z6kernel9Aggregate" + check"TYPED-NOT: *" + check"OPAQUE-NOT: ptr" + PTX.code_llvm(mod.kernel, Tuple{mod.Aggregate}; kernel = true) + end end - end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"TYPED-SAME: ({{({ i64 }|\[1 x i64\])}}*" - check"OPAQUE-SAME: (ptr" - PTX.code_llvm(mod.kernel, Tuple{mod.Aggregate}) - end + @testset "property_annotations" begin + mod = @eval module $(gensym()) + kernel() = return + end - @test @filecheck begin - check"CHECK-LABEL: define ptx_kernel void @_Z6kernel9Aggregate" - check"TYPED-NOT: *" - check"OPAQUE-NOT: ptr" - PTX.code_llvm(mod.kernel, Tuple{mod.Aggregate}; kernel=true) - end -end + @test @filecheck begin + check"CHECK-NOT: nvvm.annotations" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true) + end -@testset "property_annotations" begin - mod = @eval module $(gensym()) - kernel() = return - end + @test @filecheck begin + check"CHECK-NOT: maxntid" + check"CHECK-NOT: reqntid" + check"CHECK-NOT: minctasm" + check"CHECK-NOT: maxnreg" + check"CHECK: nvvm.annotations" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true) + end - @test @filecheck begin - check"CHECK-NOT: nvvm.annotations" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true) - end + @test @filecheck begin + check"CHECK: maxntidx\", i32 42" + check"CHECK: maxntidy\", i32 1" + check"CHECK: maxntidz\", i32 1" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true, maxthreads = 42) + end - @test @filecheck begin - check"CHECK-NOT: maxntid" - check"CHECK-NOT: reqntid" - check"CHECK-NOT: minctasm" - check"CHECK-NOT: maxnreg" - check"CHECK: nvvm.annotations" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true) - end + @test @filecheck begin + check"CHECK: reqntidx\", i32 42" + check"CHECK: reqntidy\", i32 1" + check"CHECK: reqntidz\", i32 1" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true, minthreads = 42) + end - @test @filecheck begin - check"CHECK: maxntidx\", i32 42" - check"CHECK: maxntidy\", i32 1" - check"CHECK: maxntidz\", i32 1" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, maxthreads=42) - end + @test @filecheck begin + check"CHECK: minctasm\", i32 42" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true, blocks_per_sm = 42) + end - @test @filecheck begin - check"CHECK: reqntidx\", i32 42" - check"CHECK: reqntidy\", i32 1" - check"CHECK: reqntidz\", i32 1" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, minthreads=42) - end + @test @filecheck begin + check"CHECK: maxnreg\", i32 42" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true, maxregs = 42) + end + end - @test @filecheck begin - check"CHECK: minctasm\", i32 42" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, blocks_per_sm=42) - end + LLVM.version() >= v"8" && @testset "calling convention" begin + mod = @eval module $(gensym()) + kernel() = return + end - @test @filecheck begin - check"CHECK: maxnreg\", i32 42" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, maxregs=42) - end -end + @test @filecheck begin + check"CHECK-NOT: ptx_kernel" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true) + end -LLVM.version() >= v"8" && @testset "calling convention" begin - mod = @eval module $(gensym()) - kernel() = return - end + @test @filecheck begin + check"CHECK: ptx_kernel" + PTX.code_llvm(mod.kernel, Tuple{}; dump_module = true, kernel = true) + end + end - @test @filecheck begin - check"CHECK-NOT: ptx_kernel" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true) - end + @testset "kernel state" begin + # state should be passed by value to kernel functions - @test @filecheck begin - check"CHECK: ptx_kernel" - PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true) - end -end + mod = @eval module $(gensym()) + kernel() = return + end -@testset "kernel state" begin - # state should be passed by value to kernel functions + @test @filecheck begin + check"CHECK: @{{(julia|j)_kernel[0-9_]*}}()" + PTX.code_llvm(mod.kernel, Tuple{}) + end - mod = @eval module $(gensym()) - kernel() = return - end + @test @filecheck begin + check"CHECK: @_Z6kernel([1 x i64] %state)" + PTX.code_llvm(mod.kernel, Tuple{}; kernel = true) + end - @test @filecheck begin - check"CHECK: @{{(julia|j)_kernel[0-9_]*}}()" - PTX.code_llvm(mod.kernel, Tuple{}) - end + # state should only passed to device functions that use it - @test @filecheck begin - check"CHECK: @_Z6kernel([1 x i64] %state)" - PTX.code_llvm(mod.kernel, Tuple{}; kernel=true) - end + mod = @eval module $(gensym()) + @noinline child1(ptr) = unsafe_load(ptr) + @noinline function child2() + data = $PTX.kernel_state().data + ptr = reinterpret(Ptr{Int}, data) + unsafe_load(ptr) + end - # state should only passed to device functions that use it + function kernel(ptr) + unsafe_store!(ptr, child1(ptr) + child2()) + return + end + end - mod = @eval module $(gensym()) - @noinline child1(ptr) = unsafe_load(ptr) - @noinline function child2() - data = $PTX.kernel_state().data - ptr = reinterpret(Ptr{Int}, data) - unsafe_load(ptr) + # kernel should take state argument before all else + @test @filecheck begin + check"CHECK-LABEL: define ptx_kernel void @_Z6kernelP5Int64([1 x i64] %state" + check"CHECK-NOT: julia.gpu.state_getter" + PTX.code_llvm(mod.kernel, Tuple{Ptr{Int64}}; kernel = true, dump_module = true) + end + # child1 doesn't use the state + @test @filecheck begin + check"CHECK-LABEL: define{{.*}} i64 @{{(julia|j)_child1_[0-9]+}}" + PTX.code_llvm(mod.kernel, Tuple{Ptr{Int64}}; kernel = true, dump_module = true) + end + # child2 does + @test @filecheck begin + check"CHECK-LABEL: define{{.*}} i64 @{{(julia|j)_child2_[0-9]+}}" + PTX.code_llvm(mod.kernel, Tuple{Ptr{Int64}}; kernel = true, dump_module = true) + end end + end - function kernel(ptr) - unsafe_store!(ptr, child1(ptr) + child2()) + @testset "Mock Enzyme" begin + function kernel(a) + unsafe_store!(a, unsafe_load(a)^2) return end - end - # kernel should take state argument before all else - @test @filecheck begin - check"CHECK-LABEL: define ptx_kernel void @_Z6kernelP5Int64([1 x i64] %state" - check"CHECK-NOT: julia.gpu.state_getter" - PTX.code_llvm(mod.kernel, Tuple{Ptr{Int64}}; kernel=true, dump_module=true) - end - # child1 doesn't use the state - @test @filecheck begin - check"CHECK-LABEL: define{{.*}} i64 @{{(julia|j)_child1_[0-9]+}}" - PTX.code_llvm(mod.kernel, Tuple{Ptr{Int64}}; kernel=true, dump_module=true) - end - # child2 does - @test @filecheck begin - check"CHECK-LABEL: define{{.*}} i64 @{{(julia|j)_child2_[0-9]+}}" - PTX.code_llvm(mod.kernel, Tuple{Ptr{Int64}}; kernel=true, dump_module=true) - end -end -end + function dkernel(a) + ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}}) + ccall(ptr, Cvoid, (Ptr{Float64},), a) + return + end -@testset "Mock Enzyme" begin - function kernel(a) - unsafe_store!(a, unsafe_load(a)^2) - return + ir = sprint(io -> Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo = :none)) + @test !occursin("deferred_codegen", ir) + @test occursin("call void @julia_", ir) end - - function dkernel(a) - ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}}) - ccall(ptr, Cvoid, (Ptr{Float64},), a) - return - end - - ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo=:none)) - @test !occursin("deferred_codegen", ir) - @test occursin("call void @julia_", ir) -end end ############################################################################################ if :NVPTX in LLVM.backends() -@testset "assembly" begin - -@testset "child functions" begin - # we often test using @noinline child functions, so test whether these survive - # (despite not having side-effects) - - mod = @eval module $(gensym()) - import ..sink - @noinline child(i) = sink(i) - function parent(i) - child(i) - return - end - end - - @test @filecheck begin - check"CHECK-LABEL: .visible .func {{(julia|j)_parent[0-9_]*}}" - check"CHECK: call.uni" - check"CHECK-NEXT: {{(julia|j)_child_}}" - PTX.code_native(mod.parent, Tuple{Int64}) - end -end + @testset "assembly" begin + + @testset "child functions" begin + # we often test using @noinline child functions, so test whether these survive + # (despite not having side-effects) + + mod = @eval module $(gensym()) + import ..sink + @noinline child(i) = sink(i) + function parent(i) + child(i) + return + end + end -@testset "kernel functions" begin - mod = @eval module $(gensym()) - import ..sink - @noinline nonentry(i) = sink(i) - function entry(i) - nonentry(i) - return + @test @filecheck begin + check"CHECK-LABEL: .visible .func {{(julia|j)_parent[0-9_]*}}" + check"CHECK: call.uni" + check"CHECK-NEXT: {{(julia|j)_child_}}" + PTX.code_native(mod.parent, Tuple{Int64}) + end end - end - - @test @filecheck begin - check"CHECK-NOT: .visible .func {{(julia|j)_nonentry}}" - check"CHECK-LABEL: .visible .entry _Z5entry5Int64" - check"CHECK: {{(julia|j)_nonentry}}" - PTX.code_native(mod.entry, Tuple{Int64}; kernel=true, dump_module=true) - end -@testset "property_annotations" begin - @test @filecheck begin - check"CHECK-NOT: maxntid" - PTX.code_native(mod.entry, Tuple{Int64}; kernel=true) - end + @testset "kernel functions" begin + mod = @eval module $(gensym()) + import ..sink + @noinline nonentry(i) = sink(i) + function entry(i) + nonentry(i) + return + end + end - @test @filecheck begin - check"CHECK: .maxntid 42, 1, 1" - PTX.code_native(mod.entry, Tuple{Int64}; kernel=true, maxthreads=42) - end + @test @filecheck begin + check"CHECK-NOT: .visible .func {{(julia|j)_nonentry}}" + check"CHECK-LABEL: .visible .entry _Z5entry5Int64" + check"CHECK: {{(julia|j)_nonentry}}" + PTX.code_native(mod.entry, Tuple{Int64}; kernel = true, dump_module = true) + end - @test @filecheck begin - check"CHECK: .reqntid 42, 1, 1" - PTX.code_native(mod.entry, Tuple{Int64}; kernel=true, minthreads=42) - end + @testset "property_annotations" begin + @test @filecheck begin + check"CHECK-NOT: maxntid" + PTX.code_native(mod.entry, Tuple{Int64}; kernel = true) + end + + @test @filecheck begin + check"CHECK: .maxntid 42, 1, 1" + PTX.code_native(mod.entry, Tuple{Int64}; kernel = true, maxthreads = 42) + end + + @test @filecheck begin + check"CHECK: .reqntid 42, 1, 1" + PTX.code_native(mod.entry, Tuple{Int64}; kernel = true, minthreads = 42) + end + + @test @filecheck begin + check"CHECK: .minnctapersm 42" + PTX.code_native(mod.entry, Tuple{Int64}; kernel = true, blocks_per_sm = 42) + end + + if LLVM.version() >= v"4.0" + @test @filecheck begin + check"CHECK: .maxnreg 42" + PTX.code_native(mod.entry, Tuple{Int64}; kernel = true, maxregs = 42) + end + end + end + end - @test @filecheck begin - check"CHECK: .minnctapersm 42" - PTX.code_native(mod.entry, Tuple{Int64}; kernel=true, blocks_per_sm=42) - end + @testset "child function reuse" begin + # bug: depending on a child function from multiple parents resulted in + # the child only being present once - if LLVM.version() >= v"4.0" - @test @filecheck begin - check"CHECK: .maxnreg 42" - PTX.code_native(mod.entry, Tuple{Int64}; kernel=true, maxregs=42) - end - end -end -end + mod = @eval module $(gensym()) + import ..sink + @noinline child(i) = sink(i) + function parent1(i) + child(i) + return + end + function parent2(i) + child(i + 1) + return + end + end -@testset "child function reuse" begin - # bug: depending on a child function from multiple parents resulted in - # the child only being present once + @test @filecheck begin + check"CHECK: .func {{(julia|j)_child}}" + PTX.code_native(mod.parent1, Tuple{Int}) + end - mod = @eval module $(gensym()) - import ..sink - @noinline child(i) = sink(i) - function parent1(i) - child(i) - return - end - function parent2(i) - child(i+1) - return + @test @filecheck begin + check"CHECK: .func {{(julia|j)_child}}" + PTX.code_native(mod.parent2, Tuple{Int}) + end end - end - @test @filecheck begin - check"CHECK: .func {{(julia|j)_child}}" - PTX.code_native(mod.parent1, Tuple{Int}) - end - - @test @filecheck begin - check"CHECK: .func {{(julia|j)_child}}" - PTX.code_native(mod.parent2, Tuple{Int}) - end -end + @testset "child function reuse bis" begin + # bug: similar, but slightly different issue as above + # in the case of two child functions + + mod = @eval module $(gensym()) + import ..sink + @noinline child1(i) = sink(i) + @noinline child2(i) = sink(i + 1) + function parent1(i) + child1(i) + child2(i) + return + end + function parent2(i) + child1(i + 1) + child2(i + 1) + return + end + end -@testset "child function reuse bis" begin - # bug: similar, but slightly different issue as above - # in the case of two child functions + @test @filecheck begin + check"CHECK-DAG: .func {{(julia|j)_child1}}" + check"CHECK-DAG: .func {{(julia|j)_child2}}" + PTX.code_native(mod.parent1, Tuple{Int}) + end - mod = @eval module $(gensym()) - import ..sink - @noinline child1(i) = sink(i) - @noinline child2(i) = sink(i+1) - function parent1(i) - child1(i) + child2(i) - return - end - function parent2(i) - child1(i+1) + child2(i+1) - return + @test @filecheck begin + check"CHECK-DAG: .func {{(julia|j)_child1}}" + check"CHECK-DAG: .func {{(julia|j)_child2}}" + PTX.code_native(mod.parent2, Tuple{Int}) + end end - end - - @test @filecheck begin - check"CHECK-DAG: .func {{(julia|j)_child1}}" - check"CHECK-DAG: .func {{(julia|j)_child2}}" - PTX.code_native(mod.parent1, Tuple{Int}) - end - @test @filecheck begin - check"CHECK-DAG: .func {{(julia|j)_child1}}" - check"CHECK-DAG: .func {{(julia|j)_child2}}" - PTX.code_native(mod.parent2, Tuple{Int}) - end -end + @testset "indirect sysimg function use" begin + # issue #9: re-using sysimg functions should force recompilation + # (host fldmod1->mod1 throws, so the PTX code shouldn't contain a throw) -@testset "indirect sysimg function use" begin - # issue #9: re-using sysimg functions should force recompilation - # (host fldmod1->mod1 throws, so the PTX code shouldn't contain a throw) + # NOTE: Int32 to test for #49 + mod = @eval module $(gensym()) + function kernel(out) + wid, lane = fldmod1(unsafe_load(out), Int32(32)) + unsafe_store!(out, wid) + return + end + end - # NOTE: Int32 to test for #49 - mod = @eval module $(gensym()) - function kernel(out) - wid, lane = fldmod1(unsafe_load(out), Int32(32)) - unsafe_store!(out, wid) - return + @test @filecheck begin + check"CHECK-LABEL: .visible .func {{(julia|j)_kernel[0-9_]*}}" + check"CHECK-NOT: jl_throw" + check"CHECK-NOT: jl_invoke" + PTX.code_native(mod.kernel, Tuple{Ptr{Int32}}) + end end - end - - @test @filecheck begin - check"CHECK-LABEL: .visible .func {{(julia|j)_kernel[0-9_]*}}" - check"CHECK-NOT: jl_throw" - check"CHECK-NOT: jl_invoke" - PTX.code_native(mod.kernel, Tuple{Ptr{Int32}}) - end -end -@testset "LLVM intrinsics" begin - # issue #13 (a): cannot select trunc - mod = @eval module $(gensym()) - function kernel(x) - unsafe_trunc(Int, x) - return + @testset "LLVM intrinsics" begin + # issue #13 (a): cannot select trunc + mod = @eval module $(gensym()) + function kernel(x) + unsafe_trunc(Int, x) + return + end + end + PTX.code_native(devnull, mod.kernel, Tuple{Float64}) + @test "We did not crash!" != "" end - end - PTX.code_native(devnull, mod.kernel, Tuple{Float64}) - @test "We did not crash!" != "" -end -@testset "exception arguments" begin - mod = @eval module $(gensym()) - function kernel(a) - unsafe_store!(a, trunc(Int, unsafe_load(a))) - return + @testset "exception arguments" begin + mod = @eval module $(gensym()) + function kernel(a) + unsafe_store!(a, trunc(Int, unsafe_load(a))) + return + end + end + PTX.code_native(devnull, mod.kernel, Tuple{Ptr{Float64}}) + @test "We did not crash!" != "" end - end - PTX.code_native(devnull, mod.kernel, Tuple{Ptr{Float64}}) - @test "We did not crash!" != "" -end -@testset "GC and TLS lowering" begin - mod = @eval module $(gensym()) - import ..sink + @testset "GC and TLS lowering" begin + mod = @eval module $(gensym()) + import ..sink - mutable struct PleaseAllocate - y::Csize_t - end + mutable struct PleaseAllocate + y::Csize_t + end - # common pattern in Julia 0.7: outlined throw to avoid a GC frame in the calling code - @noinline function inner(x) - sink(x.y) - nothing - end + # common pattern in Julia 0.7: outlined throw to avoid a GC frame in the calling code + @noinline function inner(x) + sink(x.y) + nothing + end - function kernel(i) - inner(PleaseAllocate(Csize_t(42))) - nothing - end - end + function kernel(i) + inner(PleaseAllocate(Csize_t(42))) + nothing + end + end - @test @filecheck begin - check"CHECK-LABEL: .visible .func {{(julia|j)_kernel[0-9_]*}}" - check"CHECK-NOT: julia.push_gc_frame" - check"CHECK-NOT: julia.pop_gc_frame" - check"CHECK-NOT: julia.get_gc_frame_slot" - check"CHECK-NOT: julia.new_gc_frame" - check"CHECK: gpu_gc_pool_alloc" - PTX.code_native(mod.kernel, Tuple{Int}) - end + @test @filecheck begin + check"CHECK-LABEL: .visible .func {{(julia|j)_kernel[0-9_]*}}" + check"CHECK-NOT: julia.push_gc_frame" + check"CHECK-NOT: julia.pop_gc_frame" + check"CHECK-NOT: julia.get_gc_frame_slot" + check"CHECK-NOT: julia.new_gc_frame" + check"CHECK: gpu_gc_pool_alloc" + PTX.code_native(mod.kernel, Tuple{Int}) + end - # make sure that we can still ellide allocations - mod = @eval module $(gensym()) - function ref_kernel(ptr, i) - data = Ref{Int64}() - data[] = 0 - if i > 1 - data[] = 1 - else - data[] = 2 - end - unsafe_store!(ptr, data[], i) - return nothing + # make sure that we can still ellide allocations + mod = @eval module $(gensym()) + function ref_kernel(ptr, i) + data = Ref{Int64}() + data[] = 0 + if i > 1 + data[] = 1 + else + data[] = 2 + end + unsafe_store!(ptr, data[], i) + return nothing + end + end + + @test @filecheck begin + check"CHECK-LABEL: .visible .func {{(julia|j)_ref_kernel[0-9_]*}}" + check"CHECK-NOT: gpu_gc_pool_alloc" + PTX.code_native(mod.ref_kernel, Tuple{Ptr{Int64}, Int}) + end end - end - @test @filecheck begin - check"CHECK-LABEL: .visible .func {{(julia|j)_ref_kernel[0-9_]*}}" - check"CHECK-NOT: gpu_gc_pool_alloc" - PTX.code_native(mod.ref_kernel, Tuple{Ptr{Int64}, Int}) - end -end + @testset "float boxes" begin + mod = @eval module $(gensym()) + function kernel(a, b) + c = Int32(a) + # the conversion to Int32 may fail, in which case the input Float32 is boxed in + # order to pass it to the @nospecialize exception constructor. we should really + # avoid that (eg. by avoiding @nospecialize, or optimize the unused arguments + # away), but for now the box should just work. + unsafe_store!(b, c) + return + end + end -@testset "float boxes" begin - mod = @eval module $(gensym()) - function kernel(a,b) - c = Int32(a) - # the conversion to Int32 may fail, in which case the input Float32 is boxed in - # order to pass it to the @nospecialize exception constructor. we should really - # avoid that (eg. by avoiding @nospecialize, or optimize the unused arguments - # away), but for now the box should just work. - unsafe_store!(b, c) - return + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK: jl_box_float32" + PTX.code_llvm(mod.kernel, Tuple{Float32, Ptr{Float32}}) + end + PTX.code_native(devnull, mod.kernel, Tuple{Float32, Ptr{Float32}}) end - end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK: jl_box_float32" - PTX.code_llvm(mod.kernel, Tuple{Float32,Ptr{Float32}}) end - PTX.code_native(devnull, mod.kernel, Tuple{Float32,Ptr{Float32}}) -end - -end end # NVPTX in LLVM.backends() diff --git a/test/ptx/precompile.jl b/test/ptx/precompile.jl index b5f980c9..96c13b05 100644 --- a/test/ptx/precompile.jl +++ b/test/ptx/precompile.jl @@ -2,28 +2,31 @@ precompile_test_harness("Inference caching") do load_path # Write out the PTX test helpers as a micro package create_standalone(load_path, "PTXCompiler", "ptx.jl") - write(joinpath(load_path, "PTXBackend.jl"), :( - module PTXBackend - import PTXCompiler - using PrecompileTools - - function kernel() - return - end - - let - job, _ = PTXCompiler.PTX.create_job(kernel, ()) - precompile(job) - end + write( + joinpath(load_path, "PTXBackend.jl"), :( + module PTXBackend + import PTXCompiler + using PrecompileTools + + function kernel() + return + end - # identity is foreign - @setup_workload begin - job, _ = PTXCompiler.PTX.create_job(identity, (Int,)) - @compile_workload begin + let + job, _ = PTXCompiler.PTX.create_job(kernel, ()) precompile(job) end - end - end) |> string) + + # identity is foreign + @setup_workload begin + job, _ = PTXCompiler.PTX.create_job(identity, (Int,)) + @compile_workload begin + precompile(job) + end + end + end + ) |> string + ) Base.compilecache(Base.PkgId("PTXBackend")) @eval let @@ -49,6 +52,6 @@ precompile_test_harness("Inference caching") do load_path @test check_presence(kernel_mi, token) # check that identity survived - @test check_presence(identity_mi, token) broken=VERSION>=v"1.12.0-DEV.1268" + @test check_presence(identity_mi, token) broken = VERSION >= v"1.12.0-DEV.1268" end end diff --git a/test/runtests.jl b/test/runtests.jl index 1f1d705d..5f834059 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,10 +22,10 @@ function test_filter(test) if LLVM.is_asserts() && test == "gcn" # XXX: GCN's non-0 stack address space triggers LLVM assertions due to Julia bugs return false - end - if VERSION < v"1.11" && test in ("ptx/precompile", "native/precompile") - return false - end + end + if VERSION < v"1.11" && test in ("ptx/precompile", "native/precompile") + return false + end return true end diff --git a/test/spirv.jl b/test/spirv.jl index e8903140..d51e02fa 100644 --- a/test/spirv.jl +++ b/test/spirv.jl @@ -1,148 +1,160 @@ for backend in (:khronos, :llvm) -@testset "IR" begin - -@testset "kernel functions" begin -@testset "calling convention" begin - mod = @eval module $(gensym()) - kernel() = return - end - - @test @filecheck begin - check"CHECK-NOT: spir_kernel" - SPIRV.code_llvm(mod.kernel, Tuple{}; backend, dump_module=true) - end + @testset "IR" begin + + @testset "kernel functions" begin + @testset "calling convention" begin + mod = @eval module $(gensym()) + kernel() = return + end + + @test @filecheck begin + check"CHECK-NOT: spir_kernel" + SPIRV.code_llvm(mod.kernel, Tuple{}; backend, dump_module = true) + end + + @test @filecheck begin + check"CHECK: spir_kernel" + SPIRV.code_llvm(mod.kernel, Tuple{}; backend, dump_module = true, kernel = true) + end + end - @test @filecheck begin - check"CHECK: spir_kernel" - SPIRV.code_llvm(mod.kernel, Tuple{}; backend, dump_module=true, kernel=true) - end -end + @testset "byval workaround" begin + mod = @eval module $(gensym()) + kernel(x) = return + end -@testset "byval workaround" begin - mod = @eval module $(gensym()) - kernel(x) = return - end + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + SPIRV.code_llvm(mod.kernel, Tuple{Tuple{Int}}; backend) + end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - SPIRV.code_llvm(mod.kernel, Tuple{Tuple{Int}}; backend) - end + @test @filecheck begin + check"CHECK-LABEL: define spir_kernel void @_Z6kernel" + SPIRV.code_llvm(mod.kernel, Tuple{Tuple{Int}}; backend, kernel = true) + end + end - @test @filecheck begin - check"CHECK-LABEL: define spir_kernel void @_Z6kernel" - SPIRV.code_llvm(mod.kernel, Tuple{Tuple{Int}}; backend, kernel=true) - end -end + @testset "byval bug" begin + # byval added alwaysinline, which could conflict with noinline and fail verification + mod = @eval module $(gensym()) + @noinline kernel() = return + end + @test @filecheck begin + check"CHECK-LABEL: define spir_kernel void @_Z6kernel" + SPIRV.code_llvm(mod.kernel, Tuple{}; backend, kernel = true) + end + end + end -@testset "byval bug" begin - # byval added alwaysinline, which could conflict with noinline and fail verification - mod = @eval module $(gensym()) - @noinline kernel() = return - end - @test @filecheck begin - check"CHECK-LABEL: define spir_kernel void @_Z6kernel" - SPIRV.code_llvm(mod.kernel, Tuple{}; backend, kernel=true) - end -end -end + @testset "unsupported type detection" begin + mod = @eval module $(gensym()) + function kernel(ptr, val) + unsafe_store!(ptr, val) + return + end + end -@testset "unsupported type detection" begin - mod = @eval module $(gensym()) - function kernel(ptr, val) - unsafe_store!(ptr, val) - return - end - end + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK: store half" + SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Float16}, Float16}; backend) + end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK: store half" - SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Float16}, Float16}; backend) - end + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK: store float" + SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Float32}, Float32}; backend) + end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK: store float" - SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Float32}, Float32}; backend) - end + @test @filecheck begin + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + check"CHECK: store double" + SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Float64}, Float64}; backend) + end - @test @filecheck begin - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - check"CHECK: store double" - SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Float64}, Float64}; backend) - end + @test_throws_message( + InvalidIRError, + SPIRV.code_execution( + mod.kernel, Tuple{Ptr{Float16}, Float16}; + backend, supports_fp16 = false + ) + ) do msg + occursin("unsupported use of half value", msg) && + occursin("[1] unsafe_store!", msg) && + occursin(r"\[\d+\] kernel", msg) + end - @test_throws_message(InvalidIRError, - SPIRV.code_execution(mod.kernel, Tuple{Ptr{Float16}, Float16}; - backend, supports_fp16=false)) do msg - occursin("unsupported use of half value", msg) && - occursin("[1] unsafe_store!", msg) && - occursin(r"\[\d+\] kernel", msg) - end + @test_throws_message( + InvalidIRError, + SPIRV.code_execution( + mod.kernel, Tuple{Ptr{Float64}, Float64}; + backend, supports_fp64 = false + ) + ) do msg + occursin("unsupported use of double value", msg) && + occursin("[1] unsafe_store!", msg) && + occursin(r"\[\d+\] kernel", msg) + end + end - @test_throws_message(InvalidIRError, - SPIRV.code_execution(mod.kernel, Tuple{Ptr{Float64}, Float64}; - backend, supports_fp64=false)) do msg - occursin("unsupported use of double value", msg) && - occursin("[1] unsafe_store!", msg) && - occursin(r"\[\d+\] kernel", msg) end -end -end + ############################################################################################ -############################################################################################ + @testset "asm" begin -@testset "asm" begin + @testset "trap removal" begin + mod = @eval module $(gensym()) + function kernel(x) + x && error() + return + end + end -@testset "trap removal" begin - mod = @eval module $(gensym()) - function kernel(x) - x && error() - return + @test @filecheck begin + check"CHECK: %_Z6kernel4Bool = OpFunction %void None" + SPIRV.code_native(mod.kernel, Tuple{Bool}; backend, kernel = true) + end end - end - @test @filecheck begin - check"CHECK: %_Z6kernel4Bool = OpFunction %void None" - SPIRV.code_native(mod.kernel, Tuple{Bool}; backend, kernel=true) end -end -end - -@testset "replace i128 allocas" begin - mod = @eval module $(gensym()) + @testset "replace i128 allocas" begin + mod = @eval module $(gensym()) # reimplement some of SIMD.jl struct Vec{N, T} data::NTuple{N, Core.VecElement{T}} end @generated function fadd(x::Vec{N, Float32}, y::Vec{N, Float32}) where {N} quote - Vec(Base.llvmcall($""" - %ret = fadd <$N x float> %0, %1 - ret <$N x float> %ret - """, NTuple{N, Core.VecElement{Float32}}, NTuple{2, NTuple{N, Core.VecElement{Float32}}}, x.data, y.data)) + Vec( + Base.llvmcall( + $""" + %ret = fadd <$N x float> %0, %1 + ret <$N x float> %ret + """, NTuple{N, Core.VecElement{Float32}}, NTuple{2, NTuple{N, Core.VecElement{Float32}}}, x.data, y.data + ) + ) end end kernel(x...) = @noinline fadd(x...) - end + end - @test @filecheck begin - # TODO: should structs of `NTuple{VecElement{T}}` be passed by value instead of sret? - check"CHECK-NOT: i128" - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - @static VERSION >= v"1.12" && check"CHECK: alloca <2 x i64>, align 16" - SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{4, Float32}}; backend, dump_module=true) - end + @test @filecheck begin + # TODO: should structs of `NTuple{VecElement{T}}` be passed by value instead of sret? + check"CHECK-NOT: i128" + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + @static VERSION >= v"1.12" && check"CHECK: alloca <2 x i64>, align 16" + SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{4, Float32}}; backend, dump_module = true) + end - @test @filecheck begin - check"CHECK-NOT: i128" - check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" - @static VERSION >= v"1.12" && check"CHECK: alloca [2 x <2 x i64>], align 16" - SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{8, Float32}}; backend, dump_module=true) + @test @filecheck begin + check"CHECK-NOT: i128" + check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}" + @static VERSION >= v"1.12" && check"CHECK: alloca [2 x <2 x i64>], align 16" + SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{8, Float32}}; backend, dump_module = true) + end end -end end diff --git a/test/utils.jl b/test/utils.jl index 3b742795..342de901 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,10 +1,10 @@ @testset "split_kwargs" begin - kwargs = [:(a=1), :(b=2), :(c=3), :(d=4)] + kwargs = [:(a = 1), :(b = 2), :(c = 3), :(d = 4)] groups = GPUCompiler.split_kwargs(kwargs, [:a], [:b, :c]) @test length(groups) == 3 - @test groups[1] == [:(a=1)] - @test groups[2] == [:(b=2), :(c=3)] - @test groups[3] == [:(d=4)] + @test groups[1] == [:(a = 1)] + @test groups[2] == [:(b = 2), :(c = 3)] + @test groups[3] == [:(d = 4)] end @testset "mangling" begin @@ -28,7 +28,7 @@ end @test mangle(identity, Val{-1}) == "identity(Val<-1>)" @test mangle(identity, Val{Cshort(1)}) == "identity(Val<(short)1>)" @test mangle(identity, Val{1.0}) == "identity(Val<0x1p+0>)" - @test mangle(identity, Val{1f0}) == "identity(Val<0x1p+0f>)" + @test mangle(identity, Val{1.0f0}) == "identity(Val<0x1p+0f>)" @test mangle(identity, Val{'a'}) == "identity(Val<97u>)" @test mangle(identity, Val{'∅'}) == "identity(Val<8709u>)" @@ -50,10 +50,12 @@ end @test mangle(identity, Tuple{Vararg{Int}}) == "identity(Tuple<>)" # many substitutions - @test mangle(identity, Val{1}, Val{2}, Val{3}, Val{4}, Val{5}, Val{6}, Val{7}, Val{8}, - Val{9}, Val{10}, Val{11}, Val{12}, Val{13}, Val{14}, Val{15}, - Val{16}, Val{16}) == - "identity(Val<1>, Val<2>, Val<3>, Val<4>, Val<5>, Val<6>, Val<7>, Val<8>, Val<9>, Val<10>, Val<11>, Val<12>, Val<13>, Val<14>, Val<15>, Val<16>, Val<16>)" + @test mangle( + identity, Val{1}, Val{2}, Val{3}, Val{4}, Val{5}, Val{6}, Val{7}, Val{8}, + Val{9}, Val{10}, Val{11}, Val{12}, Val{13}, Val{14}, Val{15}, + Val{16}, Val{16} + ) == + "identity(Val<1>, Val<2>, Val<3>, Val<4>, Val<5>, Val<6>, Val<7>, Val<8>, Val<9>, Val<10>, Val<11>, Val<12>, Val<13>, Val<14>, Val<15>, Val<16>, Val<16>)" # intertwined substitutions @test mangle( @@ -111,14 +113,14 @@ DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerM @test isoverlayed(DoubleStackedMT()) == true end - o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT()) - s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT()) + o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT()) ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT()) @test s_sin == o_sin @test ss_sin == o_sin - o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT()) - s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT()) + o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT()) ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT()) if VERSION >= v"1.11.0-DEV.363" @test o_sin.matches == s_sin.matches @@ -150,8 +152,8 @@ next_world = Base.get_world_counter() @test worlds.min_world > prev_world @test worlds.max_world == typemax(typeof(next_world)) - o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT()) - s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT()) + o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT()) ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT()) if VERSION >= v"1.11.0-DEV.363" @test o_sin.matches == s_sin.matches