Skip to content

Commit

Permalink
Merge f6c5d28 into 58cb069
Browse files Browse the repository at this point in the history
  • Loading branch information
st-- committed Jul 9, 2021
2 parents 58cb069 + f6c5d28 commit 59269ed
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export Transform,
IdentityTransform,
FunctionTransform,
PeriodicTransform
export with_lengthscale

export NystromFact, nystrom

Expand Down Expand Up @@ -75,6 +76,7 @@ include(joinpath("transform", "selecttransform.jl"))
include(joinpath("transform", "chaintransform.jl"))
include(joinpath("transform", "periodic_transform.jl"))
include(joinpath("kernels", "transformedkernel.jl"))
include(joinpath("transform", "with_lengthscale.jl"))

include(joinpath("basekernels", "constant.jl"))
include(joinpath("basekernels", "cosine.jl"))
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Base.:∘(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.tr
Base.:(k::Kernel, ::IdentityTransform) = k
Base.:(k::TransformedKernel, ::IdentityTransform) = k

function Base.isequal(k::TransformedKernel, k2::TransformedKernel)
return isequal(k.kernel, k2.kernel) && isequal(k.transform, k2.transform)
end

Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)

function printshifted(io::IO, κ::TransformedKernel, shift::Int)
Expand Down
2 changes: 2 additions & 0 deletions src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ end
set!(t::ChainTransform, θ) = set!.(t.transforms, θ)
duplicate(t::ChainTransform, θ) = ChainTransform(duplicate.(t.transforms, θ))

Base.isequal(t::ChainTransform, t2::ChainTransform) = isequal(t.transforms, t2.transforms)

Base.show(io::IO, t::ChainTransform) = printshifted(io, t, 0)

function printshifted(io::IO, t::ChainTransform, shift::Int)
Expand Down
32 changes: 32 additions & 0 deletions src/transform/with_lengthscale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
with_lengthscale(kernel::Kernel, lengthscale::Real)
with_lengthscale(kernel::Kernel, lengthscales::AbstractVector{<:Real})
Construct a transformed kernel with `lengthscale`.
If a vector `lengthscales` is passed instead, construct an "ARD" kernel with different lengthscales for each dimension.
The following two ways of constructing a squared-exponential kernel with
a given lengthscale are equivalent:
```jldoctest
julia> ℓ = 2.5;
julia> isequal(SqExponentialKernel() ∘ ScaleTransform(inv(ℓ)), with_lengthscale(SqExponentialKernel(), ℓ))
true
```
and for the ARD case:
```jldoctest
julia> ℓ = [0.5, 2.5];
julia> isequal(SqExponentialKernel() ∘ ARDTransform(inv.(ℓ)), with_lengthscale(SqExponentialKernel(), ℓ))
true
```
"""
function with_lengthscale(kernel::Kernel, lengthscale::Real)
return compose(kernel, ScaleTransform(inv(lengthscale)))
end
function with_lengthscale(kernel::Kernel, lengthscales::AbstractVector{<:Real})
return compose(kernel, ARDTransform(map(inv, lengthscales)))
end

0 comments on commit 59269ed

Please sign in to comment.