Skip to content

Commit

Permalink
simplify predictMVN
Browse files Browse the repository at this point in the history
  • Loading branch information
maximerischard committed Sep 18, 2019
1 parent 5d62227 commit 35d1acb
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/GP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ end
""" Compute predictions using the standard multivariate normal
conditional distribution formulae.
"""
function predictMVN(gp::GPBase,xpred::AbstractMatrix, xtrain::AbstractMatrix, ytrain::AbstractVector,
function predictMVN(xpred::AbstractMatrix, xtrain::AbstractMatrix, ytrain::AbstractVector,
kernel::Kernel, meanf::Mean, alpha::AbstractVector,
covstrat::CovarianceStrategy, Ktrain::AbstractPDMat)
crossdata = KernelData(kernel, xtrain, xpred)
priordata = KernelData(kernel, xpred, xpred)
Kcross = cov(kernel, xtrain, xpred, crossdata)
Kpred = cov(kernel, xpred, xpred, priordata)
mx = mean(meanf, xpred)
mu, Sigma_raw = predictMVN!(gp,Kpred, Ktrain, Kcross, mx, alpha)
mu, Sigma_raw = predictMVN!(Kpred, Ktrain, Kcross, mx, alpha)
return mu, Sigma_raw
end

Expand Down
9 changes: 1 addition & 8 deletions src/GPE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ end
# Predict observations #
#——————————————————————#

predict_full(gp::GPE, xpred::AbstractMatrix) = predictMVN(gp,xpred, gp.x, gp.y, gp.kernel, gp.mean, gp.alpha, gp.covstrat, gp.cK)
predict_full(gp::GPE, xpred::AbstractMatrix) = predictMVN(xpred, gp.x, gp.y, gp.kernel, gp.mean, gp.alpha, gp.covstrat, gp.cK)
"""
predict_full(gp::GPE, x::Union{Vector{Float64},Matrix{Float64}}[; full_cov::Bool=false])
Expand All @@ -383,13 +383,6 @@ are given as columns of matrix `x`. If `full_cov` is `true`, the full covariance
returned instead of only variances.
"""

function predictMVN!(gp::GPE,Kxx, Kff, Kfx, mx, αf)
mu = mx + Kfx' * αf
Lck = whiten!(Kff, Kfx)
subtract_Lck!(Kxx, Lck)
return mu, Kxx
end

function predict_y(gp::GPE, x::AbstractMatrix; full_cov::Bool=false)
μ, σ2 = predict_f(gp, x; full_cov=full_cov)
if full_cov
Expand Down
11 changes: 1 addition & 10 deletions src/GPMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ function update_target_and_dtarget!(gp::GPMC; kwargs...)
update_target_and_dtarget!(gp, precomp; kwargs...)
end

predict_full(gp::GPMC, xpred::AbstractMatrix) = predictMVN(gp,xpred, gp.x, gp.y, gp.kernel, gp.mean, gp.v, gp.covstrat, gp.cK)
predict_full(gp::GPMC, xpred::AbstractMatrix) = predictMVN(xpred, gp.x, gp.y, gp.kernel, gp.mean, whiten(gp.cK,gp.v), gp.covstrat, gp.cK)

This comment has been minimized.

Copy link
@maximerischard

maximerischard Sep 18, 2019

Author Contributor

@chris-nemeth does this linear algebra make sense to you? It worries me that it was wrong (before you fixed it) but the tests were passing. Can you think of a test that would have caught my earlier mistake?

This comment has been minimized.

Copy link
@chris-nemeth

chris-nemeth Sep 19, 2019

Member

I'll take a closer look later (I'm at a workshop today) and check that the linear algebra makes sense. My initial thought is that this proposal should work (and it's tidy) as the only reason why we needed separate predict functions was because gp.v is multiplied either by the kernel matrix or its square root depending on whether you use GPE or GPMC. Something that we'll need to keep mind are the VI additions that @thomaspinder is working on. The plan here was to change GPMC to GPA (A for approximate).

This comment has been minimized.

Copy link
@chris-nemeth

chris-nemeth Sep 20, 2019

Member

I've checked that this fix works for GPMC (plus some other package tests) and I think we can merge this branch with the master. @thomaspinder will have to check how this function works with his VI implementation.

"""
predict_y(gp::GPMC, x::Union{Vector{Float64},Matrix{Float64}}[; full_cov::Bool=false])
Expand All @@ -312,15 +312,6 @@ are given as columns of matrix `x`. If `full_cov` is `true`, the full covariance
returned instead of only variances.
"""


function predictMVN!(gp::GPMC,Kxx, Kff, Kfx, mx, αf)
Lck = whiten!(Kff, Kfx)
mu = mx + Lck' * αf
subtract_Lck!(Kxx, Lck)
return mu, Kxx
end


function predict_y(gp::GPMC, x::AbstractMatrix; full_cov::Bool=false)
μ, σ2 = predict_f(gp, x; full_cov=full_cov)
return predict_obs(gp.lik, μ, σ2)
Expand Down
4 changes: 2 additions & 2 deletions src/sparse/determ_train_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ end
where μ_DTC and Σ_DTC are the predictive mean and covariance
functions for the Subset of Regressors approximation.
"""
function predictMVN(gp::GPE, xpred::AbstractMatrix,
function predictMVN(xpred::AbstractMatrix,
xtrain::AbstractMatrix, ytrain::AbstractVector,
kernel::Kernel, meanf::Mean,
alpha::AbstractVector,
covstrat::DeterminTrainCondStrat, Ktrain::AbstractPDMat)
SoR = SubsetOfRegsStrategy(covstrat)
μ_SoR, Σ_SoR = predictMVN(gp, xpred, xtrain, ytrain, kernel, meanf, alpha, SoR, Ktrain)
μ_SoR, Σ_SoR = predictMVN(xpred, xtrain, ytrain, kernel, meanf, alpha, SoR, Ktrain)
inducing = covstrat.inducing
Kuu = Ktrain.Kuu

Expand Down
2 changes: 1 addition & 1 deletion src/sparse/full_scale_approximation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ predictMVN(xpred::AbstractMatrix, xtrain::AbstractMatrix, ytrain::AbstractVector
= Σxx - Kxu Kuu⁻¹ Kux + Kxu ΣQR⁻¹ Kux # expanding
= Σxx - Qxx + Kxu ΣQR⁻¹ Kux # definition of Qxx
"""
function predictMVN(gp::GPE, xpred::AbstractMatrix, blockindpred::BlockIndices,
function predictMVN(xpred::AbstractMatrix, blockindpred::BlockIndices,
xtrain::AbstractMatrix, ytrain::AbstractVector,
kernel::Kernel, meanf::Mean,
alpha::AbstractVector,
Expand Down
4 changes: 2 additions & 2 deletions src/sparse/fully_indep_train_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,13 @@ predictMVN(xpred::AbstractMatrix, xtrain::AbstractMatrix, ytrain::AbstractVector
= Σxx - Kxu Kuu⁻¹ Kux + Kxu ΣQR⁻¹ Kux # expanding
= Σxx - Qxx + Kxu ΣQR⁻¹ Kux # definition of Qxx
"""
function predictMVN(gp::GPE, xpred::AbstractMatrix,
function predictMVN(xpred::AbstractMatrix,
xtrain::AbstractMatrix, ytrain::AbstractVector,
kernel::Kernel, meanf::Mean,
alpha::AbstractVector,
covstrat::FullyIndepStrat, Ktrain::FullyIndepPDMat)
DTC = DeterminTrainCondStrat(covstrat)
μ_DTC, Σ_DTC = predictMVN(gp, xpred, xtrain, ytrain, kernel, meanf, alpha, DTC, Ktrain)
μ_DTC, Σ_DTC = predictMVN(xpred, xtrain, ytrain, kernel, meanf, alpha, DTC, Ktrain)
return μ_DTC, Σ_DTC
end

Expand Down
2 changes: 1 addition & 1 deletion src/sparse/subsetofregressors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ end
= Qxx - Qxx + Kxu ΣQR⁻¹ Kux # definition of Qxx
= Kxu ΣQR⁻¹ Kux # simplifying
"""
function predictMVN(gp::GPE, xpred::AbstractMatrix,
function predictMVN(xpred::AbstractMatrix,
xtrain::AbstractMatrix, ytrain::AbstractVector,
kernel::Kernel, meanf::Mean,
alpha::AbstractVector,
Expand Down

0 comments on commit 35d1acb

Please sign in to comment.