diff --git a/src/combinators.jl b/src/combinators.jl index b313f55ada..42772c4519 100644 --- a/src/combinators.jl +++ b/src/combinators.jl @@ -35,6 +35,9 @@ next(rf::AbstractMultiCastingRF, accs, x) = error( return results end +@inline completebasecase(rf::AbstractMultiCastingRF, accs) = + map((f, a) -> completebasecase(f, a), rf.fs, accs) + @inline combine(rf::AbstractMultiCastingRF, lefts, rights) = map((f, l, r) -> combine(f, l, r), rf.fs, lefts, rights) diff --git a/test/test_adhocrf.jl b/test/test_adhocrf.jl index da0e9d559f..dc2f674526 100644 --- a/test/test_adhocrf.jl +++ b/test/test_adhocrf.jl @@ -57,7 +57,12 @@ end @test foldl_basecase(rf, start(rf, Init)::MVector, 1:10)::SVector == ones(10) @test foldxl(rf, 1:10)::MVector == ones(10) @test foldxt(rf, 1:10)::SVector == ones(10) + @test foldxt(rf, 1:10; basesize = 1)::SVector == ones(10) # @test foldxd(rf, 1:10)::SVector == ones(10) # TODO: test this + + @test foldxt(TeeRF(rf, rf), 1:10)::NTuple{2,SVector} == (ones(10), ones(10)) + @test foldxt(TeeRF(rf, rf), 1:10; basesize = 1)::NTuple{2,SVector} == + (ones(10), ones(10)) end getoninit(rf) = rf.oninit