Skip to content

Making Kriging differentiable and better initial hyperparams #368

@archermarx

Description

@archermarx

In order to make hyperparameter optimization of gaussian process models simpler (cf #328 and #328), it would be nice to be able to use AD on kriging surrogates. This turns out not to be too difficult. Making only a single change to the computation of Kriging coefficients makes these surrogates differentiable with respect to their hyperparameters

Setup:

using Surrogates
using LinearAlgebra 

function branin(x)
    x1 = x[1]
    x2 = x[2]
    b = 5.1 / (4*pi^2);
    c = 5/pi;
    r = 6;
    a = 1;
    s = 10;
    t = 1 / (8*pi);
    term1 = a * (x2 - b*x1^2 + c*x1 - r)^2;
    term2 = s*(1-t)*cos(x1);
    y = term1 + term2 + s;
end

n_samples = 21
lower_bound = [-5.0, 0.0]
upper_bound = [10.0, 15.0]

xys = sample(n_samples, lower_bound, upper_bound, SobolSample())
zs = branin.(xys);

# good hyperparams for the Branin function
p = [2.0, 2.0]
theta = [0.03, 0.003]

kriging_surrogate = Kriging(xys, zs, lower_bound, upper_bound, p, theta)

# Log likelihood function of a Gaussian process model, AKA what we want to optimize in order to tune our hyperparameters
function kriging_logpdf(params, x, y, lb, ub)
    d = length(params) ÷ 2
    theta = params[1:d]
    p = params[d+1:end]
    surr = Kriging(x, y, lb, ub; p, theta)

    n = length(y)
    y_minus_1μ = y - ones(length(y), 1) * surr.mu
    Rinv = surr.inverse_of_R

    term1 = only(-(y_minus_1μ' * surr.inverse_of_R * y_minus_1μ) / 2 / surr.sigma)
    term2 = -log((2π * surr.sigma)^(n/2) / sqrt(det(Rinv)))

    logpdf = term1 + term2
    return logpdf
end

loss_func = params -> -kriging_logpdf(params, xys, zs, lower_bound, upper_bound)

Just to check that this works, here's how the model performs

plot_87

Now let's try taking a gradient with respect to our hyperparameters $p, \theta$

using Zygote

Zygote.gradient(loss_func, [p; theta])

We get an error (ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float64}, ...)) which comes from the _calc_kriging_coeffs function when we build the covariance matrix R. We can replace the mutating part of this function with a matrix comprehension as follows:

old:

function _calc_kriging_coeffs(x, y, p, theta)
    n = length(x)
    d = length(x[1])
    R = zeros(float(eltype(x[1])), n, n)
    @inbounds for i in 1:n
        for j in 1:n
            sum = zero(eltype(x[1]))
            for l in 1:d
                sum = sum + theta[l] * norm(x[i][l] - x[j][l])^p[l]
            end
            R[i, j] = exp(-sum)
        end
    end

    one = ones(n, 1)
    one_t = one'
    inverse_of_R = inv(R)

    mu = (one_t * inverse_of_R * y) / (one_t * inverse_of_R * one)

    y_minus_1μ = y - one * mu

    b = inverse_of_R * y_minus_1μ

    sigma = (y_minus_1μ' * inverse_of_R * y_minus_1μ) / n

    mu[1], b, sigma[1], inverse_of_R
end

new:

function _calc_kriging_coeffs(x, y, p, theta)
    n = length(x)
    d = length(x[1])
    R = [
        let
            sum = zero(eltype(x[1]))
            for l in 1:d
                sum = sum + theta[l] * norm(x[i][l] - x[j][l])^p[l]
            end
            exp(-sum)
        end

        for j in 1:n, i in 1:n
    ]

    one = ones(n, 1)
    one_t = one'
    inverse_of_R = inv(R)

    mu = (one_t * inverse_of_R * y) / (one_t * inverse_of_R * one)

    y_minus_1μ = y - one * mu

    b = inverse_of_R * y_minus_1μ

    sigma = (y_minus_1μ' * inverse_of_R * y_minus_1μ) / n

    mu[1], b, sigma[1], inverse_of_R
end

With that, we can compute a gradient!

julia> Zygote.gradient(loss_func, [0.03, 0.0034, 2.0, 2.0])
([-654.8265400112359, 1784.234854500508, 5532.336307947245, 332.2643207182241],)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions