Skip to content

Conversation

@AntonOresten
Copy link
Contributor

Closes #360

This PR leverages broadcast_to_size to define outer repeating (the default for Base.repeat) of traced arrays.

I'm new here, so I may have missed something, but this looks promising:

julia> @code_hlo repeat(rand(2,3) |> Reactant.to_rarray, 5, 7, 11)
module {
  func.func @main(%arg0: tensor<3x2xf64>) -> tensor<11x21x10xf64> {
    %0 = stablehlo.reshape %arg0 : (tensor<3x2xf64>) -> tensor<1x1x1x3x1x2xf64>
    %1 = stablehlo.transpose %0, dims = [5, 4, 3, 2, 1, 0] : (tensor<1x1x1x3x1x2xf64>) -> tensor<2x1x3x1x1x1xf64>
    %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1, 2, 3, 4, 5] : (tensor<2x1x3x1x1x1xf64>) -> tensor<2x5x3x7x1x11xf64>
    %3 = stablehlo.transpose %2, dims = [5, 4, 3, 2, 1, 0] : (tensor<2x5x3x7x1x11xf64>) -> tensor<11x1x7x3x5x2xf64>
    %4 = stablehlo.reshape %3 : (tensor<11x1x7x3x5x2xf64>) -> tensor<11x21x10xf64>
    return %4 : tensor<11x21x10xf64>
  }
}

Here's a testset (wasn't sure where to put it) that takes a few cases and compares the results to the generic method from Base, which passes for me.

@testset "repeat" begin
    for A_size in [(2,), (2,3), (2,3,4), (2,3,4,5)]
        for counts in [(), (1,), (2,), (2,1), (1,2), (2,2), (2,2,2), (1,1,1,1,1)]
            A = rand(A_size...)
            A_ra = Reactant.to_rarray(A)
            @test (@jit repeat(A_ra, counts...)) == repeat(A, counts...)
        end
    end
end

@wsmoses
Copy link
Member

wsmoses commented Dec 11, 2024

perhaps in tests/basic.jl? regardless, I'll let @avik-pal review being more familiar with repeat semantics than I

