In [38]:
import os 
os.environ['GEOMSTATS_BACKEND'] = 'pytorch'

import geomstats._backend as gs 
import geometric_kernels.torch 

import torch 
import gpytorch 
import plotly.io as pio
from plotly import graph_objects as go
from mdgp.kernels import GeometricMaternKernel
from geometric_kernels.spaces import Euclidean
from mdgp.samplers import RFFSampler
from geometric_kernels.kernels.matern_kernel import default_feature_map


torch.set_default_dtype(torch.float64)
pio.templates.default = "plotly_dark"

In [39]:
space = Euclidean(dim=2)

def euclidean_grid(space, num_points=100): 
    s = torch.linspace(-25, 25, num_points)
    xx, yy = torch.meshgrid(s, s)
    points = torch.stack([xx, yy], dim=-1).reshape(-1, 2)
    return points

points = euclidean_grid(space=space, num_points=100)

In [101]:
base_kernel = GeometricMaternKernel(space=space, num_random_phases=3000, nu=1.5, lengthscale=1.0, seed=0, trainable_nu=False)
covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
mean_module = gpytorch.means.ZeroMean()
num_random_phases = 3000

feature_map = default_feature_map(space=space, num=num_random_phases)
rff_sampler = RFFSampler(covar_module=covar_module, mean_module=mean_module, feature_map=feature_map)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [109]:
with torch.no_grad():
    # torch.random.manual_seed(0)
    sample = rff_sampler(points, sample_shape=torch.Size([]))
    c = sample
    x, y = points.view(100, 100, -1).unbind(-1)
    z = torch.zeros_like(x)

fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, surfacecolor=c.view_as(x))])
fig.update_layout(
    height=700, 
    width=700,
)
fig.show()

In [36]:
with torch.no_grad():
    torch.random.manual_seed(0)
    sample = rff_sampler(points, sample_shape=torch.Size([1000]))
    c = sample.var(dim=0)
    x, y = points.view(100, 100, -1).unbind(-1)
    z = torch.zeros_like(x)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()