In [None]:
from utils import KernelFunction, evaluate_likelihood, generate_gp_data
from automated_statistician import greedy_search
import numpy as np

In [18]:
true_kernel = KernelFunction().periodic(period=3.0, lengthscale=0.01)
X, Y, true_kernel_str = generate_gp_data(true_kernel, input_dim=1, n_points=50, noise_var=1e-3)
print("True Kernel:", true_kernel_str)

True Kernel: Periodic({'period': 3.0, 'variance': 1.0, 'lengthscale': 0.01})


In [19]:
log_likelihood = evaluate_likelihood(true_kernel, X, Y, runtime=False)
print("Log Marginal Likelihood (true kernel):", log_likelihood)

Log Marginal Likelihood (true kernel): 94.19805611197711


In [20]:
best_kernel, evals = greedy_statistician_search(X, Y, max_steps=5, method='LL')

print("\nBest kernel found:")
print(best_kernel)
print("Number of reward calls:", len(evals))

[Step 1] LL: 67.03 BIC: -122.32 | RBF({'lengthscale': 1.0, 'variance': 1.0})
[Step 2] LL: 66.13 BIC: -116.61 | (RBF({'lengthscale': 1.0, 'variance': 1.0}) + WhiteNoise({'variance': 1.0}))
[Step 3] LL: 79.78 BIC: -140.00 | ((RBF({'lengthscale': 1.0, 'variance': 1.0}) * Constant({'variance': 1.0})) + WhiteNoise({'variance': 1.0}))
[Step 4] LL: 80.17 BIC: -136.87 | ((RBF({'lengthscale': 1.0, 'variance': 1.0}) * RBF({'lengthscale': 1.0, 'variance': 1.0})) + WhiteNoise({'variance': 1.0}))
[Step 5] LL: 82.22 BIC: -129.22 | ((RBF({'lengthscale': 1.0, 'variance': 1.0}) * RBF({'lengthscale': 1.0, 'variance': 1.0})) + (WhiteNoise({'variance': 1.0}) * Periodic({'period': 1.0, 'variance': 1.0, 'lengthscale': 1.0})))

Best kernel found:
((RBF({'lengthscale': 1.0, 'variance': 1.0}) * RBF({'lengthscale': 1.0, 'variance': 1.0})) + (WhiteNoise({'variance': 1.0}) * Periodic({'period': 1.0, 'variance': 1.0, 'lengthscale': 1.0})))
Number of reward calls: 144


In [21]:
evals

{"RBF({'lengthscale': 1.0, 'variance': 1.0})": 67.02897887852069,
 "Linear({'variances': 1.0})": 40.47167725950917,
 "Periodic({'period': 1.0, 'variance': 1.0, 'lengthscale': 1.0})": 39.366951825414425,
 "WhiteNoise({'variance': 1.0})": 45.27570338801687,
 "Constant({'variance': 1.0})": 44.97334879792126,
 "(RBF({'lengthscale': 1.0, 'variance': 1.0}) + Linear({'variances': 1.0}))": 55.64141011192539,
 "(RBF({'lengthscale': 1.0, 'variance': 1.0}) * Linear({'variances': 1.0}))": 38.47788011307059,
 "(RBF({'lengthscale': 1.0, 'variance': 1.0}) + Periodic({'period': 1.0, 'variance': 1.0, 'lengthscale': 1.0}))": 53.24629362512891,
 "(RBF({'lengthscale': 1.0, 'variance': 1.0}) * Periodic({'period': 1.0, 'variance': 1.0, 'lengthscale': 1.0}))": 48.50300435905026,
 "(RBF({'lengthscale': 1.0, 'variance': 1.0}) + WhiteNoise({'variance': 1.0}))": 66.13013752473006,
 "(RBF({'lengthscale': 1.0, 'variance': 1.0}) * WhiteNoise({'variance': 1.0}))": 45.16474469365879,
 "(RBF({'lengthscale': 1.0, 'vari