Skip to content

Commit

Permalink
Fix input types, improve readability (#369)
Browse files Browse the repository at this point in the history
* Fix input types, improve readability

* Add missing bit

* Add doc string

* Fix mistake

* Add docstring to docs

* Reformulate

* Bump version

* Update src/matrix/kernelkroneckermat.jl

Co-authored-by: Théo Galy-Fajou <theo.galyfajou@gmail.com>

* Bump version further, rename api section

* Apply format suggestions from code review

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>

* Improve error handling.

Co-authored-by: st-- <st--@users.noreply.github.com>

* Formatter

Co-authored-by: Théo Galy-Fajou <theo.galyfajou@gmail.com>
Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
Co-authored-by: st-- <st--@users.noreply.github.com>
  • Loading branch information
4 people committed Sep 29, 2021
1 parent 3356fa6 commit 6e7ca17
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.19"
version = "0.10.20"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Expand Up @@ -88,3 +88,12 @@ kernelpdmat
nystrom
NystromFact
```

## Conditional Utilities
To keep the dependencies of KernelFunctions lean, some functionality is only available if specific other packages are explicitly loaded (`using`).

### Kronecker.jl
[*https://github.com/MichielStock/Kronecker.jl*](https://github.com/MichielStock/Kronecker.jl)
```@docs
kronecker_kernelmatrix
```
34 changes: 23 additions & 11 deletions src/matrix/kernelkroneckermat.jl
Expand Up @@ -27,31 +27,43 @@ end
"""
@inline iskroncompatible::Kernel) = false # Default return for kernels

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
function _kernelmatrix_kroneckerjl_helper(
::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
)
return Kronecker.kronecker(Kfeatures, Koutputs)
end

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
function _kernelmatrix_kroneckerjl_helper(
::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
)
return Kronecker.kronecker(Koutputs, Kfeatures)
end

"""
kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
Requires Kronecker.jl: Computes the `kernelmatrix` for the `IndependentMOKernel` and the
`IntrinsicCoregionMOKernel`, but returns a lazy kronecker product. This object can be very
efficiently inverted or decomposed. See also [`kernelmatrix`](@ref).
"""
function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel},
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kroneckerjl_helper(MOI, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion
)
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI
) where {MOI<:IsotopicMOInputsUnion}
Kfeatures = kernelmatrix(k.kernel, x.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kroneckerjl_helper(MOI, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
Expand Down
21 changes: 10 additions & 11 deletions src/mokernels/independent.jl
Expand Up @@ -30,25 +30,24 @@ end
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)

function kernelmatrix(
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix,
k::IndependentMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end
end

Expand Down
21 changes: 10 additions & 11 deletions src/mokernels/intrinsiccoregion.jl
Expand Up @@ -48,25 +48,24 @@ function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
end

function kernelmatrix(
k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix,
k::IntrinsicCoregionMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end
end

Expand Down
12 changes: 8 additions & 4 deletions src/mokernels/mokernel.jl
Expand Up @@ -5,20 +5,24 @@ Abstract type for kernels with multiple outpus.
"""
abstract type MOKernel <: Kernel end

function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs)
return kron(Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs)
return kron(Koutputs, Kfeatures)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
)
return kron!(K, Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
)
return kron!(K, Koutputs, Kfeatures)
end
end

2 comments on commit 6e7ca17

@Crown421
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/45780

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.20 -m "<description of version>" 6e7ca17987d3e0a8d7f1724b8639c438befcb2d0
git push origin v0.10.20

Please sign in to comment.