Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,19 @@ end
function execute_ir(N, n_outs, fn)
ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32"
cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32"
res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec, [$N x $ptr] %inps, [$N x i8] %donated) alwaysinline {
args = N > 0 ? ", [$N x $ptr] %inps, [$N x i8] %donated" : ""
stores = N > 0 ? """
store [$N x $ptr] %inps, [$N x $ptr]* %inpa
store [$N x i8] %donated, [$N x i8]* %dona
""" : ""

res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec $args) alwaysinline {
entry:
%inpa = alloca [$N x $ptr]
%dona = alloca [$N x i8]
%outa = alloca [$n_outs x $ptr]
%futpa = alloca [$n_outs x $ptr]
store [$N x $ptr] %inps, [$N x $ptr]* %inpa
%dona = alloca [$N x i8]
store [$N x i8] %donated, [$N x i8]* %dona
$stores
%futa = alloca i8
call void inttoptr ($ptr $fn to void ($ptr, $cint, [$N x $ptr]*, [$N x i8]*, $cint, [$n_outs x $ptr]*, i8*, [$n_outs x $ptr]*)*)($ptr %exec, $cint $N, [$N x $ptr]* nocapture readonly %inpa, [$N x i8]* nocapture readonly %dona, $cint $n_outs, [$n_outs x $ptr]* nocapture writeonly %outa, i8* nocapture writeonly %futa, [$n_outs x $ptr]* nocapture writeonly %futpa)
%out = load [$n_outs x $ptr], [$n_outs x $ptr]* %outa
Expand Down Expand Up @@ -323,17 +328,19 @@ end
:(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing)),
)
end

args_type = N > 0 ? (Ptr{Cvoid}, NTuple{N,Ptr{Cvoid}}, NTuple{N,UInt8}) : (Ptr{Cvoid},)
args = N > 0 ? (:inputs, :donated_args) : ()
return quote
Base.@_inline_meta
exec = exec.exec
GC.@preserve exec begin
outputs, future_res, future = Base.llvmcall(
($ir, "f"),
Tuple{NTuple{n_outs,Ptr{Cvoid}},NTuple{n_outs,Ptr{Cvoid}},Bool},
Tuple{Ptr{Cvoid},NTuple{N,Ptr{Cvoid}},NTuple{N,UInt8}},
Tuple{$args_type...},
exec,
inputs,
donated_args,
$(args...),
)
end
return ($(results...),)
Expand Down
48 changes: 22 additions & 26 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,12 @@ end
end

@testset "constant" begin
# TODO currently crashes due to #196
# for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]]
# @test x ≈ @jit Ops.constant(x)
for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]]
@test x ≈ @jit Ops.constant(x)

# xscalar = x[1]
# @test xscalar ≈ @jit Ops.constant(xscalar)
# end
xscalar = x[1]
@test xscalar ≈ @jit Ops.constant(xscalar)
end
end

@testset "cosine" begin
Expand Down Expand Up @@ -281,22 +280,21 @@ end
end

@testset "iota" begin
# TODO this crashes. seems like the same error as #196
# g1(shape) = Ops.iota(Int, shape; iota_dimension=1)
# @test [
# 0 0 0 0 0
# 1 1 1 1 1
# 2 2 2 2 2
# 3 3 3 3 3
# ] ≈ @jit g1([4, 5])

# g2(shape) = Ops.iota(Int, shape; iota_dimension=2)
# @test [
# 0 1 2 3 4
# 0 1 2 3 4
# 0 1 2 3 4
# 0 1 2 3 4
# ] ≈ @jit g2([4, 5])
g1(shape) = Ops.iota(Int, shape; iota_dimension=1)
@test [
0 0 0 0 0
1 1 1 1 1
2 2 2 2 2
3 3 3 3 3
] ≈ @jit g1([4, 5])

g2(shape) = Ops.iota(Int, shape; iota_dimension=2)
@test [
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4
] ≈ @jit g2([4, 5])
end

@testset "is_finite" begin
Expand Down Expand Up @@ -443,8 +441,7 @@ end
end

@testset "partition_id" begin
# TODO this crashes. seems like the same error as #196
# @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
@test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
end

@testset "popcnt" begin
Expand Down Expand Up @@ -481,8 +478,7 @@ end
end

@testset "replica_id" begin
# TODO this crashes. seems like the same error as #196
# @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
@test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
end

@testset "reshape" begin
Expand Down
Loading