diff --git a/Project.toml b/Project.toml index 4e0e7df..1d07afc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,25 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt and contributors"] -version = "0.6.8" +version = "0.7.0" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.5.17" Bessels = "0.2.8" BlockDiagonals = "0.1.7" -ChainRulesCore = "1" FillArrays = "0.13.0 - 0.13.7, 1" KernelFunctions = "0.9, 0.10.1" StaticArrays = "1" StructArrays = "0.5, 0.6" -Zygote = "0.6.65" julia = "1.6" diff --git a/src/TemporalGPs.jl b/src/TemporalGPs.jl index 7e9cf09..7b05673 100644 --- a/src/TemporalGPs.jl +++ b/src/TemporalGPs.jl @@ -3,15 +3,12 @@ module TemporalGPs using AbstractGPs using Bessels: besseli using BlockDiagonals - using ChainRulesCore - import ChainRulesCore: rrule using FillArrays using LinearAlgebra using KernelFunctions using Random using StaticArrays using StructArrays - using Zygote using FillArrays: AbstractFill @@ -36,12 +33,9 @@ module TemporalGPs ApproxPeriodicKernel # Various bits-and-bobs. Often commiting some type piracy. - include(joinpath("util", "harmonise.jl")) include(joinpath("util", "linear_algebra.jl")) include(joinpath("util", "scan.jl")) - include(joinpath("util", "zygote_friendly_map.jl")) - include(joinpath("util", "chainrules.jl")) include(joinpath("util", "gaussian.jl")) include(joinpath("util", "mul.jl")) include(joinpath("util", "storage_types.jl")) diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index 0cf0673..67a2622 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -71,7 +71,7 @@ end function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector) m = get_mean(f) k = get_kernel(f) - s = Zygote.literal_getfield(f, Val(:storage)) + s = f.storage As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, s) return LGSSM( GaussMarkovModel(Forward(), As, as, Qs, x0), build_emissions(emission_proj, Σys), @@ -79,22 +79,22 @@ function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector) end function build_lgssm(ft::FiniteLTISDE) - f = Zygote.literal_getfield(ft, Val(:f)) - x = Zygote.literal_getfield(ft, Val(:x)) - Σys = noise_var_to_time_form(x, Zygote.literal_getfield(ft, Val(:Σy))) + f = ft.f + x = ft.x + Σys = noise_var_to_time_form(x, ft.Σy) return build_lgssm(f, x, Σys) end -get_mean(f::LTISDE) = get_mean(Zygote.literal_getfield(f, Val(:f))) -get_mean(f::GP) = Zygote.literal_getfield(f, Val(:mean)) +get_mean(f::LTISDE) = get_mean(f.f) +get_mean(f::GP) = f.mean -get_kernel(f::LTISDE) = get_kernel(Zygote.literal_getfield(f, Val(:f))) -get_kernel(f::GP) = Zygote.literal_getfield(f, Val(:kernel)) +get_kernel(f::LTISDE) = get_kernel(f.f) +get_kernel(f::GP) = f.kernel function build_emissions( (Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector, ) - Hst = _map(adjoint, Hs) + Hst = map(adjoint, Hs) return StructArray{get_type(Hst, hs, Σs)}((Hst, hs, Σs)) end @@ -114,10 +114,6 @@ function get_type(Hs_prime, hs::AbstractVector{<:AbstractVector}, Σs) return T end -@inline function Zygote.wrap_chainrules_output(x::NamedTuple) - return map(Zygote.wrap_chainrules_output, x) -end - # Constructor for combining kernel and mean functions function lgssm_components( ::ZeroMean, k::Kernel, t::AbstractVector, storage_type::StorageType @@ -128,7 +124,7 @@ end function lgssm_components( m::AbstractGPs.MeanFunction, k::Kernel, t::AbstractVector, storage_type::StorageType ) - m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays. + m = mean_vector(m, t) As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type) hs = add_proj_mean(hs, m) @@ -146,9 +142,9 @@ end function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}) where {T} P = Symmetric(x0.P) t = vcat([first(t) - 1], t) - As = _map(Δt -> time_exp(F, T(Δt)), diff(t)) + As = map(Δt -> time_exp(F, T(Δt)), diff(t)) as = Fill(Zeros{T}(size(first(As), 1)), length(As)) - Qs = _map(A -> P - A * P * A', As) + Qs = map(A -> P - A * P * A', As) Hs = Fill(H, length(As)) hs = Fill(zero(T), length(As)) As, as, Qs, Hs, hs @@ -158,7 +154,7 @@ function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRange P = Symmetric(x0.P) A = time_exp(F, T(step(t))) As = Fill(A, length(t)) - as = @ignore_derivatives(Fill(Zeros{T}(size(F, 1)), length(t))) + as = Fill(Zeros{T}(size(F, 1)), length(t)) Q = Symmetric(P) - A * Symmetric(P) * A' Qs = Fill(Q, length(t)) Hs = Fill(H, length(t)) @@ -166,6 +162,8 @@ function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRange As, as, Qs, Hs, hs end +time_exp(A, t) = exp(A * t) + function lgssm_components( k::SimpleKernel, t::AbstractVector{<:Real}, storage::StorageType{T}, ) where {T<:Real} @@ -332,49 +330,49 @@ end # Scaled function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real} - _k = Zygote.literal_getfield(k, Val(:kernel)) - σ² = Zygote.literal_getfield(k, Val(:σ²)) + _k = k.kernel + σ² = k.σ² F, q, H = to_sde(_k, storage) σ = sqrt(convert(eltype(storage), only(σ²))) return F, σ^2 * q, σ * H end -stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage) +stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(k.kernel, storage) function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType) - _k = Zygote.literal_getfield(k, Val(:kernel)) - σ² = Zygote.literal_getfield(k, Val(:σ²)) + _k = k.kernel + σ² = k.σ² As, as, Qs, emission_proj, x0 = lgssm_components(_k, ts, storage_type) σ = sqrt(convert(eltype(storage_type), only(σ²))) return As, as, Qs, _scale_emission_projections(emission_proj, σ), x0 end function _scale_emission_projections((Hs, hs)::Tuple{AbstractVector, AbstractVector}, σ::Real) - return _map(H->σ * H, Hs), _map(h->σ * h, hs) + return map(H->σ * H, Hs), map(h->σ * h, hs) end function _scale_emission_projections((Cs, cs, Hs, hs), σ) - return (Cs, cs, _map(H -> σ * H, Hs), _map(h -> σ * h, hs)) + return (Cs, cs, map(H -> σ * H, Hs), map(h -> σ * h, hs)) end # Stretched function to_sde(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) - _k = Zygote.literal_getfield(k, Val(:kernel)) - s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s)) + _k = k.kernel + s = k.transform.s F, q, H = to_sde(_k, storage) return F * only(s), q, H end -stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage) +stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(k.kernel, storage) function lgssm_components( k::TransformedKernel{<:Kernel, <:ScaleTransform}, ts::AbstractVector, storage_type::StorageType, ) - _k = Zygote.literal_getfield(k, Val(:kernel)) - s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s)) + _k = k.kernel + s = k.transform.s return lgssm_components(_k, apply_stretch(s[1], ts), storage_type) end @@ -383,9 +381,9 @@ apply_stretch(a, ts::AbstractVector{<:Real}) = a * ts apply_stretch(a, ts::StepRangeLen) = a * ts function apply_stretch(a, ts::RegularSpacing) - t0 = Zygote.literal_getfield(ts, Val(:t0)) - Δt = Zygote.literal_getfield(ts, Val(:Δt)) - N = Zygote.literal_getfield(ts, Val(:N)) + t0 = ts.t0 + Δt = ts.Δt + N = ts.N return RegularSpacing(a * t0, a * Δt, N) end @@ -425,9 +423,9 @@ function lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::Storag emission_proj_kernels = getindex.(lgssms, 4) x0_kernels = getindex.(lgssms, 5) - As = _map(block_diagonal, As_kernels...) - as = _map(vcat, as_kernels...) - Qs = _map(block_diagonal, Qs_kernels...) + As = map(block_diagonal, As_kernels...) + as = map(vcat, as_kernels...) + Qs = map(block_diagonal, Qs_kernels...) emission_projections = _sum_emission_projections(emission_proj_kernels...) x0 = Gaussian(mapreduce(x -> getproperty(x, :m), vcat, x0_kernels), block_diagonal(getproperty.(x0_kernels, :P)...)) return As, as, Qs, emission_projections, x0 @@ -444,10 +442,10 @@ function _sum_emission_projections( cs = getindex.(Cs_cs_Hs_hs, 2) Hs = getindex.(Cs_cs_Hs_hs, 3) hs = getindex.(Cs_cs_Hs_hs, 4) - C = _map(vcat, Cs...) + C = map(vcat, Cs...) c = sum(cs) - H = _map(block_diagonal, Hs...) - h = _map(vcat, hs...) + H = map(block_diagonal, Hs...) + h = map(vcat, hs...) return C, c, H, h end @@ -460,36 +458,9 @@ function block_diagonal(As::AbstractMatrix{T}...) where {T} return hvcat(ntuple(_ -> nblocks, nblocks), Xs...) end -function ChainRulesCore.rrule(::typeof(block_diagonal), As::AbstractMatrix...) - szs = size.(As) - row_szs = (0, cumsum(first.(szs))...) - col_szs = (0, cumsum(last.(szs))...) - block_diagonal_rrule(Δ::AbstractThunk) = block_diagonal_rrule(unthunk(Δ)) - function block_diagonal_rrule(Δ) - ΔAs = ntuple(length(As)) do i - Δ[(row_szs[i]+1):row_szs[i+1], (col_szs[i]+1):col_szs[i+1]] - end - return NoTangent(), ΔAs... - end - return block_diagonal(As...), block_diagonal_rrule -end - function block_diagonal(As::SMatrix...) nblocks = length(As) sizes = size.(As) Xs = [i == j ? As[i] : zeros(SMatrix{sizes[j][1], sizes[i][2]}) for i in 1:nblocks, j in 1:nblocks] return hcat(Base.splat(vcat).(eachrow(Xs))...) end - -function ChainRulesCore.rrule(::typeof(block_diagonal), As::SMatrix...) - szs = size.(As) - row_szs = (0, cumsum(first.(szs))...) - col_szs = (0, cumsum(last.(szs))...) - function block_diagonal_rrule(Δ) - ΔAs = ntuple(length(As)) do i - Δ[SVector{szs[i][1]}((row_szs[i]+1):row_szs[i+1]), SVector{szs[i][2]}((col_szs[i]+1):col_szs[i+1])] - end - return NoTangent(), ΔAs... - end - return block_diagonal(As...), block_diagonal_rrule -end diff --git a/src/gp/posterior_lti_sde.jl b/src/gp/posterior_lti_sde.jl index 007fa38..5cc413a 100644 --- a/src/gp/posterior_lti_sde.jl +++ b/src/gp/posterior_lti_sde.jl @@ -25,15 +25,15 @@ function AbstractGPs.marginals(fx::FinitePosteriorLTISDE) model_post = replace_observation_noise_cov(posterior(model, ys), σ²s_pr_full) return destructure(x, map(marginals, marginals(model_post))[pr_indices]) else - f = Zygote.literal_getfield(fx, Val(:f)) - prior = Zygote.literal_getfield(f, Val(:prior)) - x = Zygote.literal_getfield(fx, Val(:x)) - data = Zygote.literal_getfield(f, Val(:data)) - Σy = Zygote.literal_getfield(data, Val(:Σy)) - Σy_diag = Zygote.literal_getfield(Σy, Val(:diag)) - y = Zygote.literal_getfield(data, Val(:y)) - - Σy_new = Zygote.literal_getfield(fx, Val(:Σy)) + f = fx.f + prior = f.prior + x = fx.x + data = f.data + Σy = data.Σy + Σy_diag = Σy.diag + y = data.y + + Σy_new = fx.Σy model = build_lgssm(AbstractGPs.FiniteGP(prior, x, Σy)) Σys_new = noise_var_to_time_form(x, Σy_new) diff --git a/src/models/gauss_markov_model.jl b/src/models/gauss_markov_model.jl index 7b57c26..95d9b65 100644 --- a/src/models/gauss_markov_model.jl +++ b/src/models/gauss_markov_model.jl @@ -31,14 +31,6 @@ struct GaussMarkovModel{ x0::Tx0 end -# Helps Zygote out with some type-stability issues. Why this helps is unclear. -function ChainRulesCore.rrule(::Type{<:GaussMarkovModel}, ordering, As, as, Qs, x0) - function GaussMarkovModel_pullback(Δ) - return NoTangent(), NoTangent(), Δ.As, Δ.as, Δ.Qs, Δ.x0 - end - return GaussMarkovModel(ordering, As, as, Qs, x0), GaussMarkovModel_pullback -end - ordering(model::GaussMarkovModel) = model.ordering Base.eltype(model::GaussMarkovModel) = eltype(first(model.As)) @@ -65,28 +57,4 @@ function is_of_storage_type(model::GaussMarkovModel, s::StorageType) return is_of_storage_type((model.As, model.as, model.Qs, model.x0), s) end -x0(model::GaussMarkovModel) = Zygote.literal_getfield(model, Val(:x0)) - -function get_adjoint_storage(x::GaussMarkovModel, n::Int, Δx::Tangent{T,<:NamedTuple{(:A, :a, :Q)}}) where {T} - return ( - ordering = NoTangent(), - As = get_adjoint_storage(x.As, n, Δx.A), - as = get_adjoint_storage(x.as, n, Δx.a), - Qs = get_adjoint_storage(x.Qs, n, Δx.Q), - x0 = NoTangent(), - ) -end - -function _accum_at( - Δxs::NamedTuple{(:ordering, :As, :as, :Qs, :x0)}, - n::Int, - Δx::Tangent{T, <:NamedTuple{(:A, :a, :Q)}}, -) where {T} - return ( - ordering = NoTangent(), - As = _accum_at(Δxs.As, n, Δx.A), - as = _accum_at(Δxs.as, n, Δx.a), - Qs = _accum_at(Δxs.Qs, n, Δx.Q), - x0 = NoTangent(), - ) -end +x0(model::GaussMarkovModel) = model.x0 diff --git a/src/models/lgssm.jl b/src/models/lgssm.jl index e27fe9f..fbdca0c 100644 --- a/src/models/lgssm.jl +++ b/src/models/lgssm.jl @@ -12,15 +12,14 @@ struct LGSSM{Ttransitions<:GaussMarkovModel, Temissions<:StructArray} <: Abstrac end @inline function transitions(model::LGSSM) - return Zygote.literal_getfield(model, Val(:transitions)) + return model.transitions end @inline function emissions(model::LGSSM) - return Zygote.literal_getfield(model, Val(:emissions)) + return model.emissions end @inline ordering(model::LGSSM) = ordering(transitions(model)) -ChainRulesCore.@non_differentiable ordering(model) function Base.:(==)(x::LGSSM, y::LGSSM) return (transitions(x) == transitions(y)) && (emissions(x) == emissions(y)) @@ -33,8 +32,6 @@ Base.eachindex(model::LGSSM) = eachindex(transitions(model)) storage_type(model::LGSSM) = storage_type(transitions(model)) -ChainRulesCore.@non_differentiable storage_type(x) - function is_of_storage_type(model::LGSSM, s::StorageType) return is_of_storage_type((transitions(model), emissions(model)), s) end @@ -59,15 +56,15 @@ struct ElementOfLGSSM{Tordering, Ttransition, Temission} end @inline function ordering(x::ElementOfLGSSM) - return Zygote.literal_getfield(x, Val(:ordering)) + return x.ordering end @inline function transition_dynamics(x::ElementOfLGSSM) - return Zygote.literal_getfield(x, Val(:transition)) + return x.transition end @inline function emission_dynamics(x::ElementOfLGSSM) - return Zygote.literal_getfield(x, Val(:emission)) + return x.emission end @inline function Base.getindex(model::LGSSM, n::Int) @@ -206,10 +203,10 @@ end function posterior(prior::LGSSM, y::AbstractVector) _check_inputs(prior, y) new_trans, xf = _a_bit_of_posterior(prior, y) - A = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:A)), new_trans) - a = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:a)), new_trans) - Q = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:Q)), new_trans) - ems = Zygote.literal_getfield(prior, Val(:emissions)) + A = map(x -> x.A, new_trans) + a = map(x -> x.a, new_trans) + Q = map(x -> x.Q, new_trans) + ems = prior.emissions return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), ems) end @@ -221,8 +218,6 @@ function _check_inputs(prior, y) end end -ChainRulesCore.@non_differentiable _check_inputs(::Any, ::Any) - function _a_bit_of_posterior(prior, y) return scan_emit(step_posterior, zip(prior, y), x0(prior), eachindex(prior)) end @@ -263,30 +258,6 @@ ident_eps(ε::Real) = UniformScaling(ε) ident_eps(x::ColVecs, ε::Real) = UniformScaling(convert(eltype(x.X), ε)) -ChainRulesCore.@non_differentiable ident_eps(args...) - _collect(U::Adjoint{<:Any, <:Matrix}) = collect(U) _collect(U::SMatrix) = U _collect(U::BlockDiagonal) = U - -# AD stuff. No need to understand this unless you're really plumbing the depths... - -function get_adjoint_storage( - x::LGSSM, n::Int, Δx::Tangent{T,<:NamedTuple{(:ordering,:transition,:emission)}}, -) where {T} - return Tangent{typeof(x)}( - transitions = get_adjoint_storage(x.transitions, n, Δx.transition), - emissions = get_adjoint_storage(x.emissions, n, Δx.emission) - ) -end - -function _accum_at( - Δxs::Tangent{X}, - n::Int, - Δx::Tangent{T,<:NamedTuple{(:ordering,:transition,:emission)}}, -) where {X<:LGSSM, T} - return Tangent{X}( - transitions = _accum_at(Δxs.transitions, n, Δx.transition), - emissions = _accum_at(Δxs.emissions, n, Δx.emission), - ) -end diff --git a/src/models/linear_gaussian_conditionals.jl b/src/models/linear_gaussian_conditionals.jl index 0b5fe79..2d53ffc 100644 --- a/src/models/linear_gaussian_conditionals.jl +++ b/src/models/linear_gaussian_conditionals.jl @@ -97,13 +97,9 @@ function ε_randn(rng::AbstractRNG, ::SMatrix{Dout, Din, T}) where {Dout, Din, T return randn(rng, SVector{Dout, T}) end -ChainRulesCore.@non_differentiable ε_randn(args...) - scalar_type(::AbstractVector{T}) where {T} = T scalar_type(::T) where {T<:Real} = T -ChainRulesCore.@non_differentiable scalar_type(x) - """ SmallOutputLGC{ TA<:AbstractMatrix, Ta<:AbstractVector, TQ<:AbstractMatrix, @@ -126,12 +122,12 @@ dim_out(f::SmallOutputLGC) = size(f.A, 1) dim_in(f::SmallOutputLGC) = size(f.A, 2) -noise_cov(f::SmallOutputLGC) = Zygote.literal_getfield(f, Val(:Q)) +noise_cov(f::SmallOutputLGC) = f.Q function get_fields(f::SmallOutputLGC) - A = Zygote.literal_getfield(f, Val(:A)) - a = Zygote.literal_getfield(f, Val(:a)) - Q = Zygote.literal_getfield(f, Val(:Q)) + A = f.A + a = f.a + Q = f.Q return A, a, Q end @@ -177,26 +173,16 @@ struct LargeOutputLGC{ Q::TQ end -function ChainRulesCore.rrule( - ::Type{<:LargeOutputLGC}, - A::AbstractMatrix, - a::AbstractVector, - Q::AbstractMatrix, -) - LargeOutputLGC_pullback(Δ) = NoTangent(), Δ.A, Δ.a, Δ.Q - return LargeOutputLGC(A, a, Q), LargeOutputLGC_pullback -end - dim_out(f::LargeOutputLGC) = size(f.A, 1) dim_in(f::LargeOutputLGC) = size(f.A, 2) -noise_cov(f::LargeOutputLGC) = Zygote.literal_getfield(f, Val(:Q)) +noise_cov(f::LargeOutputLGC) = f.Q function get_fields(f::LargeOutputLGC) - A = Zygote.literal_getfield(f, Val(:A)) - a = Zygote.literal_getfield(f, Val(:a)) - Q = Zygote.literal_getfield(f, Val(:Q)) + A = f.A + a = f.a + Q = f.Q return A, a, Q end @@ -259,13 +245,13 @@ dim_out(f::ScalarOutputLGC) = 1 dim_in(f::ScalarOutputLGC) = size(f.A, 2) function get_fields(f::ScalarOutputLGC) - A = Zygote.literal_getfield(f, Val(:A)) - a = Zygote.literal_getfield(f, Val(:a)) - Q = Zygote.literal_getfield(f, Val(:Q)) + A = f.A + a = f.a + Q = f.Q return A, a, Q end -noise_cov(f::ScalarOutputLGC) = Zygote.literal_getfield(f, Val(:Q)) +noise_cov(f::ScalarOutputLGC) = f.Q function conditional_rand(ε::Real, f::ScalarOutputLGC, x::AbstractVector) A, a, Q = get_fields(f) @@ -323,16 +309,16 @@ dim_out(f::BottleneckLGC) = dim_out(f.fan_out) dim_in(f::BottleneckLGC) = size(f.H, 2) -noise_cov(f::BottleneckLGC) = noise_cov(Zygote.literal_getfield(f, Val(:fan_out))) +noise_cov(f::BottleneckLGC) = noise_cov(f.fan_out) function get_fields(f::BottleneckLGC) - H = Zygote.literal_getfield(f, Val(:H)) - h = Zygote.literal_getfield(f, Val(:h)) - fan_out = Zygote.literal_getfield(f, Val(:fan_out)) + H = f.H + h = f.h + fan_out = f.fan_out return H, h, fan_out end -fan_out(f::BottleneckLGC) = Zygote.literal_getfield(f, Val(:fan_out)) +fan_out(f::BottleneckLGC) = f.fan_out function conditional_rand(ε::AbstractVector{<:Real}, f::BottleneckLGC, x::AbstractVector) H, h, fan_out = get_fields(f) diff --git a/src/models/missings.jl b/src/models/missings.jl index fd2e2e9..69cab24 100644 --- a/src/models/missings.jl +++ b/src/models/missings.jl @@ -28,7 +28,7 @@ function transform_model_and_obs( model::LGSSM, y::AbstractVector{<:Union{Missing, T}}, ) where {T<:Union{<:AbstractVector, <:Real}} Σs_filled_in, y_filled_in = fill_in_missings( - zygote_friendly_map(noise_cov, emissions(model)), y, + map(noise_cov, emissions(model)), y, ) model_with_missings = replace_observation_noise_cov(model, Σs_filled_in) return model_with_missings, y_filled_in @@ -55,8 +55,6 @@ function _logpdf_volume_compensation(y::AbstractVector{<:Union{Missing, <:Real}} end -ChainRulesCore.@non_differentiable _logpdf_volume_compensation(y) - function fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) where {T} return _fill_in_missings(Σs, y) end @@ -79,50 +77,18 @@ function _fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) wh end function fill_in_missings(Σ::Diagonal, y::AbstractVector{<:Union{Missing, <:Real}}) - Σ_diag_filled, y_filled = fill_in_missings(Zygote.literal_getfield(Σ, Val(:diag)), y) + Σ_diag_filled, y_filled = fill_in_missings(Σ.diag, y) return Diagonal(Σ_diag_filled), y_filled end # We need to densify anyway, might as well do it here and save having to implement the # rrule twice. -function fill_in_missings(Σs::Fill, y::AbstractVector{Union{Missing, T}}) where {T} +function fill_in_missings(Σs::AbstractArray, y::AbstractVector{Union{Missing, T}}) where {T} return fill_in_missings(collect(Σs), y) end fill_in_missings(Σ::Diagonal, y::AbstractVector{<:Real}) = (Σ, y) -function ChainRulesCore.rrule( - ::typeof(_fill_in_missings), - Σs::Vector, - y::AbstractVector{Union{T, Missing}}, -) where {T} - function _fill_in_missings_rrule(Δ::Tangent) - ΔΣs, Δy_filled = Δ - - # The cotangent of a `Missing` doesn't make sense, so should be a `NoTangent`. - Δy = if Δy_filled isa AbstractZero - ZeroTangent() - else - Δy = Vector{Union{eltype(Δy_filled), ZeroTangent}}(undef, length(y)) - map!( - n -> y[n] === missing ? ZeroTangent() : Δy_filled[n], - Δy, eachindex(y), - ) - Δy - end - - # Fill in missing locations with zeros. Opting for type-stability to keep things - # simple. - ΔΣs = map( - n -> y[n] === missing ? zero(Σs[n]) : ΔΣs[n], - eachindex(y), - ) - - return NoTangent(), ΔΣs, Δy - end - return fill_in_missings(Σs, y), _fill_in_missings_rrule -end - get_zero(D::Int, ::Type{Vector{T}}) where {T} = zeros(T, D) get_zero(::Int, ::Type{T}) where {T<:SVector} = zeros(T) @@ -136,5 +102,3 @@ build_large_var(::T) where {T<:SMatrix} = T(_large_var_const() * I) build_large_var(S::T) where {T<:Diagonal} = T(fill(_large_var_const(), length(diag(S)))) build_large_var(::T) where {T<:Real} = T(_large_var_const()) - -ChainRulesCore.@non_differentiable build_large_var(::Any) diff --git a/src/space_time/pseudo_point.jl b/src/space_time/pseudo_point.jl index bcb90d5..8ffe9ce 100644 --- a/src/space_time/pseudo_point.jl +++ b/src/space_time/pseudo_point.jl @@ -54,9 +54,6 @@ function AbstractGPs.dtc(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVecto return logpdf(dtcify(z_r, fx), y) end -# This stupid rule saves an absurb amount of compute time. -ChainRulesCore.@non_differentiable count(::typeof(ismissing), yn) - """ elbo(fx::FiniteLTISDE, y::AbstractVector{<:Real}, z_r::AbstractVector) @@ -77,7 +74,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect # Transform a vector into a vector-of-vectors. y_vecs = restructure(y, lgssm.emissions) - tmp = zygote_friendly_map( + tmp = map( ((Σ, Cf_diag, marg_diag, yn), ) -> begin Σ_, _ = fill_in_missings(Σ, yn) return sum(diag(Σ_ \ (Cf_diag - marg_diag.P))) - @@ -89,8 +86,6 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect return logpdf(lgssm, y_vecs) - sum(tmp) / 2 end -Zygote.accum(x::NamedTuple{(:diag, )}, y::Diagonal) = Zygote.accum(x, (diag=y.diag, )) - function kernel_diagonals(k::DTCSeparable, x::RectilinearGrid) space_kernel = k.k.l time_kernel = k.k.r @@ -146,12 +141,12 @@ function lgssm_components(k_dtc::DTCSeparable, x::SpaceTimeGrid, storage::Storag Λu_Cuf = cholesky(Symmetric(K_space_z + 1e-12I)) \ K_space_zx # Construct approximately low-rank model spatio-temporal LGSSM. - As = _map(A -> kron(ident_M, A), As_t) - as = _map(a -> repeat(a, M), as_t) - Qs = _map(Q -> kron(K_space_z, Q), Qs_t) + As = map(A -> kron(ident_M, A), As_t) + as = map(a -> repeat(a, M), as_t) + Qs = map(Q -> kron(K_space_z, Q), Qs_t) Cs = Fill(Λu_Cuf, length(ts)) - cs = _map(h -> Fill(h, N), hs_t) # This should currently be zero. - Hs = _map(H -> kron(ident_M, H), Hs_t) + cs = map(h -> Fill(h, N), hs_t) # This should currently be zero. + Hs = map(H -> kron(ident_M, H), Hs_t) hs = Fill(Zeros(M), length(ts)) x0 = Gaussian(repeat(x0_t.m, M), kron(K_space_z, x0_t.P)) return As, as, Qs, (Cs, cs, Hs, hs), x0 @@ -178,16 +173,16 @@ function lgssm_components(k_dtc::DTCSeparable, x::RegularInTime, storage::Storag ident_M = my_I(eltype(storage), M) # Construct approximately low-rank model spatio-temporal LGSSM. - As = _map(kron, Fill(ident_M, N), As_t) - as = _map(a -> repeat(a, M), as_t) - Qs = _map(kron, Fill(K_space_z, N), Qs_t) + As = map(kron, Fill(ident_M, N), As_t) + as = map(a -> repeat(a, M), as_t) + Qs = map(kron, Fill(K_space_z, N), Qs_t) x_big = _reduce(vcat, x.vs) C__ = kernelmatrix(space_kernel, z_space, x_big) C = \(K_space_z_chol, C__) - Cs = partition(ChainRulesCore.ignore_derivatives(map(length, x.vs)), C) + Cs = partition(map(length, x.vs), C) cs = fill.(hs_t, length.(x.vs)) # This should currently be zero. - Hs = _map( + Hs = map( ((I, H_t), ) -> kron(I, H_t), zip(Fill(ident_M, N), Hs_t), ) @@ -211,22 +206,12 @@ function partition(lengths::AbstractVector{<:Integer}, A::Matrix{<:Real}) return map((s, d) -> collect(view(A, :, s:s+d-1)), starts, lengths) end -function ChainRulesCore.rrule( - ::typeof(partition), - lengths::AbstractVector{<:Integer}, - A::Matrix{<:Real}, -) - partition_pullback(::NoTangent) = NoTangent(), NoTangent(), NoTangent() - partition_pullback(Δ::Vector) = NoTangent(), NoTangent(), reduce(hcat, Δ) - return partition(lengths, A), partition_pullback -end - function build_emissions( (Cs, cs, Hs, hs)::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector}, Σs::AbstractVector, ) - Hst = _map(adjoint, Hs) - Cst = _map(adjoint, Cs) + Hst = map(adjoint, Hs) + Cst = map(adjoint, Cs) fan_outs = StructArray{LargeOutputLGC{eltype(Cs), eltype(cs), eltype(Σs)}}((Cst, cs, Σs)) return StructArray{BottleneckLGC{eltype(Hst), eltype(hs), eltype(fan_outs)}}((Hst, hs, fan_outs)) end @@ -378,16 +363,16 @@ end function dtc_post_emissions(k::ScaledKernel, x_new::AbstractVector, storage::StorageType) (Cs, cs, Hs, hs), Σs = dtc_post_emissions(k.kernel, x_new, storage) σ = sqrt(convert(eltype(storage_type), only(k.σ²))) - return (Cs, cs, _map(H->σ * H, Hs), _map(h->σ * h, hs)), _map(Σ->σ^2 * Σ, Σs) + return (Cs, cs, map(H->σ * H, Hs), map(h->σ * h, hs)), map(Σ->σ^2 * Σ, Σs) end function dtc_post_emissions(k::KernelSum, x_new::AbstractVector, storage::StorageType) post_emissions = dtc_post_emissions.(k.kernels, Ref(x_new), Ref(storage)) Cs_cs_Hs_hs = getindex.(post_emissions, 1) Σs = getindex.(post_emissions, 2) - Cs = _map(vcat, getindex.(Cs_cs_Hs_hs, 1)...) + Cs = map(vcat, getindex.(Cs_cs_Hs_hs, 1)...) cs = sum(getindex.(Cs_cs_Hs_hs, 2)) - Hs = _map(block_diagonal, getindex.(Cs_cs_Hs_hs, 3)...) - hs = _map(vcat, getindex.(Cs_cs_Hs_hs, 4)...) + Hs = map(block_diagonal, getindex.(Cs_cs_Hs_hs, 3)...) + hs = map(vcat, getindex.(Cs_cs_Hs_hs, 4)...) return (Cs, cs, Hs, hs), sum(Σs) end diff --git a/src/space_time/rectilinear_grid.jl b/src/space_time/rectilinear_grid.jl index cc7558f..cef8e00 100644 --- a/src/space_time/rectilinear_grid.jl +++ b/src/space_time/rectilinear_grid.jl @@ -15,9 +15,9 @@ struct RectilinearGrid{ xr::Txr end -get_space(x::RectilinearGrid) = Zygote.literal_getfield(x, Val(:xl)) +get_space(x::RectilinearGrid) = x.xl -get_times(x::RectilinearGrid) = Zygote.literal_getfield(x, Val(:xr)) +get_times(x::RectilinearGrid) = x.xr Base.size(X::RectilinearGrid) = (length(X.xl) * length(X.xr),) @@ -92,11 +92,9 @@ end function noise_var_to_time_form(x::RectilinearGrid, S::Diagonal{<:Real}) vs = restructure( diag(S), - ChainRulesCore.ignore_derivatives() do - Fill(length(get_space(x)), length(get_times(x))) - end, + Fill(length(get_space(x)), length(get_times(x))) ) - return zygote_friendly_map(v -> Diagonal(collect(v)), vs) + return map(v -> Diagonal(collect(v)), vs) end destructure(::RectilinearGrid, y::AbstractVector) = reduce(vcat, y) diff --git a/src/space_time/regular_in_time.jl b/src/space_time/regular_in_time.jl index c3abac1..452d831 100644 --- a/src/space_time/regular_in_time.jl +++ b/src/space_time/regular_in_time.jl @@ -12,9 +12,9 @@ struct RegularInTime{ vs::Tvs end -get_space(x::RegularInTime) = Zygote.literal_getfield(x, Val(:vs)) +get_space(x::RegularInTime) = x.vs -get_times(x::RegularInTime) = Zygote.literal_getfield(x, Val(:ts)) +get_times(x::RegularInTime) = x.ts Base.size(x::RegularInTime) = (sum(length, x.vs), ) @@ -78,18 +78,11 @@ function restructure(y::AbstractVector{T}, lengths::AbstractVector{<:Integer}) w end end -function ChainRulesCore.rrule( - ::typeof(restructure), y::Vector, lengths::AbstractVector{<:Integer}, -) - restructure_pullback(Δ::Vector) = NoTangent(), reduce(vcat, Δ), NoTangent() - return restructure(y, lengths), restructure_pullback -end - # Implementation specific to Fills for AD's sake. function restructure(y::Fill{<:Real}, lengths::AbstractVector{<:Integer}) - return map(l -> Fill(y.value, l), ChainRulesCore.ignore_derivatives(lengths)) + return map(l -> Fill(y.value, l), lengths) end function restructure(y::AbstractVector, emissions::StructArray) - return restructure(y, ChainRulesCore.ignore_derivatives(map(dim_out, emissions))) + return restructure(y, map(dim_out, emissions)) end diff --git a/src/space_time/to_gauss_markov.jl b/src/space_time/to_gauss_markov.jl index 5b6afb7..a627f09 100644 --- a/src/space_time/to_gauss_markov.jl +++ b/src/space_time/to_gauss_markov.jl @@ -1,6 +1,4 @@ -using ChainRulesCore my_I(T, N) = Matrix{T}(I, N, N) -ChainRulesCore.@non_differentiable my_I(args...) function lgssm_components(k::Separable, x::SpaceTimeGrid, storage) @@ -15,16 +13,16 @@ function lgssm_components(k::Separable, x::SpaceTimeGrid, storage) # Compute components of complete LGSSM. Nr = length(r) ident = my_I(eltype(storage), Nr) - As = _map(Base.Fix1(kron, ident), As_t) - as = _map(Base.Fix2(repeat, Nr), as_t) - Qs = _map(Base.Fix1(kron, Kr + ident_eps(1e-12)), Qs_t) + As = map(Base.Fix1(kron, ident), As_t) + as = map(Base.Fix2(repeat, Nr), as_t) + Qs = map(Base.Fix1(kron, Kr + ident_eps(1e-12)), Qs_t) emission_proj = _build_st_proj(emission_proj_t, Nr, ident) x0 = Gaussian(repeat(x0_t.m, Nr), kron(Kr, x0_t.P)) return As, as, Qs, emission_proj, x0 end function _build_st_proj((Hs, hs)::Tuple{AbstractVector, AbstractVector}, Nr::Integer, ident) - return (_map(H -> kron(ident, H), Hs), _map(h -> Fill(h, Nr), hs)) + return (map(H -> kron(ident, H), Hs), map(h -> Fill(h, Nr), hs)) end function build_prediction_obs_vars( diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl deleted file mode 100644 index fe9e67e..0000000 --- a/src/util/chainrules.jl +++ /dev/null @@ -1,409 +0,0 @@ -# This is all AD-related stuff. If you're looking to understand TemporalGPs, this can be -# safely ignored. - -using Zygote: accum, AContext -import ChainRulesCore: ProjectTo, rrule, _eltype_projectto - -# This context doesn't allow any globals. -struct NoContext <: Zygote.AContext end - -# Stupid implementation to obtain type-stability. -Zygote.cache(::NoContext) = (; cache_fields=nothing) - -# Stupid implementation. -Base.haskey(cx::NoContext, x) = false - -Zygote.accum_param(::NoContext, x, Δ) = Δ - -ChainRulesCore.@non_differentiable eltype(x) - -# Hacks to help the compiler out in very specific situations. -Zygote.accum(a::Array{T}, b::Array{T}) where {T<:Real} = a + b - -Zygote.accum(a::SArray{size, T}, b::SArray{size, T}) where {size, T<:Real} = a + b - -Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c) - -# ---------------------------------------------------------------------------- # -# StaticArrays # -# ---------------------------------------------------------------------------- # - -function rrule(::Type{T}, x::Tuple) where {T<:SArray} - SArray_rrule(Δ) = begin - (NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...)) - end - return T(x), SArray_rrule -end - -function rrule(::RuleConfig{>:HasReverseMode}, ::Type{SArray{S, T, N, L}}, x::NTuple{L, T}) where {S, T, N, L} - SArray_rrule(::AbstractZero) = NoTangent(), NoTangent() - SArray_rrule(Δ::NamedTuple{(:data,)}) = NoTangent(), Δ.data - SArray_rrule(Δ::StaticArray{S}) = NoTangent(), Δ.data - return SArray{S, T, N, L}(x), SArray_rrule -end - -function rrule( - config::RuleConfig{>:HasReverseMode}, ::Type{X}, x::NTuple{L, Any}, -) where {S, T, N, L, X <: SArray{S, T, N, L}} - new_x, convert_pb = rrule_via_ad(config, StaticArrays.convert_ntuple, T, x) - _, pb = rrule_via_ad(config, SArray{S, T, N, L}, new_x) - SArray_rrule(::AbstractZero) = NoTangent(), NoTangent() - SArray_rrule(Δ::SArray{S}) = SArray_rrule(Tangent{X}(data=Δ.data)) - SArray_rrule(Δ::SizedArray{S}) = SArray_rrule(Tangent{X}(data=Tuple(Δ.data))) - SArray_rrule(Δ::AbstractVector) = SArray_rrule(Tangent{X}(data=Tuple(Δ))) - SArray_rrule(Δ::Matrix) = SArray_rrule(Tangent{X}(data=Δ)) - function SArray_rrule(Δ::Tangent{X,<:NamedTuple{(:data,)}}) where {X} - _, Δnew_x = pb(backing(Δ)) - _, ΔT, Δx = convert_pb(Tuple(Δnew_x)) - return ΔT, Δx - end - return SArray{S, T, N, L}(x), SArray_rrule -end - -function rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}} - y = collect(x) - proj = ProjectTo(y) - collect_rrule(Δ) = NoTangent(), proj(Δ) - return y, collect_rrule -end - -function rrule(::typeof(vcat), A::SVector{DA}, B::SVector{DB}) where {DA, DB} - function vcat_rrule(Δ) # SVector - ΔA = Δ[SVector{DA}(1:DA)] - ΔB = Δ[SVector{DB}((DA+1):(DA+DB))] - return NoTangent(), ΔA, ΔB - end - return vcat(A, B), vcat_rrule -end - -@non_differentiable vcat(x::Zeros, y::Zeros) - -# Implementation of the matrix exponential that assumes one doesn't require access to the -# gradient w.r.t. `A`, only `t`. The former is a bit compute-intensive to get at, while the -# latter is very cheap. - -time_exp(A, t) = exp(A * t) -function rrule(::typeof(time_exp), A, t::Real) - B = exp(A * t) - time_exp_rrule(Ω̄) = NoTangent(), NoTangent(), sum(Ω̄ .* (A * B)) - return B, time_exp_rrule -end - - -# Following is taken from https://github.com/JuliaArrays/FillArrays.jl/pull/153 -# Until a solution has been found this code will be needed here. -""" - ProjectTo(::Fill) -> ProjectTo{Fill} - ProjectTo(::Ones) -> ProjectTo{NoTangent} - -Most FillArrays arrays store one number, and so their gradients under automatic -differentiation represent the variation of this one number. - -The exception is those like `Ones` and `Zeros` whose type fixes their value, -which have no graidient. -""" -ProjectTo(x::Fill) = ProjectTo{Fill}(; element = ProjectTo(FillArrays.getindex_value(x)), axes = axes(x)) - -ProjectTo(::AbstractFill{Bool}) = ProjectTo{NoTangent}() # Bool is always regarded as categorical - -ProjectTo(::Zeros) = ProjectTo{NoTangent}() -ProjectTo(::Ones) = ProjectTo{NoTangent}() - -(project::ProjectTo{Fill})(x::Fill) = x -function (project::ProjectTo{Fill})(dx::AbstractArray) - for d in 1:max(ndims(dx), length(project.axes)) - size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(axes_x, size(dx))) - end - Fill(sum(dx), project.axes) -end - -function (project::ProjectTo{Fill})(dx::Tangent{<:Fill}) - # This would need a definition for length(::NoTangent) to be safe: - # for d in 1:max(length(dx.axes), length(project.axes)) - # length(get(dx.axes, d, 1)) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(dx.axes, size(dx))) - # end - Fill(dx.value / prod(length, project.axes), project.axes) -end -function (project::ProjectTo{Fill})(dx::Tangent{Any,<:NamedTuple{(:value, :axes)}}) - Fill(dx.value / prod(length, project.axes), project.axes) -end - -# Yet another thing that should not happen -function Zygote.accum(x::Fill, y::NamedTuple{(:value, :axes)}) - Fill(x.value + y.value, x.axes) -end - -# We have an alternative map to avoid Zygote untouchable specialisation on map. -_map(f, args...) = map(f, args...) - -function rrule(::Type{<:Fill}, x, sz) - Fill_rrule(Δ::Union{Fill,Thunk}) = NoTangent(), FillArrays.getindex_value(unthunk(Δ)), NoTangent() - Fill_rrule(Δ::Tangent{T,<:NamedTuple{(:value, :axes)}}) where {T} = NoTangent(), Δ.value, NoTangent() - Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - Fill_rrule(Δ::Tangent{T,<:NTuple}) where {T} = NoTangent(), sum(Δ), NoTangent() - function Fill_rrule(Δ::AbstractArray) - # all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value") - # sum(Δ) - # TODO Fix this rule, or what seems to be a downstream bug. - return NoTangent(), sum(Δ), NoTangent() - end - Fill(x, sz), Fill_rrule -end - -function rrule(::typeof(Base.collect), x::Fill) - y = collect(x) - proj = ProjectTo(x) - function collect_Fill_rrule(Δ) - NoTangent(), proj(Δ) - end - return y, collect_Fill_rrule -end - - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, x::Fill) - y_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value) - function _map_Fill_rrule(Δ::AbstractArray) - all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value") - Δf, Δx_el = back(first(Δ)) - NoTangent(), Δf, Fill(Δx_el, axes(x)) - end - function _map_Fill_rrule(Δ::Union{Thunk,Fill,Tangent}) - Δf, Δx_el = back(unthunk(Δ).value) - return NoTangent(), Δf, Fill(Δx_el, axes(x)) - end - _map_Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - return Fill(y_el, axes(x)), _map_Fill_rrule -end - -# Somehow needed to avoid the _map -> map indirection -function _map(f, xs::Fill...) - all(==(axes(first(xs))), axes.(xs)) || error("All axes should be the same") - Fill(f(FillArrays.getindex_value.(xs)...), axes(first(xs))) -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, xs::Fill...) - z_el, back = ChainRulesCore.rrule_via_ad(config, f, FillArrays.getindex_value.(xs)...) - function _map_Fill_rrule(Δ) - Δf, Δxs_el... = back(unthunk(Δ).value) - return NoTangent(), Δf, Fill.(Δxs_el, axes.(xs))... - end - return Fill(z_el, axes(first(xs))), _map_Fill_rrule -end -### Same thing for `StructArray` - - -function rrule(::typeof(step), x::T) where {T<:StepRangeLen} - function step_StepRangeLen_rrule(Δ) - return NoTangent(), Tangent{T}(step=Δ) - end - return step(x), step_StepRangeLen_rrule -end - -function rrule(::typeof(Base.getindex), x::SVector{1,1}, n::Int) - getindex_SArray_rrule(Δ) = NoTangent(), SVector{1}(Δ), NoTangent() - return x[n], getindex_SArray_rrule -end - -# -# AD-free pullbacks for a few things. These are primitives that will be used to write the -# gradients. -# - -function cholesky_rrule(Σ::Symmetric{<:Real, <:StridedMatrix}) - C = cholesky(Σ) - function cholesky_pullback(Δ::NamedTuple) - U, Ū = C.U, Δ.factors - Σ̄ = Ū * U' - Σ̄ = LinearAlgebra.copytri!(Σ̄, 'U') - Σ̄ = ldiv!(U, Σ̄) - BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) - - for n in diagind(Σ̄) - Σ̄[n] /= 2 - end - return NoTangent(), UpperTriangular(Σ̄) - end - return C, cholesky_pullback -end - -function cholesky_rrule(S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} - C = cholesky(S) - function cholesky_pullback(Δ::Tangent) - U, Ū = C.U, Δ.factors - Σ̄ = SMatrix{N,N}(Symmetric(Ū * U')) - Σ̄ = U \ (U \ Σ̄)' - Σ̄ = Σ̄ - Diagonal(Σ̄) / 2 - return NoTangent(), Tangent{typeof(S)}(data=SMatrix{N, N}(UpperTriangular(Σ̄))) - end - return C, cholesky_pullback -end - -function rrule(::typeof(cholesky), S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} - return cholesky_rrule(S) -end - -function Zygote.accum(a::UpperTriangular, b::UpperTriangular) - return UpperTriangular(Zygote.accum(a.data, b.data)) -end - -Zygote.accum(D::Diagonal{<:Real}, U::UpperTriangular{<:Real}) = UpperTriangular(D + U.data) -Zygote.accum(a::UpperTriangular, b::Diagonal) = Zygote.accum(b, a) - -Zygote._symmetric_back(Δ::UpperTriangular{<:Any, <:SArray}, uplo) = Δ -function Zygote._symmetric_back(Δ::SMatrix{N, N}, uplo) where {N} - if uplo === 'U' - return SMatrix{N, N}(UpperTriangular(Δ) + UpperTriangular(Δ') - Diagonal(Δ)) - else - return SMatrix{N, N}(LowerTriangular(Δ) + LowerTriangular(Δ') - Diagonal(Δ)) - end -end - -# Temporary hacks. - -using Zygote: literal_getproperty, literal_indexed_iterate, literal_getindex - -function Zygote._pullback(::NoContext, ::typeof(*), A::Adjoint, B::AbstractMatrix) - times_pullback(::Nothing) = nothing - times_pullback(Δ) = nothing, Adjoint(B * Δ'), A' * Δ - return A * B, times_pullback -end - -function Zygote._pullback(::NoContext, ::typeof(literal_getproperty), C::Cholesky, ::Val{:U}) - function literal_getproperty_pullback(Δ) - return (nothing, (uplo=nothing, info=nothing, factors=UpperTriangular(Δ))) - end - literal_getproperty_pullback(Δ::Nothing) = nothing - return literal_getproperty(C, Val(:U)), literal_getproperty_pullback -end - -Zygote.accum(x::Adjoint...) = Adjoint(Zygote.accum(map(parent, x)...)) - -Zygote.accum(x::NamedTuple{(:parent,)}, y::Adjoint) = (parent=accum(x.parent, y.parent),) - -function Zygote.accum(A::UpperTriangular{<:Any, <:SMatrix{P}}, B::SMatrix{P, P}) where {P} - return Zygote.accum(SMatrix{P, P}(A), B) -end - -function Zygote.accum(B::SMatrix{P, P}, A::UpperTriangular{<:Any, <:SMatrix{P}}) where {P} - return Zygote.accum(B, SMatrix{P, P}(A)) -end - -function Zygote.accum(a::Tangent{T}, b::NamedTuple) where {T} - return Zygote.accum(a, Tangent{T}(; b...)) -end - -function Base.:(-)( - A::UpperTriangular{<:Real, <:SMatrix{N, N}}, B::Diagonal{<:Real, <:SVector{N}}, -) where {N} - return UpperTriangular(A.data - B) -end - -function _symmetric_back(Δ, uplo) - L, U, D = LowerTriangular(Δ), UpperTriangular(Δ), Diagonal(Δ) - return collect(uplo == Symbol(:U) ? U .+ transpose(L) - D : L .+ transpose(U) - D) -end -_symmetric_back(Δ::Diagonal, uplo) = Δ -_symmetric_back(Δ::UpperTriangular, uplo) = collect(uplo == Symbol('U') ? Δ : transpose(Δ)) -_symmetric_back(Δ::LowerTriangular, uplo) = collect(uplo == Symbol('U') ? transpose(Δ) : Δ) - -function ChainRulesCore.rrule(::Type{Symmetric}, X::StridedMatrix{<:Real}, uplo=:U) - function Symmetric_rrule(Δ) - ΔX = Δ isa AbstractZero ? NoTangent() : _symmetric_back(Δ, uplo) - return NoTangent(), ΔX, NoTangent() - end - return Symmetric(X, uplo), Symmetric_rrule -end - -function rrule(::Type{StructArray}, x::T) where {T<:Union{Tuple,NamedTuple}} - y = StructArray(x) - StructArray_rrule(Δ::Thunk) = StructArray_rrule(unthunk(Δ)) - function StructArray_rrule(Δ) - return NoTangent(), Tangent{T}(StructArrays.components(backing.(Δ))...) - end - function StructArray_rrule(Δ::AbstractArray) - return NoTangent(), Tangent{T}((getproperty.(Δ, p) for p in propertynames(y))...) - end - return y, StructArray_rrule -end -function rrule(::Type{StructArray{X}}, x::T) where {X,T<:Union{Tuple,NamedTuple}} - y = StructArray{X}(x) - function StructArray_rrule(Δ) - return NoTangent(), Tangent{T}(StructArrays.components(backing.(Δ))...) - end - function StructArray_rrule(Δ::Tangent) - return NoTangent(), Tangent{T}(Δ.components...) - end - return y, StructArray_rrule -end - - -# `getproperty` accesses the `components` field of a `StructArray`. This rule makes that -# explicit. -# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(Base.getproperty), x::StructArray, ::Val{p}, -# ) where {p} -# value, pb = rrule_via_ad(config, Base.getproperty, StructArrays.components(x), Val(p)) -# function getproperty_rrule(Δ) -# return NoTangent(), Tangent{typeof(x)}(components=pb(Δ)[2]), NoTangent() -# end -# return value, getproperty_rrule -# end - -function time_ad(label::String, f, x...) - println("primal: ", label) - return @time f(x...) -end - -time_ad(::Val{:disabled}, label::String, f, x...) = f(x...) - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(time_ad), label::String, f, x...) - println("Forward: ", label) - out, pb = @time rrule_via_ad(config, f, x...) - function time_ad_pullback(Δ) - println("Pullback: ", label) - Δinputs = @time pb(Δ) - return (NoTangent(), NoTangent(), NoTangent(), Δinputs...) - end - return out, time_ad_pullback -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(\), A::Diagonal{<:Real}, x::Vector{<:Real}) - out, pb = rrule_via_ad(config, (a, x) -> a .\ x, diag(A), x) - function ldiv_pullback(Δ) - if Δ isa AbstractZero - return NoTangent() - else - _, Δa, Δx = pb(Δ) - return NoTangent(), Diagonal(Δa), Δx - end - end - return out, ldiv_pullback -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(\), A::Diagonal{<:Real}, x::Matrix{<:Real}) - out, pb = rrule_via_ad(config, (a, x) -> a .\ x, diag(A), x) - function ldiv_pullback(Δ) - if Δ isa AbstractZero - return NoTangent() - else - _, Δa, Δx = pb(Δ) - return NoTangent(), Diagonal(Δa), Δx - end - end - return out, ldiv_pullback -end - -using Base.Broadcast: broadcasted - -function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(\), a::Vector{<:Real}, x::Vector{<:Real}) - y = a .\ x - broadcast_ldiv_pullback(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - broadcast_ldiv_pullback(Δ::AbstractVector{<:Real}) = NoTangent(), NoTangent(), -(Δ .* y ./ a), a .\ Δ - return y, broadcast_ldiv_pullback -end - -function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(\), a::Vector{<:Real}, x::Matrix{<:Real}) - y = a .\ x - broadcast_ldiv_pullback(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - broadcast_ldiv_pullback(Δ::AbstractMatrix{<:Real}) = NoTangent(), NoTangent(), -vec(sum(Δ .* y ./ a; dims=2)), a .\ Δ - return y, broadcast_ldiv_pullback -end diff --git a/src/util/gaussian.jl b/src/util/gaussian.jl index 25c531f..5f9667d 100644 --- a/src/util/gaussian.jl +++ b/src/util/gaussian.jl @@ -20,9 +20,9 @@ end dim(x::Gaussian) = length(x.m) -AbstractGPs.mean(x::Gaussian) = Zygote.literal_getfield(x, Val(:m)) +AbstractGPs.mean(x::Gaussian) = x.m -AbstractGPs.cov(x::Gaussian) = Zygote.literal_getfield(x, Val(:P)) +AbstractGPs.cov(x::Gaussian) = x.P AbstractGPs.var(x::Gaussian{<:AbstractVector}) = diag(cov(x)) @@ -70,13 +70,6 @@ storage_type(::Gaussian{<:Vector{T}}) where {T<:Real} = ArrayStorage(T) storage_type(::Gaussian{<:SVector{D, T}}) where {D, T<:Real} = SArrayStorage(T) storage_type(::Gaussian{T}) where {T<:Real} = ScalarStorage(T) -function ChainRulesCore.rrule(::Type{<:Gaussian}, m, P) - proj_P = ProjectTo(P) - Gaussian_pullback(::ZeroTangent) = NoTangent(), NoTangent(), NoTangent() - Gaussian_pullback(Δ) = NoTangent(), Δ.m, proj_P(Δ.P) - return Gaussian(m, P), Gaussian_pullback -end - Base.length(x::Gaussian) = 0 # Zero-adjoint initialisation for the benefit of `scan`. diff --git a/src/util/harmonise.jl b/src/util/harmonise.jl deleted file mode 100644 index 5989d89..0000000 --- a/src/util/harmonise.jl +++ /dev/null @@ -1,122 +0,0 @@ -# All of this functionality is utilised only in the AD tests. Can be safely ignored if -# you're concerned with understanding how TemporalGPs works. - -using ChainRulesCore: backing - -# Functionality to test my testing functionality. -are_harmonised(a::Any, b::AbstractZero) = true -are_harmonised(a::AbstractZero, b::Any) = true -are_harmonised(a::AbstractZero, b::AbstractZero) = true - -are_harmonised(a::Number, b::Number) = true - -function are_harmonised(a::AbstractArray, b::AbstractArray) - return all(ab -> are_harmonised(ab...), zip(a, b)) -end - -are_harmonised(a::Tuple, b::Tuple) = all(ab -> are_harmonised(ab...), zip(a, b)) - -function are_harmonised(a::Tangent{<:Any, <:Tuple}, b::Tangent{<:Any, <:Tuple}) - return all(ab -> are_harmonised(ab...), zip(a, b)) -end - -function are_harmonised( - a::Tangent{<:Any, <:NamedTuple}, - b::Tangent{<:Any, <:NamedTuple}, -) - return all( - name -> are_harmonised(getproperty(a, name), getproperty(b, name)), - union(fieldnames(typeof(a)), fieldnames(typeof(b))), - ) -end - -# Functionality to make it possible to compare different kinds of differentials. It's not -# entirely clear how much sense this makes mathematically, but it seems to work in a -# practical sense at the minute. -harmonise(a::Any, b::AbstractZero) = (a, b) -harmonise(a::AbstractZero, b::Any) = (a, b) -harmonise(a::AbstractZero, b::AbstractZero) = (a, b) - -# Resolve ambiguity. -harmonise(a::AbstractZero, b::Tangent{<:Any, <:NamedTuple}) = (a, b) - -harmonise(a::Number, b::Number) = (a, b) - -function harmonise(a::Tuple, b::Tuple) - vals = map(harmonise, a, b) - return first.(vals), last.(vals) -end -function harmonise(a::AbstractArray, b::AbstractArray) - vals = map(harmonise, a, b) - return first.(vals), last.(vals) -end - -function harmonise(a::Adjoint, b::Adjoint) - vals = harmonise(a.parent, b.parent) - return Tangent{Any}(parent=vals[1]), Tangent{Any}(parent=vals[2]) -end - -function harmonise(a::Tangent{<:Any, <:Tuple}, b::Tangent{<:Any, <:Tuple}) - vals = map(harmonise, backing(a), backing(b)) - return (Tangent{Any}(first.(vals)...), Tangent{Any}(last.(vals)...)) -end - -harmonise(a::Tangent{<:Any, <:Tuple}, b::Tuple) = harmonise(a, Tangent{Any}(b...)) - -harmonise(a::Tuple, b::Tangent{<:Any, <:Tuple}) = harmonise(Tangent{Any}(a...), b) - -function harmonise( - a::Tangent{<:Any, <:NamedTuple{names}}, - b::Tangent{<:Any, <:NamedTuple{names}}, -) where {names} - vals = map(harmonise, values(backing(a)), values(backing(b))) - a_harmonised = Tangent{Any}(; NamedTuple{names}(first.(vals))...) - b_harmonised = Tangent{Any}(; NamedTuple{names}(last.(vals))...) - return (a_harmonised, b_harmonised) -end - -function harmonise(a::Tangent{<:Any, <:NamedTuple}, b::Tangent{<:Any, <:NamedTuple}) - - # Compute names missing / present in each data structure. - a_names = propertynames(backing(a)) - b_names = propertynames(backing(b)) - mutual_names = intersect(a_names, b_names) - all_names = (union(a_names, b_names)..., ) - a_missing_names = setdiff(all_names, a_names) - b_missing_names = setdiff(all_names, b_names) - - # Construct `Tangent`s with the same names. - a_vals = map(name -> name ∈ a_names ? getproperty(a, name) : ZeroTangent(), all_names) - b_vals = map(name -> name ∈ b_names ? getproperty(b, name) : ZeroTangent(), all_names) - a_unioned_names = Tangent{Any}(; NamedTuple{all_names}(a_vals)...) - b_unioned_names = Tangent{Any}(; NamedTuple{all_names}(b_vals)...) - - # Harmonise those composites. - return harmonise(a_unioned_names, b_unioned_names) -end - -function harmonise(a::Tangent{<:Any, <:NamedTuple}, b) - b_names = fieldnames(typeof(b)) - vals = map(name -> getfield(b, name), b_names) - return harmonise( - a, Tangent{Any}(; NamedTuple{b_names}(vals)...), - ) -end - -harmonise(x::AbstractMatrix, y::NamedTuple{(:diag,)}) = (diag(x), y.diag) -function harmonise(x::AbstractVector, y::NamedTuple{(:value,:axes)}) - x = reduce(Zygote.accum, x) - (x, y.value) -end - - -harmonise(a::Tangent{<:Any, <:NamedTuple}, b::AbstractZero) = (a, b) - -harmonise(a, b::Tangent{<:Any, <:NamedTuple}) = reverse(harmonise(b, a)) - -# Special-cased handling for `Adjoint`s. Due to our usual AD setup, a differential for an -# Adjoint can be represented either by a matrix or a `Tangent`. Both ought to `to_vec` to -# the same thing though, so this should be fine for now, if a little unsatisfactory. -function harmonise(a::Adjoint, b::Tangent{<:Adjoint, <:NamedTuple}) - return Tangent{Any}(parent=parent(a)), b -end diff --git a/src/util/regular_data.jl b/src/util/regular_data.jl index d9c59ff..50fe17b 100644 --- a/src/util/regular_data.jl +++ b/src/util/regular_data.jl @@ -23,11 +23,3 @@ Base.size(x::RegularSpacing) = (x.N,) Base.getindex(x::RegularSpacing, n::Int) = x.t0 + (n - 1) * x.Δt Base.step(x::RegularSpacing) = x.Δt - -function ChainRulesCore.rrule(::Type{TR}, t0::T, Δt::T, N::Int) where {TR<:RegularSpacing, T<:Real} - function RegularSpacing_rrule(Δ) - Δ = unthunk(Δ) - return NoTangent(), Δ.t0, Δ.Δt, NoTangent() - end - return RegularSpacing(t0, Δt, N), RegularSpacing_rrule -end diff --git a/src/util/scan.jl b/src/util/scan.jl index 8ce67db..72694e6 100644 --- a/src/util/scan.jl +++ b/src/util/scan.jl @@ -27,67 +27,6 @@ function scan_emit(f, xs, state, idx) return (ys, state) end -function rrule(config::RuleConfig, ::typeof(scan_emit), f, xs, init_state, idx) - state = init_state - (y, state) = f(state, _getindex(xs, idx[1])) - - # Heuristic Warning: assume all ys and states have the same type as the 1st. - ys = Vector{typeof(y)}(undef, length(xs)) - states = Vector{typeof(state)}(undef, length(xs)) - - ys[idx[1]] = y - states[idx[1]] = state - - for t in idx[2:end] - (y, state) = f(state, _getindex(xs, t)) - ys[t] = y - states[t] = state - end - - function scan_emit_rrule(Δ) - Δ isa AbstractZero && return ntuple(_->NoTangent(), 5) - Δys = Δ[1] - Δstate = Δ[2] - - # This is a hack to handle the case that Δstate=nothing, and the "look at the - # type of the first thing" heuristic breaks down. - Δstate = Δ[2] isa AbstractZero ? _get_zero_adjoint(states[idx[end]]) : Δ[2] - - T = length(idx) - if T > 1 - _, Δstate, Δx = step_pullback( - config, f, states[idx[T-1]], _getindex(xs, idx[T]), Δys[idx[T]], Δstate, - ) - Δxs = get_adjoint_storage(xs, idx[T], Δx) - for t in reverse(2:(T - 1)) - a = _getindex(xs, idx[t]) - b = Δys[idx[t]] - c = states[idx[t-1]] - _, Δstate, Δx = step_pullback( - config, f, c, a, b, Δstate, - ) - Δxs = _accum_at(Δxs, idx[t], Δx) - end - _, Δstate, Δx = step_pullback( - config, f, init_state, _getindex(xs, idx[1]), Δys[idx[1]], Δstate, - ) - Δxs = _accum_at(Δxs, idx[1], Δx) - return NoTangent(), NoTangent(), Δxs, Δstate, NoTangent() - else - _, Δstate, Δx = step_pullback( - config, f, init_state, _getindex(xs, idx[1]), Δys[idx[1]], Δstate, - ) - Δxs = get_adjoint_storage(xs, idx[1], Δx) - return NoTangent(), NoTangent(), Δxs, Δstate, NoTangent() - end - end - return (ys, state), scan_emit_rrule -end - -@inline function step_pullback(config::RuleConfig, f::Tf, state, x, Δy, Δstate) where {Tf} - _, pb = rrule_via_ad(config, f, state, x) - return pb((Δy, Δstate)) -end # Helper functionality for constructing appropriate differentials. @@ -100,78 +39,3 @@ _getindex(x, idx::Int) = getindex(x, idx) _getindex(x::Base.Iterators.Zip, idx::Int) = __getindex(x.is, idx) __getindex(x::Tuple{Any}, idx::Int) = (_getindex(x[1], idx), ) __getindex(x::Tuple, idx::Int) = (_getindex(x[1], idx), __getindex(Base.tail(x), idx)...) - - -_get_zero_adjoint(::Any) = ZeroTangent() - -# Vector. In all probability, only one of these methods is necessary. - -function get_adjoint_storage(x::Array, n::Int, Δx::T) where {T} - x̄ = Array{T}(undef, size(x)) - x̄[n] = Δx - return x̄ -end - -@inline function _accum_at(Δxs::Vector{T}, n::Int, Δx::T) where {T} - Δxs[n] = Δx - return Δxs -end - -@inline function _accum_at(Δxs::Vector{T}, n::Int, Δx::AbstractMatrix) where {T<:AbstractMatrix} - Δxs[n] = convert(T, Δx) - return Δxs -end - -# If there's nothing, there's nothing to do. -_accum_at(::AbstractZero, ::Int, ::AbstractZero) = NoTangent() - -# Zip -function get_adjoint_storage(x::Base.Iterators.Zip, n::Int, Δx::Tangent) - return (is=map((x_, Δx_) -> get_adjoint_storage(x_, n, Δx_), x.is, backing(Δx)),) -end - -# This is a work-around for `map` not inferring for some unknown reason. Very odd... -function _accum_at(Δxs::NamedTuple{(:is, )}, n::Int, Δx::Tangent) - return (is=__accum_at(Δxs.is, n, backing(Δx)), ) -end -__accum_at(Δxs::Tuple{Any}, n::Int, Δx::Tuple{Any}) = (_accum_at(Δxs[1], n, Δx[1]), ) -function __accum_at(Δxs::Tuple, n::Int, Δx::Tuple) - return (_accum_at(Δxs[1], n, Δx[1]), __accum_at(Base.tail(Δxs), n, Base.tail(Δx))...) -end -# Fill - -get_adjoint_storage(::Fill, ::Int, init) = (value=init, axes=NoTangent()) - -# T is not parametrized since T can be SMatrix and Δx isa SizedMatrix -@inline function _accum_at( - Δxs::NamedTuple{(:value, :axes)}, ::Int, Δx, -) - return (value=Zygote.accum(Δxs.value, Δx), axes=NoTangent()) -end - - - -# StructArray - -function get_adjoint_storage(x::StructArray, n::Int, Δx::Tangent) - init_arrays = map( - (x_, Δx_) -> get_adjoint_storage(x_, n, Δx_), getfield(x, :components), ChainRulesCore.backing(Δx), - ) - return (components = init_arrays, ) -end - -function get_adjoint_storage(x::StructArray, n::Int, Δx::StaticVector) - init_arrays = map( - (x_, Δx_) -> get_adjoint_storage(x_, n, Δx_), getfield(x, :components), Δx, - ) - return (components = init_arrays, ) -end - -# _accum_at for StructArrayget_adjoint_storage(xs, idx[T], Δx) -function _accum_at(Δxs::NamedTuple{(:components,)}, n::Int, Δx::Tangent) - return (components = map((Δy, y) -> _accum_at(Δy, n, y), Δxs.components, backing(Δx)), ) -end - -function _accum_at(Δxs::NamedTuple{(:components,)}, n::Int, Δx::SVector) - return (components = map((Δy, y) -> _accum_at(Δy, n, y), Δxs.components, backing(Δx)), ) -end diff --git a/src/util/zygote_friendly_map.jl b/src/util/zygote_friendly_map.jl deleted file mode 100644 index ab0eba2..0000000 --- a/src/util/zygote_friendly_map.jl +++ /dev/null @@ -1,78 +0,0 @@ -""" - zygote_friendly_map(f, x) - -This version of map is a bit weird. It makes slightly stronger assumptions about the nature -of what you're allowed to pass in to it than `Base.map` does and, in return, you get much -improved performance when used in conjunction with `Zygote`. - -# Assumptions. -- No globals are used in `f`. This means that `TemporalGPs.NoContext` can be employed. -- `f` has no fields. If you've got data to share across elements, use a `Fill`. -- Similarly, `f` has no mutable state (follows from the above). -- `f` doesn't mutate its argument. -""" -zygote_friendly_map(f, x) = dense_zygote_friendly_map(f, x) - -function dense_zygote_friendly_map(f::Tf, x) where {Tf} - - # Perform first iteration. - y_1 = f(_getindex(x, 1)) - - # Allocate for outputs. - ys = Array{typeof(y_1)}(undef, size(x)) - ys[1] = y_1 - - # Perform remainder of iterations. - for n in 2:length(x) - ys[n] = f(_getindex(x, n)) - end - - return ys -end - -function ChainRulesCore.rrule(::typeof(dense_zygote_friendly_map), f::Tf, x) where {Tf} - - # Perform first iteration. - y_1, pb_1 = rrule_via_ad(Zygote.ZygoteRuleConfig(NoContext()), f, _getindex(x, 1)) - - # Allocate for outputs. - ys = Array{typeof(y_1)}(undef, size(x)) - ys[1] = y_1 - - # Allocate for pullbacks. - pbs = Array{typeof(pb_1)}(undef, size(x)) - pbs[1] = pb_1 - - for n in 2:length(x) - y, pb = rrule_via_ad(Zygote.ZygoteRuleConfig(NoContext()), f, _getindex(x, n)) - ys[n] = y - pbs[n] = pb - end - - function zygote_friendly_map_pullback(Δ) - Δ isa AbstractZero && return NoTangent(), NoTangent(), NoTangent() - - # Do first iteration. - Δx_1 = pbs[1](Δ[1]) - - # Allocate for cotangents. - Δxs = get_adjoint_storage(x, 1, Δx_1[2]) - - for n in 2:length(x) - Δx = pbs[n](Δ[n]) - Δxs = _accum_at(Δxs, n, Δx[2]) - end - - return NoTangent(), NoTangent(), Δxs - end - - return ys, zygote_friendly_map_pullback -end - -zygote_friendly_map(f, x::Fill) = map(f, x) - -function zygote_friendly_map( - f, x::Base.Iterators.Zip{<:Tuple{Vararg{Fill, N}}}, -) where {N} - return zygote_friendly_map(f, Fill(map(first, x.is), length(x))) -end diff --git a/test/Project.toml b/test/Project.toml index d3eabb4..9028d8f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,10 +2,7 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -13,17 +10,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.5" BenchmarkTools = "0.5" BlockDiagonals = "0.1" -ChainRulesCore = "1" -ChainRulesTestUtils = "1.10" FillArrays = "0.13.0 - 0.13.7" -FiniteDifferences = "0.12" KernelFunctions = "0.10" StaticArrays = "1" StructArrays = "0.6" -Zygote = "0.6" diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index a4f2057..5abc9d2 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -1,6 +1,5 @@ using KernelFunctions using KernelFunctions: kappa -using ChainRulesTestUtils using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components using Test @@ -44,20 +43,6 @@ end println("lti_sde:") @testset "lti_sde" begin - @testset "block_diagonal" begin - A = randn(2, 2) - B = randn(3, 3) - C = randn(5, 5) - test_rrule(TemporalGPs.block_diagonal, A, B, C; check_inferred=false) - test_rrule( - TemporalGPs.block_diagonal, - SMatrix{2,2}(A), - SMatrix{3,3}(B), - SMatrix{5,5}(C); - check_inferred=false, - ) - end - @testset "SimpleKernel parameter types" begin storages = ( (name="dense storage Float64", val=ArrayStorage(Float64)), @@ -208,48 +193,6 @@ println("lti_sde:") @test last(m_and_v) ≈ var(fx) @test logpdf(fx, y) ≈ logpdf(fx_naive, y) end - - @testset "check args to_vec properly" begin - k_vec, k_from_vec = to_vec(kernel.val) - @test typeof(k_from_vec(k_vec)) == typeof(kernel.val) - - storage_vec, storage_from_vec = to_vec(storage.val) - @test typeof(storage_from_vec(storage_vec)) == typeof(storage.val) - - σ²_vec, σ²_from_vec = to_vec(σ².val) - @test typeof(σ²_from_vec(σ²_vec)) == typeof(σ².val) - - t_vec, t_from_vec = to_vec(t.val) - @test typeof(t_from_vec(t_vec)) == typeof(t.val) - end - - # Just need to ensure we can differentiate through construction properly. - if isnothing(kernel.to_vec_grad) - @test_broken false # "Gradient tests are not passing" - continue - elseif kernel.to_vec_grad - test_zygote_grad_finite_differences_compatible( - _construction_tester, - f_naive, - storage.val, - σ².val, - t.val; - check_inferred=false, - rtol=1e-6, - atol=1e-6, - ) - else - test_zygote_grad( - _construction_tester, - f_naive, - storage.val, - σ².val, - t.val; - check_inferred=false, - rtol=1e-6, - atol=1e-6, - ) - end end end end diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index 19662b6..97dd73c 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -16,14 +16,12 @@ using TemporalGPs: ScalarOutputLGC, Forward, Reverse, - ordering, - NoContext + ordering using KernelFunctions using Test using Random: MersenneTwister using LinearAlgebra using StructArrays -using Zygote, StaticArrays println("lgssm:") @testset "lgssm" begin @@ -91,7 +89,6 @@ println("lgssm:") @testset "step_marginals" begin @inferred step_marginals(x, model[1]) - adjoint_test(step_marginals, (x, model[1])) if storage.val isa SArrayStorage && TEST_ALLOC check_adjoint_allocations(step_marginals, (x, model[1])) end @@ -99,7 +96,6 @@ println("lgssm:") @testset "step_logpdf" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_logpdf(args...) - adjoint_test(step_logpdf, args) if storage.val isa SArrayStorage && TEST_ALLOC check_adjoint_allocations(step_logpdf, args) end @@ -107,7 +103,6 @@ println("lgssm:") @testset "step_filter" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_filter(args...) - adjoint_test(step_filter, args) if storage.val isa SArrayStorage && TEST_ALLOC check_adjoint_allocations(step_filter, args) end @@ -115,7 +110,6 @@ println("lgssm:") @testset "invert_dynamics" begin args = (x, x, model[1].transition) @inferred invert_dynamics(args...) - adjoint_test(invert_dynamics, args) if storage.val isa SArrayStorage && TEST_ALLOC check_adjoint_allocations(invert_dynamics, args) end @@ -123,7 +117,6 @@ println("lgssm:") @testset "step_posterior" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_posterior(args...) - adjoint_test(step_posterior, args) if storage.val isa SArrayStorage && TEST_ALLOC check_adjoint_allocations(step_posterior, args) end @@ -134,7 +127,6 @@ println("lgssm:") rng, model; rtol=1e-5, atol=1e-5, - context=NoContext(), max_primal_allocs=25, max_forward_allocs=25, max_backward_allocs=25, diff --git a/test/models/linear_gaussian_conditionals.jl b/test/models/linear_gaussian_conditionals.jl index 9e0e7ba..57f8a00 100644 --- a/test/models/linear_gaussian_conditionals.jl +++ b/test/models/linear_gaussian_conditionals.jl @@ -58,8 +58,6 @@ println("linear_gaussian_conditionals:") # Check that everything infers and AD gives the right answer. @inferred posterior_and_lml(x, model, y_missing) - # BROKEN: gradients with Zygote look fine but are failing because of ChainRulesTestUtils checks see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/270 - # test_zygote_grad(posterior_and_lml, x, model, y_missing) end end @@ -101,11 +99,6 @@ println("linear_gaussian_conditionals:") # Check that they give roughly the same answer. @test x_post_vanilla ≈ x_post_large @test lml_vanilla ≈ lml_large rtol=1e-8 atol=1e-8 - - # Check that everything infers and AD gives the right answer. - @inferred posterior_and_lml(x, model, y_missing) - x̄ = adjoint_test(posterior_and_lml, (x, model, y_missing)) - @test x̄[2].Q isa NamedTuple{(:diag, )} end end @@ -201,11 +194,6 @@ println("linear_gaussian_conditionals:") # Check that they give roughly the same answer. @test x_post_vanilla ≈ x_post_large rtol=1e-8 atol=1e-8 @test lml_vanilla ≈ lml_large rtol=1e-8 atol=1e-8 - - # Check that everything infers and AD gives the right answer. - @inferred posterior_and_lml(x, model, y_missing) - x̄ = adjoint_test(posterior_and_lml, (x, model, y_missing)) - @test x̄[2].fan_out.Q isa NamedTuple{(:diag, )} end end end diff --git a/test/models/missings.jl b/test/models/missings.jl index 3b4084e..4b42596 100644 --- a/test/models/missings.jl +++ b/test/models/missings.jl @@ -4,8 +4,6 @@ using TemporalGPs: replace_observation_noise_cov, transform_model_and_obs using Random: randperm -using ChainRulesTestUtils -using Zygote: Context @info "missings:" @testset "missings" begin @@ -122,14 +120,6 @@ using Zygote: Context @test logpdf(new_posterior, new_y) ≈ logpdf(post, y_missing) rtol=1e-4 end - - # Only test the bits of AD that we haven't tested before. - @testset "AD: transform_model_and_obs" begin - fdm = central_fdm(2, 1) - adjoint_test(fill_in_missings, (model.emissions.Q, y_missing); fdm=fdm) - adjoint_test(replace_observation_noise_cov, (model, model.emissions.Q)) - adjoint_test(transform_model_and_obs, (model, y_missing); fdm=fdm) - end end storages = ( @@ -179,7 +169,6 @@ using Zygote: Context # Check logpdf and inference run, infer, and play nicely with AD. @inferred logpdf(model, y_missing) - test_zygote_grad_finite_differences_compatible(y -> logpdf(model, y) ⊢ NoTangent(), y_missing) @inferred posterior(model, y_missing) end end; diff --git a/test/models/model_test_utils.jl b/test/models/model_test_utils.jl index 5d7f3e4..9c6553f 100644 --- a/test/models/model_test_utils.jl +++ b/test/models/model_test_utils.jl @@ -1,4 +1,3 @@ -using ChainRulesTestUtils: ChainRulesTestUtils, rand_tangent using FillArrays using Random: AbstractRNG using TemporalGPs: @@ -92,13 +91,6 @@ function random_gaussian(rng::AbstractRNG, dim::Int, s::StorageType) return Gaussian(random_vector(rng, dim, s), random_nice_psd_matrix(rng, dim, s)) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, d::T) where {T<:Gaussian} - return Tangent{T}( - m=rand_tangent(rng, d.m), - P=random_nice_psd_matrix(rng, length(d.m), storage_type(d)), - ) -end - # Generation of SmallOutputLGC. @@ -185,16 +177,6 @@ function random_ti_gmm(rng::AbstractRNG, ordering, Dlat::Int, N::Int, s::Storage return GaussMarkovModel(ordering, As, as, Qs, x0) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, gmm::T) where {T<:GaussMarkovModel} - return Tangent{T}( - ordering = nothing, - As = rand_tangent(rng, gmm.As), - as = rand_tangent(rng, gmm.as), - Qs = gmm_Qs_tangent(rng, gmm.Qs, storage_type(gmm)), - x0 = rand_tangent(rng, gmm.x0), - ) -end - function gmm_Qs_tangent( rng::AbstractRNG, Qs::T, storage_type::StorageType, ) where {T<:Vector{<:AbstractMatrix}} @@ -302,41 +284,6 @@ function random_lgssm( return LGSSM(transitions, emissions) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, ssm::T) where {T<:LGSSM} - Hs = ssm.emissions.A - hs = ssm.emissions.a - Σs = ssm.emissions.Q - return Tangent{T}( - transitions = rand_tangent(rng, ssm.transitions), - emissions = Tangent{typeof(ssm.emissions)}(components=( - A=rand_tangent(rng, Hs), - a=rand_tangent(rng, hs), - Q=gmm_Qs_tangent(rng, Σs, storage_type(ssm)), - )), - ) -end - -# function random_tv_scalar_lgssm(rng::AbstractRNG, Dlat::Int, N::Int, storage) -# return ScalarLGSSM(random_tv_lgssm(rng, Dlat, 1, N, storage)) -# end - -# function random_ti_scalar_lgssm(rng::AbstractRNG, Dlat::Int, N::Int, storage) -# return ScalarLGSSM(random_ti_lgssm(rng, Dlat, 1, N, storage)) -# end - -# function random_tv_posterior_lgssm(rng::AbstractRNG, Dlat::Int, Dobs::Int, N::Int, storage) -# lgssm = random_tv_lgssm(rng, Dlat, Dobs, N, storage) -# y = rand(rng, lgssm) -# Σs = map(_ -> random_nice_psd_matrix(rng, Dobs, storage), eachindex(y)) -# return posterior(lgssm, y, Σs) -# end - -# function random_ti_posterior_lgssm(rng::AbstractRNG, Dlat::Int, Dobs::Int, N::Int, storage) -# lgssm = random_ti_lgssm(rng, Dlat, Dobs, N, storage) -# y = rand(rng, lgssm) -# Σs = Fill(random_nice_psd_matrix(rng, Dobs, storage), length(lgssm)) -# return posterior(lgssm, y, Σs) -# end # @@ -391,16 +338,3 @@ function validate_dims(model::LGSSM) return nothing end - -# function __verify_model_properties(model, Dlat, Dobs, N, storage_type) -# @test is_of_storage_type(model, storage_type) -# @test length(model) == N -# @test dim_obs(model) == Dobs -# @test dim_latent(model) == Dlat -# validate_dims(model) -# return nothing -# end - -# function __verify_model_properties(model, Dlat, N, storage_type) -# return __verify_model_properties(model, Dlat, 1, N, storage_type) -# end diff --git a/test/runtests.jl b/test/runtests.jl index 0620b70..bb6b7d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,10 +22,7 @@ if OUTER_GROUP == "test" || OUTER_GROUP == "all" using AbstractGPs using BlockDiagonals - using ChainRulesCore - using ChainRulesTestUtils using FillArrays - using FiniteDifferences using LinearAlgebra using KernelFunctions using Random @@ -33,11 +30,8 @@ if OUTER_GROUP == "test" || OUTER_GROUP == "all" using StructArrays using TemporalGPs - using Zygote - using AbstractGPs: var - using TemporalGPs: AbstractLGSSM, _filter, NoContext - using Zygote: Context, _pullback + using TemporalGPs: AbstractLGSSM, _filter include("test_util.jl") @@ -48,10 +42,7 @@ if OUTER_GROUP == "test" || OUTER_GROUP == "all" if TEST_GROUP == "util" || GROUP == "all" println("util:") @testset "util" begin - include(joinpath("util", "harmonise.jl")) include(joinpath("util", "scan.jl")) - include(joinpath("util", "zygote_friendly_map.jl")) - include(joinpath("util", "chainrules.jl")) include(joinpath("util", "gaussian.jl")) include(joinpath("util", "mul.jl")) include(joinpath("util", "regular_data.jl")) diff --git a/test/space_time/pseudo_point.jl b/test/space_time/pseudo_point.jl index e90f037..5c0c238 100644 --- a/test/space_time/pseudo_point.jl +++ b/test/space_time/pseudo_point.jl @@ -102,8 +102,6 @@ using Test elbo_sde = elbo(fx, y, z_r) @test elbo_naive ≈ elbo_sde rtol=1e-6 - test_zygote_grad_finite_differences_compatible((y, z_r) -> elbo(fx, y, z_r), y, z_r) - # Compute approximate posterior marginals naively. f_approx_post_naive = posterior(VFE(f_naive(z_naive)), fx_naive, y) x_pr = RectilinearGrid(x_pr_r, get_times(x.val)) diff --git a/test/space_time/rectilinear_grid.jl b/test/space_time/rectilinear_grid.jl index fd21e76..efdaa87 100644 --- a/test/space_time/rectilinear_grid.jl +++ b/test/space_time/rectilinear_grid.jl @@ -1,15 +1,6 @@ using Random using TemporalGPs: RectilinearGrid, SpaceTimeGrid -function FiniteDifferences.to_vec(x::RectilinearGrid) - v, tup_from_vec = to_vec((x.xl, x.xr)) - function RectilinearGrid_from_vec(v) - tup = tup_from_vec(v) - return RectilinearGrid(tup[1], tup[2]) - end - return v, RectilinearGrid_from_vec -end - @testset "rectilinear_grid" begin rng = MersenneTwister(123456) Nl = 5 diff --git a/test/space_time/to_gauss_markov.jl b/test/space_time/to_gauss_markov.jl index 002bb9d..9a7e17b 100644 --- a/test/space_time/to_gauss_markov.jl +++ b/test/space_time/to_gauss_markov.jl @@ -6,17 +6,6 @@ using TemporalGPs: RectilinearGrid, Separable, is_of_storage_type Nt = 5 Nt_pr = 2 - @testset "restructure" begin - adjoint_test( - x -> TemporalGPs.restructure(x, [26, 24, 20, 30]), (randn(100), ); - check_inferred=false, - ) - adjoint_test( - x -> TemporalGPs.restructure(x, [26, 24, 20, 30]), (Fill(randn(), 100), ); - check_inferred=false, - ) - end - k_sep = 1.5 * Separable( SEKernel() ∘ ScaleTransform(1.4), Matern32Kernel() ∘ ScaleTransform(1.3), ) @@ -95,27 +84,5 @@ using TemporalGPs: RectilinearGrid, Separable, is_of_storage_type end end - - # # I'm not checking correctness here, just that it runs. No custom adjoints have been - # # written that are involved in this that aren't tested, so there should be no need - # # to check correctness. - # @testset "logpdf AD" begin - # out, pb = Zygote._pullback(NoContext(), logpdf, ft_sde, y) - # pb(rand_zygote_tangent(out)) - # end - # # adjoint_test(logpdf, (ft_sde, y); fdm=central_fdm(2, 1), check_inferred=false) - - # if t.val isa RegularSpacing - # adjoint_test( - # (r, Δt, y) -> begin - # x = RectilinearGrid(r, RegularSpacing(t.val.t0, Δt, Nt)) - # _f = to_sde(GP(k.val, GPC())) - # _ft = _f(x, σ².val...) - # return logpdf(_ft, y) - # end, - # (r, t.val.Δt, y_sde); - # check_inferred=false, - # ) - # end end end diff --git a/test/test_util.jl b/test/test_util.jl index 083f930..cf99ed6 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -1,8 +1,5 @@ using AbstractGPs using BlockDiagonals -using ChainRulesCore: backing, ZeroTangent, NoTangent, Tangent -using ChainRulesTestUtils: ChainRulesTestUtils, test_approx, rand_tangent, test_rrule, ⊢, @ignore_derivatives -using FiniteDifferences using FillArrays using LinearAlgebra using Random: AbstractRNG, MersenneTwister @@ -13,7 +10,6 @@ using TemporalGPs: AbstractLGSSM, ElementOfLGSSM, Gaussian, - harmonise, Forward, Reverse, GaussMarkovModel, @@ -31,458 +27,9 @@ using TemporalGPs: scan_emit, ε_randn using Test -using Zygote -using Zygote: Context -# Make FiniteDifferences work with some of the types in this package. Shame this isn't -# automated... - -import FiniteDifferences: to_vec - -test_zygote_grad(f, args...; check_inferred=false, kwargs...) = test_rrule(Zygote.ZygoteRuleConfig(), f, args...; rrule_f=rrule_via_ad, check_inferred, kwargs...) - -function test_zygote_grad_finite_differences_compatible(f, args...; kwargs...) - x_vec, from_vec = to_vec(args) - function finite_diff_compatible_f(x::AbstractVector) - return @ignore_derivatives(f)(@ignore_derivatives(from_vec)(x)...) - end - test_zygote_grad(finite_diff_compatible_f ⊢ NoTangent(), x_vec; testset_name="test_rrule: $(f) on $(typeof.(args))", kwargs...) -end - -function to_vec(x::Fill) - x_vec, back_vec = to_vec(FillArrays.getindex_value(x)) - function Fill_from_vec(x_vec) - return Fill(back_vec(x_vec), axes(x)) - end - return x_vec, Fill_from_vec -end - -function to_vec(x::Union{Zeros, Ones}) - return Vector{eltype(x)}(undef, 0), _ -> x -end - -# I'M OVERRIDING FINITEDIFFERENCES DEFINITION HERE. THIS IS BAD. -function to_vec(x::Diagonal) - v, diag_from_vec = to_vec(x.diag) - Diagonal_from_vec(v) = Diagonal(diag_from_vec(v)) - return v, Diagonal_from_vec -end - -# function to_vec(x::T) where {T<:NamedTuple} -# isempty(fieldnames(T)) && throw(error("Expected some fields. None found.")) -# vecs_and_backs = map(name->to_vec(getfield(x, name)), fieldnames(T)) -# vecs, backs = first.(vecs_and_backs), last.(vecs_and_backs) -# x_vec, back = to_vec(vecs) -# function namedtuple_to_vec(x′_vec) -# vecs′ = back(x′_vec) -# x′s = map((back, vec)->back(vec), backs, vecs′) -# return (; zip(fieldnames(T), x′s)...) -# end -# return x_vec, namedtuple_to_vec -# end - -function to_vec(x::T) where {T<:StaticArray} - x_dense = collect(x) - x_vec, back_vec = to_vec(x_dense) - function StaticArray_to_vec(x_vec) - return T(back_vec(x_vec)) - end - return x_vec, StaticArray_to_vec -end - -function to_vec(x::Adjoint{<:Any, T}) where {T<:StaticVector} - x_vec, back = to_vec(Matrix(x)) - Adjoint_from_vec(x_vec) = Adjoint(T(conj!(vec(back(x_vec))))) - return x_vec, Adjoint_from_vec -end - -function to_vec(::Tuple{}) - empty_tuple_from_vec(::AbstractVector) = () - return Bool[], empty_tuple_from_vec -end - -function to_vec(x::StructArray{T}) where {T} - x_vec, x_fields_from_vec = to_vec(StructArrays.components(x)) - function StructArray_from_vec(x_vec) - x_field_vecs = x_fields_from_vec(x_vec) - return StructArray{T}(Tuple(x_field_vecs)) - end - return x_vec, StructArray_from_vec -end - -function to_vec(x::TemporalGPs.LGSSM) - x_vec, from_vec = to_vec((x.transitions, x.emissions)) - function LGSSM_from_vec(x_vec) - (transition, emission) = from_vec(x_vec) - return LGSSM(transition, emission) - end - return x_vec, LGSSM_from_vec -end - -function to_vec(x::ElementOfLGSSM) - x_vec, from_vec = to_vec((x.transition, x.emission)) - function ElementOfLGSSM_from_vec(x_vec) - (transition, emission) = from_vec(x_vec) - return ElementOfLGSSM(x.ordering, transition, emission) - end - return x_vec, ElementOfLGSSM_from_vec -end - -function ChainRulesTestUtils.test_approx(actual::Tangent{<:Fill}, expected, msg=""; kwargs...) - test_approx(actual.value, expected.value, msg; kwargs...) -end - -function to_vec(x::PeriodicKernel) - x, to_r = to_vec(x.r) - function PeriodicKernel_from_vec(x) - return PeriodicKernel(;r=exp.(to_r(x))) - end - log.(x), PeriodicKernel_from_vec -end - -to_vec(x::T) where {T} = generic_struct_to_vec(x) - -# This is a copy from FiniteDifferences.jl without the try catch -function generic_struct_to_vec(x::T) where {T} - Base.isstructtype(T) || throw(error("Expected a struct type")) - isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types - val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T)) - vals = first.(val_vecs_and_backs) - backs = last.(val_vecs_and_backs) - v, vals_from_vec = to_vec(vals) - function structtype_from_vec(v::Vector{<:Real}) - val_vecs = vals_from_vec(v) - vals = map((b, v) -> b(v), backs, val_vecs) - return T(vals...) - end - return v, structtype_from_vec -end - -to_vec(x::TemporalGPs.RectilinearGrid) = generic_struct_to_vec(x) - -function to_vec(x::AbstractRNG) - return Bool[], _ -> x -end - -Base.zero(x::AbstractRNG) = x - -function to_vec(f::GP) - gp_vec, t_from_vec = to_vec((f.mean, f.kernel)) - function GP_from_vec(v) - m, k = t_from_vec(v) - return GP(m, k) - end - return gp_vec, GP_from_vec -end - -function to_vec(k::ConstantKernel) - c, c_to_vec = to_vec(k.c) - function ConstantKernel_from_vec(c) - return ConstantKernel(c=first(c_to_vec(c))) - end - c, ConstantKernel_from_vec -end - -Base.zero(x::AbstractGPs.ZeroMean) = x -Base.zero(x::Kernel) = x -Base.zero(x::TemporalGPs.LTISDE) = x -Base.zero(x::GP) = x -Base.zero(x::AbstractGPs.MeanFunction) = x - -function to_vec(X::BlockDiagonal) - Xs = blocks(X) - Xs_vec, Xs_from_vec = to_vec(Xs) - - function BlockDiagonal_from_vec(Xs_vec) - Xs = Xs_from_vec(Xs_vec) - return BlockDiagonal(Xs) - end - - return Xs_vec, BlockDiagonal_from_vec -end - -function to_vec(x::RegularSpacing) - RegularSpacing_from_vec(v) = RegularSpacing(v[1], v[2], x.N) - return [x.t0, x.Δt], RegularSpacing_from_vec -end - -# Ensure that to_vec works for the types that we care about in this package. -@testset "custom FiniteDifferences stuff" begin - @testset "NamedTuple" begin - a, b = 5.0, randn(2) - t = (a=a, b=b) - nt_vec, back = to_vec(t) - @test nt_vec isa Vector{Float64} - @test back(nt_vec) == t - end - @testset "Fill" begin - @testset "$(typeof(val))" for val in [5.0, randn(3)] - x = Fill(val, 5) - x_vec, back = to_vec(x) - @test x_vec isa Vector{Float64} - @test back(x_vec) == x - end - end - @testset "Zeros{T}" for T in [Float32, Float64] - x = Zeros{T}(4) - x_vec, back = to_vec(x) - @test x_vec isa Vector{eltype(x)} - @test back(x_vec) == x - end - @testset "gaussian" begin - @testset "Gaussian" begin - x = TemporalGPs.Gaussian(randn(3), randn(3, 3)) - x_vec, back = to_vec(x) - @test back(x_vec) == x - end - end - @testset "to_vec(::SmallOutputLGC)" begin - A = randn(2, 2) - a = randn(2) - Q = randn(2, 2) - model = SmallOutputLGC(A, a, Q) - model_vec, model_from_vec = to_vec(model) - @test model_vec isa Vector{<:Real} - @test model_from_vec(model_vec) == model - end - @testset "to_vec(::GaussMarkovModel)" begin - N = 11 - A = [randn(2, 2) for _ in 1:N] - a = [randn(2) for _ in 1:N] - Q = [randn(2, 2) for _ in 1:N] - H = [randn(3, 2) for _ in 1:N] - h = [randn(3) for _ in 1:N] - x0 = TemporalGPs.Gaussian(randn(2), randn(2, 2)) - gmm = TemporalGPs.GaussMarkovModel(Forward(), A, a, Q, x0) - - gmm_vec, gmm_from_vec = to_vec(gmm) - @test gmm_vec isa Vector{<:Real} - @test gmm_from_vec(gmm_vec) == gmm - end - @testset "StructArray" begin - x = StructArray([Gaussian(randn(2), randn(2, 2)) for _ in 1:10]) - x_vec, x_from_vec = to_vec(x) - @test x_vec isa Vector{<:Real} - @test x_from_vec(x_vec) == x - end - @testset "to_vec(::LGSSM)" begin - N = 11 - - # Build GaussMarkovModel. - A = [randn(2, 2) for _ in 1:N] - a = [randn(2) for _ in 1:N] - Q = [randn(2, 2) for _ in 1:N] - x0 = Gaussian(randn(2), randn(2, 2)) - gmm = GaussMarkovModel(Forward(), A, a, Q, x0) - - # Build LGSSM. - H = [randn(3, 2) for _ in 1:N] - h = [randn(3) for _ in 1:N] - Σ = [randn(3, 3) for _ in 1:N] - model = TemporalGPs.LGSSM(gmm, StructArray(map(SmallOutputLGC, H, h, Σ))) - - model_vec, model_from_vec = to_vec(model) - @test model_from_vec(model_vec) == model - end - @testset "to_vec(::BlockDiagonal)" begin - Ns = [3, 5, 1] - Xs = map(N -> randn(N, N), Ns) - X = BlockDiagonal(Xs) - - X_vec, X_from_vec = to_vec(X) - @test X_vec isa Vector{<:Real} - @test X_from_vec(X_vec) == X - end -end - -my_zero(x) = zero(x) -my_zero(x::AbstractArray{<:Real}) = zero(x) -my_zero(x::AbstractArray) = map(my_zero, x) -my_zero(x::Tuple) = map(my_zero, x) - -# My version of isapprox -function fd_isapprox(x_ad::Nothing, x_fd, rtol, atol) - return fd_isapprox(x_fd, my_zero(x_fd), rtol, atol) -end -function fd_isapprox(x_ad::AbstractArray, x_fd::AbstractArray, rtol, atol) - return all(fd_isapprox.(x_ad, x_fd, rtol, atol)) -end -function fd_isapprox(x_ad::Real, x_fd::Real, rtol, atol) - return isapprox(x_ad, x_fd; rtol=rtol, atol=atol) -end -function fd_isapprox(x_ad::NamedTuple, x_fd, rtol, atol) - f = (x_ad, x_fd)->fd_isapprox(x_ad, x_fd, rtol, atol) - return all([f(getfield(x_ad, key), getfield(x_fd, key)) for key in keys(x_ad)]) -end -function fd_isapprox(x_ad::Tuple, x_fd::Tuple, rtol, atol) - return all(map((x, x′)->fd_isapprox(x, x′, rtol, atol), x_ad, x_fd)) -end -function fd_isapprox(x_ad::Dict, x_fd::Dict, rtol, atol) - return all([fd_isapprox(get(()->nothing, x_ad, key), x_fd[key], rtol, atol) for - key in keys(x_fd)]) -end -function fd_isapprox(x::Gaussian, y::Gaussian, rtol, atol) - return isapprox(x.m, y.m; rtol=rtol, atol=atol) && - isapprox(x.P, y.P; rtol=rtol, atol=atol) -end -function fd_isapprox(x::Real, y::ZeroTangent, rtol, atol) - return fd_isapprox(x, zero(x), rtol, atol) -end -fd_isapprox(x::ZeroTangent, y::Real, rtol, atol) = fd_isapprox(y, x, rtol, atol) - -function fd_isapprox(x_ad::T, x_fd::T, rtol, atol) where {T<:NamedTuple} - f = (x_ad, x_fd)->fd_isapprox(x_ad, x_fd, rtol, atol) - return all([f(getfield(x_ad, key), getfield(x_fd, key)) for key in keys(x_ad)]) -end - -function fd_isapprox(x::T, y::T, rtol, atol) where {T} - if !isstructtype(T) - throw(ArgumentError("Non-struct types are not supported by this fallback.")) - end - - return all(n -> fd_isapprox(getfield(x, n), getfield(y, n), rtol, atol), fieldnames(T)) -end - -function adjoint_test( - f, ȳ, x::Tuple, ẋ::Tuple; - rtol=1e-6, - atol=1e-6, - fdm=central_fdm(5, 1; max_range=1e-3), - test=true, - check_inferred=TEST_TYPE_INFER, - context=Context(), - kwargs..., -) - # Compute = using Zygote. - y, pb = Zygote.pullback(f, x...) - - # Check type inference if requested. - if check_inferred - # @descend only works if you `using Cthulhu`. - # @descend Zygote._pullback(context, f, x...) - # @descend pb(ȳ) - - # @code_warntype Zygote._pullback(context, f, x...) - # @code_warntype pb(ȳ) - @inferred Zygote._pullback(context, f, x...) - @inferred pb(ȳ) - end - x̄ = pb(ȳ) - x̄_ad, ẋ_ad = harmonise(Zygote.wrap_chainrules_input(x̄), ẋ) - inner_ad = dot(x̄_ad, ẋ_ad) - - # Approximate = using FiniteDifferences. - # x̄_fd = j′vp(fdm, f, ȳ, x...) - ẏ = jvp(fdm, f, zip(x, ẋ)...) - - ȳ_fd, ẏ_fd = harmonise(Zygote.wrap_chainrules_input(ȳ), ẏ) - inner_fd = dot(ȳ_fd, ẏ_fd) - # Check that Zygote didn't modify the forwards-pass. - test && @test fd_isapprox(y, f(x...), rtol, atol) - - # Check for approximate agreement in "inner-products". - test && @test fd_isapprox(inner_ad, inner_fd, rtol, atol) - - return x̄ -end - -function adjoint_test(f, input::Tuple; kwargs...) - Δoutput = rand_zygote_tangent(f(input...)) - return adjoint_test(f, Δoutput, input; kwargs...) -end - -function adjoint_test(f, Δoutput, input::Tuple; kwargs...) - ∂input = map(rand_zygote_tangent, input) - return adjoint_test(f, Δoutput, input, ∂input; kwargs...) -end - -function print_adjoints(adjoint_ad, adjoint_fd, rtol, atol) - @show typeof(adjoint_ad), typeof(adjoint_fd) - - # println("ad") - # display(adjoint_ad) - # println() - - # println("fd") - # display(adjoint_fd) - # println() - - adjoint_ad, adjoint_fd = to_vec(adjoint_ad)[1], to_vec(adjoint_fd)[1] - println("atol is $atol, rtol is $rtol") - println("ad, fd, abs, rel") - abs_err = abs.(adjoint_ad .- adjoint_fd) - rel_err = abs_err ./ adjoint_ad - display([adjoint_ad adjoint_fd abs_err rel_err]) - println() -end - -using BenchmarkTools - -# Also checks the forwards-pass because it's helpful. -function check_adjoint_allocations( - f, Δoutput, input::Tuple; - context=NoContext(), - max_primal_allocs=0, - max_forward_allocs=0, - max_backward_allocs=0, - kwargs..., -) - _, pb = _pullback(context, f, input...) - - primal_allocs = allocs(@benchmark($f($input...); samples=1, evals=1)) - forward_allocs = allocs( - @benchmark(_pullback($context, $f, $input...); samples=1, evals=1), - ) - backward_allocs = allocs(@benchmark $pb($Δoutput) samples=1 evals=1) - - # primal_allocs = allocs(@benchmark($f($input...))) - # forward_allocs = allocs( - # @benchmark(_pullback($context, $f, $input...)), - # ) - # backward_allocs = allocs(@benchmark $pb($Δoutput)) - - # @show primal_allocs - # @show forward_allocs - # @show backward_allocs - - @test primal_allocs <= max_primal_allocs - @test forward_allocs <= max_forward_allocs - @test backward_allocs <= max_backward_allocs -end - -function check_adjoint_allocations(f, input::Tuple; kwargs...) - return check_adjoint_allocations(f, rand_zygote_tangent(f(input...)), input; kwargs...) -end - -function benchmark_adjoint(f, ȳ, args...; disp=false) - disp && println("primal") - primal = @benchmark($f($args...); samples=1, evals=1) - if disp - display(primal) - println() - end - - disp && println("pullback generation") - forward_pass = @benchmark(Zygote.pullback($f, $args...); samples=1, evals=1) - if disp - display(forward_pass) - println() - end - - y, back = Zygote.pullback(f, args...) - - disp && println("pullback evaluation") - reverse_pass = @benchmark($back($ȳ); samples=1, evals=1) - if disp - display(reverse_pass) - println() - end - - return primal, forward_pass, reverse_pass -end - function test_interface( rng::AbstractRNG, conditional::AbstractLGC, x::Gaussian; check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, atol=1e-6, rtol=1e-6, kwargs..., @@ -494,22 +41,11 @@ function test_interface( @test length(y) == dim_out(conditional) args = (TemporalGPs.ε_randn(rng, conditional), conditional, x_val) check_inferred && @inferred conditional_rand(args...) - if check_adjoints - test_zygote_grad( - conditional_rand, args...; - check_inferred, rtol, atol, - ) - end - if check_allocs - check_adjoint_allocations(conditional_rand, args; kwargs...) - end end @testset "predict" begin @test predict(x, conditional) isa Gaussian check_inferred && @inferred predict(x, conditional) - check_adjoints && adjoint_test(predict, (x, conditional); kwargs...) - check_allocs && check_adjoint_allocations(predict, (x, conditional); kwargs...) end conditional isa ScalarOutputLGC || @testset "predict_marginals" begin @@ -525,21 +61,6 @@ function test_interface( args = (x, conditional, y) @test posterior_and_lml(args...) isa Tuple{Gaussian, Real} check_inferred && @inferred posterior_and_lml(args...) - if check_adjoints - (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) - ∂args = map(rand_tangent, args) - adjoint_test(posterior_and_lml, (Δx, Δlml), args, ∂args) - adjoint_test(posterior_and_lml, (Δx, nothing), args, ∂args) - adjoint_test(posterior_and_lml, (nothing, Δlml), args, ∂args) - adjoint_test(posterior_and_lml, (nothing, nothing), args, ∂args) - end - if check_allocs - (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) - check_adjoint_allocations(posterior_and_lml, (Δx, Δlml), args; kwargs...) - check_adjoint_allocations(posterior_and_lml, (nothing, Δlml), args; kwargs...) - check_adjoint_allocations(posterior_and_lml, (Δx, nothing), args; kwargs...) - check_adjoint_allocations(posterior_and_lml, (nothing, nothing), args; kwargs...) - end end end @@ -565,19 +86,6 @@ function test_interface( @test length(y_no_missing) == length(ssm) check_inferred && @inferred rand(rng, ssm) rng = MersenneTwister(123456) - if check_adjoints - # We need the whole scan_emit machinery to test the adjoint of rand - @test_broken 1 == 0 - # It seems test_rrule cannot deal good with `rng` at the moment - # test_zygote_grad(rng, ssm; check_inferred, rtol, atol) do rng, model - # iterable = zip(ε_randn(rng, model), model) - # init = rand(rng, x0(model)) - # return scan_emit(step_rand, iterable, init, eachindex(model)) - # end - end - if check_allocs - check_adjoint_allocations(rand, (rng, ssm); kwargs...) - end end @testset "basics" begin @@ -591,15 +99,6 @@ function test_interface( @test xs isa AbstractVector{<:Gaussian} @test length(xs) == length(ssm) check_inferred && @inferred marginals(ssm) - if check_adjoints - # We need to test the whole scan_emit to avoid throwing a state. - test_zygote_grad(ssm; check_inferred, rtol, atol) do model - scan_emit(step_marginals, model, x0(model), eachindex(model)) - end - end - if check_allocs - check_adjoint_allocations(marginals, (ssm, ); kwargs...) - end end @testset "$(data.name)" for data in [ @@ -614,11 +113,6 @@ function test_interface( @test lml isa Real @test is_of_storage_type(lml, storage_type(ssm)) _check_inferred && @inferred logpdf(ssm, y) - if check_adjoints - test_zygote_grad(ssm, y; check_inferred, rtol, atol) do model, y - scan_emit(step_logpdf, zip(model, y), x0(model), eachindex(model)) - end - end end @testset "_filter" begin xs = _filter(ssm, y) @@ -626,65 +120,13 @@ function test_interface( @test xs isa AbstractVector{<:Gaussian} @test length(xs) == length(ssm) _check_inferred && @inferred _filter(ssm, y) - if check_adjoints - test_zygote_grad(ssm, y; check_inferred, rtol, atol) do model, y - scan_emit(step_filter, zip(model, y), x0(model), eachindex(model)) - end - end end @testset "posterior" begin posterior_ssm = posterior(ssm, y) @test length(posterior_ssm) == length(ssm) @test ordering(posterior_ssm) != ordering(ssm) _check_inferred && @inferred posterior(ssm, y) - if check_adjoints - test_zygote_grad(posterior, ssm, y; check_inferred, rtol, atol) - end - end - - # Hack to only run the AD tests if requested. - @testset "adjoints" for _ in (check_adjoints ? [1] : []) - if check_allocs - check_adjoint_allocations(_filter, (ssm, y); kwargs...) - check_adjoint_allocations(posterior, (ssm, y); kwargs...) - end end end end end - -# This is unfortunately needed to make ChainRulesTestUtils comparison works. -# See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/271 -Base.zero(::Forward) = Forward() -Base.zero(::Reverse) = Reverse() - -_diag(x) = diag(x) -_diag(x::Real) = x - -function FiniteDifferences.rand_tangent(rng::AbstractRNG, A::StaticArray) - return map(x -> rand_tangent(rng, x), A) -end - -FiniteDifferences.rand_tangent(::AbstractRNG, ::Base.OneTo) = ZeroTangent() - -# Hacks to make rand_tangent play nicely with Zygote. -rand_zygote_tangent(A) = Zygote.wrap_chainrules_output(FiniteDifferences.rand_tangent(A)) - -Zygote.wrap_chainrules_output(x::Array) = map(Zygote.wrap_chainrules_output, x) - -function Zygote.wrap_chainrules_input(x::Array) - return map(Zygote.wrap_chainrules_input, x) -end - -function LinearAlgebra.dot(A::Tangent, B::Tangent) - mutual_names = intersect(propertynames(A), propertynames(B)) - if length(mutual_names) == 0 - return 0 - else - return sum(n -> dot(getproperty(A, n), getproperty(B, n)), mutual_names) - end -end - -function ChainRulesTestUtils.test_approx(actual::Tangent{T}, expected::StructArray, msg=""; kwargs...) where {T<:StructArray} - return test_approx(actual.components, expected; kwargs...) -end \ No newline at end of file diff --git a/test/util/chainrules.jl b/test/util/chainrules.jl deleted file mode 100644 index b68ab30..0000000 --- a/test/util/chainrules.jl +++ /dev/null @@ -1,117 +0,0 @@ -using StaticArrays -using BenchmarkTools -using BlockDiagonals -using ChainRulesCore -using ChainRulesTestUtils -using Test -using TemporalGPs -using TemporalGPs: time_exp, _map, Gaussian -using FillArrays -using StructArrays -using Zygote: ZygoteRuleConfig - -@testset "chainrules" begin - @testset "StaticArrays" begin - @testset "SArray constructor" begin - for (f, x) in ( - (SArray{Tuple{3, 2, 1}}, ntuple(i -> 2.5i, 6)), - (SVector{5}, (ntuple(i -> 2.5i, 5))), - (SVector{2}, (2.0, 1.0)), - (SMatrix{5, 4}, (ntuple(i -> 2.5i, 20))), - (SMatrix{1, 1}, (randn(),)) - ) - test_rrule(ZygoteRuleConfig(), f, x; rrule_f=rrule_via_ad) - end - end - @testset "collect(::SArray)" begin - A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6)) - test_rrule(collect, A) - end - @testset "vcat(::SVector, ::SVector)" begin - a = SVector{3}(randn(3)) - b = SVector{2}(randn(2)) - test_rrule(vcat, a, b) - end - end - @testset "time_exp" begin - A = randn(3, 3) - test_rrule(time_exp, A ⊢ NoTangent(), 0.1) - end - @testset "Fill" begin - @testset "Fill constructor" begin - for x in ( - randn(), - randn(1, 2), - SMatrix{1, 2}(randn(1, 2)), - ) - test_rrule(Fill, x, 3; check_inferred=false) - test_rrule(Fill, x, (3, 4); check_inferred=false) - end - end - @testset "collect(::Fill)" begin - P = 11 - Q = 3 - @testset "$(typeof(x)) element" for x in [ - randn(), - randn(1, 2), - SMatrix{1, 2}(randn(1, 2)), - ] - test_rrule(collect, Fill(x, P)) - # The test rule does not work due to inconsistencies of FiniteDifferencies for FillArrays - test_rrule(collect, Fill(x, P, Q)) - end - end - end - - # The rrule is not even used... - @testset "getindex(::Fill, ::Int)" begin - X = Fill(randn(5, 3), 10) - test_rrule(getindex, X, 3; check_inferred=false) - end - @testset "BlockDiagonal" begin - X = map(N -> randn(N, N), [3, 4, 1]) - test_rrule(BlockDiagonal, X) - end - @testset "_map(f, x::Fill)" begin - x = Fill(randn(3, 4), 4) - test_rrule(_map, sum, x; check_inferred=false) - test_rrule(_map, x->map(sin, x), x; check_inferred=false) - test_rrule(_map, x -> 2.0 * x, x; check_inferred=false) - test_rrule(ZygoteRuleConfig(), (x,a)-> _map(x -> x * a, x), x, 2.0; check_inferred=false, rrule_f=rrule_via_ad) - end - @testset "_map(f, x::Fill....)" begin - x1 = Fill(randn(3, 4), 3) - x2 = Fill(randn(3, 4), 3) - x3 = Fill(randn(3, 4), 3) - - @test _map(+, x1, x2) == _map(+, collect(x1), collect(x2)) - test_rrule(_map, +, x1, x2; check_inferred=true) - - @test _map(+, x1, x2, x3) == _map(+, collect(x1), collect(x2), collect(x3)) - test_rrule(_map, +, x1, x2, x3; check_inferred=true) - - fsin(x, y) = sin.(x .* y) - test_rrule(_map, fsin, x1, x2; check_inferred=false) - - foo(a, x1, x2) = _map((z1, z2) -> a * sin.(z1 .* z2), x1, x2) - test_rrule(ZygoteRuleConfig(), foo, randn(), x1, x2; check_inferred=false, rrule_f=rrule_via_ad) - end - @testset "StructArray" begin - a = randn(5) - b = rand(5) - # This test is broken due to FiniteDifferences returning the wrong Tangent. - @test_broken 1 == 0 - # test_rrule(StructArray, (a, b); check_inferred=false) - - xs = [Gaussian(randn(1), randn(1, 1)) for _ in 1:2] - ms = getfield.(xs, :m) - Ps = getfield.(xs, :P) - # Same here. - @test_broken 1 == 0 - # test_rrule(StructArray{eltype(xs)}, (ms, Ps)) - xs_sa = StructArray{eltype(xs)}((ms, Ps)) - # And here. - @test_broken 1 == 0 - # test_zygote_grad(getproperty, xs_sa, :m) - end -end diff --git a/test/util/harmonise.jl b/test/util/harmonise.jl deleted file mode 100644 index 50ae773..0000000 --- a/test/util/harmonise.jl +++ /dev/null @@ -1,57 +0,0 @@ -using TemporalGPs: are_harmonised - -function test_harmonise(a, b; recurse=true) - h = harmonise(a, b) - @test h isa Tuple - @test length(h) == 2 - @test are_harmonised(h[1], h[2]) - - recurse && test_harmonise(b, a; recurse=false) - h′ = harmonise(b, a) - @test h isa Tuple - @test length(h) == 2 - @test are_harmonised(h′[1], h′[2]) - @test are_harmonised(h[1], h′[1]) - @test are_harmonised(h[1], h′[2]) -end - -@testset "harmonise" begin - test_harmonise(5.0, 4.0) - - @testset "AbstractZero" begin - test_harmonise(5.0, ZeroTangent()) - test_harmonise(ZeroTangent(), randn(10)) - test_harmonise(ZeroTangent(), ZeroTangent()) - end - - @testset "Array" begin - test_harmonise(randn(5), randn(5)) - test_harmonise( - [(randn(), randn()) for _ in 1:10], - [Tangent{Any}(randn(), rand()) for _ in 1:10], - ) - end - - @testset "Tuple / Tangent{Tuple}" begin - test_harmonise((5, 4), (5, 4)) - test_harmonise(Tangent{Tuple}(5, 4), (5, 4)) - test_harmonise(Tangent{Tuple}(5, 4), Tangent{Tuple}(5, 4)) - - test_harmonise((5, Tangent{Tuple}(randn(5))), (5, (randn(5), ))) - test_harmonise( - Tangent{Any}(Tangent{Any}(randn(5))), - (Tangent{Any}(randn(5)), ), - ) - end - - @testset "NamedTuple / Tangent{NamedTuple}" begin - test_harmonise(Tangent{Any}(; m=4, P=5), Tangent{Gaussian}(; m=5, P=4)) - test_harmonise(Tangent{Any}(; m=4, P=5), Tangent{Any}(; m=4)) - test_harmonise(Tangent{Any}(; m=5), Tangent{Any}(; P=4)) - - test_harmonise(Tangent{Any}(; m=(5, 4)), Tangent{Any}(; P=4)) - - test_harmonise(Tangent{Any}(; m=5, P=4), Gaussian(5, 4)) - test_harmonise(Tangent{Any}(; P=4), Gaussian(4, 5)) - end -end diff --git a/test/util/regular_data.jl b/test/util/regular_data.jl index 9cc2c04..f56bb0e 100644 --- a/test/util/regular_data.jl +++ b/test/util/regular_data.jl @@ -1,13 +1,3 @@ -using FiniteDifferences -using Zygote - -function FiniteDifferences.to_vec(x::RegularSpacing) - function from_vec_RegularSpacing(x_vec) - return RegularSpacing(x_vec[1], x_vec[2], x.N) - end - return [x.t0, x.Δt], from_vec_RegularSpacing -end - @testset "regular_data" begin t0 = randn() Δt = randn() @@ -20,14 +10,4 @@ end @test collect(x) ≈ collect(x_range) @test step(x) == step(x_range) @test length(x) == length(x_range) - - let - x, back = Zygote.pullback(RegularSpacing, t0, Δt, N) - - Δ_t0 = randn() - Δ_Δt = randn() - @test back((t0 = Δ_t0, Δt = Δ_Δt, N=nothing)) == (Δ_t0, Δ_Δt, nothing) - - test_rrule(RegularSpacing, randn(), rand(), 10; output_tangent=Tangent{RegularSpacing}(Δt=0.1, t0=0.2)) - end end diff --git a/test/util/scan.jl b/test/util/scan.jl index 7f1d47b..8021682 100644 --- a/test/util/scan.jl +++ b/test/util/scan.jl @@ -1,16 +1,9 @@ using Test -using Zygote: ZygoteRuleConfig using TemporalGPs: scan_emit using StructArrays -using ChainRulesTestUtils @testset "scan" begin - # Run forwards. x = StructArray([(a=randn(), b=randn()) for _ in 1:10]) stepper = (x_, y_) -> (x_ + y_.a * y_.b * x_, x_ + y_.b) - # test_rrule(scan_emit, stepper, x, 0.0, eachindex(x)) - - # Run in reverse. - # test_rrule(scan_emit, stepper, x, 0.0, reverse(eachindex(x))) end diff --git a/test/util/zygote_friendly_map.jl b/test/util/zygote_friendly_map.jl deleted file mode 100644 index e81c21b..0000000 --- a/test/util/zygote_friendly_map.jl +++ /dev/null @@ -1,18 +0,0 @@ -using FillArrays -using TemporalGPs - -@testset "zygote_friendly_map" begin - @testset "$name" for (name, f, x) in [ - ("Vector{Float64}", x -> sin(x) + cos(x) * exp(x), randn(100)), - ("Fill{Float64}", x -> sin(x) + exp(x) + 5, Fill(randn(), 100)), - ("Vector{Vector{Float64}}", sum, [randn(25) for _ in 1:33]), - ( - "zip(Vector{Float64}, Fill{Float64})", - x -> x[1] * x[2], - zip(randn(5), Fill(1.0, 5)), - ), - ] - @test TemporalGPs.zygote_friendly_map(f, x) ≈ map(f, x) - # adjoint_test(x -> TemporalGPs.zygote_friendly_map(f, x), (x, )) - end -end