Skip to content

Commit

Permalink
Make forwards-pass type stable for Float32 (#83)
Browse files Browse the repository at this point in the history
* Make forwards-pass type stable for Float32

* Remove new space
  • Loading branch information
willtebbutt committed Feb 25, 2020
1 parent 757b938 commit 7b6159f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ BlockArrays = "0.9, 0.10"
Distances = ">= 0.8"
Distributions = ">= 0.19"
FillArrays = "0.8"
FiniteDifferences = ">= 0.8"
MacroTools = ">= 0.4"
RecipesBase = "0.7.0"
Zygote = ">= 0.4.1"
Expand Down
16 changes: 7 additions & 9 deletions src/abstract_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ true
"""
function rand(rng::AbstractRNG, f::FiniteGP, N::Int)
μ, C = mean(f), cholesky(Symmetric(cov(f)))
return μ .+ C.U' * randn(rng, length(μ), N)
return μ .+ C.U' * randn(rng, promote_type(eltype(μ), eltype(C)), length(μ), N)
end
rand(f::FiniteGP, N::Int) = rand(Random.GLOBAL_RNG, f, N)
rand(rng::AbstractRNG, f::FiniteGP) = vec(rand(rng, f, 1))
Expand Down Expand Up @@ -192,14 +192,12 @@ julia> logpdf(f(x), Y) isa AbstractVector{<:Real}
true
```
"""
function logpdf(f::FiniteGP, y::AbstractVector{<:Real})
μ, C = mean(f), cholesky(Symmetric(cov(f)))
return -(length(y) * log(2π) + logdet(C) + Xt_invA_X(C, y - μ)) / 2
end
logpdf(f::FiniteGP, y::AbstractVector{<:Real}) = first(logpdf(f, reshape(y, :, 1)))

function logpdf(f::FiniteGP, Y::AbstractMatrix{<:Real})
μ, C = mean(f), cholesky(Symmetric(cov(f)))
return -((size(Y, 1) * log(2π) + logdet(C)) .+ diag_Xt_invA_X(C, Y .- μ)) ./ 2
T = promote_type(eltype(μ), eltype(C), eltype(Y))
return -((size(Y, 1) * T(log(2π)) + logdet(C)) .+ diag_Xt_invA_X(C, Y .- μ)) ./ 2
end

"""
Expand Down Expand Up @@ -229,9 +227,9 @@ function elbo(f::FiniteGP, y::AV{<:Real}, u::FiniteGP)
Λ_ε = cholesky(Symmetric(A * A' + I))
δ = chol_Σy.U' \ (y - mean(f))

return -(length(y) * log(2π) + logdet(chol_Σy) + logdet(Λ_ε) +
sum(abs2, δ) - sum(abs2, Λ_ε.U' \ (A * δ)) +
tr_Cf_invΣy(f, f.Σy, chol_Σy) - sum(abs2, A)) / 2
tmp = logdet(chol_Σy) + logdet(Λ_ε) + sum(abs2, δ) - sum(abs2, Λ_ε.U' \ (A * δ)) +
tr_Cf_invΣy(f, f.Σy, chol_Σy) - sum(abs2, A)
return -(length(y) * typeof(tmp)(log(2π)) + tmp) / 2
end

function consistency_check(f, y, u)
Expand Down
14 changes: 14 additions & 0 deletions test/abstract_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ end
atol=1e-6, rtol=1e-6,
)
end
@testset "Type Stability - $T" for T in [Float64, Float32]
rng = MersenneTwister(123456)
x = randn(rng, T, 123)
z = randn(rng, T, 13)
f = GP(T(0), EQ(), GPC())

fx = f(x, T(0.1))
u = f(z, T(1e-4))

y = rand(rng, fx)
@test y isa Vector{T}
@test logpdf(fx, y) isa T
@test elbo(fx, y, u) isa T
end
end

# """
Expand Down

0 comments on commit 7b6159f

Please sign in to comment.