In [165]:
from jax import numpy as jnp, vmap, random
import matplotlib.pyplot as plt
from S_4 import *

rng = random.PRNGKey(0)
rngs = random.split(rng, 4)

In [166]:
normalize = lambda x: x / jnp.linalg.norm(x)
key_pt = normalize(random.normal(rngs[0], (3,)))
val_pt = normalize(random.normal(rngs[1], (3,)))
x = random.normal(rngs[2], (3,))

def embed_pt(pt, representation):
    return jnp.dot(representation @ pt, key_pt)

def unembed_pt(embedded_pt, representation):
    return jnp.sum((representation.swapaxes(1, 2) @ val_pt) * embedded_pt[:, None], 0)


def func(x):
    y = x
    y = embed_pt(y, U_1)
    # y = jnp.maximum(y, 0.) # relu for a non-polynomial equivariant function
    y = jnp.power(y, 3) # cubing to match bens function, squaring makes the map trivial
    y = unembed_pt(y, U_2)
    return y

# According to Ben, this is the unique cubic equivariant function (up to a scalar)
def bens_func(x):
    return x * jnp.stack([
        (x[2] ** 2) - (x[1] ** 2),
        (x[0] ** 2) - (x[2] ** 2),
        (x[1] ** 2) - (x[0] ** 2)
    ])

In [167]:
# Checking both functions are equivariant
check_equivariant(func, x, U_1, U_2)
check_equivariant(bens_func, x, U_1, U_2)

Passed test!
Passed test!


In [168]:
# Checking whether the hidden-layer function matches the one Ben defined:
X = random.normal(key = rngs[3], shape=(100, 3))
Y = vmap(func)(X)
ben_Y = vmap(bens_func)(X)

print(f"mean error: {jnp.abs(Y / ben_Y  - (Y / ben_Y).mean()).mean() }")

mean error: 5.623499873763649e-06
