Test hyperparameters initialization

In [16]:
import gpflow 
from kernel_discovery.kernel import RBF, Linear, Periodic, White
import numpy as np
from numpy.random import rand, normal
from kernel_discovery.preprocessing import DataShape, get_datashape

Test RBF

In [27]:
def init_rbf(datashape_x: DataShape, datashape_y: DataShape, sd=1.):
    
    # lengthscale
    if rand() < 0.5:
        log_lengthscale = normal(loc=datashape_x.std, scale=sd)
    else:
        log_lengthscale = normal(loc=np.log(2*(datashape_x.max - datashape_x.min)),
                                 scale=sd)

    # variance
    if rand() < 0.5:
        log_variance = normal(loc=datashape_y.std, scale=sd)
    else:
        log_variance = normal(loc=0, scale=sd)

    init_params = RBF(variance=np.exp(log_variance),
                      lengthscales=np.exp(log_lengthscale)).parameters
    return [p.numpy() for p in init_params]

In [28]:

lengthscales = np.linspace(0.05, 3., 10)
variances = np.linspace(0.05, 10, 10)

x = np.linspace(0, 5, 100)[:,None]
data_shape_x = get_datashape(x)

for i in range(10):
    kernel = RBF(variance=variances[i], lengthscales=lengthscales[i])
    k = kernel.K(x)
    k = k.numpy()
    chol = np.linalg.cholesky(k + 1e-6*np.eye(100))
    epsilon  = np.random.randn(100,1)
    y = chol @ epsilon
    data_shape_y = get_datashape(y)
    init_param = init_rbf(data_shape_x, data_shape_y, sd=0.5)
    print("="*20)
    print(f"Real parameter {lengthscales[i]}, {variances[i]}")
    print(f"Init parameter {init_param}")
    
    

Real parameter 0.05, 0.05
Init parameter [array(13.28500232), array(0.96973273)]
Real parameter 0.37777777777777777, 1.1555555555555554
Init parameter [array(2.94981324), array(2.69950964)]
Real parameter 0.7055555555555556, 2.2611111111111106
Init parameter [array(6.97376741), array(3.36331306)]
Real parameter 1.0333333333333334, 3.3666666666666663
Init parameter [array(6.67265334), array(3.73400281)]
Real parameter 1.3611111111111112, 4.472222222222221
Init parameter [array(5.07675323), array(2.13093314)]
Real parameter 1.6888888888888889, 5.577777777777777
Init parameter [array(17.18571305), array(2.65961167)]
Real parameter 2.0166666666666666, 6.683333333333333
Init parameter [array(2.92515545), array(1.36691075)]
Real parameter 2.344444444444444, 7.788888888888888
Init parameter [array(11.0724237), array(1.43218168)]
Real parameter 2.672222222222222, 8.894444444444444
Init parameter [array(33.66580483), array(0.82997625)]
Real parameter 3.0, 10.0
Init parameter [array(3.82655096),

In [40]:
def init_periodic(datashape_x: DataShape, datashape_y: DataShape, sd=1.):
    
    # lengthscales
    log_lengthscale = normal(loc=0, scale=sd)

    # periodicity
    if rand() < 0.5:
        # no mim_period
        log_period = normal(loc=datashape_x.std-2., scale=sd)
        # TODO: min_period
    else:
        log_period = normal(loc=np.log(datashape_x.max - datashape_x.min) - 3.2,
                            scale=sd)
        # TODO: min_period

    # variance
    if rand() < 0.5:
        log_variance = normal(loc=datashape_y.std, scale=sd)
    else:
        log_variance = normal(loc=0., scale=sd)

    init_params = Periodic(variance=np.exp(log_variance),
                           lengthscales=np.exp(log_lengthscale),
                           period=np.exp(log_period)).parameters
    return [p.numpy() for p in init_params]

In [41]:
periods = np.linspace(0.05, 5, 10)

for i in range(10):
    kernel = Periodic(period=periods[i])
    k = kernel.K(x)
    k = k.numpy()
    chol = np.linalg.cholesky(k + 1e-6*np.eye(100))
    epsilon  = np.random.randn(100,1)
    y = chol @ epsilon
    data_shape_y = get_datashape(y)
    init_param = init_periodic(data_shape_x, data_shape_y, sd=1.)
    
    print(f"Real param {periods[i]}")
    print(f"Init param {init_param[0]}")
    print("="*20)



Real param 0.05
Init param 0.8524235248820111
Real param 0.6000000000000001
Init param 0.10598383559855826
Real param 1.1500000000000001
Init param 0.22533413201795982
Real param 1.7000000000000002
Init param 1.1850680964377898
Real param 2.25
Init param 0.9904751900889998
Real param 2.8
Init param 0.6668184361792288
Real param 3.35
Init param 0.3554378967440945
Real param 3.9000000000000004
Init param 0.06963470211142325
Real param 4.45
Init param 0.2032007852521709
Real param 5.0
Init param 0.7744524768106249


In [33]:
def init_linear(datashape_x: DataShape, datashape_y: DataShape, sd=1.):
    
    r = rand()
    if r < 1. / 3.:
        log_variance = normal(loc=datashape_y.std - datashape_x.std, scale=sd)
    elif r < 2. / 3:
        dist_y = datashape_y.max - datashape_y.min
        dist_x = datashape_x.max - datashape_x.min
        loc = np.log(np.abs(dist_y / dist_x))
        log_variance = normal(loc=loc, scale=sd)
    else:
        log_variance = normal(loc=0., scale=sd)
        
    location = np.random.uniform(low=2 * datashape_x.min - datashape_x.max,
                                 high=2 * datashape_x.max - datashape_x.min)

    init_params = Linear(variance=np.exp(log_variance), location=location).parameters
    return [p.numpy() for p in init_params]

In [42]:
for i in range(10):
    k = Linear(variance=variances[i], location=2.5)
    k = kernel.K(x)
    k = k.numpy()
    chol = np.linalg.cholesky(k + 1e-6*np.eye(100))
    epsilon  = np.random.randn(100,1)
    y = chol @ epsilon
    data_shape_y = get_datashape(y)
    init_param = init_periodic(data_shape_x, data_shape_y, sd=1.)
    
    print("="*20)
    print(f"Real param {variances[i]}")
    print(f"Init param {init_param[1]}")

Real param 0.05
Init param 1.8456277730837776
Real param 1.1555555555555554
Init param 5.747273288174298
Real param 2.2611111111111106
Init param 0.4388289725062545
Real param 3.3666666666666663
Init param 2.007641702981547
Real param 4.472222222222221
Init param 1.7554673383485406
Real param 5.577777777777777
Init param 0.8912238795273719
Real param 6.683333333333333
Init param 1.3942659132010515
Real param 7.788888888888888
Init param 2.8977881345363663
Real param 8.894444444444444
Init param 1.2382246286994185
Real param 10.0
Init param 0.853753117787976
