# This notebook reproduces results from the spherical harmonics paper on UCI datasets

In [1]:
from torch import Tensor 


import torch 
import gpytorch 
import geometric_kernels.torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from geometric_kernels.spaces import Hypersphere
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import DiagLinearOperator, LinearOperator
from mdgp.kernels import GeometricMaternKernel
from tqdm.autonotebook import tqdm 
from gpytorch.metrics import negative_log_predictive_density, mean_squared_error


torch.set_default_dtype(torch.float64)
from mdgp.variational.spherical_harmonic_features.utils import * 
from mdgp.variational.spherical_harmonic_features_variational_strategy import SphericalHarmonicFeaturesVariationalStrategy

INFO: Using numpy backend
  from tqdm.autonotebook import tqdm


In [2]:
class gpytorchSGP(gpytorch.models.ApproximateGP):
    def __init__(self, max_ell, d, max_ell_prior, epsilon_sigma=1.0, kappa=1.0, nu=2.5, sigma=1.0, batch_shape=torch.Size([]), jitter_val=1e-6, optimize_nu: bool = True):
        m = total_num_harmonics(max_ell, d)
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val)
        super().__init__(variational_strategy=variational_strategy)

        # constants 
        self.jitter_val = jitter_val
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
        self.mean_module = gpytorch.means.ConstantMean()
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

        # prior hyperparams 
        self.covar_module.outputscale = sigma ** 2
        self.likelihood.noise = epsilon_sigma ** 2

    def forward(self, x) -> MultivariateNormal:
        p_sigma = self.covar_module(x)
        p_mu = self.mean_module(x)
        return MultivariateNormal(p_mu, p_sigma)

In [62]:
from linear_operator.operators import DiagLinearOperator
from gpytorch.distributions import MultivariateNormal


