From a6ac51b26ae2f84653a1984547e0a443df0c58d4 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 10 Jan 2024 23:21:12 +0100 Subject: [PATCH 1/6] Add compat upper bounds --- Project.toml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index e7642d0..7695511 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt and contributors"] -version = "0.6.6" +version = "0.6.7" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" @@ -17,13 +17,13 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractGPs = "0.5.17" -Bessels = "0.2.8" -BlockDiagonals = "0.1.7" -ChainRulesCore = "1" -FillArrays = "0.13.0 - 0.13.7, 1" -KernelFunctions = "0.9, 0.10.1" -StaticArrays = "1" -StructArrays = "0.5, 0.6" -Zygote = "0.6.65" +AbstractGPs = "0.5.17 - 0.5.19" +Bessels = "0.2.8 - 0.2.8" +BlockDiagonals = "0.1.7 - 0.1.41" +ChainRulesCore = "1.0.0 - 1.16.0" +FillArrays = "0.13.0 - 0.13.7" +KernelFunctions = "0.9, 0.10.1 - 0.10.57" +StaticArrays = "1.0.0 - 1.6.5" +StructArrays = "0.5, 0.6.0 - 0.6.16" +Zygote = "0.6.65 - 0.6.65" julia = "1.6" From 40ea62e4fab3398ac2f1f082e0b2fd71ddc90ffb Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 8 Jan 2024 18:02:07 +0100 Subject: [PATCH 2/6] Fix @test_broken problem --- test/gp/lti_sde.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index 203b16e..9af1c38 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -226,7 +226,7 @@ println("lti_sde:") # Just need to ensure we can differentiate through construction properly. if isnothing(kernel.to_vec_grad) - @test_broken "Gradient tests are not passing" + @test_broken false # "Gradient tests are not passing" continue elseif kernel.to_vec_grad test_zygote_grad_finite_differences_compatible( From 5b3cc7ddceb78b1b8c3f7cb701ecf5916ba627be Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:12:26 +0100 Subject: [PATCH 3/6] Remove rule Superseded by https://github.com/JuliaArrays/StaticArrays.jl/blob/e2d772f9767abdcab20ce7ae6927dc25dc38714b/ext/StaticArraysChainRulesCoreExt.jl#L26-L30 --- src/util/chainrules.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl index fe9e67e..6420434 100644 --- a/src/util/chainrules.jl +++ b/src/util/chainrules.jl @@ -28,13 +28,6 @@ Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c) # StaticArrays # # ---------------------------------------------------------------------------- # -function rrule(::Type{T}, x::Tuple) where {T<:SArray} - SArray_rrule(Δ) = begin - (NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...)) - end - return T(x), SArray_rrule -end - function rrule(::RuleConfig{>:HasReverseMode}, ::Type{SArray{S, T, N, L}}, x::NTuple{L, T}) where {S, T, N, L} SArray_rrule(::AbstractZero) = NoTangent(), NoTangent() SArray_rrule(Δ::NamedTuple{(:data,)}) = NoTangent(), Δ.data From cc55ec9b7d2017fb67a35c71c5dcf49319bc07fe Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 10 Jan 2024 12:08:58 +0100 Subject: [PATCH 4/6] Get rid of duplicate `include`s --- test/gp/lti_sde.jl | 3 +-- test/models/lgssm.jl | 3 --- test/models/linear_gaussian_conditionals.jl | 3 --- test/models/missings.jl | 3 --- test/space_time/pseudo_point.jl | 2 -- test/util/chainrules.jl | 1 - 6 files changed, 1 insertion(+), 14 deletions(-) diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index 9af1c38..a4f2057 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -3,8 +3,7 @@ using KernelFunctions: kappa using ChainRulesTestUtils using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components using Test -include("../test_util.jl") -include("../models/model_test_utils.jl") + _logistic(x) = 1 / (1 + exp(-x)) # Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index be257f6..19662b6 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -25,9 +25,6 @@ using LinearAlgebra using StructArrays using Zygote, StaticArrays -include("model_test_utils.jl") -include("../test_util.jl") - println("lgssm:") @testset "lgssm" begin diff --git a/test/models/linear_gaussian_conditionals.jl b/test/models/linear_gaussian_conditionals.jl index d3219a8..9e0e7ba 100644 --- a/test/models/linear_gaussian_conditionals.jl +++ b/test/models/linear_gaussian_conditionals.jl @@ -1,9 +1,6 @@ using TemporalGPs: posterior_and_lml, predict, predict_marginals using Test -include("../test_util.jl") -include("../models/model_test_utils.jl") - println("linear_gaussian_conditionals:") @testset "linear_gaussian_conditionals" begin Dlats = [1, 3] diff --git a/test/models/missings.jl b/test/models/missings.jl index 407d030..3b4084e 100644 --- a/test/models/missings.jl +++ b/test/models/missings.jl @@ -7,9 +7,6 @@ using Random: randperm using ChainRulesTestUtils using Zygote: Context -include("../test_util.jl") -include("../models/model_test_utils.jl") - @info "missings:" @testset "missings" begin diff --git a/test/space_time/pseudo_point.jl b/test/space_time/pseudo_point.jl index 8e43fa0..e90f037 100644 --- a/test/space_time/pseudo_point.jl +++ b/test/space_time/pseudo_point.jl @@ -16,8 +16,6 @@ using TemporalGPs: Separable, approx_posterior_marginals using Test -include("../test_util.jl") -include("../models/model_test_utils.jl") @testset "pseudo_point" begin diff --git a/test/util/chainrules.jl b/test/util/chainrules.jl index 0d71611..b68ab30 100644 --- a/test/util/chainrules.jl +++ b/test/util/chainrules.jl @@ -9,7 +9,6 @@ using TemporalGPs: time_exp, _map, Gaussian using FillArrays using StructArrays using Zygote: ZygoteRuleConfig -include("../test_util.jl") @testset "chainrules" begin @testset "StaticArrays" begin From afb213edef53fac33a46d418a0667237c855f773 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 11 Jan 2024 00:12:06 +0100 Subject: [PATCH 5/6] Revert "Remove rule" This reverts commit 5b3cc7ddceb78b1b8c3f7cb701ecf5916ba627be. --- src/util/chainrules.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl index 6420434..fe9e67e 100644 --- a/src/util/chainrules.jl +++ b/src/util/chainrules.jl @@ -28,6 +28,13 @@ Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c) # StaticArrays # # ---------------------------------------------------------------------------- # +function rrule(::Type{T}, x::Tuple) where {T<:SArray} + SArray_rrule(Δ) = begin + (NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...)) + end + return T(x), SArray_rrule +end + function rrule(::RuleConfig{>:HasReverseMode}, ::Type{SArray{S, T, N, L}}, x::NTuple{L, T}) where {S, T, N, L} SArray_rrule(::AbstractZero) = NoTangent(), NoTangent() SArray_rrule(Δ::NamedTuple{(:data,)}) = NoTangent(), Δ.data From 3b966b4b4cbf3b927c336c440a721270b1e1fc7f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 22 Jan 2024 12:21:50 +0100 Subject: [PATCH 6/6] Unpin AbstractGPs and KernelFunctions --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 7695511..f50ed4a 100644 --- a/Project.toml +++ b/Project.toml @@ -17,12 +17,12 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractGPs = "0.5.17 - 0.5.19" +AbstractGPs = "0.5.17" Bessels = "0.2.8 - 0.2.8" BlockDiagonals = "0.1.7 - 0.1.41" ChainRulesCore = "1.0.0 - 1.16.0" FillArrays = "0.13.0 - 0.13.7" -KernelFunctions = "0.9, 0.10.1 - 0.10.57" +KernelFunctions = "0.9, 0.10.1" StaticArrays = "1.0.0 - 1.6.5" StructArrays = "0.5, 0.6.0 - 0.6.16" Zygote = "0.6.65 - 0.6.65"