In [1]:
import os
os.environ['GEOMSTATS_BACKEND'] = 'pytorch'


from torch import Tensor 


import pandas as pd
import torch 
import gpytorch 
import geometric_kernels.torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from geometric_kernels.spaces import Hypersphere
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import DiagLinearOperator, LinearOperator
from mdgp.kernels import GeometricMaternKernel
from tqdm.autonotebook import tqdm 


torch.set_default_dtype(torch.float64)

INFO: Using pytorch backend
  from tqdm.autonotebook import tqdm


In [2]:
from linear_operator.operators import DiagLinearOperator
from gpytorch.distributions import MultivariateNormal


class SphereProjector(torch.nn.Module):
    def __init__(self, b: float = 2.0):
        super().__init__()
        self.b = torch.nn.Parameter(torch.tensor(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y / self.norm
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        norm = self.norm.squeeze(-1) # [..., N, 1] -> [..., N]
        mean = torch.einsum('...ij,...i->...ij', mvn.mean, norm) # [..., N, O] @ [..., N] -> [..., N, O]
        # mean = mvn.mean @ L
        norm = norm.repeat((1, ) * (norm.ndim - 1) + (mean.shape[-1], )) # [..., N] -> [..., N, O]
        # cov = torch.einsum('...ij, ...j -> ...ij', mvn.covariance_matrix, norm) # [..., N * O] @ [..., N * O] -> [..., N, O]
        # cov = torch.einsum('...i, ...ij -> ...ij', norm, cov)
        cov = torch.einsum('...i, ...ij, ...j -> ...ij', norm, mvn.covariance_matrix, norm)
        # cov = torch.einsum('...i,...ij,...j->...ij', mvn.lazy_covariance_matrix, norm, norm) # [..., N] @ [..., N, N] @ [..., N] -> [..., N, N]

        # cov = L @ mvn.lazy_covariance_matrix @ L
        # print(f"Norm shape {norm.shape}")
        # print(f"Mean before {mvn.mean.shape} after {mean.shape}")
        # print(f"Cov before {mvn.covariance_matrix.shape} after {cov.shape}")
        return gpytorch.distributions.MultitaskMultivariateNormal(mean, cov)
        cov = BlockDiagLinearOperator(cov, block_dim=-3)
        return MultitaskMultivariateNormal(mean, covar, interleaved=False)
        return MultivariateNormal(mean=mean, covariance_matrix=cov)

In [3]:
# L_ij K_jk L_kl = Lii K_ij L_jj = Li Kij Lj 

In [4]:
from mdgp.experiments.uci.data.datasets import Kin8mn, Power, Concrete, Energy, UCIDataset
from torch.utils.data import DataLoader


# Datasets 
kin8mn = Kin8mn()
power = Power()
concrete = Concrete()
energy = Energy()

# Other model parameters
batch_shape = torch.Size([])
optimize_nu = False

# Training parameters 
BATCH_SIZE = 256 
LR = 0.01
NUM_EPOCHS = 20

In [5]:
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree
from mdgp.experiments.uci.model.geometric import SHFDeepGP


def get_model_and_projector(dataset: UCIDataset, num_layers: int):
    model = SHFDeepGP(dataset, num_layers=num_layers, optimize_nu=optimize_nu)
    projector = SphereProjector()
    return model, projector


model, projector = get_model_and_projector(kin8mn, 1)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


In [11]:
def manual_forward(model, x):
    return model.layers[0].variational_strategy(x)


def train_step(x, y, model, projector, optimizer, mll) -> float:
    optimizer.zero_grad(set_to_none=True)
    x, y = projector(x, y)
    output = model(x)
    # output = model(x)
    loss = mll(output, y)
    loss.backward()
    optimizer.step()
    return loss.item()


def train(dataset, model, projector, num_epochs=NUM_EPOCHS, lr=LR, batch_size=BATCH_SIZE) -> list[float]: 
    # optimizer and criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, maximize=True)
    mll = gpytorch.mlls.DeepApproximateMLL(gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_y.size(0)))

    # data
    train_loader = DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = 0
        for x_batch, y_batch in train_loader:
            loss = train_step(x=x_batch, y=y_batch, model=model, projector=projector, optimizer=optimizer, mll=mll)
            epoch_loss += loss
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses 


from mdgp.experiments.uci.fit.metrics import mean_squared_error, test_log_likelihood, negative_log_predictive_density
from mdgp.utils import test_psd


def evaluate(dataset: UCIDataset, model: gpytorch.models.deep_gps.DeepGP, projector, device: torch.device) -> dict[str, float]:
    with torch.no_grad():
        test_x, test_y = projector(dataset.test_x), dataset.test_y
        test_x, test_y = test_x.to(device), test_y.to(device)
        test_y_std = dataset.test_y_std.to(device)

        out = model.likelihood(model(test_x))
        out = projector.inverse(out)

        tll = test_log_likelihood(out, test_y, test_y_std)
        mse = mean_squared_error(out, test_y, test_y_std)
        mse_no_std = mean_squared_error(out, test_y)
        nlpd = negative_log_predictive_density(out, test_y)
        metrics = {
            'tll': tll.mean().item(), 
            'mse': mse.mean().item(),
            'mse_no_std': mse_no_std.mean().item(),
            'nlpd': nlpd.mean().item(),
        }
        print(f"TLL: {metrics['tll']}, MSE: {metrics['mse']}, MSE_NO_STD: {metrics['mse_no_std']}, NLPD: {metrics['nlpd']}")
    return metrics 


def reproduce_results(dataset, num_layers, num_runs: int = 5, num_epochs=NUM_EPOCHS, lr=LR, batch_size=BATCH_SIZE):
    print(f"Reproducing results for {dataset.name}".center(80, '-') + '\n')

    metrics = []
    for run in range(num_runs):
        print(f"Run {run + 1}".center(80, '-'))

        torch.random.manual_seed(run)
        model, projector = get_model_and_projector(dataset, num_layers=num_layers)
        train(dataset, model, projector, num_epochs=num_epochs, lr=lr, batch_size=batch_size)
        run_metrics = evaluate(dataset, model, projector, torch.device('cpu'))
        metrics.append(run_metrics)
    df = pd.DataFrame(metrics)

    print("Metrics mean".center(80, '-'))
    print(df.mean())

    print("Metrics STD".center(80, '-'))
    print(df.std())

    return df 

In [7]:
model, _ = get_model_and_projector(kin8mn, 1)
print(model.layers[0].variational_strategy.variational_distribution.mean.shape)
print(model.layers[0].covar_module.outputscale, model.layers[0].covar_module.base_kernel.lengthscale, model.layers[0].covar_module.base_kernel.nu)
print(model.likelihood.noise)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)
torch.Size([1, 210])
tensor([1.], grad_fn=<SoftplusBackward0>) tensor([[[1.]]], grad_fn=<SoftplusBackward0>) tensor([[[1.5000]]])
tensor([1.0000], grad_fn=<AddBackward0>)


