Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ for file in [
"Gpu.jl",
"Affine.jl",
"TPU.jl",
"Triton.jl"
"Triton.jl",
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
135 changes: 95 additions & 40 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ struct LLVMFunc{F,tt}
entry::String
end

function Base.getproperty(f::LLVMFunc{F, tt}, sym::Symbol) where {F, tt}
function Base.getproperty(f::LLVMFunc{F,tt}, sym::Symbol) where {F,tt}
if sym === :fun
f
else
Expand All @@ -235,8 +235,14 @@ end

# TODO in the future we may want to avoid doing a second cufunction compilation
# for computing the thread/block count (or potentially do it ourselves).
@noinline function CUDA.launch_configuration(f::LLVMFunc{F, tt}; shmem::Union{Integer, Base.Callable}=0, max_threads::Integer=0) where {F, tt}
CUDA.launch_configuration(Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; shmem, max_threads)
@noinline function CUDA.launch_configuration(
f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0
) where {F,tt}
return CUDA.launch_configuration(
Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun;
shmem,
max_threads,
)
end

const GPUCompiler = CUDA.GPUCompiler
Expand Down Expand Up @@ -282,7 +288,12 @@ function compile(job)
entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
:llvm,
job;
optimize=false,
cleanup=false,
validate=false,
libraries=false,
# :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
)
Expand Down Expand Up @@ -357,19 +368,21 @@ function link(job, compiled)
end

function to_bytes(x)
sz = sizeof(x)
ref = Ref(x)
GC.@preserve ref begin
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
vec = Vector{UInt8}(undef, sz)
for i in 1:sz
@inbounds vec[i] = Base.unsafe_load(ptr, i)
end
vec
end
end

function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)
sz = sizeof(x)
ref = Ref(x)
GC.@preserve ref begin
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
vec = Vector{UInt8}(undef, sz)
for i in 1:sz
@inbounds vec[i] = Base.unsafe_load(ptr, i)
end
vec
end
end

function Reactant.make_tracer(
seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...
)
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
Reactant.make_tracer(seen, x, path, mode; kwargs...)
return prev
Expand All @@ -388,7 +401,9 @@ function get_field_offset(T::Type, path)
findfirst(==(field), fieldnames(current_type))
end
if field_idx === nothing
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
error(
"Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path",
)
end

# Add the offset of this field
Expand Down Expand Up @@ -419,7 +434,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
rarrays = TracedRArray[]

fname = func.entry

wrapper_tys = MLIR.IR.Type[]
ctx = MLIR.IR.context()
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1))
Expand All @@ -436,19 +451,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end
push!(wrapper_tys, cullvm_ty)
end

sym_name = String(gensym("call_$fname"))
mod = MLIR.IR.mmodule()
CConv=MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel))
CConv = MLIR.IR.Attribute(
MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel)
)
voidty = MLIR.IR.Type(MLIR.API.mlirLLVMVoidTypeGet(ctx))
wrapftype = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false))
wrapftype = MLIR.IR.Type(
MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false)
)
wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.llvm.func(;
sym_name,
sym_visibility=MLIR.IR.Attribute("private"),
function_type=wrapftype,
body=MLIR.IR.Region(),
CConv
CConv,
)
end
wrapbody = MLIR.IR.Block(wrapper_tys, [MLIR.IR.Location() for _ in wrapper_tys])
Expand All @@ -459,11 +478,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

symtab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
gpufunc = MLIR.IR.lookup(symtab, fname)
MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)))
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))
MLIR.IR.attr!(
gpufunc,
"CConv",
MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)),
)
gpu_function_type = MLIR.IR.Type(
Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")
)

trueidx = 1
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[]

llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
i8 = MLIR.IR.Type(UInt8)
Expand All @@ -476,18 +501,34 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

# TODO check for only integer and explicitly non cutraced types
MLIR.IR.block!(wrapbody) do
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
argty = MLIR.IR.Type(
MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1)
)
trueidx += 1
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1)
c1 = MLIR.IR.result(
MLIR.Dialects.llvm.mlir_constant(;
res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)
),
1,
)
alloc = MLIR.IR.result(
MLIR.Dialects.llvm.alloca(
c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr
),
1,
)
push!(allocs, (alloc, argty))

sz = sizeof(a)
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
cdata = MLIR.IR.result(
MLIR.Dialects.llvm.mlir_constant(;
res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))
),
1,
)
MLIR.Dialects.llvm.store(cdata, alloc)
end

