import geometric_kernels.torch 
import torch 
from torch import Tensor 
from mdgp.kernels import GeometricMaternKernel
from torch.func import jacfwd, vmap
from geometric_kernels.spaces import Hypersphere
from torch import nn 
from mdgp.utils import sphere_uniform_grid

In [84]:
space = Hypersphere(2)
# NOTE need to figure out how to support likelihood samples 
base_kernel = GeometricMaternKernel(space, batch_shape=None, num_eigenfunctions=20)

In [85]:
def flatten_matrix_valued_kernel(K: Tensor) -> Tensor: 
    """
    K: (..., N, M, D, D)
    return: (..., D * N, D * M)
    """
    return torch.cat(
        [torch.cat([block for block in row_blocks.unbind(-1)], dim=-2)
        for row_blocks in K.unbind(-1)], dim=-1
    )


def unflatten_matrix_valued_kernel(K: Tensor, D: int) -> Tensor: 
    """
    K: (..., D * N, D * M)
    return: (..., N, M, D, D)
    """
    return torch.stack(
        [torch.stack([block for block in row_blocks.tensor_split(D, dim=-2)], dim=-1)
         for row_blocks in K.tensor_split(D, dim=-1)], dim=-1
    )


def test_flatten_matrix_valued_kernel(): 
    K = torch.randn(2, 5, 7, 3, 3)
    K_flat = flatten_matrix_valued_kernel(K)
    for i in range(3): 
        for j in range(3): 
            assert torch.all(K[..., i, j] == K_flat[..., i * 5: (i + 1) * 5, j * 7: (j + 1) * 7])


def test_unflatten_matrix_valued_kernel():
    K = torch.randn(2, 5, 7, 3, 3)
    K_flat = flatten_matrix_valued_kernel(K)
    K_unflat = unflatten_matrix_valued_kernel(K_flat, 3)
    assert torch.all(K_unflat == K)


test_flatten_matrix_valued_kernel()
test_unflatten_matrix_valued_kernel()

In [86]:
def broadcast_batch_shapes(*tensors: Tensor) -> torch.Size:
    """
    Given a list of tensors, returns the broadcasted batch shape.
    """
    return torch.broadcast_shapes(*[t.shape[:-2] for t in tensors])


def broadcast_batch_tensors(*tensors: Tensor) -> list[Tensor]:
    """
    Given a list of tensors, return a list of tensors with the same batch shape (if broadcastable).
    """
    batch_shape = broadcast_batch_shapes(*tensors)
    return [t.broadcast_to(*batch_shape, *t.shape[-2:]) for t in tensors]

In [87]:
def chart(sph: Tensor) -> Tensor: 
    """
    sph: (..., 2) with first dim longitude and second dim latitude
    return: (..., 3)
    """
    lon, lat = sph.unbind(-1)

    cos_lat = torch.cos(lat)
    x = cos_lat * torch.cos(lon)
    y = cos_lat * torch.sin(lon)
    z = torch.sin(lat)
    return torch.stack((x, y, z), dim=-1)


def tangent_basis_at(sph: Tensor) -> Tensor: 
    """
    sph: (2,)
    return: (3, 2)
    """
    f = jacfwd(chart)
    return f(sph)


def tangent_basis_no_batch(sph: Tensor) -> Tensor: 
    """
    sph: (N, 2)
    return: (N, 3, 2)
    """
    f = vmap(tangent_basis_at)
    return f(sph)


def tangent_basis_batch(sph: Tensor) -> Tensor:
    """
    sph: (..., N, 2)
    return: (..., N, 3, 2)
    """
    f = vmap(tangent_basis_no_batch)
    return f(sph)


def inverse_chart(car):
    # single coordinate, non-pole
    x, y, z = car.unbind(-1)
    theta = torch.arccos(z)
    phi = torch.atan2(y, x)
    return torch.stack((theta, phi), dim=-1)


def car_to_sph(car: Tensor, epsilon: float = 1e-6) -> Tensor:
    # car = jnp.atleast_2d(car)
    # assert car.ndim == 2 and car.shape[1] == 3
    north_pole = (car[:, 2] > 1 - epsilon)
    south_pole = (-car[:, 2] > 1 - epsilon)
    poles = north_pole | south_pole
    if poles.any():
        sph = torch.empty((*car.shape[:-1], 2), dtype=car.dtype)
        sph[~poles] = inverse_chart(car[~poles])
        sph[north_pole] = torch.tensor([torch.pi / 2, 0])
        sph[south_pole] = torch.tensor([-torch.pi / 2, 0])
    else:
        sph = inverse_chart(car)
    # some numerical issues can lead to NaNs when y=0, remediate here
    greenwich = (abs(car[:, 1]) < epsilon) & (~poles)
    if greenwich.any():
        sph[greenwich, 1] = torch.where(car[greenwich, 0] > 0, 0., torch.pi)
    return sph

