In [5]:
import geometric_kernels.torch 
import torch 
import os
import pandas as pd 
import torch 
from torch import Tensor
from torch.utils.data import TensorDataset, Dataset
from linear_operator.operators import DiagLinearOperator
from gpytorch.distributions import MultivariateNormal
from geometric_kernels.kernels import MaternGeometricKernel
from geometric_kernels.spaces import Hypersphere 


class UCIDataset:

    UCI_BASE_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/'

    def __init__(self, name: str, url: str, path: str = '../../data/uci/', seed: int | None = None): 
        self.name = name 
        self.url = url
        self.path = path 
        self.csv_path = os.path.join(self.path, self.name + '.csv')

        # Set generator if seed is provided.
        self.generator = torch.Generator()
        if seed is not None: 
            self.generator.manual_seed(seed)

        # Load, shuffle, split, and normalize data. TODO except for load, these don't need to be object methods. 
        # We keep the standard deviation of the test set for log-likelihood evaluation.
        x, y = self.load_data()
        x, y = self.shuffle(x, y, generator=self.generator)     
        self.train_x, self.train_y, self.test_x, self.test_y = self.split(x, y)
        self.test_y_std = self.test_y.std(dim=0, keepdim=True)
        self.train_x, self.train_y, self.test_x, self.test_y = map(
            self.normalize, (self.train_x, self.train_y, self.test_x, self.test_y))

    @property
    def dimension(self) -> int:
        return self.train_x.shape[-1]

    @property 
    def train_dataset(self) -> Dataset:
        return TensorDataset(self.train_x, self.train_y)
    
    @property
    def test_dataset(self) -> Dataset:
        return TensorDataset(self.test_x, self.test_y)


    def read_data(self) -> tuple[Tensor, Tensor]:
        xy = torch.from_numpy(pd.read_csv(self.csv_path).values)
        return xy[:, :-1], xy[:, -1]

    def download_data(self) -> None:
        NotImplementedError

    def load_data(self, overwrite: bool = False) -> tuple[Tensor, Tensor]:
        if overwrite or not os.path.isfile(self.csv_path):
            self.download_data()
        return self.read_data()

    def normalize(self, x: Tensor) -> Tensor:
        return (x - x.mean(dim=0)) / x.std(dim=0, keepdim=True)
    
    def shuffle(self, x: Tensor, y: Tensor, generator: torch.Generator) -> tuple[Tensor, Tensor]:
        perm_idx = torch.randperm(x.size(0), generator=generator)
        return x[perm_idx], y[perm_idx]
    
    def split(self, x: Tensor, y: Tensor, test_size: float = 0.1) -> tuple[Tensor, Tensor, Tensor, Tensor]: 
        """
        Split the dataset into train and test sets.
        """
        split_idx = int(test_size * x.size(0))
        return x[split_idx:], y[split_idx:], x[:split_idx], y[:split_idx]


class Kin8mn(UCIDataset):

    DEFAULT_URL = 'https://raw.githubusercontent.com/liusiyan/UQnet/master/datasets/UCI_datasets/kin8nm/dataset_2175_kin8nm.csv'

    def __init__(self, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None, url: str = DEFAULT_URL):
        super().__init__(name='kin8nm', url=Kin8mn.DEFAULT_URL, path=path, seed=seed)
        self.url = url 

    def download_data(self) -> None:
        df = pd.read_csv(self.url)
        os.makedirs(self.path, exist_ok=True)
        df.to_csv(self.csv_path, index=False)


class SphereProjector(torch.nn.Module):
    def __init__(self):
        super().__init__()
        b = 1. + torch.exp(torch.randn(torch.Size([])))
        self.register_parameter('b', torch.nn.Parameter(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)


In [6]:
torch.set_default_dtype(torch.float64)
torch.random.manual_seed(1)


dataset = Kin8mn()
space = Hypersphere(dim=dataset.dimension)


kernel = MaternGeometricKernel(space=space, num=6, normalize=True)
params = kernel.init_params()
params['lengthscale'] = torch.tensor(0.1)
params['nu'] = torch.tensor(2.5)
projector = SphereProjector()


with torch.no_grad():
    x = projector(dataset.test_x)
    K = kernel.K(params, x)
    print(torch.linalg.eigvalsh(K)[:50])

tensor([0.0005, 0.0006, 0.0007, 0.0008, 0.0008, 0.0008, 0.0009, 0.0009, 0.0009,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0011, 0.0011, 0.0011, 0.0012, 0.0012,
        0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0013, 0.0014, 0.0014, 0.0014,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0017, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0021])
