Skip to content

Commit

Permalink
Merge pull request #5 from StevenWhitaker/A_mul_A_forwarddiff
Browse files Browse the repository at this point in the history
Add ForwardDiff.jl and use A_mul_A′ for faster AD.
  • Loading branch information
StevenWhitaker committed Jan 11, 2022
2 parents b10f089 + 5a85f14 commit fcd0ed9
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 44 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
name = "PERK"
uuid = "85f1910e-5d62-11e9-1bce-2b30a05932d4"
authors = ["Steven Whitaker <stwhit@umich.edu>"]
version = "0.3.1"
version = "0.3.2"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ForwardDiff = "0.10"
julia = "1.2"
1 change: 1 addition & 0 deletions src/PERK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Module implementing parameter estimation via regression with kernels (PERK).
"""
module PERK

import ForwardDiff # ForwardDiff.Dual, ForwardDiff.partials, ForwardDiff.value
using LinearAlgebra: I, Diagonal, norm
using Statistics: mean

Expand Down
19 changes: 17 additions & 2 deletions src/krr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function _krr_train(
# Calculate sample covariances
xtrain = xtrain .- xm # [T]
z = z .- zm # [H,T]
Czz = div0.(z * z', T) # [H,H]
Czz = div0.(A_mul_A′(z), T) # [H,H]
Cxz = div0.(z * xtrain, T) # [H]

# Calculate the (regularized) inverse of Czz and multiply by Cxz
Expand Down Expand Up @@ -159,7 +159,7 @@ function _krr_train(
# Calculate sample covariances
xtrain = xtrain .- xm # [L,T]
z = z .- zm # [H,T]
Czz = div0.(z * z', T) # [H,H]
Czz = div0.(A_mul_A′(z), T) # [H,H]
Cxz = div0.(xtrain * z', T) # [L,H]

# Calculate the (regularized) inverse of Czz and multiply by Cxz
Expand All @@ -169,6 +169,21 @@ function _krr_train(

end

A_mul_A′(A) = A * A'

# See https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
function A_mul_A′(D::AbstractMatrix{<:ForwardDiff.Dual{T}}) where T

A = ForwardDiff.value.(D)
AA′ = A_mul_A′(A)
Δ = ForwardDiff.partials.(D)
# dAA′ = Δ * A' + A * Δ'
ΔA′ = Δ * A'
dAA′ = [ΔA′[i,j] + conj(ΔA′[j,i]) for i in axes(ΔA′, 1), j in axes(ΔA′, 2)]
return ForwardDiff.Dual{T}.(AA′, dAA′)

end

"""
krr(ytest, trainData, kernel)
Expand Down
Loading

0 comments on commit fcd0ed9

Please sign in to comment.