Skip to content

Commit

Permalink
Vector Inputs (#29)
Browse files Browse the repository at this point in the history
* Test ColVecs

* Test RowVecs

* Bump patch

* Update docs

* Remove redundant code

* Apply formatting suggestions

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Simplify RowVecs implementation

* Update src/bayesian_linear_regression.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Add error test

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
willtebbutt and github-actions[bot] committed Jan 10, 2022
1 parent 0dd9735 commit eabf965
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 185 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BayesianLinearRegressors"
uuid = "f579363c-4606-5e5c-a623-c4549f609c4b"
authors = ["Will Tebbutt <wt0881@my.bristol.ac.uk>"]
version = "0.3.4"
version = "0.3.5"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ The interface sits at roughly the same level as that of [Distributions.jl](https

## Conventions

A `BayesianLinearRegressor` in `D` dimensions works with data where:
- inputs `X` should be a `D x N` matrix of `Real`s where each column is from one data point.
- outputs `y` should be an `N`-vector of `Real`s, where each element is from one data point.
`BayesianLinearRegressors` is consistent with `AbstractGPs`.
Consequently, a `BayesianLinearRegressor` in `D` dimensions can work with the following input types:
1. `ColVecs` -- a wrapper around an `D x N` matrix of `Real`s saying that each column should be interpreted as an input.
2. `RowVecs`s -- a wrapper around an `N x D` matrix of `Real`s, saying that each row should be interpreted as an input.
3. `Matrix{<:Real}` -- must be `D x N`. Prefer using `ColVecs` or `RowVecs` for the sake of being explicit.

Consult the `Design` section of the [KernelFunctions.jl](https://juliagaussianprocesses.github.io/KernelFunctions.jl/dev/design/) docs for more info on these conventions.

Outputs for a BayesianLinearRegressor should be an `AbstractVector{<:Real}` of length `N`.

## Example Usage

Expand All @@ -38,7 +44,7 @@ f = BayesianLinearRegressor(mw, Λw)

# Index into the regressor and assume heterscedastic observation noise `Σ_noise`.
N = 10
X = collect(hcat(collect(range(-5.0, 5.0, length=N)), ones(N))')
X = ColVecs(collect(hcat(collect(range(-5.0, 5.0, length=N)), ones(N))'))
Σ_noise = Diagonal(exp.(randn(N)))
fX = f(X, Σ_noise)

Expand Down Expand Up @@ -70,7 +76,7 @@ logpdf(f′(X, Σ_noise), y)

# Sample from the posterior predictive distribution.
N_plt = 1000
X_plt = hcat(collect(range(-6.0, 6.0, length=N_plt)), ones(N_plt))'
X_plt = ColVecs(hcat(collect(range(-6.0, 6.0, length=N_plt)), ones(N_plt))')
f′X_plt = rand(rng, f′(X_plt, eps()), 100) # Samples with machine-epsilon noise for stability

# Compute some posterior marginal statisics.
Expand Down
28 changes: 21 additions & 7 deletions src/bayesian_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,28 @@ const FiniteBLR = FiniteGP{<:BayesianLinearRegressor}

# All code below implements the primary + secondary AbstractGPs.jl APIs.

AbstractGPs.mean(fx::FiniteBLR) = fx.x.X' * fx.f.mw
x_as_colvecs(fx::FiniteBLR) = x_as_colvecs(fx.x)

x_as_colvecs(x::ColVecs) = x

x_as_colvecs(x::RowVecs) = ColVecs(x.X')

function x_as_colvecs(x::T) where {T<:AbstractVector}
return error(
"$T is not a subtype of AbstractVector that is known. Please provide either a",
"ColVecs or RowVecs.",
)
end

AbstractGPs.mean(fx::FiniteBLR) = x_as_colvecs(fx).X' * fx.f.mw

function AbstractGPs.cov(fx::FiniteBLR)
α = _cholesky(fx.f.Λw).U' \ fx.x.X
α = _cholesky(fx.f.Λw).U' \ x_as_colvecs(fx).X
return Symmetric' * α + fx.Σy)
end

function AbstractGPs.var(fx::FiniteBLR)
α = _cholesky(fx.f.Λw).U' \ fx.x.X
α = _cholesky(fx.f.Λw).U' \ x_as_colvecs(fx).X
return vec(sum(abs2, α; dims=1)) .+ diag(fx.Σy)
end

Expand All @@ -34,8 +47,9 @@ AbstractGPs.mean_and_cov(fx::FiniteBLR) = (mean(fx), cov(fx))
AbstractGPs.mean_and_var(fx::FiniteBLR) = (mean(fx), var(fx))

function AbstractGPs.rand(rng::AbstractRNG, fx::FiniteBLR, samples::Int)
w = fx.f.mw .+ _cholesky(fx.f.Λw).U \ randn(rng, size(fx.x.X, 1), samples)
return fx.x.X' * w .+ _cholesky(fx.Σy).U' * randn(rng, size(fx.x.X, 2), samples)
X = x_as_colvecs(fx).X
w = fx.f.mw .+ _cholesky(fx.f.Λw).U \ randn(rng, size(X, 1), samples)
return X' * w .+ _cholesky(fx.Σy).U' * randn(rng, size(X, 2), samples)
end

function AbstractGPs.logpdf(fx::FiniteBLR, y::AbstractVector{<:Real})
Expand All @@ -56,9 +70,9 @@ end

# Computation utilised in both `logpdf` and `posterior`.
function __compute_inference_quantities(fx::FiniteBLR, y::AbstractVector{<:Real})
length(y) == size(fx.x.X, 2) || throw(error("length(y) != size(fx.x.X, 2)"))
X = x_as_colvecs(fx).X
length(y) == size(X, 2) || throw(error("length(y) != size(fx.x.X, 2)"))
blr = fx.f
X = fx.x.X
N = length(y)

Uw = _cholesky(blr.Λw).U
Expand Down
Loading

2 comments on commit eabf965

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/52012

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.5 -m "<description of version>" eabf9657aac4a594e0b4b5eaf83f82078606b0a6
git push origin v0.3.5

Please sign in to comment.