Skip to content

Commit

Permalink
Merge f37b566 into cd26d09
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Oct 19, 2021
2 parents cd26d09 + f37b566 commit 2154dc7
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 75 deletions.
2 changes: 1 addition & 1 deletion docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ GammaRationalKernel
### Spectral Mixture Kernels

```@docs
spectral_mixture_kernel
SpectralMixtureKernel
spectral_mixture_product_kernel
```

Expand Down
5 changes: 4 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export SpectralMixtureKernel
export KernelSum, KernelProduct, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
export GibbsKernel
Expand All @@ -33,7 +34,7 @@ export with_lengthscale
export NystromFact, nystrom

export gaborkernel
export spectral_mixture_kernel, spectral_mixture_product_kernel
export spectral_mixture_product_kernel

export ColVecs, RowVecs

Expand Down Expand Up @@ -120,6 +121,8 @@ include("mokernels/lmm.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("deprecated.jl")

include("test_utils.jl")

function __init__()
Expand Down
175 changes: 121 additions & 54 deletions src/basekernels/sm.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
"""
spectral_mixture_kernel(
@doc raw"""
SpectralMixtureKernel(
h::Kernel=SqExponentialKernel(),
αs::AbstractVector{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractVector{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
SpectralMixtureKernel(
h::Kernel=SqExponentialKernel(),
α::AbstractVector{<:Real},
γ::AbstractVector{<:AbstractVecOrMat{<:Real}},
ω::AbstractVector{<:AbstractVecOrMat{<:Real}},
)
where αs are the weights of dimension (A, ), γs is the covariance matrix of
dimension (D, A) and ωs are the mean vectors and is of dimension (D, A).
Here, D is input dimension and A is the number of spectral components.
`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
Generalised Spectral Mixture kernel function as described in [1] in equation 6.
This family of functions is dense in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]
Generalised Spectral Mixture kernel function. This family of functions is dense
in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]
## Definition
For inputs ``x, x′ \in \mathbb{R}^D``, the spectral mixture kernel ``\tilde{k}`` with ``K`` mixture components, mixture weights ``\alpha \in \mathbb{R}^K``, linear transformations ``\gamma_1, \ldots, \gamma_K \in \mathbb{R}^D``, and frequencies ``\omega_1, \ldots, \omega_K \in \mathbb{R}^D`` derived from a translation-invariant kernel ``k`` is defined as
```math
κ(x, y) = αs' (h(-(γs' * t)^2) .* cos(π * ωs' * t), t = x - y
\tilde{k}(x, x'; \alpha, \gamma_1, \ldots, \gamma_K, \omega_1, \ldots, \omega_K, k) = \sum_{i=1}^K \alpha_i k(\gamma_i \odot x, \gamma_i \odot y) \cos(2\pi \omega_i^\top (x-y)).
```
## Arguments
- `h`: Stationary kernel (translation invariant), [`SqExponentialKernel`](@ref) by default
- `α`: Weight vector of each mixture component (should be positive)
- `γ`: Linear transformation of the input for `h`.
- `ω`: Frequencies for the cosine function. (should be positive)
`γ` and `ω` can be an
- `AbstractMatrix` of dimension `D x K` where `D` is the dimension of the inputs
and `K` is the number of components
- `AbstractVector` of `K` `D`-dimensional `AbstractVector`
# References:
[1] Generalized Spectral Kernels, by Yves-Laurent Kom Samo and Stephen J. Roberts
[2] SM: Gaussian Process Kernels for Pattern Discovery and Extrapolation,
Expand All @@ -29,77 +44,129 @@ in the family of stationary real-valued kernels with respect to the pointwise co
[4] http://www.cs.cmu.edu/~andrewgw/pattern/.
"""
function spectral_mixture_kernel(
struct SpectralMixtureKernel{
K<:Kernel,Tα<:AbstractVector,Tγ<:AbstractVector,Tω<:AbstractVector
} <: Kernel
kernel::K
α::Tα
γ::Tγ
ω::Tω
function SpectralMixtureKernel(
h::Kernel,
α::AbstractVector{<:Real},
γ::AbstractVector{<:AbstractVector},
ω::AbstractVector{<:AbstractVector},
)
(length(α) == length(γ) == length(ω)) ||
throw(DimensionMismatch("The dimensions of α, γ, ans ω do not match"))
any(<(0), α) && throw(ArgumentError("At least one element of α is negative"))
any(any.(<(0), ω)) && throw(ArgumentError("At least one element of ω is negative"))
return new{typeof(h),typeof(α),typeof(γ),typeof(ω)}(h, α, γ, ω)
end
end

@functor SpectralMixtureKernel

function SpectralMixtureKernel(
h::Kernel,
αs::AbstractVector{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractVector{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
if !(size(αs, 1) == size(γs, 2) == size(ωs, 2))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
end
if size(γs) != size(ωs)
throw(DimensionMismatch("The dimensions of γs ans ωs do not match"))
end
size(γ) == size(ω) || throw(DimensionMismatch("γ and ω have different dimensions"))
return SpectralMixtureKernel(h, α, ColVecs(γ), ColVecs(ω))
end

return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
a = TransformedKernel(h, LinearTransform'))
b = TransformedKernel(CosineKernel(), LinearTransform'))
return α * a * b
function SpectralMixtureKernel(
αs::AbstractVector{<:Real}, γs::AbstractVecOrMat, ωs::AbstractVecOrMat
)
return SpectralMixtureKernel(SqExponentialKernel(), αs, γs, ωs)
end

function::SpectralMixtureKernel)(x, y)
xy = x - y
# use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
broadcasted = Broadcast.broadcasted.α, κ.γ, κ.ω) do α, γ, ω
k = TransformedKernel.kernel, ARDTransform(γ))
return α * k(x, y) * cospi(2 * dot(ω, xy))
end
return sum(Broadcast.instantiate(broadcasted))
end

function spectral_mixture_kernel(
αs::AbstractVector{<:Real}, γs::AbstractMatrix{<:Real}, ωs::AbstractMatrix{<:Real}
)
return spectral_mixture_kernel(SqExponentialKernel(), αs, γs, ωs)
function Base.show(io::IO, κ::SpectralMixtureKernel)
return print(
io,
"SpectralMixtureKernel Kernel (kernel = ",
κ.kernel,
", # components = ",
length.α),
")",
)
end

"""
@doc raw"""
spectral_mixture_product_kernel(
h::Kernel=SqExponentialKernel(),
αs::AbstractMatrix{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractMatrix{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
where αs are the weights of dimension (D, A), γs is the covariance matrix of
dimension (D, A) and ωs are the mean vectors and is of dimension (D, A).
Here, D is input dimension and A is the number of spectral components.
Spectral Mixture Product Kernel. With enough components A, the SMP kernel
The spectral mixture product is tensor product of spectral mixture kernel applied
on each dimension as described in [1] in equations 13 and 14.
With enough components, the SMP kernel
can model any product kernel to arbitrary precision, and is flexible even
with a small number of components [1]
with a small number of components
## Definition
`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
For inputs ``x, x′ \in \mathbb{R}^D``, the spectral mixture product kernel ``\tilde{k}`` with ``K`` mixture components, mixture weights ``\alpha_1, \alpha_2, \ldots, \alpha_K \in \mathbb{R}^D``, linear transformations ``\gamma_1, \ldots, \gamma_K \in \mathbb{R}^D``, and frequencies ``\omega_1, \ldots, \omega_K \in \mathbb{R}^D`` derived from a translation-invariant kernel ``k`` is defined as
```math
κ(x, y) = Πᵢ₌₁ᴷ Σ(αsᵢᵀ .* (h(-(γsᵢᵀ * tᵢ)²) .* cos(ωsᵢᵀ * tᵢ))), tᵢ = xᵢ - yᵢ
\tilde{k}(x, x'; \alpha_1, \ldots, \alpha_k, \gamma_1, \ldots, \gamma_K, \omega_1, \ldots, \omega_K, k) = \prod_{i=1}^D \sum_{k=1}^K \alpha_{ik} \cdot h(\gamma_{ik} \cdot x_i, \gamma_{ik} \cdot y_i)) \cdot \cos(2\pi \cdot \omega_{ik} \cdot (x_i - y_i))))
```
## Arguments
- `h`: Stationary kernel (translation invariant), [`SqExponentialKernel`](@ref) by default
- `α`: Weight of each mixture component for each dimension
- `γ`: Linear transformation of the input for `h`.
- `ω`: Frequencies for the cosine function.
`α`, `γ` and `ω` can be an
- `AbstractMatrix` of dimension `D x K` where `D` is the dimension of the inputs
and `K` is the number of components
- `AbstractVector` of `D` `K`-dimensional `AbstractVector`
# References:
[1] GPatt: Fast Multidimensional Pattern Extrapolation with GPs,
arXiv 1310.5288, 2013, by Andrew Gordon Wilson, Elad Gilboa,
Arye Nehorai and John P. Cunningham
"""
function spectral_mixture_product_kernel(
h::Kernel,
αs::AbstractMatrix{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractMatrix{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
if !(size(αs) == size(γs) == size(ωs))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
(size(α) == size(γ) == size(ω)) ||
throw(DimensionMismatch("α, γ and ω have different dimensions"))
return spectral_mixture_product_kernel(h, RowVecs(α), RowVecs(γ), RowVecs(ω))
end

function spectral_mixture_product_kernel(
h::Kernel,
α::AbstractVector{<:AbstractVector{<:Real}},
γ::AbstractVector{<:AbstractVector{<:Real}},
ω::AbstractVector{<:AbstractVector{<:Real}},
)
return mapreduce(, α, γ, ω) do αᵢ, γᵢ, ωᵢ
return SpectralMixtureKernel(h, αᵢ, permutedims(γᵢ), permutedims(ωᵢ))
end
return KernelTensorProduct(
spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for
(α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs))
)
end

function spectral_mixture_product_kernel(
αs::AbstractMatrix{<:Real}, γs::AbstractMatrix{<:Real}, ωs::AbstractMatrix{<:Real}
α::AbstractVecOrMat, γ::AbstractVecOrMat, ω::AbstractVecOrMat
)
return spectral_mixture_product_kernel(SqExponentialKernel(), αs, γs, ωs)
return spectral_mixture_product_kernel(SqExponentialKernel(), α, γ, ω)
end
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate spectral_mixture_kernel SpectralMixtureKernel
1 change: 1 addition & 0 deletions src/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
LinearTransform(A::AbstractMatrix)
Linear transformation of the input realised by the matrix `A`.
If `a` is an `AbstractVector`, `Diagonal(a)` is passed
The second dimension of `A` must match the number of features of the target.
Expand Down
39 changes: 20 additions & 19 deletions test/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,33 @@
v1 = rand(D_in)
v2 = rand(D_in)

αs₁ = rand(3)
αs₂ = rand(D_in, 3)
γs = rand(D_in, 3)
ωs = rand(D_in, 3)
K = 3
αs₁ = rand(K)
αs₂ = rand(D_in, K)
γs = rand(D_in, K)
ωs = rand(D_in, K)

k1 = spectral_mixture_kernel(αs₁, γs, ωs)
k1 = SpectralMixtureKernel(αs₁, γs, ωs)
k2 = spectral_mixture_product_kernel(αs₂, γs, ωs)

t = v1 - v2

@test k1(v1, v2) sum(αs₁ .* exp.(-(t' * γs)' .^ 2 ./ 2) .* cospi.((t' * ωs)')) atol =
1e-5
@test k1(v1, v2) sum(
αs₁[k] * exp(-norm(t .* γs[:, k])^2 / 2) * cospi(2 * dot(ωs[:, k], t)) for k in 1:K
)

@test isapprox(
k2(v1, v2),
prod([
prod(
sum(
αs₂[i, :]' .* exp.(-(γs[i, :]' * t[i]) .^ 2 ./ 2) .*
cospi.(ωs[i, :]' * t[i]),
) for i in 1:length(t)
],);
atol=1e-5,
αs₂[i, k] * exp(-(γs[i, k] * t[i])^2 / 2) * cospi(2 * ωs[i, k] * t[i]) for
k in 1:K
) for i in 1:D_in
),
)

@test_throws DimensionMismatch spectral_mixture_kernel(rand(5), rand(4, 3), rand(4, 3))
@test_throws DimensionMismatch spectral_mixture_kernel(rand(3), rand(4, 3), rand(5, 3))
@test_throws DimensionMismatch SpectralMixtureKernel(rand(5), rand(4, 3), rand(4, 3))
@test_throws DimensionMismatch SpectralMixtureKernel(rand(3), rand(4, 3), rand(5, 3))
@test_throws DimensionMismatch spectral_mixture_product_kernel(
rand(5, 3), rand(4, 3), rand(5, 3)
)
Expand All @@ -38,15 +39,15 @@
x0 = ColVecs(randn(D_in, 3))
x1 = ColVecs(randn(D_in, 3))
x2 = ColVecs(randn(D_in, 2))
TestUtils.test_interface(k1, x0, x1, x2)
TestUtils.test_interface(k2, x0, x1, x2)
test_interface(k1, x0, x1, x2)
test_interface(k2, x0, x1, x2)
end
@testset "RowVecs" begin
x0 = RowVecs(randn(3, D_in))
x1 = RowVecs(randn(3, D_in))
x2 = RowVecs(randn(2, D_in))
TestUtils.test_interface(k1, x0, x1, x2)
TestUtils.test_interface(k2, x0, x1, x2)
test_interface(k1, x0, x1, x2)
test_interface(k2, x0, x1, x2)
end
# test_ADs(x->spectral_mixture_kernel(exp.(x[1:3]), reshape(x[4:18], 5, 3), reshape(x[19:end], 5, 3)), vcat(log.(αs₁), γs[:], ωs[:]), dims = [5,5])
@test_broken "No tests passing (BaseKernel)"
Expand Down

0 comments on commit 2154dc7

Please sign in to comment.