end

argidx = 1
Expand All @@ -499,21 +540,30 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
if p[1] !== kernelargsym
continue
end

arg = arg.mlir_data
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

# Get the allocation corresponding to which arg we're doing
alloc = allocs[p[2]][1]

# we need to now compute the offset in bytes of the path
julia_arg = allargs[p[2]]

offset = get_field_offset(typeof(julia_arg), p[3:end])
MLIR.IR.block!(wrapbody) do
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1)
ptr = MLIR.IR.result(
MLIR.Dialects.llvm.getelementptr(
alloc,
MLIR.IR.Value[];
res=llvmptr,
elem_type=i8,
rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]),
),
1,
)
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
end

Expand All @@ -530,11 +580,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
),
),
)

argidx += 1
end
end

MLIR.IR.block!(wrapbody) do
for arg in allocs
if arg === nothing
Expand All @@ -544,7 +594,12 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
MLIR.Dialects.llvm.call(
wrapargs,
MLIR.IR.Value[];
callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)),
op_bundle_sizes=MLIR.IR.Attribute(Int32[]),
)
MLIR.Dialects.llvm.return_(nothing)
end

Expand All @@ -565,7 +620,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
mlir_args;
result_0=restys,
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),
)

argidx = 1
Expand All @@ -574,7 +629,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
continue
end
arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx))
argidx+=1
argidx += 1
end
end

Expand Down
7 changes: 6 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,12 @@ function optimization_passes(; no_nan::Bool=false)
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
return join(
["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ','
[
"inline{default-pipeline=canonicalize max-iterations=4}",
"libdevice-funcs-raise",
func_passes,
],
',',
)
end

Expand Down
21 changes: 9 additions & 12 deletions src/mlir/Dialects/Nvvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,15 @@ function barrier(
attributes = NamedAttribute[]
!isnothing(barrierId) && push!(operands, barrierId)
!isnothing(numberOfThreads) && push!(operands, numberOfThreads)
push!(
attributes,
operandsegmentsizes([
if (barrierId == nothing)
0
elseif 1(numberOfThreads == nothing)
0
else
1
end
]),
)
push!(attributes, operandsegmentsizes([
if (barrierId == nothing)
0
elseif 1(numberOfThreads == nothing)
0
else
1
end,
]))

return create_operation(
"nvvm.barrier",
Expand Down
23 changes: 11 additions & 12 deletions src/mlir/Dialects/TPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,18 +902,17 @@ function sem_signal(
attributes = NamedAttribute[]
!isnothing(device_id) && push!(operands, device_id)
!isnothing(core_id) && push!(operands, core_id)
push!(
attributes,
operandsegmentsizes([
1, 1, if (device_id == nothing)
0
elseif 1(core_id == nothing)
0
else
1
end
]),
)
push!(attributes, operandsegmentsizes([
1,
1,
if (device_id == nothing)
0
elseif 1(core_id == nothing)
0
else
1
end,
]))
!isnothing(core_type) && push!(attributes, namedattribute("core_type", core_type))

return create_operation(
Expand Down
34 changes: 17 additions & 17 deletions src/mlir/Dialects/Triton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,18 @@ function dot_scaled(
]
!isnothing(lhs_scale) && push!(operands, lhs_scale)
!isnothing(rhs_scale) && push!(operands, rhs_scale)
push!(
attributes,
operandsegmentsizes([
1, 1, 1, if (lhs_scale == nothing)
0
elseif 1(rhs_scale == nothing)
0
else
1
end
]),
)
push!(attributes, operandsegmentsizes([
1,
1,
1,
if (lhs_scale == nothing)
0
elseif 1(rhs_scale == nothing)
0
else
1
end,
]))

return create_operation(
"tt.dot_scaled",
Expand Down Expand Up @@ -949,16 +949,16 @@ function load(
attributes = NamedAttribute[]
!isnothing(mask) && push!(operands, mask)
!isnothing(other) && push!(operands, other)
push!(
attributes,
operandsegmentsizes([1, if (mask == nothing)
push!(attributes, operandsegmentsizes([
1,
if (mask == nothing)
0
elseif 1(other == nothing)
0
else
1
end]),
)
end,
]))
!isnothing(result) && push!(op_ty_results, result)
!isnothing(boundaryCheck) &&
push!(attributes, namedattribute("boundaryCheck", boundaryCheck))
Expand Down
Loading