In [1]:
import torch 
import gpytorch 
import geometric_kernels.torch
import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from geometric_kernels.spaces import Hypersphere
from mdgp.models.deep_gps import GeometricDeepGPLayer
from mdgp.utils import sphere_meshgrid, sphere_uniform_grid, spherical_harmonic, sphere_random_uniform


torch.set_default_dtype(torch.float64)
pio.templates.default = "plotly_dark"

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using numpy backend


In [2]:
DIM = 2
SPACE = Hypersphere(DIM)
NUM_EIGENFUNCTIONS = 20
OUTPUT_DIMS = 2
NU = 0.5
NUM_INDUCING = 60 
INDUCING_POINTS = sphere_uniform_grid(NUM_INDUCING)
LEARN_INDUCING_LOCATIONS = False
OPTIMIZE_NU = False
FEATURE_MAP = 'deterministic'

layer = GeometricDeepGPLayer(
    space=SPACE,
    num_eigenfunctions=NUM_EIGENFUNCTIONS,
    output_dims=OUTPUT_DIMS,
    inducing_points=INDUCING_POINTS,
    nu=NU,
    feature_map=FEATURE_MAP,
    learn_inducing_locations=LEARN_INDUCING_LOCATIONS,
    optimize_nu=OPTIMIZE_NU,
    whitened_variational_strategy=True, 
)

In [174]:
import math 
from mdgp.utils import cart_to_sph


def spherical_harmonic_torch(x, n=3, m=2): 
    # direct implementation of the spherical harmonic function not using the spherical_harmonic function

    phi, theta = cart_to_sph(x)

   
    Y = torch.zeros_like(theta)
    if m == 0:
        Y = math.sqrt(1 / (4 * torch.pi)) * torch.ones_like(theta)
    elif m > 0:
        Y = math.sqrt(1 / (2 * torch.pi)) * torch.cos(m * phi)
    else:
        Y = math.sqrt(1 / (2 * torch.pi)) * torch.sin(-m * phi)
    P = torch.zeros_like(theta)
    if m == 0:
        P = torch.ones_like(theta)
    elif m > 0:
        P = torch.cos(m * theta)
    else:
        P = torch.sin(-m * theta)
    return math.sqrt((2 * n + 1) / (4 * torch.pi) * math.exp(math.lgamma(n - m + 1) - math.lgamma(n + m + 1))) * P * Y

In [188]:
import torch
import math

def associated_legendre_polynomial(l, m, x):
    """Compute the associated Legendre polynomial of degree l and order m."""
    # Compute the associated Legendre polynomial using recursion
    P = torch.zeros_like(x)
    if m < 0 or m > l or l < 0:
        return P
    pmm = 1.0
    if m > 0:
        somx2 = torch.sqrt((1.0 - x) * (1.0 + x))
        fact = 1.0
        for i in range(1, m+1):
            pmm *= (-fact) * somx2
            fact += 2.0
    if l == m:
        return pmm
    pmmp1 = x * (2 * m + 1) * pmm
    if l == m + 1:
        return pmmp1
    for ll in range(m + 2, l + 1):
        pll = ((2 * ll - 1) * x * pmmp1 - (ll + m - 1) * pmm) / (ll - m)
        pmm = pmmp1
        pmmp1 = pll
    return pll

def real_spherical_harmonic(x, l, m):
    """Compute the real spherical harmonic of degree l and order m."""
    theta = torch.acos(x[..., 2])
    phi = torch.atan2(x[..., 1], x[..., 0])
    # Calculate the associated Legendre polynomial
    legendre_poly = associated_legendre_polynomial(l, abs(m), torch.cos(theta))

    # Compute the normalization factor
    normalization =  math.sqrt(((2 * l + 1) / (4 * math.pi)) * 
                               math.exp(math.lgamma(l - abs(m) + 1)) / 
                               math.exp(math.lgamma(l + abs(m) + 1)))

    # Calculate the real spherical harmonic
    if m > 0:
        return math.sqrt(2) * normalization * torch.cos(m * phi) * legendre_poly
    elif m < 0:
        return math.sqrt(2) * normalization * torch.sin(-m * phi) * legendre_poly
    else:
        return normalization * legendre_poly

# Example usage
l, m = 2, 0
theta = torch.tensor([0.5 * math.pi])  # Polar angle
phi = torch.tensor([0.5 * math.pi])    # Azimuthal angle

