Skip to content

Conversation

@mofeing
Copy link
Collaborator

@mofeing mofeing commented Feb 9, 2025

cc @wsmoses

works except for ifrt_CopyArrayToHostBuffer. it crashes with Illegal instruction when called from Julia, but works when used in the pjrt-ifrt-test/demo-ifrt.cpp

most probably i'm doing sth wrong

julia> using Reactant
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")

julia> using Reactant.XLA: IFRT

julia> a = ConcreteRArray(rand(2,2))
2×2 ConcreteRArray{Float64, 2, 1, Reactant.Sharding.FinalizedNoSharding}:
 0.625281  0.690309
 0.698304  0.966754

julia> code = @code_hlo sin.(a)
module @reactant_Base.Br... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<2x2xf64>) -> tensor<2x2xf64> {
    %0 = stablehlo.sine %arg0 : tensor<2x2xf64>
    return %0 : tensor<2x2xf64>
  }
}


julia> client = IFRT.Client(Reactant.XLA.default_backend[])
Reactant.XLA.IFRT.Client(Ptr{Nothing} @0x0000000004ba6ee0)

julia> exec = IFRT.compile(client, code)
Reactant.XLA.IFRT.LoadedExecutable(Ptr{Nothing} @0x0000000004bae7a0)

julia> a_ifrt = IFRT.Array(client, a.data[1].buffer)
Reactant.XLA.IFRT.Array(Ptr{Nothing} @0x0000000003aff330)

julia> results = IFRT.execute(exec, (a_ifrt.ptr,), (UInt8(false),), Val(1))
(Reactant.XLA.IFRT.Array(Ptr{Nothing} @0x0000000002456d00),)

julia> b = zeros(size(a))
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0

julia> IFRT.CopyArrayToHostBuffer(results[1], pointer(b))
Unreachable reached at 0x71deb97d069b

[2249386] signal 4 (2): Illegal instruction
in expression starting at REPL[10]:1
ifrt_CopyArrayToHostBuffer at /proc/self/cwd/API.cpp:1106
CopyArrayToHostBuffer at /home/mofeing/Reactant.jl/src/xla/IFRT/Array.jl:30
unknown function (ip: 0x71e2480038f9)
jl_apply at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10097.1 at /home/mofeing/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/compiled/v1.11/REPL/u0gqU_XvZAg.so (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_14693.1 at /home/mofeing/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/compiled/v1.11/REPL/u0gqU_XvZAg.so (unknown line)
jl_apply at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73609.1 at /home/mofeing/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x71e24942a1c9)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 21279099 (Pool: 21278744; Big: 355); GC: 21
Illegal instruction (core dumped)

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@mofeing
Copy link
Collaborator Author

mofeing commented Feb 9, 2025

@wsmoses i don't think we can avoid virtual destructors: while we create tsl::RCReference<ifrt::PjRtArray>, LoadedExecutable::Execute returns a tsl::RCReference<ifrt::Array> which is a virtual class.

@wsmoses
Copy link
Member

wsmoses commented Feb 9, 2025

That's okay we should just have a rcreference ifrt::array for all the the ifrt data

@mofeing mofeing requested a review from wsmoses February 9, 2025 18:38
@mofeing
Copy link
Collaborator Author

mofeing commented Feb 9, 2025

@giordano this is ready to be included in the JLL

@giordano
Copy link
Member

Is it expected this PR is only deleting code in API.cpp?

@mofeing
Copy link
Collaborator Author

mofeing commented Feb 10, 2025

Is it expected this PR is only deleting code in API.cpp?

nope, this PR removes old, unused (and probably broken) IFRT code in API.cpp and introduces new, simpler and tested code for IFRT.

only reviewing the new code is ok

@giordano
Copy link
Member

Ah, ok, I was looking at the negative diff figure 😅

#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "llvm/Support/ExtensibleRTTI.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this header?

}
#pragma endregion
template<typename T>
struct Holded {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor english nit, Holded -> HeldValue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right 🤦

@mofeing mofeing changed the title IFRT bindings (round 4) [ReactantExtra] IFRT bindings (round 4) Feb 10, 2025
@mofeing mofeing merged commit 447ff4f into main Feb 10, 2025
28 of 32 checks passed
@mofeing mofeing deleted the ss/ifrt-take-4 branch February 10, 2025 15:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants