In [1]:
# Geomstats and GeometricKernels backends 
import geometric_kernels.torch 
import os
os.environ['GEOMSTATS_BACKEND'] = 'pytorch'

# This notebook reproduces results from the spherical harmonics paper on UCI datasets

In [2]:
from torch import Tensor 


import torch 
import gpytorch 
import geometric_kernels.torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from geometric_kernels.spaces import Hypersphere
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import DiagLinearOperator, LinearOperator
from mdgp.kernels import GeometricMaternKernel
from tqdm.autonotebook import tqdm 
from gpytorch.metrics import negative_log_predictive_density, mean_squared_error


torch.set_default_dtype(torch.float64)
from mdgp.variational.spherical_harmonic_features.utils import * 
from mdgp.variational.spherical_harmonic_features_variational_strategy import SphericalHarmonicFeaturesVariationalStrategy

INFO: Using pytorch backend
  from tqdm.autonotebook import tqdm


In [3]:
class gpytorchSGP(gpytorch.models.ApproximateGP):
    def __init__(self, max_ell, d, max_ell_prior, epsilon_sigma=1.0, kappa=1.0, nu=2.5, sigma=1.0, batch_shape=torch.Size([]), jitter_val=1e-6, optimize_nu: bool = True):
        m = total_num_harmonics(max_ell, d)
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val)
        super().__init__(variational_strategy=variational_strategy)

        # constants 
        self.jitter_val = jitter_val
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
        self.mean_module = gpytorch.means.ConstantMean()
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

        # prior hyperparams 
        self.covar_module.outputscale = sigma ** 2
        self.likelihood.noise = epsilon_sigma ** 2

    def forward(self, x) -> MultivariateNormal:
        p_sigma = self.covar_module(x)
        p_mu = self.mean_module(x)
        return MultivariateNormal(p_mu, p_sigma)

In [4]:
from mdgp.experiments.uci.model.projectors import SphereProjector

# UCI data

In [5]:
from mdgp.experiments.uci.data.datasets import Kin8mn, Power, Concrete, Energy, UCIDataset

### Fixed parameters as in the spherical harmonics paper

In [6]:
# Datasets 
kin8mn = Kin8mn()
power = Power()
concrete = Concrete()

# Variational parameters
dimension_to_num_harmonics_variational = {
    4: 336,
    6: 294,
    8: 210,
}

# Prior parameters 
"""
NOTE 
- It is difficult to say what number of spherical harmonics was used for the prior.
  I set it to be the same as the number of inducing variables.
- It is also difficult to say what lengthscale initialisation was used in the paper. 
  I set it to 1.0 for now, although a lower number, e.g. 0.001, would likely be better for higher dimensions.
"""
dimension_to_num_harmonics_prior = {
    4: 336, 
    6: 1000,
    8: 625,
}
nu = 1.5
optimize_nu = False
kappa = 1.0

# Other model parameters
batch_shape = torch.Size([])

# Training parameters 
batch_size = 256 
LR = 0.01
NUM_EPOCHS = 20

In [7]:
from mdgp.variational.spherical_harmonic_features.utils import num_spherical_harmonics_to_num_levels


def get_model_and_projector(dataset: UCIDataset, proj: str = 'sphere'):
    sphere_dimension = dataset.dimension

    # number of levels for variational inference 
    num_spherical_harmonics = dimension_to_num_harmonics_variational[dataset.dimension]
    max_ell, _ = num_spherical_harmonics_to_num_levels(num_spherical_harmonics, sphere_dimension)

    # number of levels for prior
    num_spherical_harmonics_prior = dimension_to_num_harmonics_prior[dataset.dimension]
    max_ell_prior, _ = num_spherical_harmonics_to_num_levels(num_spherical_harmonics_prior, sphere_dimension)

    model = gpytorchSGP(max_ell=max_ell, d=sphere_dimension, max_ell_prior=max_ell_prior, kappa=kappa, nu=nu, optimize_nu=optimize_nu, batch_shape=batch_shape)
    projector = SphereProjector()
    return model, projector


model, projector = get_model_and_projector(kin8mn)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


In [8]:
from torch.utils.data import DataLoader

In [9]:
def train_step(x, y, model, projector, optimizer, mll) -> float:
    optimizer.zero_grad(set_to_none=True)
    x, y = projector(x, y)
    # print(x.shape, y.shape)
    output = model(x)
    loss = mll(output, y.squeeze(-1))
    # print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()


def train(dataset, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR, optimize_projector: bool = False) -> list[float]: 
    # optimizer and criterion
    parameters = [{
        'params': model.parameters()
    }]
    if optimize_projector:
        parameters.append({
            'params': projector.parameters()
        })
    optimizer = torch.optim.Adam(parameters, lr=lr, maximize=True)
    mll = gpytorch.mlls.DeepApproximateMLL(gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_y.size(0)))

    # data
    train_loader = DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = 0
        for x_batch, y_batch in train_loader:
            loss = train_step(x=x_batch, y=y_batch, model=model, projector=projector, optimizer=optimizer, mll=mll)
            epoch_loss += loss
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses 


def train_lfbgs(dataset, model, projector, num_epochs=NUM_EPOCHS, lr=1.0) -> list[float]:
    optimizer = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20)

    mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_y.size(0))

    def closure():
        optimizer.zero_grad()
        x, y = projector(dataset.train_x, dataset.train_y)
        output = model(x)
        loss = -mll(output, y.squeeze(-1))
        loss.backward()
        return loss

    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = optimizer.step(closure)
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses


def evaluate(dataset, model, projector):
    with torch.no_grad():
        test_x, test_y = projector(dataset.test_x), dataset.test_y
        test_y = test_y.squeeze(-1)
        out = model.likelihood(model(test_x))
        out = projector.inverse(out)
        nlpd = negative_log_predictive_density(out, test_y)
        mse = mean_squared_error(out, test_y)
        metrics = {
            'nlpd': nlpd.item(), 
            'mse': mse.item(),
        }
        print(f"NLPD: {metrics['nlpd']}, MSE: {metrics['mse']}")
    return metrics 


def reproduce_results(dataset, batch_size, num_runs: int = 5, num_epochs=NUM_EPOCHS, lr=LR):
    print(f"Reproducing results for {dataset.name}".center(80, '-') + '\n')

    metrics = []
    for run in range(num_runs):
        print(f"Run {run + 1}".center(80, '-'))

        torch.random.manual_seed(run)
        model, projector = get_model_and_projector(dataset)
        train(dataset, model, projector, num_epochs=num_epochs, lr=lr, batch_size=batch_size)
        # print(
        #     model.covar_module.base_kernel.lengthscale, 
        #     model.covar_module.base_kernel.nu,
        #     model.covar_module.outputscale,
        #     model.likelihood.noise,
        #     model.variational_strategy._variational_distribution.variational_mean,
        #     model.variational_strategy._variational_distribution.chol_variational_covar,
        # )
        run_metrics = evaluate(dataset, model, projector)
        metrics.append(run_metrics)
    df = pd.DataFrame(metrics)

    print("Metrics mean".center(80, '-'))
    print(df.mean())

    print("Metrics STD".center(80, '-'))
    print(df.std())

    return df 

In [10]:
from mdgp.experiments.uci.model.geometric import SHFDeepGP
from mdgp.experiments.uci.model.euclidean import EuclideanDeepGP
from mdgp.experiments.uci.model.projectors import IdentityProjector
from mdgp.experiments.uci.fit import train as train_uci
from mdgp.experiments.uci.fit import FitArguments


model = 'euclidean'
num_layers = 3

if model == 'geometric':
    model = SHFDeepGP(kin8mn, num_layers)
    projector = SphereProjector()
elif model == 'euclidean':
    model = EuclideanDeepGP(kin8mn, num_layers, num_inducing_points=210)
    projector = IdentityProjector()
fit_args = FitArguments(
    train_batch_size=1024,
    num_iterations=100,
)

In [11]:
with gpytorch.settings.num_likelihood_samples(10):
    train_uci(kin8mn, model=model, projector=projector, fit_args=fit_args, device='cpu', inner_pbar=True)

Batches: 100%|██████████| 8/8 [00:10<00:00,  1.33s/it]
Batches:  25%|██▌       | 2/8 [00:03<00:09,  1.60s/it], ELBO=-17.1]
Epochs:   1%|          | 1/100 [00:13<22:50, 13.84s/it, ELBO=-17.1]


KeyboardInterrupt: 

In [12]:
from mdgp.experiments.uci.fit.fit import evaluate_deep

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
    fit_args.test_num_samples = 10
    evaluate_deep(kin8mn, model, projector, 'cpu', fit_args=fit_args)

TLL: -0.43074761694610925, MSE: 0.03635215431338528


In [21]:
with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
    fit_args.test_num_samples = 10
    evaluate_deep(kin8mn, model, projector, 'cpu', fit_args=fit_args)

TLL: 0.7194914541851841, MSE: 0.01465116672011924


In [18]:
with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
    fit_args.test_num_samples = 10
    evaluate_deep(kin8mn, model, projector, 'cpu', fit_args=fit_args)

TLL: 1.0851964894289, MSE: 0.007424122800293092


In [14]:
with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
    fit_args.test_num_samples = 10
    evaluate_deep(kin8mn, model, projector, 'cpu', fit_args=fit_args)

TLL: -0.8657127998655103, MSE: 0.056299769163549174


In [20]:
from mdgp.experiments.uci.fit.fit import evaluate_deep, evaluate_shallow

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)

TLL: 0.2188367189484403, MSE: 0.024650527151646066


In [17]:
from mdgp.experiments.uci.fit.fit import evaluate_deep, evaluate_shallow

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)

TLL: 1.209683798301345, MSE: 0.005304259921531311


In [14]:
from mdgp.experiments.uci.fit.fit import evaluate_deep, evaluate_shallow

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(100):
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)

TLL: 1.0008430227727956, MSE: 0.007140704163331918


In [18]:
from mdgp.experiments.uci.fit.fit import evaluate_deep, evaluate_shallow

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(100):
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)

TLL: 1.2228897539138972, MSE: 0.005064238676492455


In [33]:
from mdgp.experiments.uci.fit.fit import evaluate_deep, evaluate_shallow

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(200):
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)

TLL: 1.2651115829407433, MSE: 0.004955803643329355


: 

In [15]:
from mdgp.experiments.uci.fit.fit import evaluate_deep, evaluate_shallow

with torch.no_grad():
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)
    out = model.likelihood(model(kin8mn.test_x))
    out = projector.inverse(out)
    nlpd = negative_log_predictive_density(out, kin8mn.test_y)
    print(f"NLPD {nlpd.mean().item()}")

TLL: 1.0000001444029862, MSE: 0.007093938590989231
NLPD 0.3235617586172367


In [14]:
with torch.no_grad():
    evaluate_shallow(kin8mn, model, projector, device='cpu', fit_args=fit_args)
    evaluate_deep(kin8mn, model, projector, device='cpu', fit_args=fit_args)

TLL: 1.0026375218917551, MSE: 0.007027994436844125
TLL: 1.0026375218917551, MSE: 0.007027994436844124


In [14]:
model, projector = get_model_and_projector(kin8mn)
train_lfbgs(kin8mn, model, projector, batch_size, num_epochs=40, lr=0.01)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/40 [00:00<?, ?it/s]

