# This notebook reproduces results from the spherical harmonics paper on UCI datasets

In [1]:
from torch import Tensor 


import torch 
import gpytorch 
import geometric_kernels.torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from geometric_kernels.spaces import Hypersphere
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import DiagLinearOperator, LinearOperator
from mdgp.kernels import GeometricMaternKernel
from tqdm.autonotebook import tqdm 
from gpytorch.metrics import negative_log_predictive_density, mean_squared_error


torch.set_default_dtype(torch.float64)
from mdgp.variational.spherical_harmonic_features.utils import * 
from mdgp.variational.spherical_harmonic_features_variational_strategy import SphericalHarmonicFeaturesVariationalStrategy

INFO: Using numpy backend
  from tqdm.autonotebook import tqdm


In [2]:
class gpytorchSGP(gpytorch.models.ApproximateGP):
    def __init__(self, max_ell, d, max_ell_prior, epsilon_sigma=1.0, kappa=1.0, nu=2.5, sigma=1.0, batch_shape=torch.Size([]), jitter_val=1e-6, optimize_nu: bool = True):
        m = total_num_harmonics(max_ell, d)
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val)
        super().__init__(variational_strategy=variational_strategy)

        # constants 
        self.jitter_val = jitter_val
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
        self.mean_module = gpytorch.means.ConstantMean()
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

        # prior hyperparams 
        self.covar_module.outputscale = sigma ** 2
        self.likelihood.noise = epsilon_sigma ** 2

    def forward(self, x) -> MultivariateNormal:
        p_sigma = self.covar_module(x)
        p_mu = self.mean_module(x)
        return MultivariateNormal(p_mu, p_sigma)

In [3]:
from linear_operator.operators import DiagLinearOperator
from gpytorch.distributions import MultivariateNormal


class SphereProjector(torch.nn.Module):
    def __init__(self, b: float = 1.0):
        super().__init__()
        self.b = torch.nn.Parameter(torch.tensor(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y.squeeze(-1) / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)

# UCI data

### Fixed parameters as in the spherical harmonics paper

In [4]:
# Variational parameters
num_harmonics_variational = 210
num_harmonics_prior = 625
nu = 1.5
optimize_nu = False
kappa = 1.0

# Other model parameters
batch_shape = torch.Size([])

# Training parameters 
batch_size = 256 
LR = 0.01
NUM_EPOCHS = 20

In [5]:
from mdgp.variational.spherical_harmonic_features.utils import num_spherical_harmonics_to_num_levels


def get_model_and_projector(dim=8):
    # number of levels for variational inference 
    max_ell, _ = num_spherical_harmonics_to_num_levels(num_harmonics_variational, dim)

    # number of levels for prior
    max_ell_prior, _ = num_spherical_harmonics_to_num_levels(num_harmonics_prior, dim)

    model = gpytorchSGP(max_ell=max_ell, d=dim, max_ell_prior=max_ell_prior, kappa=kappa, nu=nu, optimize_nu=optimize_nu, batch_shape=batch_shape)
    projector = SphereProjector()
    return model, projector

In [6]:
from torch.utils.data import DataLoader

In [7]:
def train_step(x, y, model, projector, optimizer, mll) -> float:
    optimizer.zero_grad(set_to_none=True)
    x, y = projector(x, y)
    # print(x.shape, y.shape)
    output = model(x)
    loss = mll(output, y.squeeze(-1))
    # print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()


def train(dataset, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR, optimize_projector: bool = False) -> list[float]: 
    # optimizer and criterion
    parameters = [{
        'params': model.parameters()
    }]
    optimizer = torch.optim.Adam(parameters, lr=lr, maximize=True)
    mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_y.size(0))

    # data
    train_loader = DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = 0
        for x_batch, y_batch in train_loader:
            loss = train_step(x=x_batch, y=y_batch, model=model, projector=projector, optimizer=optimizer, mll=mll)
            epoch_loss += loss
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses 


def train_lfbgs(train_x, train_y, model, projector, num_epochs=NUM_EPOCHS, lr=1.0) -> list[float]:
    optimizer = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20)

    mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, train_y.size(0))

    def closure():
        optimizer.zero_grad()
        x, y = projector(train_x, train_y)
        output = model(x)
        loss = -mll(output, y.squeeze(-1))
        loss.backward()
        return loss

    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = optimizer.step(closure)
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses


def evaluate(test_x, test_y, model, projector):
    with torch.no_grad():
        test_x, test_y = projector(test_x), test_y
        test_y = test_y.squeeze(-1)
        out = model.likelihood(model(test_x))
        out = projector.inverse(out)
        nlpd = negative_log_predictive_density(out, test_y)
        mse = mean_squared_error(out, test_y)
        metrics = {
            'nlpd': nlpd.item(), 
            'mse': mse.item(),
        }
        print(f"NLPD: {metrics['nlpd']}, MSE: {metrics['mse']}")
    return metrics 


def reproduce_results(dataset, batch_size, num_runs: int = 5, num_epochs=NUM_EPOCHS, lr=LR):
    print(f"Reproducing results for {dataset.name}".center(80, '-') + '\n')

    metrics = []
    for run in range(num_runs):
        print(f"Run {run + 1}".center(80, '-'))

        torch.random.manual_seed(run)
        model, projector = get_model_and_projector(dataset)
        train(dataset, model, projector, num_epochs=num_epochs, lr=lr, batch_size=batch_size)
        run_metrics = evaluate(dataset, model, projector)
        metrics.append(run_metrics)
    df = pd.DataFrame(metrics)

    print("Metrics mean".center(80, '-'))
    print(df.mean())

    print("Metrics STD".center(80, '-'))
    print(df.std())

    return df 

In [8]:
from datasets_dsvi import Energy as EnergyDSVI
DTYPE = torch.get_default_dtype()

data = EnergyDSVI().get_data()
X, Y, Xs, Ys, Y_std = [data[_] for _ in ['X', 'Y', 'Xs', 'Ys', 'Y_std']]
x, y = torch.from_numpy(X).to(DTYPE), torch.from_numpy(Y).to(DTYPE).squeeze(-1)
test_x, test_y = torch.from_numpy(Xs).to(DTYPE), torch.from_numpy(Ys).to(DTYPE).squeeze(-1)
test_y_std = torch.from_numpy(Y_std).to(DTYPE).squeeze(-1)

  if self.type is 'regression':


Normalizing X with mean [[7.63950796e-01 6.71782200e+02 3.18748191e+02 1.76517004e+02
  5.25759768e+00 3.47756874e+00 2.32995658e-01 2.82344428e+00]] and std [[ 0.10901257 89.79381901 41.44677061 45.98855734  1.74867228  1.16290081
   0.13683384  1.56965762]]
Normalizing Y with mean [[22.33689725]] and std [[9.9405047]]


In [9]:
model, projector = get_model_and_projector()
train_lfbgs(x, y, model, projector, num_epochs=80, lr=0.01)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/80 [00:00<?, ?it/s]

