From 2005a04a4089ffa72d53c00474cfd760598a2b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 4 Dec 2024 14:57:01 +0100 Subject: [PATCH] fix no inputs XLA execute calls. --- src/XLA.jl | 21 ++++++++++++++------- test/ops.jl | 48 ++++++++++++++++++++++-------------------------- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/XLA.jl b/src/XLA.jl index a8a315a7b6..00420edb84 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -285,14 +285,19 @@ end function execute_ir(N, n_outs, fn) ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32" cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32" - res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec, [$N x $ptr] %inps, [$N x i8] %donated) alwaysinline { + args = N > 0 ? ", [$N x $ptr] %inps, [$N x i8] %donated" : "" + stores = N > 0 ? """ +store [$N x $ptr] %inps, [$N x $ptr]* %inpa +store [$N x i8] %donated, [$N x i8]* %dona + """ : "" + + res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec $args) alwaysinline { entry: %inpa = alloca [$N x $ptr] + %dona = alloca [$N x i8] %outa = alloca [$n_outs x $ptr] %futpa = alloca [$n_outs x $ptr] - store [$N x $ptr] %inps, [$N x $ptr]* %inpa - %dona = alloca [$N x i8] - store [$N x i8] %donated, [$N x i8]* %dona + $stores %futa = alloca i8 call void inttoptr ($ptr $fn to void ($ptr, $cint, [$N x $ptr]*, [$N x i8]*, $cint, [$n_outs x $ptr]*, i8*, [$n_outs x $ptr]*)*)($ptr %exec, $cint $N, [$N x $ptr]* nocapture readonly %inpa, [$N x i8]* nocapture readonly %dona, $cint $n_outs, [$n_outs x $ptr]* nocapture writeonly %outa, i8* nocapture writeonly %futa, [$n_outs x $ptr]* nocapture writeonly %futpa) %out = load [$n_outs x $ptr], [$n_outs x $ptr]* %outa @@ -323,6 +328,9 @@ end :(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing)), ) end + + args_type = N > 0 ? (Ptr{Cvoid}, NTuple{N,Ptr{Cvoid}}, NTuple{N,UInt8}) : (Ptr{Cvoid},) + args = N > 0 ? (:inputs, :donated_args) : () return quote Base.@_inline_meta exec = exec.exec @@ -330,10 +338,9 @@ end outputs, future_res, future = Base.llvmcall( ($ir, "f"), Tuple{NTuple{n_outs,Ptr{Cvoid}},NTuple{n_outs,Ptr{Cvoid}},Bool}, - Tuple{Ptr{Cvoid},NTuple{N,Ptr{Cvoid}},NTuple{N,UInt8}}, + Tuple{$args_type...}, exec, - inputs, - donated_args, + $(args...), ) end return ($(results...),) diff --git a/test/ops.jl b/test/ops.jl index 58c486f16c..5ae5a2fd03 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -140,13 +140,12 @@ end end @testset "constant" begin - # TODO currently crashes due to #196 - # for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]] - # @test x ≈ @jit Ops.constant(x) + for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]] + @test x ≈ @jit Ops.constant(x) - # xscalar = x[1] - # @test xscalar ≈ @jit Ops.constant(xscalar) - # end + xscalar = x[1] + @test xscalar ≈ @jit Ops.constant(xscalar) + end end @testset "cosine" begin @@ -281,22 +280,21 @@ end end @testset "iota" begin - # TODO this crashes. seems like the same error as #196 - # g1(shape) = Ops.iota(Int, shape; iota_dimension=1) - # @test [ - # 0 0 0 0 0 - # 1 1 1 1 1 - # 2 2 2 2 2 - # 3 3 3 3 3 - # ] ≈ @jit g1([4, 5]) - - # g2(shape) = Ops.iota(Int, shape; iota_dimension=2) - # @test [ - # 0 1 2 3 4 - # 0 1 2 3 4 - # 0 1 2 3 4 - # 0 1 2 3 4 - # ] ≈ @jit g2([4, 5]) + g1(shape) = Ops.iota(Int, shape; iota_dimension=1) + @test [ + 0 0 0 0 0 + 1 1 1 1 1 + 2 2 2 2 2 + 3 3 3 3 3 + ] ≈ @jit g1([4, 5]) + + g2(shape) = Ops.iota(Int, shape; iota_dimension=2) + @test [ + 0 1 2 3 4 + 0 1 2 3 4 + 0 1 2 3 4 + 0 1 2 3 4 + ] ≈ @jit g2([4, 5]) end @testset "is_finite" begin @@ -443,8 +441,7 @@ end end @testset "partition_id" begin - # TODO this crashes. seems like the same error as #196 - # @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32} + @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32} end @testset "popcnt" begin @@ -481,8 +478,7 @@ end end @testset "replica_id" begin - # TODO this crashes. seems like the same error as #196 - # @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32} + @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32} end @testset "reshape" begin