In [144]:
from jax import numpy as jnp, vmap, random
import matplotlib.pyplot as plt

rng = random.PRNGKey(0)

In [247]:
def get_perm_matrix(permutation):
    perm_matrix = jnp.zeros((4, 4))
    for i, j in enumerate(permutation):
        perm_matrix = perm_matrix.at[i, j].set(1)
    return perm_matrix


tetrahedron_pts = jnp.eye(4) - 1 / 4
eig_vals, eig_vecs = jnp.linalg.eigh(tetrahedron_pts)
tetrahedron_pts = tetrahedron_pts @ eig_vecs[:, 1:]

cube_pts = jnp.array(
    [
        [[(-1) ** i, (-1) ** j, (-1) ** k] for k in range(2)]
        for i in range(2)
        for j in range(2)
    ]
).reshape(-1, 3)


def tetrahedron_represent(permutation):
    perm_matrix = get_perm_matrix(permutation)
    P_1 = tetrahedron_pts
    P_2 = perm_matrix @ P_1
    A = jnp.linalg.pinv(P_1) @ P_2
    return A


def cube_represent(permutation):
    first_four = cube_pts[:4]
    P_1 = first_four.astype(float)[:3]
    P_2 = first_four[permutation].astype(float)[:3]

    for signs in cube_pts:
        A = jnp.linalg.pinv(P_1) @ (P_2 * signs[:, None])
        if jnp.allclose(jnp.linalg.det(A), 1.0, atol=1e-5) and jnp.allclose(
            A @ A.T, jnp.eye(3), atol=1e-5
        ):
            break
    return A


def compose_permutations(perm_1, perm_2):
    return perm_1[perm_2]


S_4 = jnp.array(
    [
        [0, 1, 2, 3],
        [0, 1, 3, 2],
        [0, 2, 1, 3],
        [0, 2, 3, 1],
        [0, 3, 1, 2],
        [0, 3, 2, 1],
        [1, 0, 2, 3],
        [1, 0, 3, 2],
        [1, 2, 0, 3],
        [1, 2, 3, 0],
        [1, 3, 0, 2],
        [1, 3, 2, 0],
        [2, 0, 1, 3],
        [2, 0, 3, 1],
        [2, 1, 0, 3],
        [2, 1, 3, 0],
        [2, 3, 0, 1],
        [2, 3, 1, 0],
        [3, 0, 1, 2],
        [3, 0, 2, 1],
        [3, 1, 0, 2],
        [3, 1, 2, 0],
        [3, 2, 0, 1],
        [3, 2, 1, 0],
    ]
)

# Symmetries of a tetrahedron
U_1 = jnp.array([tetrahedron_represent(permutation) for permutation in S_4])

# Symmetries of a cube
U_2 = jnp.array([cube_represent(permutation) for permutation in S_4])


# Tests whether the representation is associative
def test_repr(represent):
    for i in range(24):
        for j in range(i, 24):
            try:
                # note that the order of the composition is "reversed"
                assert jnp.allclose(
                    represent(compose_permutations(S_4[i], S_4[j])),
                    represent(S_4[j]) @ represent(S_4[i]),
                    atol=1e-5,
                )
            except:
                print(f"Error thrown at {(i, )}")


def regular_represent(permutation):
    new_perms = vmap(compose_permutations, in_axes=(None, 0))(permutation, S_4)
    representation = jnp.all(S_4[:, None, :] == new_perms[None, :, :], axis=-1).astype(
        int
    )
    return representation.T


def check_equivariant(map, input, rep_1, rep_2):
    for i in range(24):
        assert jnp.allclose(
            map(rep_1[i] @ input), rep_2[i] @ map(input), atol=1e-5
        )
    print("Passed test!")

In [274]:
rngs = random.split(rng, 3)
key_pt = random.normal(rngs[0], (3,))
val_pt = random.normal(rngs[1], (3,))
x = random.normal(rngs[2], (3,))

def embed_pt(pt):
    return jnp.dot(U_1 @ pt, key_pt)

def unembed_pt(embedded_pt):
    return jnp.sum((U_2.swapaxes(1, 2) @ val_pt) * embedded_pt[:, None], 0)


def func(x):
    y = x
    y = embed_pt(y)
    y = jnp.power(y, 2)
    y = unembed_pt(y)
    return y


check_equivariant(func, x, U_1, U_2)
func(x)


Passed test!


Array([1.4901161e-08, 0.0000000e+00, 8.3819032e-08], dtype=float32)

In [231]:
jnp.sum((U_2 @ val_pt) * embed_pt(x)[:, None], 0)

Array([-2.7939677e-09,  5.5879354e-09, -3.3527613e-08], dtype=float32)

In [135]:
regular_represent(perm).T @ embed_pt(x)

Array([-0.38873065,  0.37088457,  0.05911535,  0.33873057,  0.11000005,
       -0.36999992, -0.44359416,  0.31602103,  0.12476372,  0.2752147 ,
        0.17564842, -0.43351585, -0.19963191,  0.07998333, -0.07912003,
        0.07133094, -0.06038932, -0.18955357,  0.17390527, -0.3060947 ,
        0.29441717, -0.31474707,  0.26226318,  0.1330989 ], dtype=float32)

In [73]:
# get_perm_matrix([0, 1, 2, 3])
jnp.linalg.det(represent_permutation([1, 0, 3, 2]))

Array(0.9999999, dtype=float32)