# TODO 
[X] Project the gradient onto the coordinate frame of the sphere. <br>
[ ] Implement a single-layer GP with the HodgeMaternKernel <br>
[ ] Run single-layer GP on the wind dataset 

In [88]:
class CurlFreeHodgeMaternKernel(nn.Module): 
    def __init__(self, scalar_matern_kernel): 
        super().__init__()
        self.scalar_matern_kernel = scalar_matern_kernel

    def scalar_matern_kernel_at(self, x1, x2) -> Tensor:
        """
        This takes in a single point x1 and x2 i.e.
        x1.shape = (3,)
        x2.shape = (3,)
        """
        return self.scalar_matern_kernel(x1.unsqueeze(0), x2.unsqueeze(0)).evaluate().squeeze()
    
    def hodge_at(self, x1, x2) -> Tensor: 
        """
        x1.shape = (3,)
        x2.shape = (3,)
        """
        f = jacfwd(jacfwd(self.scalar_matern_kernel_at, argnums=0), argnums=1)
        return f(x1, x2)

    def hodge_no_batch(self, x1, x2): 
        """
        x1.shape = (N, 3)
        x2.shape = (M, 3)
        """
        f = vmap(vmap(self.hodge_at, in_dims=(None, 0)), in_dims=(0, None))
        return f(x1, x2)
    
    def hodge_batch(self, x1, x2): 
        """
        x1.shape = (B, N, 3)
        x2.shape = (B, M, 3)
        """
        f = vmap(self.hodge_no_batch)
        return f(x1, x2)
    
    def tangent_basis(self, sph: Tensor) -> Tensor: 
        """
        sph.shape = (..., 2)
        return: (..., 3, 2)
        """
        if sph.ndim == 1:
            return tangent_basis_at(sph)
        if sph.ndim == 2:
            return tangent_basis_no_batch(sph)
        if sph.ndim == 3:
            return tangent_basis_batch(sph)
        raise ValueError(f"sph must have at most 3 dimensions, got {sph.ndim}")

    def ambient_to_tangent_kernel(self, K_ambient: Tensor, x1: Tensor, x2: Tensor) -> Tensor: 
        """
        K_ambient.shape = (..., N, M, 3, 3)
        x1.shape = (..., N, 3)
        x2.shape = (..., M, 3)
        return: (..., N, M, 2, 2)
        """
        x1_sph = car_to_sph(x1)
        x2_sph = car_to_sph(x2)

        # Project kernel onto S^2 tangent space 
        tangent_basis_x1 = self.tangent_basis(x1_sph) # (..., N, 3, 2)
        tangent_basis_x2 = self.tangent_basis(x2_sph) # (..., M, 3, 2)
        # (..., N, 3, 2) @ (..., N, M, 3, 3) @ (..., M, 3, 2) -> (..., N, M, 2, 2)
        K_tangent = torch.einsum('...nji, ...nmjk, ...mkl -> ...nmil', tangent_basis_x1, K_ambient, tangent_basis_x2)
        return K_tangent

    def forward(self, x1, x2) -> Tensor:
        """
        x1.shape = (..., N, 3)
        x2.shape = (..., M, 3)
        return: (..., N, M, 3, 3)
        """
        # TODO Currently we don't support broadcasting over batch dimensions. 
        # This introduces some redundant computation. 
        x1, x2 = broadcast_batch_tensors(x1, x2)

        # Get kernel with cartesian gradient (embedded field kernel)
        # K_cart.shape = (..., N, M, 3, 3)
        if x1.ndim == 1: 
            K_cart = self.hodge_at(x1, x2)
        if x1.ndim == 2:
            K_cart = self.hodge_no_batch(x1, x2)
        if x1.ndim == 3:
            K_cart = self.hodge_batch(x1, x2)
        
        # Project kernel onto S^2 tangent space 
        K_tangent = self.ambient_to_tangent_kernel(K_cart, x1, x2)
        return K_tangent 

    def __call__(self, x1, x2) -> Tensor: 
        """
        x1.shape = (..., N, 3)
        x2.shape = (..., M, 3)
        return: (..., 3 * N, 3 * M)
        """
        K = super().__call__(x1, x2)
        return flatten_matrix_valued_kernel(K)

