Skip to content

Incorrect broadcast_to_size implementation #512

@avik-pal

Description

@avik-pal
julia> using Reactant
Precompiling Reactant...
Info Given Reactant was explicitly requested, output will be shown live 
2025-01-11 18:53:43.088241: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 13648986965302749079
  1 dependency successfully precompiled in 23 seconds. 55 already precompiled.
  1 dependency had output during precompilation:
┌ Reactant
│  [Output was shown above]
└  
[ Info: Precompiling ReactantStatisticsExt [963ed91e-491b-54ce-bb4b-249dcb1ed2bb] (cache misses: wrong dep version loaded (18))

julia> x_ra = rand(2, 3) |> Reactant.to_rarray
2×3 ConcreteRArray{Float64, 2}:
 0.515395  0.375137  0.155133
 0.82185   0.5739    0.0452939

julia> @code_hlo Reactant.TracedUtils.broadcast_to_size(x_ra, (1, 2, 1, 3))
error: size of operand dimension 0 (2) is not equal to 1 or size of result dimension 0 (1)
ERROR: "failed to run pass manager on module"
Stacktrace:
  [1] run!
    @ /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:79 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:320
  [3] run_pass_pipeline!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:315 [inlined]
  [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}, NTuple{4, Int64}}; optimize::Bool, no_nan::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:382
  [5] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:350 [inlined]
  [6] #7
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:341 [inlined]
  [7] context!(f::Reactant.Compiler.var"#7#8"{@Kwargs{no_nan::Bool, optimize::Bool}, typeof(Reactant.TracedUtils.broadcast_to_size), Tuple{ConcreteRArray{Float64, 2}, NTuple{4, Int64}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
  [8] compile_mlir(f::Function, args::Tuple{ConcreteRArray{Float64, 2}, NTuple{4, Int64}}; kwargs::@Kwargs{no_nan::Bool, optimize::Bool})
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:339
  [9] top-level scope
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:591
 [10] top-level scope
    @ none:1
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<3x2xf64>) -> (tensor<3x1x2x1xf64>, tensor<3x2xf64>), sym_name = "main"}> ({
  ^bb0(%arg0: tensor<3x2xf64>):
    %0 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 1, 0>}> : (tensor<3x2xf64>) -> tensor<2x3xf64>
    %1 = "stablehlo.broadcast_in_dim"(%0) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x3xf64>) -> tensor<1x2x1x3xf64>
    %2 = "stablehlo.transpose"(%1) <{permutation = array<i64: 3, 2, 1, 0>}> : (tensor<1x2x1x3xf64>) -> tensor<3x1x2x1xf64>
    %3 = "stablehlo.transpose"(%0) <{permutation = array<i64: 1, 0>}> : (tensor<2x3xf64>) -> tensor<3x2xf64>
    "func.return"(%2, %3) : (tensor<3x1x2x1xf64>, tensor<3x2xf64>) -> ()
  }) : () -> ()
}) : () -> ()

we need to add permutedims to do this correctly.

xref LuxDL/Lux.jl#954

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions