diff --git a/Project.toml b/Project.toml index 14bd054..3b37d7c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearMixingModels" uuid = "b8ce4b42-e81b-4a39-a84a-67f74a9a16dd" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.4" +version = "0.1.5" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/src/ilmm.jl b/src/ilmm.jl index 2bb10d6..6e3c309 100644 --- a/src/ilmm.jl +++ b/src/ilmm.jl @@ -105,8 +105,7 @@ function Distributions._rand!( end end -# See AbstractGPs.jl API docs. -function AbstractGPs.mean_and_var(fx::FiniteGP{<:ILMM}) +function _intermediate_mean_and_var_and_cov_quantities(fx::FiniteGP{<:ILMM}) f, H, σ², x = unpack(fx) p, m = size(H) n = length(x) @@ -116,6 +115,12 @@ function AbstractGPs.mean_and_var(fx::FiniteGP{<:ILMM}) latent_mean, latent_cov = mean_and_cov(f(x_mo_input)) H_full = kron(H, Matrix(I, n, n)) + return H_full, latent_mean, latent_cov, σ² +end + +# See AbstractGPs.jl API docs. +function AbstractGPs.mean_and_var(fx::FiniteGP{<:ILMM}) + H_full, latent_mean, latent_cov, σ² = _intermediate_mean_and_var_and_cov_quantities(fx) M = H_full * latent_mean V = AbstractGPs.diag_Xt_A_X(cholesky(latent_cov), H_full') .+ σ² @@ -123,12 +128,24 @@ function AbstractGPs.mean_and_var(fx::FiniteGP{<:ILMM}) return collect(vec(M)), V end +# See AbstractGPs.jl API docs. +function AbstractGPs.mean_and_cov(fx::FiniteGP{<:ILMM}) + H_full, latent_mean, latent_cov, σ² = _intermediate_mean_and_var_and_cov_quantities(fx) + + M = H_full * latent_mean + C = AbstractGPs.Xt_A_X(cholesky(latent_cov), H_full') + σ² * I + + return collect(vec(M)), C +end + # See AbstractGPs.jl API docs. AbstractGPs.mean(fx::FiniteGP{<:ILMM}) = mean_and_var(fx)[1] # See AbstractGPs.jl API docs. AbstractGPs.var(fx::FiniteGP{<:ILMM}) = mean_and_var(fx)[2] +AbstractGPs.cov(fx::FiniteGP{<:ILMM}) = mean_and_cov(fx)[2] + # See AbstractGPs.jl API docs. function AbstractGPs.logpdf(fx::FiniteGP{<:ILMM}, y::AbstractVector{<:Real}) f, H, σ², x = unpack(fx) diff --git a/test/ilmm.jl b/test/ilmm.jl index 9e9520c..4ee5cc0 100644 --- a/test/ilmm.jl +++ b/test/ilmm.jl @@ -9,6 +9,7 @@ function test_ilmm(rng, kernels, H, x_train, x_test, y_train, y_test) @test isapprox(mean(ilmmx), mean(n_ilmmx)) @test isapprox(var(ilmmx), var(n_ilmmx)) + @test isapprox(cov(ilmmx), cov(n_ilmmx)) @test isapprox(logpdf(ilmmx, y_train), logpdf(n_ilmmx, y_train)) @test _is_approx(marginals(ilmmx), marginals(n_ilmmx)) @test length(rand(rng, ilmmx)) == size(H, 1) * length(x_train.x) @@ -31,8 +32,8 @@ function test_ilmm(rng, kernels, H, x_train, x_test, y_train, y_test) @test gradient(logpdf, pi, y_test) isa Tuple @testset "primary_public_interface" begin - test_finitegp_primary_public_interface(rng, ilmmx) - test_finitegp_primary_public_interface(rng, pi) + test_finitegp_primary_and_secondary_public_interface(rng, ilmmx) + test_finitegp_primary_and_secondary_public_interface(rng, pi) end end diff --git a/test/oilmm.jl b/test/oilmm.jl index 5ea03d1..7ee2572 100644 --- a/test/oilmm.jl +++ b/test/oilmm.jl @@ -9,6 +9,7 @@ function test_oilmm(rng, kernels, H::Orthogonal, x_train, x_test, y_train, y_tes @test isapprox(mean(ilmmx), mean(oilmmx)) @test isapprox(var(ilmmx), var(oilmmx)) + @test isapprox(cov(ilmmx), cov(oilmmx)) @test isapprox(logpdf(ilmmx, y_train), logpdf(oilmmx, y_train)) @test _is_approx(marginals(ilmmx), marginals(oilmmx)) @test length(rand(rng, oilmmx)) == size(H, 1) * length(x_train.x) @@ -31,8 +32,8 @@ function test_oilmm(rng, kernels, H::Orthogonal, x_train, x_test, y_train, y_tes @test gradient(logpdf, po, y_test) isa Tuple @testset "primary_public_interface" begin - test_finitegp_primary_public_interface(rng, oilmmx) - test_finitegp_primary_public_interface(rng, po) + test_finitegp_primary_and_secondary_public_interface(rng, oilmmx) + test_finitegp_primary_and_secondary_public_interface(rng, po) end end