In [8]:
model, projector = get_model_and_projector(kin8mn, 1)
train(kin8mn, model, projector, num_epochs=100, lr=0.01, batch_size=1024)
evaluate(kin8mn, model, projector, torch.device('cpu'))

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([10, 819, 1])
torch.Size([10, 819, 1])
TLL: 0.6873664930014585, MSE: 0.015011430036360858, MSE_NO_STD: 0.2154235304189149, NLPD: 0.6374828335317264


{'tll': 0.6873664930014585,
 'mse': 0.015011430036360858,
 'mse_no_std': 0.2154235304189149,
 'nlpd': 0.6374828335317264}

In [16]:
model, projector = get_model_and_projector(energy, 1)
train(energy, model, projector, num_epochs=10000, lr=0.01, batch_size=1024)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/10000 [00:00<?, ?it/s]

[-3.0087345864802697,
 -2.9908352986745603,
 -2.9680879210509836,
 -2.94857272174441,
 -2.9293646082138767,
 -2.90887361714614,
 -2.8882473676181846,
 -2.8683653308153714,
 -2.8490898986788267,
 -2.829699176067688,
 -2.8098956947965927,
 -2.7899749075073497,
 -2.7703083826998203,
 -2.750963419087472,
 -2.7317009356563506,
 -2.7122785786201566,
 -2.6926933769862536,
 -2.6731153144451834,
 -2.653677934732368,
 -2.6343562676443386,
 -2.615016016878882,
 -2.595564539343209,
 -2.5760328996238333,
 -2.5565159160264006,
 -2.5370625845828574,
 -2.517632874559478,
 -2.4981540899129273,
 -2.478604772524354,
 -2.4590313762183795,
 -2.4394925649598744,
 -2.4200025900938753,
 -2.400533486257023,
 -2.3810669349262517,
 -2.361630737647777,
 -2.342279225825902,
 -2.3230489932257252,
 -2.303946556889798,
 -2.2849785506939475,
 -2.2661799408023136,
 -2.2476048742460084,
 -2.2292989843179187,
 -2.2112915746043624,
 -2.193611440715305,
 -2.176298882613698,
 -2.1593994798073766,
 -2.142952307379408,
 -2.12

In [17]:
evaluate(energy, model, projector, torch.device('cpu'))

TLL: -2.422432269292398, MSE: 7.114778272517773, MSE_NO_STD: 0.09392552583764097, NLPD: 0.5164609905225624


{'tll': -2.422432269292398,
 'mse': 7.114778272517773,
 'mse_no_std': 0.09392552583764097,
 'nlpd': 0.5164609905225624}

In [13]:
model, projector = get_model_and_projector(kin8mn, 2)
train(kin8mn, model, projector, num_epochs=1000, lr=0.01, batch_size=1024)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)
The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/1000 [00:00<?, ?it/s]