class SphereProjector(torch.nn.Module):
    def __init__(self, b: float = 1.0):
        super().__init__()
        self.b = torch.nn.Parameter(torch.tensor(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y.squeeze(-1) / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)
    

class SphereProjectorNoY(torch.nn.Module):
    def __init__(self, b: float = 1.0):
        super().__init__()
        self.b = torch.nn.Parameter(torch.tensor(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        return mvn
    

class StereographicProjector(torch.nn.Module):
    """
    Projects a plane R^n to S^n-1 using the stereographic projection.
    
    """

    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        s_squared = x.square().sum(dim=-1, keepdim=True)
        proj_x0 = (s_squared - 1) / (s_squared + 1)
        proj_xi = 2 * x / (s_squared + 1)
        proj_x = torch.cat([proj_x0, proj_xi], dim=-1)
        if y is None:
            return proj_x
        return proj_x, y

    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        return mvn

# UCI data

In [63]:
from mdgp.experiments.uci.data.datasets import Kin8mn, Power, Concrete, Energy, UCIDataset

### Fixed parameters as in the spherical harmonics paper

In [64]:
# Datasets 
kin8mn = Kin8mn()
power = Power()
concrete = Concrete()

# Variational parameters
dimension_to_num_harmonics_variational = {
    4: 336,
    6: 294,
    8: 210,
}

# Prior parameters 
"""
NOTE 
- It is difficult to say what number of spherical harmonics was used for the prior.
  I set it to be the same as the number of inducing variables.
- It is also difficult to say what lengthscale initialisation was used in the paper. 
  I set it to 1.0 for now, although a lower number, e.g. 0.001, would likely be better for higher dimensions.
"""
dimension_to_num_harmonics_prior = {
    # 4: 336,
    # 6: 294,
    # 8: 210, 
    4: 336, 
    6: 1000,
    8: 625,
}
nu = 1.5
optimize_nu = False
kappa = 1.0

# Other model parameters
batch_shape = torch.Size([])

# Training parameters 
batch_size = 256 
LR = 0.01
NUM_EPOCHS = 20

In [65]:
from mdgp.variational.spherical_harmonic_features.utils import num_spherical_harmonics_to_num_levels


def get_model_and_projector(dataset: UCIDataset, proj: str = 'sphere'):
    sphere_dimension = dataset.dimension

    # number of levels for variational inference 
    num_spherical_harmonics = dimension_to_num_harmonics_variational[dataset.dimension]
    max_ell, _ = num_spherical_harmonics_to_num_levels(num_spherical_harmonics, sphere_dimension)

    # number of levels for prior
    num_spherical_harmonics_prior = dimension_to_num_harmonics_prior[dataset.dimension]
    max_ell_prior, _ = num_spherical_harmonics_to_num_levels(num_spherical_harmonics_prior, sphere_dimension)

    model = gpytorchSGP(max_ell=max_ell, d=sphere_dimension, max_ell_prior=max_ell_prior, kappa=kappa, nu=nu, optimize_nu=optimize_nu, batch_shape=batch_shape)
    if proj == 'sphere':
        projector = SphereProjector()
    elif proj == 'no_y':
        projector = SphereProjectorNoY()
    else:
        projector = StereographicProjector()
    return model, projector


model, projector = get_model_and_projector(kin8mn)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


In [66]:
from torch.utils.data import DataLoader

In [67]:
def train_step(x, y, model, projector, optimizer, mll) -> float:
    optimizer.zero_grad(set_to_none=True)
    x, y = projector(x, y)
    # print(x.shape, y.shape)
    output = model(x)
    loss = mll(output, y.squeeze(-1))
    # print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()


def train(dataset, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR, optimize_projector: bool = False) -> list[float]: 
    # optimizer and criterion
    parameters = [{
        'params': model.parameters()
    }]
    if optimize_projector:
        parameters.append({
            'params': projector.parameters()
        })
    optimizer = torch.optim.Adam(parameters, lr=lr, maximize=True)
    mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_y.size(0))

    # data
    train_loader = DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = 0
        for x_batch, y_batch in train_loader:
            loss = train_step(x=x_batch, y=y_batch, model=model, projector=projector, optimizer=optimizer, mll=mll)
            epoch_loss += loss
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses 


def evaluate(dataset, model, projector):
    with torch.no_grad():
        test_x, test_y = projector(dataset.test_x), dataset.test_y
        test_y = test_y.squeeze(-1)
        out = model.likelihood(model(test_x))
        out = projector.inverse(out)
        nlpd = negative_log_predictive_density(out, test_y)
        mse = mean_squared_error(out, test_y)
        metrics = {
            'nlpd': nlpd.item(), 
            'mse': mse.item(),
        }
        print(f"NLPD: {metrics['nlpd']}, MSE: {metrics['mse']}")
    return metrics 


def reproduce_results(dataset, batch_size, num_runs: int = 5, num_epochs=NUM_EPOCHS, lr=LR):
    print(f"Reproducing results for {dataset.name}".center(80, '-') + '\n')

    metrics = []
    for run in range(num_runs):
        print(f"Run {run + 1}".center(80, '-'))

        torch.random.manual_seed(run)
        model, projector = get_model_and_projector(dataset)
        train(dataset, model, projector, num_epochs=num_epochs, lr=lr, batch_size=batch_size)
        # print(
        #     model.covar_module.base_kernel.lengthscale, 
        #     model.covar_module.base_kernel.nu,
        #     model.covar_module.outputscale,
        #     model.likelihood.noise,
        #     model.variational_strategy._variational_distribution.variational_mean,
        #     model.variational_strategy._variational_distribution.chol_variational_covar,
        # )
        run_metrics = evaluate(dataset, model, projector)
        metrics.append(run_metrics)
    df = pd.DataFrame(metrics)

    print("Metrics mean".center(80, '-'))
    print(df.mean())

    print("Metrics STD".center(80, '-'))
    print(df.std())

    return df 

In [50]:
model, projector = get_model_and_projector(kin8mn, 'sphere')
train(kin8mn, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR)
evaluate(kin8mn, model, projector)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/20 [00:00<?, ?it/s]

RuntimeError: grad can be implicitly created only for scalar outputs

In [60]:
projector.b

Parameter containing:
tensor(0.9604, requires_grad=True)

In [71]:
kin8mn = Kin8mn()
model, projector = get_model_and_projector(kin8mn, 'no_y')
print(projector)
projector.b = torch.nn.Parameter(torch.tensor(2.0))
train(kin8mn, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR, optimize_projector=True)
evaluate(kin8mn, model, projector)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)
SphereProjectorNoY()


  0%|          | 0/20 [00:00<?, ?it/s]

NLPD: 0.6415671221120341, MSE: 0.21345969458131758


{'nlpd': 0.6415671221120341, 'mse': 0.21345969458131758}

In [72]:
projector.b

Parameter containing:
tensor(0.9711, requires_grad=True)

In [56]:
model, projector = get_model_and_projector(kin8mn, 'stereo')
kin8mn = Kin8mn()
kin8mn.train_x *= 0.3
kin8mn.test_x *= 0.3
train(kin8mn, model, projector, batch_size, num_epochs=NUM_EPOCHS, lr=LR)
evaluate(kin8mn, model, projector)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/20 [00:00<?, ?it/s]

NLPD: 0.5998450905329317, MSE: 0.198891639633125


{'nlpd': 0.5998450905329317, 'mse': 0.198891639633125}

In [42]:
evaluate(kin8mn, model, projector)

NLPD: 0.8433933326422332, MSE: 0.31481791472424014


{'nlpd': 0.8433933326422332, 'mse': 0.31481791472424014}

In [15]:
from datasets_dsvi import Energy as EnergyDSVI
DTYPE = torch.get_default_dtype()

data = EnergyDSVI().get_data()
X, Y, Xs, Ys, Y_std = [data[_] for _ in ['X', 'Y', 'Xs', 'Ys', 'Y_std']]
x, y = torch.from_numpy(X).to(DTYPE), torch.from_numpy(Y).to(DTYPE).squeeze(-1)
test_x, test_y = torch.from_numpy(Xs).to(DTYPE), torch.from_numpy(Ys).to(DTYPE).squeeze(-1)
test_y_std = torch.from_numpy(Y_std).to(DTYPE).squeeze(-1)

Normalizing X with mean [[7.63950796e-01 6.71782200e+02 3.18748191e+02 1.76517004e+02
  5.25759768e+00 3.47756874e+00 2.32995658e-01 2.82344428e+00]] and std [[ 0.10901257 89.79381901 41.44677061 45.98855734  1.74867228  1.16290081
   0.13683384  1.56965762]]
Normalizing Y with mean [[22.33689725]] and std [[9.9405047]]


In [10]:
model, projector = get_model_and_projector(kin8mn)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
elbo = gpytorch.mlls.VariationalELBO(model.likelihood, model, y.size(0))


for _ in (pbar := tqdm(range(1000), desc='Epochs')):
    optimizer.zero_grad()
    x_batch, y_batch = projector(x, y)
    output = model(x_batch)
    # output = projector.inverse(output)
    loss = elbo(output, y_batch)
    loss.backward()
    optimizer.step()
    pbar.set_postfix({'ELBO': loss.item()})

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

In [16]:
with torch.no_grad():
    test_x_proj = projector(test_x)
    out = model(test_x_proj)
    out = projector.inverse(out)
    mean = out.mean
    print((test_y_std ** 2 * (mean - test_y) ** 2).mean().sqrt())
    print(mean_squared_error(out, test_y, test_y_std).item())
    print(negative_log_predictive_density(out, test_y).mean().item())

tensor(2.1122)
0.045148396017277985
12.486459883835613


In [12]:
reproduce_results(power, num_epochs=1000, batch_size=kin8mn.train_x.size(0))

-------------------------Reproducing results for power--------------------------

-------------------------------------Run 1--------------------------------------


  0%|          | 0/1000 [00:00<?, ?it/s]

NLPD: 0.014695136020697464, MSE: 0.05854422073080649
-------------------------------------Run 2--------------------------------------


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [88]:
reproduce_results(power, num_epochs=20)

-------------------------Reproducing results for power--------------------------

-------------------------------------Run 1--------------------------------------


100%|██████████| 20/20 [02:20<00:00,  7.04s/it, ELBO=34]  


NLPD: 0.018200022938892634, MSE: 0.0604323361222259
-------------------------------------Run 2--------------------------------------


100%|██████████| 20/20 [02:12<00:00,  6.63s/it, ELBO=34]  


NLPD: 0.020585911675719538, MSE: 0.06111090834631231
-------------------------------------Run 3--------------------------------------


100%|██████████| 20/20 [02:17<00:00,  6.85s/it, ELBO=34]  


NLPD: 0.029410862008512097, MSE: 0.06202288396624321
-------------------------------------Run 4--------------------------------------


100%|██████████| 20/20 [02:17<00:00,  6.89s/it, ELBO=34]  


NLPD: 0.011682565872741057, MSE: 0.05958469777710783
-------------------------------------Run 5--------------------------------------


100%|██████████| 20/20 [02:14<00:00,  6.71s/it, ELBO=34.2]


NLPD: 0.018373846236252843, MSE: 0.06075983045730417
----------------------------------Metrics mean----------------------------------
nlpd    0.019651
mse     0.060782
dtype: float64
----------------------------------Metrics STD-----------------------------------
nlpd    0.006391
mse     0.000895
dtype: float64


Unnamed: 0,nlpd,mse
0,0.0182,0.060432
1,0.020586,0.061111
2,0.029411,0.062023
3,0.011683,0.059585
4,0.018374,0.06076


In [89]:
reproduce_results(concrete, num_epochs=125)

------------------------Reproducing results for concrete------------------------

-------------------------------------Run 1--------------------------------------


100%|██████████| 125/125 [01:26<00:00,  1.44it/s, ELBO=2.44] 


NLPD: 0.3623051474366128, MSE: 0.1097215590579564
-------------------------------------Run 2--------------------------------------


100%|██████████| 125/125 [01:32<00:00,  1.35it/s, ELBO=2.46] 


NLPD: 0.3629351313556932, MSE: 0.10951533423101913
-------------------------------------Run 3--------------------------------------


100%|██████████| 125/125 [01:30<00:00,  1.39it/s, ELBO=2.41] 


NLPD: 0.3790285364454814, MSE: 0.11759354105563866
-------------------------------------Run 4--------------------------------------


100%|██████████| 125/125 [01:28<00:00,  1.41it/s, ELBO=2.46] 


NLPD: 0.37304157369488944, MSE: 0.1139838157709856
-------------------------------------Run 5--------------------------------------


100%|██████████| 125/125 [01:28<00:00,  1.41it/s, ELBO=2.41] 


NLPD: 0.37610820039944604, MSE: 0.11769848257163341
----------------------------------Metrics mean----------------------------------
nlpd    0.370684
mse     0.113703
dtype: float64
----------------------------------Metrics STD-----------------------------------
nlpd    0.007663
mse     0.004018
dtype: float64


Unnamed: 0,nlpd,mse
0,0.362305,0.109722
1,0.362935,0.109515
2,0.379029,0.117594
3,0.373042,0.113984
4,0.376108,0.117698