@wsmoses wsmoses requested a review from avik-pal December 11, 2024 01:10
@avik-pal avik-pal merged commit 311498b into EnzymeAD:main Dec 12, 2024
17 of 37 checks passed
jumerckx added a commit to jumerckx/Reactant.jl that referenced this pull request Dec 16, 2024
commit 65e9976
Author: William Moses <gh@wsmoses.com>
Date:   Sat Dec 14 14:05:03 2024 -0600

    Interp2 (EnzymeAD#365)

    * WIP: kernels

    * more files

    * fix

    * wip

    * wqtmp

    * wip

    * inc

    * continuing

    * wip

    * more work

    * inf rec

    * fix

    * overload working

    * continuing

    * continuing

    * push

    * fix `call_with_reactant_generator` for Julia 1.11 (EnzymeAD#359)

    * conversion

    * continuing

    * Cleanup

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * Delete test/cuda.jl

    * fixup

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * fix apply

    * indep of change

    * minor fix in name

    * Update utils.jl

    * Interp take 2

    * continuing adentures

    * delcode

    * fix

    * tmp

    * make

    * fix

    * cleanup

    * continuing

    * more working

    * further simplify

    * fx

    * more improvements

    * minus show

    * less prints

    * even fewer

    * confusion

    * tmp

    * force clean

    * force oc

    * clean

    * Rewrite

    * fixup

    * fix

    * fix

    * fix

    * fixup

    * fix

    * wip

    * safe prints

    * fix

    * fix

    * stackoverflow

    * cleanup

    * dyindex

    * rt

    * continue

    * clean

    * fix

    * fix

    * fix

    * fix

    * fixup

    * fix

    * fix

    * capture oc

    * compile perf

    * v1.11 fix

    * other way 'round

    * formatting

    ---------

    Co-authored-by: William Moses <wsmoses@cyclops.juliacomputing.io>
    Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com>
    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
    Co-authored-by: jumerckx <julesmerckx12@gmail.com>

commit 73899f5
Author: Avik Pal <avikpal@mit.edu>
Date:   Sat Dec 14 14:58:47 2024 +0530

    fix: include files if they end with .jl (EnzymeAD#377)

commit 9f96c09
Author: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
Date:   Fri Dec 13 23:12:43 2024 +0100

    Run CI on aarch64 (EnzymeAD#350)

    * Run CI on aarch64

    * use julia pipeline for aarch64-linux

    * fix var

    * exclude aarch64-linux jobs from github ci

commit b56e661
Author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Date:   Fri Dec 13 08:58:51 2024 +0530

    chore: format code (EnzymeAD#371)

    Co-authored-by: mofeing <15837247+mofeing@users.noreply.github.com>

commit 311498b
Author: Anton Oresten <antonoresten@gmail.com>
Date:   Thu Dec 12 05:41:39 2024 +0100

    feat: define outer `repeat` method for `TracedRArray` (EnzymeAD#361)

    * Add repeat method

    * Add repeat tests

    * Update test/basic.jl

    * Update src/TracedRArray.jl

commit 8b90501
Author: Avik Pal <avikpal@mit.edu>
Date:   Thu Dec 12 10:11:00 2024 +0530

    fix: ensure printing of wrapped ConcreteRArrays goes through our show (EnzymeAD#367)

    * fix: ensure printing of wrapped ConcreteRArrays goes through our show

    * fix: allow wrapped arrays in mapreduce

commit ea97be3
Author: Sergio Sánchez Ramírez <sergio.sanchez.ramirez+git@bsc.es>
Date:   Wed Dec 11 22:02:14 2024 +0100

    Ignore versioned manifests
jumerckx added a commit to jumerckx/Reactant.jl that referenced this pull request Dec 16, 2024
commit 65e9976
Author: William Moses <gh@wsmoses.com>
Date:   Sat Dec 14 14:05:03 2024 -0600

    Interp2 (EnzymeAD#365)

    * WIP: kernels

    * more files

    * fix

    * wip

    * wqtmp

    * wip

    * inc

    * continuing

    * wip

    * more work

    * inf rec

    * fix

    * overload working

    * continuing

    * continuing

    * push

    * fix `call_with_reactant_generator` for Julia 1.11 (EnzymeAD#359)

    * conversion

    * continuing

    * Cleanup

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * Delete test/cuda.jl

    * fixup

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * fix apply

    * indep of change

    * minor fix in name

    * Update utils.jl

    * Interp take 2

    * continuing adentures

    * delcode

    * fix

    * tmp

    * make

    * fix

    * cleanup

    * continuing

    * more working

    * further simplify

    * fx

    * more improvements

    * minus show

    * less prints

    * even fewer

    * confusion

    * tmp

    * force clean

    * force oc

    * clean

    * Rewrite

    * fixup

    * fix

    * fix

    * fix

    * fixup

    * fix

    * wip

    * safe prints

    * fix

    * fix

    * stackoverflow

    * cleanup

    * dyindex

    * rt

    * continue

    * clean

    * fix

    * fix

    * fix

    * fix

    * fixup

    * fix

    * fix

    * capture oc

    * compile perf

    * v1.11 fix

    * other way 'round

    * formatting

    ---------

    Co-authored-by: William Moses <wsmoses@cyclops.juliacomputing.io>
    Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com>
    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
    Co-authored-by: jumerckx <julesmerckx12@gmail.com>

commit 73899f5
Author: Avik Pal <avikpal@mit.edu>
Date:   Sat Dec 14 14:58:47 2024 +0530

    fix: include files if they end with .jl (EnzymeAD#377)

commit 9f96c09
Author: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
Date:   Fri Dec 13 23:12:43 2024 +0100

    Run CI on aarch64 (EnzymeAD#350)

    * Run CI on aarch64

    * use julia pipeline for aarch64-linux

    * fix var

    * exclude aarch64-linux jobs from github ci

commit b56e661
Author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Date:   Fri Dec 13 08:58:51 2024 +0530

    chore: format code (EnzymeAD#371)

    Co-authored-by: mofeing <15837247+mofeing@users.noreply.github.com>

commit 311498b
Author: Anton Oresten <antonoresten@gmail.com>
Date:   Thu Dec 12 05:41:39 2024 +0100

    feat: define outer `repeat` method for `TracedRArray` (EnzymeAD#361)

    * Add repeat method

    * Add repeat tests

    * Update test/basic.jl

    * Update src/TracedRArray.jl

commit 8b90501
Author: Avik Pal <avikpal@mit.edu>
Date:   Thu Dec 12 10:11:00 2024 +0530

    fix: ensure printing of wrapped ConcreteRArrays goes through our show (EnzymeAD#367)

    * fix: ensure printing of wrapped ConcreteRArrays goes through our show

    * fix: allow wrapped arrays in mapreduce

commit ea97be3
Author: Sergio Sánchez Ramírez <sergio.sanchez.ramirez+git@bsc.es>
Date:   Wed Dec 11 22:02:14 2024 +0100

    Ignore versioned manifests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cannot repeat >2-dimensional arrays

3 participants