Visualize exponential coordinates ($r = \lambda \theta$) for $SO(3)$,
but only show a slice of space by setting $\lambda_z = 0$

In [None]:
import einops
import numpy as np
from numpy.linalg import norm
from PIL import Image

In [None]:
x = np.zeros((10, 3, 3))
np.trace(x, axis1=-2, axis2=-1).shape

In [None]:
# Generic SO(3).


def wrap_angle(th):
    """Wrap angle onto [-pi, pi]."""
    return np.mod(th + np.pi, 2 * np.pi) - np.pi


def wrap_ax_ang(ax, th):
    """
    Wraps angle onto [-pi, pi]. If ang is negative, flip both its sign and the
    axis's sign.
    """
    th = wrap_angle(th)
    neg = th < 0
    th[neg] *= -1
    ax[neg] *= -1
    return ax, th


def decompose_ax_ang_restricted(r, *, tol=1e-10):
    """
    Given R^3 represent so(3), decompose into angle [0, pi) and axis.
    In this restricted domain, we have a bijective map with SO(3).
    """
    th = norm(r, axis=-1)
    good = th > tol
    ax = np.zeros_like(r)
    ax[~good, :] = np.array([1, 0, 0])
    ax[good] = r[good] / th[good].reshape((-1, 1))
    ax, th = wrap_ax_ang(ax, th)
    assert np.all((th >= 0.0) & (th <= np.pi))
    return ax, th


# Visualization for restricted 2D subspace thinger.


def calc_ax_gamma(ax):
    """
    Given axis that is restricted to z=0, calculate angle (gamma) that can
    produce axis.
    """
    assert np.all(ax[:, 2] == 0)
    gamma = np.arctan2(ax[:, 1], ax[:, 0])
    return gamma


def _reshape_interpoland(c):
    if c.ndim != 2:
        c = c.reshape((1, -1))
    return c


def interp(s, a, b, *, si=0, sf=1):
    s = (s - si) / (sf - si)
    s = np.clip(s, 0, 1)
    s = s.reshape((-1, 1))
    a = _reshape_interpoland(a)
    b = _reshape_interpoland(b)
    return a + s * (b - a)


def color_ax_ang(ax, th):
    """
    Takes output of decompose_ax_ang_restricted() and adds purty colors.
    The same color should correspond to the same SO(3), and there should
    be no colors that are the same but do not represnt same SO(3).
    """
    nx, ny = th.shape
    ax = einops.rearrange(ax, "nx ny c -> (nx ny) c")
    th = einops.rearrange(th, "nx ny -> (nx ny)")
    # Axis - interp from red (0) to blue (pi) to red (0)
    ax_start = np.array([1, 0, 0])
    ax_end = np.array([0, 0, 1])
    # Compute blending based on axis using 2d angle.
    ax_gamma = calc_ax_gamma(ax)
    # Blend based on axis, then based on angle.
    ax_color = interp(ax_gamma, ax_start, ax_end, si=-np.pi, sf=np.pi)
    # Angle - from white to axis color. 
    th_start = np.array([0, 0, 0])
    # Angle magnitude of zero should be equivalent, regardless of axes.
    color = interp(th, th_start, ax_color, si=0, sf=np.pi)
    color = einops.rearrange(color, "(nx ny) c -> nx ny c", nx=nx, ny=ny)
    return color
    

def to_pil(x):
    assert x.dtype == float
    x = (x * 255).astype(np.uint8)
    return Image.fromarray(x)

In [None]:
# Make grid of R^2
nx = 4
ny = nx
bound = 15
xs = np.linspace(-bound, bound, nx)
ys = np.linspace(-bound, bound, ny)
X, Y = np.meshgrid(xs, ys)
Z = np.zeros_like(X)
rs = np.stack((X, Y, Z), axis=-1)

ax, ang = decompose_ax_ang_restricted(rs)
img = color_ax_ang(ax, ang)
img = img[:, ::-1, :]

to_pil(img)

In [None]:
def skew(rs):
    N, M = rs.shape
    assert M == 3
    S = np.zeros((N, 3, 3))
    r1, r2, r3 = rs.T
    S[:, 0, 1] = -r3
    S[:, 1, 0] = r3
    S[:, 0, 2] = r2
    S[:, 2, 0] = -r2
    S[:, 1, 2] = -r1
    S[:, 2, 1] = r1
    return S


def axang(axs, ths):
    """Exponential-map thinger, using Eq. (2.14) of [MLS]."""
    # return RotationMatrix(AngleAxis(angle, axis)).matrix()
    N, = ths.shape
    assert axs.shape == (N, 3)
    c = np.cos(ths)
    s = np.sin(ths)
    L = skew(axs)
    I = einops.repeat(np.eye(3), "A B -> N A B", N=N, A=3, B=3)
    R = I + L * s + L @ L * (1 - c)
    return R


def axang3(rs, *, tol=1e-10):
    """Expononential map against R^3, splitting into (θ, λ)"""
    ths = norm(rs, axis=-1)
    small = ths < tol
    axs = np.zeros_like(rs)
    axs[small, :] = [0, 0, 1]
    div = ths[~small]
    div = div.reshape((-1, 1))
    axs[~small, :] = rs[~small] / div
    return axang(axs, ths)


def trace(Xs):
    assert Xs.ndim == 3
    return np.trace(Xs, axis1=1, axis2=2)


def so3_angle(Rs):
    # same as AngleAxis(R).angle()
    inner = (trace(Rs) - 1) / 2
    reflect = inner > 1
    # Reflect crossing at 1... I guess?
    inner[reflect] = 2 - inner[reflect]
    ths = np.arccos(inner)
    assert np.isfinite(ths).all()
    return ths

def exp_dist(r1s, r2s):
    R1s = axang3(r1s)
    R2s = axang3(r2s)
    Rs = R1s.T @ R2s
    return so3_angle(Rs)

In [None]:
# Show distance on SO(3) between a pixel and a Euclidean displacement
# in exponential coordinates.
r_delta = np.array([0.1, 0, 0])
rds = rs + r_delta

# Flatten
W = nx
H = ny
rfs = einops.rearrange(rs, "W H C -> (W H) C", C=3)
rd_fs = einops.rearrange(rds, "W H C -> (W H) C", C=3)

ang_ds = exp_dist(rfs, rd_fs)