From 951283edf644ba5b08ecfb597b378631fc1edaaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 25 Apr 2023 17:30:25 +0200 Subject: [PATCH 1/6] Add fix for _map on multiple Fill --- src/util/chainrules.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl index 0d2bec6..7bee9db 100644 --- a/src/util/chainrules.jl +++ b/src/util/chainrules.jl @@ -181,15 +181,20 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_ma return Fill(y_el, axes(x)), _map_Fill_rrule end -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, x::Fill, y::Fill) - z_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value, y.value) +# Somehow needed to avoid the _map -> map indirection +function _map(f, xs::Fill...) + all(==(axes(first(xs))), axes.(xs)) || error("All axes should be the same") + Fill(_map(f, FillArrays.getindex_value.(xs)...), axes(first(xs))) +end + +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, xs::Fill...) + z_el, back = ChainRulesCore.rrule_via_ad(config, f, FillArrays.getindex_value.(xs)...) function _map_Fill_rrule(Δ) - Δf, Δx_el, Δy_el = back(unthunk(Δ).value) - return NoTangent(), Δf, Fill(Δx_el, axes(x)), Fill(Δy_el, axes(x)) + Δf, Δxs_el... = back(unthunk(Δ).value) + return NoTangent(), Δf, Fill.(Δxs_el, axes.(xs))... end - return Fill(z_el, axes(x)), _map_Fill_rrule + return Fill(z_el, axes(first(xs))), _map_Fill_rrule end - ### Same thing for `StructArray` From 59f3a2f7c9d28aa40a3199b51f90dbbf0bc5ad9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 25 Apr 2023 17:31:27 +0200 Subject: [PATCH 2/6] Add tests --- test/util/chainrules.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/util/chainrules.jl b/test/util/chainrules.jl index 2fd8267..0d71611 100644 --- a/test/util/chainrules.jl +++ b/test/util/chainrules.jl @@ -80,13 +80,17 @@ include("../test_util.jl") test_rrule(_map, x -> 2.0 * x, x; check_inferred=false) test_rrule(ZygoteRuleConfig(), (x,a)-> _map(x -> x * a, x), x, 2.0; check_inferred=false, rrule_f=rrule_via_ad) end - @testset "_map(f, x1::Fill, x2::Fill)" begin + @testset "_map(f, x::Fill....)" begin x1 = Fill(randn(3, 4), 3) x2 = Fill(randn(3, 4), 3) + x3 = Fill(randn(3, 4), 3) @test _map(+, x1, x2) == _map(+, collect(x1), collect(x2)) test_rrule(_map, +, x1, x2; check_inferred=true) + @test _map(+, x1, x2, x3) == _map(+, collect(x1), collect(x2), collect(x3)) + test_rrule(_map, +, x1, x2, x3; check_inferred=true) + fsin(x, y) = sin.(x .* y) test_rrule(_map, fsin, x1, x2; check_inferred=false) From 735bec75c689a520da51623e9ec12d29fec3a207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 25 Apr 2023 17:31:39 +0200 Subject: [PATCH 3/6] Patch bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 18f0776..cc71dee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt and contributors"] -version = "0.6.3" +version = "0.6.4" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" From 55491b1527668b165762598583db1755f7ac7a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 25 Apr 2023 17:36:39 +0200 Subject: [PATCH 4/6] Remove tests on examples --- .github/workflows/examples.yml | 51 ---------------------------------- test/runtests.jl | 21 +------------- 2 files changed, 1 insertion(+), 71 deletions(-) delete mode 100644 .github/workflows/examples.yml diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml deleted file mode 100644 index 2509b2d..0000000 --- a/.github/workflows/examples.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: Examples -on: - push: - branches: - - master - pull_request: - branches: - - master - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - -jobs: - examples: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - '1' - os: - - ubuntu-latest - arch: - - x64 - steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - with: - file: lcov.info - coverage: false - env: - GROUP: examples diff --git a/test/runtests.jl b/test/runtests.jl index c5ce381..cd4ddd2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,23 +98,4 @@ if OUTER_GROUP == "test" || OUTER_GROUP == "all" end end end -end - -# Run the examples. -if GROUP == "examples" - - using Pkg - pkgpath = joinpath(@__DIR__, "..") - Pkg.activate(joinpath(pkgpath, "examples")) - Pkg.develop(path=pkgpath) - Pkg.resolve() - Pkg.instantiate() - - include(joinpath(pkgpath, "examples", "exact_time_inference.jl")) - include(joinpath(pkgpath, "examples", "exact_time_learning.jl")) - include(joinpath(pkgpath, "examples", "exact_space_time_inference.jl")) - include(joinpath(pkgpath, "examples", "exact_space_time_learning.jl")) - include(joinpath(pkgpath, "examples", "approx_space_time_inference.jl")) - include(joinpath(pkgpath, "examples", "approx_space_time_learning.jl")) - include(joinpath(pkgpath, "examples", "augmented_inference.jl")) -end +end \ No newline at end of file From d00ec59dcff789342d6e1a5b5800e242729684cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 25 Apr 2023 17:36:53 +0200 Subject: [PATCH 5/6] Revert "Remove tests on examples" This reverts commit 55491b1527668b165762598583db1755f7ac7a2a. --- .github/workflows/examples.yml | 51 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 21 +++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/examples.yml diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml new file mode 100644 index 0000000..2509b2d --- /dev/null +++ b/.github/workflows/examples.yml @@ -0,0 +1,51 @@ +name: Examples +on: + push: + branches: + - master + pull_request: + branches: + - master + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + examples: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: actions/cache@v1 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + with: + file: lcov.info + coverage: false + env: + GROUP: examples diff --git a/test/runtests.jl b/test/runtests.jl index cd4ddd2..c5ce381 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,4 +98,23 @@ if OUTER_GROUP == "test" || OUTER_GROUP == "all" end end end -end \ No newline at end of file +end + +# Run the examples. +if GROUP == "examples" + + using Pkg + pkgpath = joinpath(@__DIR__, "..") + Pkg.activate(joinpath(pkgpath, "examples")) + Pkg.develop(path=pkgpath) + Pkg.resolve() + Pkg.instantiate() + + include(joinpath(pkgpath, "examples", "exact_time_inference.jl")) + include(joinpath(pkgpath, "examples", "exact_time_learning.jl")) + include(joinpath(pkgpath, "examples", "exact_space_time_inference.jl")) + include(joinpath(pkgpath, "examples", "exact_space_time_learning.jl")) + include(joinpath(pkgpath, "examples", "approx_space_time_inference.jl")) + include(joinpath(pkgpath, "examples", "approx_space_time_learning.jl")) + include(joinpath(pkgpath, "examples", "augmented_inference.jl")) +end From e94f6689698661c9d3998a15a609843d287516c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 25 Apr 2023 17:58:48 +0200 Subject: [PATCH 6/6] Fix function --- src/util/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl index 7bee9db..9b40b25 100644 --- a/src/util/chainrules.jl +++ b/src/util/chainrules.jl @@ -184,7 +184,7 @@ end # Somehow needed to avoid the _map -> map indirection function _map(f, xs::Fill...) all(==(axes(first(xs))), axes.(xs)) || error("All axes should be the same") - Fill(_map(f, FillArrays.getindex_value.(xs)...), axes(first(xs))) + Fill(f(FillArrays.getindex_value.(xs)...), axes(first(xs))) end function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, xs::Fill...)