In [8]:
import numpy as np
from numpy.linalg import eigh
from scipy.stats import special_ortho_group
# from scipy.spatial.distance import pdist, squareform

In [9]:
def sample_from_sphere(N, d):
    """ Generate N d-dimensional vectors uniformly sampled from the unit sphere. """
    xi = np.random.randn(N, d)
    xi /= np.linalg.norm(xi, axis=1, keepdims=True)
    return xi

def generate_covariates(N, d=10):
    """ Generate N covariates x^{(i)} based on the description. """
    D = np.diag([1, 1, 0.25, 2.25, 1] + [1] * (d - 5))
    U = special_ortho_group.rvs(dim=d)
    Sigma_sqrt = U.T @ np.sqrt(D) @ U
    xi = sample_from_sphere(N, d)
    return Sigma_sqrt @ xi.T

def kernel_matrix(X, kernel_type='linear'):
    """ Compute the kernel matrix for different types of kernels. """
    if kernel_type == 'linear':
        return X.T @ X
    elif kernel_type == 'relu':
        return np.maximum(X.T @ X, 0)
    elif kernel_type == 'exp':
        return np.exp(X.T @ X)

def gaussian_process(N, kernel_type):
    """ Generate labels y^{(i)} using a Gaussian Process. """
    X = generate_covariates(N)
    K = kernel_matrix(X, kernel_type=kernel_type)
    eigvals, eigvecs = eigh(K)
    K_plus = eigvecs @ np.diag(np.abs(eigvals)) @ eigvecs.T
    X = X.T
    return X, np.random.multivariate_normal(np.zeros(N), K_plus)

In [10]:
N = 5
X, Y = gaussian_process(N, 'exp')
print(X)
print(X.shape)
print(Y)
print(Y.shape)

(10, 10)
The shape of xi: (5, 10)
[[-0.26129557 -0.19103813 -0.20753888 -0.06698089 -0.37786894 -0.26217595
  -0.52501432 -0.53025769  0.01816147  0.27956417]
 [-0.30627804 -0.24151949 -0.61831367  0.20891018  0.47543401  0.05628515
  -0.0322445  -0.09974851  0.42511992  0.02182368]
 [ 0.13616555 -0.02125894 -0.17238148 -0.86426077  0.158742    0.06520823
   0.14735863  0.25880334  0.37185453  0.05740556]
 [-0.29953041 -0.22739268 -0.45475475  0.31299666 -0.62152266  0.23937856
   0.21974033  0.16520104  0.32463194 -0.08345543]
 [-0.21146871 -0.38289178  0.54812159 -0.32998459  0.30242495 -0.50857222
  -0.33767632 -0.08341033 -0.03110509 -0.00663919]]
(5, 10)
[-3.76695925 -0.14943362 -0.04335937 -0.24726158 -0.42761125]
(5,)
