-
Couldn't load subscription status.
- Fork 33
fix: handle aos for mul #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
We still need to support linear solves before we can support the example: 1-element ExceptionStack:
TypeError: non-boolean (Reactant.TracedRNumber{Bool}) used in boolean context
Stacktrace:
[1] findnext
@ ./array.jl:2352 [inlined]
[2] findfirst
@ ./array.jl:2403 [inlined]
[3] ldiv!
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:490 [inlined]
[4] ldiv!(none::Matrix{Reactant.TracedRNumber{Float64}}, none::Diagonal{Reactant.TracedRNumber{Float64}, Vector{Reactant.TracedRNumber{Float64}}}, none::Adjoint{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}})
@ Reactant ./<missing>:0
[5] getproperty
@ ./Base.jl:49 [inlined]
[6] ldiv!
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:484 [inlined]
[7] call_with_reactant(::typeof(ldiv!), ::Matrix{Reactant.TracedRNumber{Float64}}, ::Diagonal{Reactant.TracedRNumber{Float64}, Vector{Reactant.TracedRNumber{Float64}}}, ::Adjoint{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0
[8] \
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:478 [inlined]
[9] \
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:1124 [inlined]
[10] \(none::Adjoint{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}}, none::Adjoint{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}})
@ Reactant ./<missing>:0
[11] getproperty
@ ./Base.jl:49 [inlined]
[12] size
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:326 [inlined]
[13] \
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:1120 [inlined]
[14] call_with_reactant(::typeof(\), ::Adjoint{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}}, ::Adjoint{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0
[15] /
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:1164 [inlined]
[16] /(none::Matrix{Reactant.TracedRNumber{Float64}}, none::Matrix{Reactant.TracedRNumber{Float64}})
@ Reactant ./<missing>:0
[17] size
@ ./array.jl:191 [inlined]
[18] /
@ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:1163 [inlined]
[19] call_with_reactant(::typeof(/), ::Matrix{Reactant.TracedRNumber{Float64}}, ::Matrix{Reactant.TracedRNumber{Float64}})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0
[20] correct
@ /mnt/.julia/packages/GaussianDistributions/glMRi/src/GaussianDistributions.jl:180 [inlined]
[21] correct(none::Gaussian{Vector{Reactant.TracedRNumber{Float64}}, Reactant.TracedRArray{Float64, 2}}, none::Gaussian{Vector{Float64}, Diagonal{Reactant.TracedRNumber{Float64}, Vector{Reactant.TracedRNumber{Float64}}}}, none::Matrix{Reactant.TracedRNumber{Float64}})
@ Reactant ./<missing>:0
[22] getproperty
@ ./Base.jl:49 [inlined]
[23] pair
@ /mnt/.julia/packages/GaussianDistributions/glMRi/src/GaussianDistributions.jl:66 [inlined]
[24] correct
@ /mnt/.julia/packages/GaussianDistributions/glMRi/src/GaussianDistributions.jl:174 [inlined]
[25] call_with_reactant(::typeof(GaussianDistributions.correct), ::Gaussian{Vector{Reactant.TracedRNumber{Float64}}, Reactant.TracedRArray{Float64, 2}}, ::Gaussian{Vector{Float64}, Diagonal{Reactant.TracedRNumber{Float64}, Vector{Reactant.TracedRNumber{Float64}}}}, ::Matrix{Reactant.TracedRNumber{Float64}})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0
[26] kalman_filter
@ /mnt/software/lux/Reactant.jl/envs/gd/ex.jl:44 [inlined]
[27] kalman_filter(none::LinearGaussianModel{Reactant.TracedRNumber{Float64}, LinearGaussianProcess{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}, Diagonal{Reactant.TracedRNumber{Float64}, Reactant.TracedRArray{Float64, 1}}}, LinearGaussianProcess{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}, Diagonal{Reactant.TracedRNumber{Float64}, Vector{Reactant.TracedRNumber{Float64}}}}}, none::Gaussian{Vector{Reactant.TracedRNumber{Float64}}, Matrix{Reactant.TracedRNumber{Float64}}}, none::Vector{Float64})
@ Reactant ./<missing>:0
[28] getproperty
@ ./Base.jl:49 [inlined]
[29] kalman_filter
@ /mnt/software/lux/Reactant.jl/envs/gd/ex.jl:33 [inlined]
[30] call_with_reactant(::typeof(kalman_filter), ::LinearGaussianModel{Reactant.TracedRNumber{Float64}, LinearGaussianProcess{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}, Diagonal{Reactant.TracedRNumber{Float64}, Reactant.TracedRArray{Float64, 1}}}, LinearGaussianProcess{Reactant.TracedRNumber{Float64}, Matrix{Reactant.TracedRNumber{Float64}}, Diagonal{Reactant.TracedRNumber{Float64}, Vector{Reactant.TracedRNumber{Float64}}}}}, ::Gaussian{Vector{Reactant.TracedRNumber{Float64}}, Matrix{Reactant.TracedRNumber{Float64}}}, ::Vector{Float64})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0
[31] logℓ
@ /mnt/software/lux/Reactant.jl/envs/gd/ex.jl:69 [inlined]
[32] logℓ(none::Reactant.TracedRArray{Float64, 1}, none::Vector{Float64})
@ Reactant ./<missing>:0
[33] logℓ
@ /mnt/software/lux/Reactant.jl/envs/gd/ex.jl:67 [inlined]
[34] call_with_reactant(::typeof(logℓ), ::Reactant.TracedRArray{Float64, 1}, ::Vector{Float64})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0
[35] (::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(logℓ), Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 1}, Vector{Float64}}})()
@ Reactant.TracedUtils /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:182
[36] block!(f::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(logℓ), Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 1}, Vector{Float64}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
[37] make_mlir_fn(f::Function, args::Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
@ Reactant.TracedUtils /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:169
[38] make_mlir_fn
@ /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:86 [inlined]
[39] #10
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:319 [inlined]
[40] block!(f::Reactant.Compiler.var"#10#15"{typeof(logℓ), Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
[41] #9
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:318 [inlined]
[42] mmodule!(f::Reactant.Compiler.var"#9#14"{Reactant.MLIR.IR.Module, typeof(logℓ), Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92
[43] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}; optimize::Bool)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:315
[44] compile_mlir!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:314 [inlined]
[45] #6
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:305 [inlined]
[46] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{optimize::Bool}, typeof(logℓ), Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
[47] compile_mlir(f::Function, args::Tuple{ConcreteRArray{Float64, 1}, Vector{Float64}}; kwargs::@Kwargs{optimize::Bool})
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:303
[48] top-level scope
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:536
[49] top-level scope
@ none:1
[50] eval
@ ./boot.jl:430 [inlined]
[51] eval
@ ./Base.jl:130 [inlined]
[52] repleval(m::Module, code::Expr, ::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:229
[53] #112
@ ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:192 [inlined]
[54] with_logstate(f::VSCodeServer.var"#112#114"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt}, logstate::Base.CoreLogging.LogState)
@ Base.CoreLogging ./logging/logging.jl:522
[55] with_logger
@ ./logging/logging.jl:632 [inlined]
[56] (::VSCodeServer.var"#111#113"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:193
[57] #invokelatest#2
@ ./essentials.jl:1055 [inlined]
[58] invokelatest(::Any)
@ Base ./essentials.jl:1052 |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/Overlay.jl
Outdated
| @reactant_overlay @noinline function LinearAlgebra.mul!( | ||
| C::$cT, A::$aT, B::$bT, α::Number, β::Number | ||
| ) | ||
| C = aos_to_soa(C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won’t work for C because we will create a new array instead of updating the existing matrix of traces in place.
so if c is a matrix of traces we need to after doing the actual mul set all the values of the original C
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor comment about C
* Fix mul overload * fix * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: handle aos for mul (#441) * fix: handle aos for mul * Update ext/ReactantArrayInterfaceExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * revert: incorrect aos_to_soa for C --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Reactant.jl * Update ReactantArrayInterfaceExt.jl * Update ext/ReactantArrayInterfaceExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Avik Pal <avikpal@mit.edu>
No description provided.