[-7.615882356559742,
 -7.4274365552521235,
 -7.233812396024564,
 -7.034629896272984,
 -6.842024361414799,
 -6.627397020930417,
 -6.428718171548478,
 -6.213550687235886,
 -6.000786016941789,
 -5.797035328893804,
 -5.577490223164378,
 -5.365948677582851,
 -5.147332714572608,
 -4.922737475603662,
 -4.6947161743092165,
 -4.493019937181473,
 -4.282501983160966,
 -4.072683321709395,
 -3.878073664527402,
 -3.6948650655865,
 -3.4732192565266153,
 -3.2998189831509817,
 -3.113307927222764,
 -2.93247396325518,
 -2.756947702422385,
 -2.61103551999795,
 -2.452297512861447,
 -2.3185483815061962,
 -2.1873146880427754,
 -2.0657699051155594,
 -1.9400449582576575,
 -1.8624885022804567,
 -1.8049740284741047,
 -1.63351238629271,
 -1.5557723631621483,
 -1.422952223799315,
 -1.3360122337686464,
 -1.1458271581498412,
 -1.0392073915429387,
 -0.9052864650688649,
 -0.7106356256858577,
 -0.5660039715865796,
 -0.34047883935775447,
 -0.12610966844351537,
 0.01867046802069811,
 0.18312168764965667,
 0.3397354136556

In [14]:
with torch.no_grad():
    print(evaluate(kin8mn, model, projector, torch.device('cpu')))

torch.Size([10, 819])
TLL: -0.8735516225228949, MSE: 0.06959825155962196, MSE_NO_STD: 0.9987789987790006, NLPD: 1.0209302512883451
{'tll': -0.8735516225228949, 'mse': 0.06959825155962196, 'mse_no_std': 0.9987789987790006, 'nlpd': 1.0209302512883451}


In [45]:
with torch.no_grad():
    out = model.likelihood(model(projector(kin8mn.test_x)))
    out = projector.inverse(out)
    nlpd = negative_log_predictive_density(out[0], kin8mn.test_y)

In [48]:
test_psd(out[0].lazy_covariance_matrix)

In [52]:
reproduce_results(kin8mn, 1, num_epochs=20000, batch_size=kin8mn.train_x.size(0))

-------------------------Reproducing results for kin8nm-------------------------

-------------------------------------Run 1--------------------------------------
The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 660, which includes all spherical harmonics up to level 5 (exclusive)


  0%|          | 0/20000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
reproduce_results(power, 1, num_epochs=1000, batch_size=kin8mn.train_x.size(0))