diff --git a/Project.toml b/Project.toml index 3f714944..76a7ca2d 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,9 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Cassette = "7057c7e9-c182-5462-911a-8362d720325c" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 83b7b8ea..0b44766f 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -1,12 +1,13 @@ module KernelAbstractions export @kernel -export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print +export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print, @printf export Device, GPU, CPU, CUDADevice, Event, MultiEvent, NoneEvent export async_copy! using MacroTools +using Printf using StaticArrays using Cassette using Adapt @@ -28,6 +29,7 @@ and then invoked on the arguments. - [`@uniform`](@ref) - [`@synchronize`](@ref) - [`@print`](@ref) +- [`@printf`](@ref) # Example: @@ -236,6 +238,32 @@ macro print(items...) end end +# When a function with a variable-length argument list is called, the variable +# arguments are passed using C's old ``default argument promotions.'' These say that +# types char and short int are automatically promoted to int, and type float is +# automatically promoted to double. Therefore, varargs functions will never receive +# arguments of type char, short int, or float. + +promote_c_argument(arg) = arg +promote_c_argument(arg::Cfloat) = Cdouble(arg) +promote_c_argument(arg::Cchar) = Cint(arg) +promote_c_argument(arg::Cshort) = Cint(arg) + +""" + @printf(fmt::String, args...) + +This is a unified formatted printf statement. + +# Platform differences + - `GPU`: This will reorganize the items to print via @cuprintf + - `CPU`: This will call `sprintf(fmt, items...)` +""" +macro printf(fmt::String, args...) + fmt_val = Val(Symbol(fmt)) + + return :(__printf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...))) +end + """ @index @@ -452,6 +480,76 @@ end end end +# Results in "Conversion of boxed type String is not allowed" +# @generated function __printf(::Val{fmt}, argspec...) where {fmt} +# arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)] +# arg_types = [argspec...] + +# T_void = LLVM.VoidType(LLVM.Interop.JuliaContext()) +# T_int32 = LLVM.Int32Type(LLVM.Interop.JuliaContext()) +# T_pint8 = LLVM.PointerType(LLVM.Int8Type(LLVM.Interop.JuliaContext())) + +# # create functions +# param_types = LLVMType[convert.(LLVMType, arg_types)...] +# llvm_f, _ = create_function(T_int32, param_types) +# mod = LLVM.parent(llvm_f) +# sfmt = String(fmt) +# # generate IR +# Builder(LLVM.Interop.JuliaContext()) do builder +# entry = BasicBlock(llvm_f, "entry", LLVM.Interop.JuliaContext()) +# position!(builder, entry) + +# str = globalstring_ptr!(builder, sfmt) + +# # construct and fill args buffer +# if isempty(argspec) +# buffer = LLVM.PointerNull(T_pint8) +# else +# argtypes = LLVM.StructType("printf_args", LLVM.Interop.JuliaContext()) +# elements!(argtypes, param_types) + +# args = alloca!(builder, argtypes) +# for (i, param) in enumerate(parameters(llvm_f)) +# p = struct_gep!(builder, args, i-1) +# store!(builder, param, p) +# end + +# buffer = bitcast!(builder, args, T_pint8) +# end + +# # invoke vprintf and return +# vprintf_typ = LLVM.FunctionType(T_int32, [T_pint8, T_pint8]) +# vprintf = LLVM.Function(mod, "vprintf", vprintf_typ) +# chars = call!(builder, vprintf, [str, buffer]) + +# ret!(builder, chars) +# end + +# arg_tuple = Expr(:tuple, arg_exprs...) +# call_function(llvm_f, Int32, Tuple{arg_types...}, arg_tuple) +# end + +# Results in "InvalidIRError: compiling kernel +# gpu_kernel_printf(... Reason: unsupported dynamic +# function invocation" +@generated function __printf(::Val{fmt}, items...) where {fmt} + str = "" + args = [] + + for i in 1:length(items) + item = :(items[$i]) + T = items[i] + if T <: Val + item = QuoteNode(T.parameters[1]) + end + push!(args, item) + end + sfmt = String(fmt) + quote + Printf.@printf($sfmt, $(args...)) + end +end + ### # Backends/Implementation ### diff --git a/src/backends/cpu.jl b/src/backends/cpu.jl index 559c6b04..c0fd53f7 100644 --- a/src/backends/cpu.jl +++ b/src/backends/cpu.jl @@ -208,6 +208,10 @@ end __print(items...) end +@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__printf), fmt, items...) + __printf(fmt, items...) +end + generate_overdubs(CPUCtx) # Don't recurse into these functions diff --git a/src/backends/cuda.jl b/src/backends/cuda.jl index 3c19312b..62ad64e2 100644 --- a/src/backends/cuda.jl +++ b/src/backends/cuda.jl @@ -319,6 +319,14 @@ end CUDA._cuprint(args...) end +@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), fmt, args...) + CUDA._cuprintf(Val(fmt), args...) +end + +@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), ::Val{fmt}, args...) where fmt + CUDA._cuprintf(Val(fmt), args...) +end + ### # GPU implementation of const memory ### diff --git a/test/print_test.jl b/test/print_test.jl index 24356369..f0d31d98 100644 --- a/test/print_test.jl +++ b/test/print_test.jl @@ -5,25 +5,51 @@ if has_cuda_gpu() CUDA.allowscalar(false) end +struct Foo{A,B} end +get_name(::Type{T}) where T<:Foo = "Foo" + @kernel function kernel_print() I = @index(Global) @print("Hello from thread ", I, "!\n") end +@kernel function kernel_printf() + I = @index(Global) + # @printf("Hello printf %s thread %d! type = %s.\n", "from", I, nameof(Foo)) + # @print("Hello printf from thread ", I, "!\n") + # @printf("Hello printf %s thread %d! type = %s.\n", "from", I, string(nameof(Foo))) + @printf("Hello printf %s thread %d! type = %s.\n", "from", I, "Foo") + @printf("Hello printf %s thread %d! type = %s.\n", "from", I, get_name(Foo)) +end + function test_print(backend) kernel = kernel_print(backend, 4) - kernel(ndrange=(4,)) + kernel(ndrange=(4,)) +end + +function test_printf(backend) + kernel = kernel_printf(backend, 4) + kernel(ndrange=(4,)) end @testset "print test" begin + wait(test_print(CPU())) + @test true + + wait(test_printf(CPU())) + @test true + if has_cuda_gpu() wait(test_print(CUDADevice())) @test true + wait(test_printf(CUDADevice())) + @test true end - wait(test_print(CPU())) + @print("Why this should work") @test true - @print("Why this should work") + @printf("Why this should work") @test true end +