[tensor(1.4776, grad_fn=<NegBackward0>),
 tensor(0.9105, grad_fn=<NegBackward0>),
 tensor(0.6691, grad_fn=<NegBackward0>),
 tensor(0.5100, grad_fn=<NegBackward0>),
 tensor(0.4262, grad_fn=<NegBackward0>),
 tensor(0.3034, grad_fn=<NegBackward0>),
 tensor(0.2287, grad_fn=<NegBackward0>),
 tensor(0.1842, grad_fn=<NegBackward0>),
 tensor(0.1534, grad_fn=<NegBackward0>),
 tensor(0.1279, grad_fn=<NegBackward0>),
 tensor(0.0992, grad_fn=<NegBackward0>),
 tensor(0.0627, grad_fn=<NegBackward0>),
 tensor(0.0159, grad_fn=<NegBackward0>),
 tensor(-0.0425, grad_fn=<NegBackward0>),
 tensor(-0.0696, grad_fn=<NegBackward0>),
 tensor(-0.0945, grad_fn=<NegBackward0>),
 tensor(-0.1192, grad_fn=<NegBackward0>),
 tensor(-0.1430, grad_fn=<NegBackward0>),
 tensor(-0.1667, grad_fn=<NegBackward0>),
 tensor(-0.1870, grad_fn=<NegBackward0>),
 tensor(-0.2044, grad_fn=<NegBackward0>),
 tensor(-0.2241, grad_fn=<NegBackward0>),
 tensor(-0.2484, grad_fn=<NegBackward0>),
 tensor(-0.2625, grad_fn=<NegBackward0>),
 tens

In [15]:
evaluate(kin8mn, model, projector)

NLPD: 0.5905679521542057, MSE: 0.21266969401689906


{'nlpd': 0.5905679521542057, 'mse': 0.21266969401689906}

In [9]:
model, projector = get_model_and_projector(kin8mn, 'sphere')
train(kin8mn, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR)
evaluate(kin8mn, model, projector)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
reproduce_results(power, num_epochs=1000, batch_size=kin8mn.train_x.size(0))

-------------------------Reproducing results for power--------------------------

-------------------------------------Run 1--------------------------------------


  0%|          | 0/1000 [00:00<?, ?it/s]

NLPD: 0.014695136020697464, MSE: 0.05854422073080649
-------------------------------------Run 2--------------------------------------


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [88]:
reproduce_results(power, num_epochs=20)

-------------------------Reproducing results for power--------------------------

-------------------------------------Run 1--------------------------------------


100%|██████████| 20/20 [02:20<00:00,  7.04s/it, ELBO=34]  


NLPD: 0.018200022938892634, MSE: 0.0604323361222259
-------------------------------------Run 2--------------------------------------


100%|██████████| 20/20 [02:12<00:00,  6.63s/it, ELBO=34]  


NLPD: 0.020585911675719538, MSE: 0.06111090834631231
-------------------------------------Run 3--------------------------------------


100%|██████████| 20/20 [02:17<00:00,  6.85s/it, ELBO=34]  


NLPD: 0.029410862008512097, MSE: 0.06202288396624321
-------------------------------------Run 4--------------------------------------


100%|██████████| 20/20 [02:17<00:00,  6.89s/it, ELBO=34]  


NLPD: 0.011682565872741057, MSE: 0.05958469777710783
-------------------------------------Run 5--------------------------------------


100%|██████████| 20/20 [02:14<00:00,  6.71s/it, ELBO=34.2]


NLPD: 0.018373846236252843, MSE: 0.06075983045730417
----------------------------------Metrics mean----------------------------------
nlpd    0.019651
mse     0.060782
dtype: float64
----------------------------------Metrics STD-----------------------------------
nlpd    0.006391
mse     0.000895
dtype: float64


Unnamed: 0,nlpd,mse
0,0.0182,0.060432
1,0.020586,0.061111
2,0.029411,0.062023
3,0.011683,0.059585
4,0.018374,0.06076


In [89]:
reproduce_results(concrete, num_epochs=125)

------------------------Reproducing results for concrete------------------------

-------------------------------------Run 1--------------------------------------


100%|██████████| 125/125 [01:26<00:00,  1.44it/s, ELBO=2.44] 


NLPD: 0.3623051474366128, MSE: 0.1097215590579564
-------------------------------------Run 2--------------------------------------


100%|██████████| 125/125 [01:32<00:00,  1.35it/s, ELBO=2.46] 


NLPD: 0.3629351313556932, MSE: 0.10951533423101913
-------------------------------------Run 3--------------------------------------


100%|██████████| 125/125 [01:30<00:00,  1.39it/s, ELBO=2.41] 


NLPD: 0.3790285364454814, MSE: 0.11759354105563866
-------------------------------------Run 4--------------------------------------


100%|██████████| 125/125 [01:28<00:00,  1.41it/s, ELBO=2.46] 


NLPD: 0.37304157369488944, MSE: 0.1139838157709856
-------------------------------------Run 5--------------------------------------


100%|██████████| 125/125 [01:28<00:00,  1.41it/s, ELBO=2.41] 


NLPD: 0.37610820039944604, MSE: 0.11769848257163341
----------------------------------Metrics mean----------------------------------
nlpd    0.370684
mse     0.113703
dtype: float64
----------------------------------Metrics STD-----------------------------------
nlpd    0.007663
mse     0.004018
dtype: float64


Unnamed: 0,nlpd,mse
0,0.362305,0.109722
1,0.362935,0.109515
2,0.379029,0.117594
3,0.373042,0.113984
4,0.376108,0.117698