[tensor(1.4869, grad_fn=<NegBackward0>),
 tensor(0.9331, grad_fn=<NegBackward0>),
 tensor(0.7039, grad_fn=<NegBackward0>),
 tensor(0.5486, grad_fn=<NegBackward0>),
 tensor(0.2744, grad_fn=<NegBackward0>),
 tensor(0.1352, grad_fn=<NegBackward0>),
 tensor(0.0258, grad_fn=<NegBackward0>),
 tensor(-0.1119, grad_fn=<NegBackward0>),
 tensor(-0.2369, grad_fn=<NegBackward0>),
 tensor(-0.3286, grad_fn=<NegBackward0>),
 tensor(-0.3950, grad_fn=<NegBackward0>),
 tensor(-0.4525, grad_fn=<NegBackward0>),
 tensor(-0.5009, grad_fn=<NegBackward0>),
 tensor(-0.5450, grad_fn=<NegBackward0>),
 tensor(-0.5872, grad_fn=<NegBackward0>),
 tensor(-0.6234, grad_fn=<NegBackward0>),
 tensor(-0.6514, grad_fn=<NegBackward0>),
 tensor(-0.6716, grad_fn=<NegBackward0>),
 tensor(-0.6863, grad_fn=<NegBackward0>),
 tensor(-0.6974, grad_fn=<NegBackward0>),
 tensor(-0.7056, grad_fn=<NegBackward0>),
 tensor(-0.7118, grad_fn=<NegBackward0>),
 tensor(-0.7168, grad_fn=<NegBackward0>),
 tensor(-0.7208, grad_fn=<NegBackward0>),

In [10]:
evaluate(test_x, test_y, model, projector)

NLPD: 0.05947370034459709, MSE: 0.04431437675793123


{'nlpd': 0.05947370034459709, 'mse': 0.04431437675793123}

In [9]:
model, projector = get_model_and_projector()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
elbo = gpytorch.mlls.VariationalELBO(model.likelihood, model, y.size(0))


for step in (pbar := tqdm(range(10000), desc='Epochs')):
    model.train()
    optimizer.zero_grad()
    x_batch, y_batch = projector(x, y)
    output = model(x_batch)
    loss = elbo(output, y_batch)
    loss.backward()
    optimizer.step()
    pbar.set_postfix({'ELBO': loss.item()})

    if step % 50 == 0:
        with torch.no_grad():
            model.eval()
            test_x_proj = projector(test_x)
            out = model.likelihood(model(test_x_proj))
            out = projector.inverse(out)
            nlpd = negative_log_predictive_density(out, test_y)
            mse = mean_squared_error(out, test_y)
            print(f"NLPD: {nlpd}, MSE: {mse}")

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


Epochs:   0%|          | 0/10000 [00:00<?, ?it/s]

NLPD: 2.0815537799947337, MSE: 0.9939946234508938
NLPD: 1.87599291875139, MSE: 0.08693688022662341
NLPD: 1.620360506518523, MSE: 0.055414195833855456
NLPD: 1.362619233032334, MSE: 0.05467460285310863
NLPD: 1.1022443933201664, MSE: 0.05426147816649955
NLPD: 0.8532727792778565, MSE: 0.05444260276696631
NLPD: 0.6305420553232204, MSE: 0.05383040729539002
NLPD: 0.44651796123683246, MSE: 0.0527989926172567
NLPD: 0.3083038323576212, MSE: 0.05149217917667947
NLPD: 0.21427132697482854, MSE: 0.04989288854232348
NLPD: 0.15500423866140392, MSE: 0.04840779629622393
NLPD: 0.12037733547007452, MSE: 0.047261446198796464
NLPD: 0.1011249960878665, MSE: 0.04653117979938449
NLPD: 0.09029520385370318, MSE: 0.04603873756211407
NLPD: 0.08450382761332585, MSE: 0.045757410529742686
NLPD: 0.08107958618844814, MSE: 0.04555610090843142
NLPD: 0.07896958079252589, MSE: 0.045426193698496514
NLPD: 0.07767795256753642, MSE: 0.045338117114614634
NLPD: 0.07660445066081982, MSE: 0.04526225770690132
NLPD: 0.07612299779694

In [23]:
from mdgp.utils import test_psd


with torch.no_grad():
    test_x_proj = projector(test_x)
    out = model(test_x_proj)
    out = projector.inverse(out)
    test_psd(out.lazy_covariance_matrix)
    mean = out.mean
    print(mean_squared_error(out, test_y).item())
    print(negative_log_predictive_density(out, test_y).mean().item())

0.0545016590434723
5.432272166524743


# Exact GP

In [39]:
class ExactGP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGP, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=1.5))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGP(x, y, likelihood)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)


for step in (pbar := tqdm(range(10000), desc='Epochs')):
    model.train()
    likelihood.train()
    optimizer.zero_grad()
    output = model(x)
    loss = mll(output, y)
    loss.backward()
    optimizer.step()
    pbar.set_postfix({'MLL': loss.item()})


Epochs:   0%|          | 0/10000 [00:00<?, ?it/s]

In [40]:
with torch.no_grad():
    model.eval()
    likelihood.eval()
    out = model(test_x)
    test_psd(out.lazy_covariance_matrix)
    mean = out.mean
    print(f"MSE: {mean_squared_error(out, test_y).item()}")
    print(f"NLPD: {negative_log_predictive_density(out, test_y).mean().item()}")

MSE: 0.0017146845907218108
NLPD: -1.5748066209556104
