diff --git a/src/TracedRange.jl b/src/TracedRange.jl index c432499f73..6da8ccf812 100644 --- a/src/TracedRange.jl +++ b/src/TracedRange.jl @@ -177,4 +177,13 @@ function Base._reshape(parent::TracedUnitRange, dims::Dims) return Base.__reshape((parent, IndexStyle(parent)), dims) end +function (C::Base.Colon)(start::TracedRNumber{T},stop::TracedRNumber{T}) where T + TracedUnitRange(start,stop) +end +function (C::Base.Colon)(start::TracedRNumber{T},stop::T) where T + C(start,TracedRNumber{T}(stop)) +end +function (C::Base.Colon)(start::T,stop::TracedRNumber{T}) where T + C(TracedRNumber{T}(start),stop) +end end diff --git a/test/ranges.jl b/test/ranges.jl new file mode 100644 index 0000000000..949bb00627 --- /dev/null +++ b/test/ranges.jl @@ -0,0 +1,9 @@ +using Reactant,Test + +@testset "ranges" begin + i = Reactant.to_rarray(5,track_numbers=true) + @test Array{Int64}(@jit(1:i)) == collect(1:5) + @test Array{Int64}(@jit(i:10)) == collect(5:10) + j = Reactant.to_rarray(10,track_numbers=true) + @test Array{Int64}(@jit(i:j)) == collect(5:10) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5edaa478e5..f812deee5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,6 +29,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Sorting" include("sorting.jl") @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Indexing" include("indexing.jl") + @safetestset "Ranges" include("ranges.jl") if !Sys.isapple() @safetestset "Custom Number Types" include("custom_number_types.jl") end