Skip to content

Commit

Permalink
Merge branch 'master' into tgf/docstring_kronmat
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Oct 8, 2021
2 parents e010524 + f78f028 commit c807b23
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.23"
version = "0.10.25"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
18 changes: 13 additions & 5 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Delta is not following the PreMetric rules since d(x, x) == 1
struct Delta <: Distances.UnionPreMetric end

@inline Distances.eval_op(::Delta, a::Real, b::Real) = a == b
@inline Distances.eval_reduce(::Delta, a, b) = a && b
@inline Distances.eval_start(::Delta, a, b) = true
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Delta)(a::Number, b::Number) = a == b
(dist::Delta)(a::Number, b::Number) = a == b
Base.@propagate_inbounds function (dist::Delta)(
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
)
@boundscheck if length(a) != length(b)
throw(
DimensionMismatch(
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
),
)
end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
2 changes: 1 addition & 1 deletion src/distances/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)

function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector)
return broadcast!(d, out, X, Y')
return broadcast!(d, out, X, permutedims(Y))
end

pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)
Expand Down
69 changes: 57 additions & 12 deletions src/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,17 @@ function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVec
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi)
first_kernels, tail_kernels = Iterators.peel(k.kernels)
first_x, tail_x = Iterators.peel(slices(x))

# handle first kernel and input
kernelmatrix!(K, first_kernels, first_x)

# handle remaining kernels and inputs
Ktmp = similar(K)
for (ki, xi) in zip(tail_kernels, tail_x)
kernelmatrix!(Ktmp, ki, xi)
hadamard!(K, K, Ktmp)
end

return K
Expand All @@ -86,10 +93,18 @@ function kernelmatrix!(
validate_inplace_dims(K, x, y)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi, yi)
first_kernels, tail_kernels = Iterators.peel(k.kernels)
first_x, tail_x = Iterators.peel(slices(x))
first_y, tail_y = Iterators.peel(slices(y))

# handle first kernel and inputs
kernelmatrix!(K, first_kernels, first_x, first_y)

# handle remaining kernels and inputs
Ktmp = similar(K)
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
kernelmatrix!(Ktmp, ki, xi, yi)
hadamard!(K, K, Ktmp)
end

return K
Expand All @@ -99,10 +114,40 @@ function kernelmatrix_diag!(K::AbstractVector, k::KernelTensorProduct, x::Abstra
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kernelmatrix_diag!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix_diag(k, xi)
first_kernels, tail_kernels = Iterators.peel(k.kernels)
first_x, tail_x = Iterators.peel(slices(x))

# handle first kernel and input
kernelmatrix_diag!(K, first_kernels, first_x)

# handle remaining kernels and inputs
Ktmp = similar(K)
for (ki, xi) in zip(tail_kernels, tail_x)
kernelmatrix_diag!(Ktmp, ki, xi)
hadamard!(K, K, Ktmp)
end

return K
end

function kernelmatrix_diag!(
K::AbstractVector, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
validate_domain(k, x)

first_kernels, tail_kernels = Iterators.peel(k.kernels)
first_x, tail_x = Iterators.peel(slices(x))
first_y, tail_y = Iterators.peel(slices(y))

# handle first kernel and inputs
kernelmatrix_diag!(K, first_kernels, first_x, first_y)

# handle remaining kernels and inputs
Ktmp = similar(K)
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
kernelmatrix_diag!(Ktmp, ki, xi, yi)
hadamard!(K, K, Ktmp)
end

return K
Expand Down
8 changes: 8 additions & 0 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ function kernelmatrix_diag!(K::AbstractVector, κ::ScaledKernel, x::AbstractVect
return K
end

function kernelmatrix_diag!(
K::AbstractVector, κ::ScaledKernel, x::AbstractVector, y::AbstractVector
)
kernelmatrix_diag!(K, κ.kernel, x, y)
K .*= κ.σ²
return K
end

Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)

Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)
Expand Down
1 change: 1 addition & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ function test_interface(

tmp_diag = Vector{Float64}(undef, length(x0))
@test kernelmatrix_diag!(tmp_diag, k, x0) kernelmatrix_diag(k, x0)
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) kernelmatrix_diag(k, x0, x1)
end

function test_interface(
Expand Down
26 changes: 18 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,27 @@ function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::Abstract
end
end

function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector)
return validate_inplace_dims(K, x, x)
end

function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
if length(K) != length(x)
function validate_inplace_dims(K::AbstractVector, x::AbstractVector, y::AbstractVector)
validate_inputs(x, y)
n = length(x)
if length(y) != n
throw(
DimensionMismatch(
"Length of input x ($n) not consistent with length of input y " *
"($(length(y))",
),
)
end
if length(K) != n
throw(
DimensionMismatch(
"Length of target vector K ($(length(K))) not consistent with length of input" *
"vector x ($(length(x))",
"Length of target vector K ($(length(K))) not consistent with length of " *
"inputs ($n)",
),
)
end
end

function validate_inplace_dims(K::AbstractVecOrMat, x::AbstractVector)
return validate_inplace_dims(K, x, x)
end

0 comments on commit c807b23

Please sign in to comment.