Y_lm = real_spherical_harmonic(sphere_uniform_grid(10), l, m)
print(Y_lm)


tensor([ 0.4510,  0.1482, -0.0788, -0.2302, -0.3059, -0.3059, -0.2302, -0.0788,
         0.1482,  0.4510])


In [234]:
torch.set_default_dtype(torch.float64)

In [261]:
test_inputs_scalar = sphere_meshgrid(100, 100)
test_inputs_field = sphere_uniform_grid(1000)

from torch.func import jacfwd, vmap, grad
from spherical_harmonics import SphericalHarmonics

sph_harm = SphericalHarmonics(3, 10)

def target_func(x): 
    # return torch.sin(x[..., 0]) + torch.cos(x[..., 1])
    # return real_spherical_harmonic(x, 3, 2)
    return sph_harm(x)[..., 15].squeeze()


div_operator = lambda f: vmap(grad(f))


def project_onto_tangent_space(points, vectors):
    """
    Project vectors onto the tangent space of a unit sphere at given points.

    :param points: Tensor of shape (N, 3) representing Cartesian coordinates on the unit sphere.
    :param vectors: Tensor of shape (N, 3) representing vectors in the ambient space.
    :return: Tensor of shape (N, 3) representing projected vectors in the tangent space.
    """
    # Normalize the points to lie on the unit sphere
    normalized_points = torch.nn.functional.normalize(points, p=2, dim=1)

    # Compute the component of each vector along the radial direction
    radial_component = torch.sum(vectors * normalized_points, dim=1, keepdim=True)

    # Project vectors onto the tangent space
    tangent_vectors = vectors - radial_component * normalized_points

    return tangent_vectors

In [262]:
jacfwd(target_func)(sphere_uniform_grid(10))

tensor([[[ 1.2974,  0.5936,  5.0648],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [-4.5320,  0.7693,  7.6183],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 9.4008,  3.6100,  9.6278],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
        

In [258]:
test_targets_scalar = target_func(test_inputs_scalar.view(-1, 3)).view(100, 100, -1)
test_targets_field = target_func(test_inputs_field)
div_field_ambient = div_operator(target_func)(test_inputs_field)
div_field = project_onto_tangent_space(test_inputs_field, div_field_ambient)
curl_field = torch.cross(div_field, test_inputs_field, dim=-1)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3000x1 and 3x1)

In [248]:
fig = go.Figure()
x, y, z = test_inputs_scalar.unbind(dim=-1)
fig.add_trace(
    go.Surface(x=x, y=y, z=z, surfacecolor=test_targets_scalar.squeeze())
)
x, y, z = test_inputs_field.unbind(dim=-1)
u, v, w = curl_field.unbind(dim=-1)
fig.add_trace(go.Cone(
    x=x,
    y=y,
    z=z,
    u=u,
    v=v,
    w=w,
    colorscale=['black', 'black'], 
    sizeref= 0.1 / curl_field.abs().max().item(),

))
fig.update_layout(
    height=800,
    width=800,
)
fig.show()

In [252]:
fig = go.Figure()
x, y, z = test_inputs_scalar.unbind(dim=-1)
fig.add_trace(
    go.Surface(x=x, y=y, z=z, surfacecolor=test_targets_scalar.squeeze())
)
x, y, z = test_inputs_field.unbind(dim=-1)
u, v, w = div_field.unbind(dim=-1)
fig.add_trace(go.Cone(
    x=x,
    y=y,
    z=z,
    u=u,
    v=v,
    w=w,
    colorscale=['black', 'black'], 
    sizeref= 0.1 / div_field.abs().max().item(),

))
fig.update_layout(
    height=800,
    width=800,
)
fig.show()

In [114]:
import plotly.graph_objects as go
import numpy as np

# Define a grid of points
x, y, z = np.meshgrid(np.arange(-5, 6, 2),
                      np.arange(-5, 6, 2),
                      np.arange(-5, 6, 2))

# Define a vector field (for example, a simple linear field)
u = x
v = y
w = z

# Create the cone plot
fig = go.Figure(data=go.Cone(x=x.flatten(), y=y.flatten(), z=z.flatten(),
                             u=u.flatten(), v=v.flatten(), w=w.flatten()))

# Update layout
fig.update_layout(
    title='3D Cone Plot',
    scene=dict(
        xaxis_title='X AXIS',
        yaxis_title='Y AXIS',
        zaxis_title='Z AXIS'
    )
)

fig.show()