In [93]:
nx = 400
ny = 400
num_batches = 2
x = sphere_uniform_grid(nx * num_batches).reshape(num_batches, nx, 3)
y = sphere_uniform_grid(ny * num_batches).reshape(num_batches, ny, 3)
hodge_kernel = CurlFreeHodgeMaternKernel(base_kernel)

In [94]:
hodge_kernel(x, y).shape

torch.Size([2, 800, 800])

### Shape requirements for a kernel in a deep GP 
input shape [S, N, D] or [N, D]

In [117]:
import jax 
import jax.numpy as jnp 
from jax.scipy.special import lpmn_values

# KERNEL FUNCTIONS

def array(els):
    return jnp.array(els, dtype='float64')


MAX_ELL = 35

@jax.jit
def legendre_values(x, y):
    # x, y in cartesian coordinates
    legendre_vals = lpmn_values(MAX_ELL, MAX_ELL, jnp.dot(x, y)[None], False)
    legendre_vals = jnp.squeeze(legendre_vals[0, :, :])
    return legendre_vals


@jax.jit 
def lambd(ell: int) -> int: 
    return ell * (ell + 1)


@jax.jit
def phi(kappa: float, nu: float, lam: float) -> float: 
    return jnp.power(2 * nu / kappa + lam, -nu - 1)


@jax.jit
def legendre_tilde_constant(kappa: float, nu: float, ell: int) -> float:
    lambd_ell = lambd(ell)
    return (2 * ell + 1) / (4 * jnp.pi * lambd_ell) * phi(kappa, nu, lambd_ell)


@jax.jit 
def legendre_tilde_values(x, y, kappa: float = 1.0, nu: float = 2.5):
    legendre_vals = legendre_values(x, y)[1:]
    return jnp.multiply(
        legendre_vals, 
        array([legendre_tilde_constant(kappa, nu, ell) for ell in jnp.arange(1, MAX_ELL + 1)])
    )


@jax.jit 
def dd_legendre_tilde_vals(x, y): 
    dx = jax.jacfwd(legendre_values, argnums=0)(x, y)[1:]
    dy = jax.jacfwd(legendre_values, argnums=1)(x, y)[1:]
    dd = jnp.einsum('ij,ik->ijk', dx, dy) # batched outer product 
    return dd 


@jax.jit
def hodge_new(x, y, kappa: float = 1.0, nu: float = 2.5):
    # x, y in cartesian coordinates
    # dd_legendre_vals = jax.jacfwd(jax.jacfwd(legendre_values, argnums=0), argnums=1)(x, y)[1:]
    dx = jax.jacfwd(legendre_tilde_values, argnums=0)(x, y, kappa, nu)
    dy = jax.jacfwd(legendre_tilde_values, argnums=1)(x, y, kappa, nu)
    return jnp.einsum('ij, ik -> jk', dx, dy)




# The first implementation is taken directly from the code for Intrinsic Gaussian Vector Fields on Manifolds. 
# The second implementation is equivalent to the first.     

@jax.jit
def hodge_matern_k(x, y, kappa: float = 1.0, nu: float = 2.5):
    """
    Unnormalized hodge matern kernel on the sphere. 
    """
    # x, y in cartesian coordinates
    dd_legendre_vals = jax.jacfwd(jax.jacfwd(legendre_values, argnums=0), argnums=1)(x, y)[1:]
    # d term
    dd = jnp.multiply(
        dd_legendre_vals,
        array([
            jnp.power(2 * nu / kappa + ell * (ell + 1), -nu - 1) * (2 * ell + 1) / (4 * jnp.pi * ell * (ell + 1)) # legendre_tilde_constant
            for ell in jnp.arange(1, MAX_ELL + 1)])[:, None, None]
    ).sum(axis=0)
    # vector k
    vk = dd
    return vk


@jax.jit 
def scalar_matern_k(x, y, kappa: float = 1.0, nu: float = 2.5): 
    """
    Unnormalized matern kernel on the sphere. 
    """
    legendre_vals = legendre_values(x, y)[1:]
    return jnp.multiply(
        legendre_vals, 
        array([legendre_tilde_constant(kappa, nu, ell) for ell in jnp.arange(1, MAX_ELL + 1)])
    ).sum(axis=0)


@jax.jit
def hodge_matern_k_equivalent(x, y, kappa: float = 1.0, nu: float = 2.5):
    """
    Unnormalized hodge matern kernel on the sphere. Equivalent to hodge_matern_k. 

    How does a sum of outer products of gradients relate to the second order partial derivatives? 
    """
    return jax.jacfwd(jax.jacfwd(scalar_matern_k, argnums=0), argnums=1)(x, y, kappa, nu)