In [10]:
#@title Genaral Imports

import os

import math
import operator
import functools
import itertools

import torch
from torch import nn
import torch.nn.functional as F

In [1]:
#@title Set up git custom lib from git repo

os.chdir('/')
if not os.path.exists("/clifford-group-equivariant-neural-networks"):
    !git clone https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks.git
os.chdir("/clifford-group-equivariant-neural-networks")

Cloning into 'clifford-group-equivariant-neural-networks'...
remote: Enumerating objects: 112, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 112 (delta 32), reused 23 (delta 23), pack-reused 67[K
Receiving objects: 100% (112/112), 349.95 KiB | 2.99 MiB/s, done.
Resolving deltas: 100% (38/38), done.


# Clifford Group Equivariant Neural Networks
![image](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/raw/master/assets/figure.png)

# Links
📜 [ArXiV](https://arxiv.org/abs/2305.11141)

🖥️ [Github](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks)

🤓 [Blog Posts](https://davidruhe.github.io/)



# Utils

In this sectioan are implemented all the utility functioion used to build the fundamntals blocks of the Clifford Geometry NN modules.

In [None]:
def unsqueeze_like(tensor: torch.Tensor, like: torch.Tensor, dim=0):
    """
    Unsqueeze last dimensions of tensor to match another tensor's number of dimensions.

    Args:
        tensor (torch.Tensor): tensor to unsqueeze
        like (torch.Tensor): tensor whose dimensions to match
        dim: int: starting dim, default: 0.
    """
    n_unsqueezes = like.ndim - tensor.ndim
    if n_unsqueezes < 0:
        raise ValueError(f"tensor.ndim={tensor.ndim} > like.ndim={like.ndim}")
    elif n_unsqueezes == 0:
        return tensor
    else:
        return tensor[dim * (slice(None),) + (None,) * n_unsqueezes]


# copied from the itertools docs
def _powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return itertools.chain.from_iterable(
        itertools.combinations(s, r) for r in range(len(s) + 1)
    )


class ShortLexBasisBladeOrder:
    def __init__(self, n_vectors):
        self.index_to_bitmap = torch.empty(2**n_vectors, dtype=int)
        self.grades = torch.empty(2**n_vectors, dtype=int)
        self.bitmap_to_index = torch.empty(2**n_vectors, dtype=int)

        for i, t in enumerate(_powerset([1 << i for i in range(n_vectors)])):
            bitmap = functools.reduce(operator.or_, t, 0)
            self.index_to_bitmap[i] = bitmap
            self.grades[i] = len(t)
            self.bitmap_to_index[bitmap] = i
            del t  # enables an optimization inside itertools.combinations


def set_bit_indices(x: int):
    """Iterate over the indices of bits set to 1 in `x`, in ascending order"""
    n = 0
    while x > 0:
        if x & 1:
            yield n
        x = x >> 1
        n = n + 1


def count_set_bits(bitmap: int) -> int:
    """Counts the number of bits set to 1 in bitmap"""
    count = 0
    for i in set_bit_indices(bitmap):
        count += 1
    return count


def canonical_reordering_sign_euclidean(bitmap_a, bitmap_b):
    """
    Computes the sign for the product of bitmap_a and bitmap_b
    assuming a euclidean metric
    """
    a = bitmap_a >> 1
    sum_value = 0
    while a != 0:
        sum_value = sum_value + count_set_bits(a & bitmap_b)
        a = a >> 1
    if (sum_value & 1) == 0:
        return 1
    else:
        return -1


def canonical_reordering_sign(bitmap_a, bitmap_b, metric):
    """
    Computes the sign for the product of bitmap_a and bitmap_b
    given the supplied metric
    """
    bitmap = bitmap_a & bitmap_b
    output_sign = canonical_reordering_sign_euclidean(bitmap_a, bitmap_b)
    i = 0
    while bitmap != 0:
        if (bitmap & 1) != 0:
            output_sign *= metric[i]
        i = i + 1
        bitmap = bitmap >> 1
    return output_sign


def gmt_element(bitmap_a, bitmap_b, sig_array):
    """
    Element of the geometric multiplication table given blades a, b.
    The implementation used here is described in :cite:`ga4cs` chapter 19.
    """
    output_sign = canonical_reordering_sign(bitmap_a, bitmap_b, sig_array)
    output_bitmap = bitmap_a ^ bitmap_b
    return output_bitmap, output_sign


def construct_gmt(index_to_bitmap, bitmap_to_index, signature):
    n = len(index_to_bitmap)
    array_length = int(n * n)
    coords = torch.zeros((3, array_length), dtype=torch.int)
    k_list = coords[0, :]
    l_list = coords[1, :]
    m_list = coords[2, :]

    # use as small a type as possible to minimize type promotion
    mult_table_vals = torch.zeros(array_length)

    for i in range(n):
        bitmap_i = index_to_bitmap[i]

        for j in range(n):
            bitmap_j = index_to_bitmap[j]
            bitmap_v, mul = gmt_element(bitmap_i, bitmap_j, signature)
            v = bitmap_to_index[bitmap_v]

            list_ind = i * n + j
            k_list[list_ind] = i
            l_list[list_ind] = v
            m_list[list_ind] = j

            mult_table_vals[list_ind] = mul

    return torch.sparse_coo_tensor(
        indices=coords, values=mult_table_vals, size=(n, n, n)
    )

# Dataset

## THIS PART MUST BE CHANGED
Let's first create a dataset. Before we can do so, we have to specify a dimensionality $d$.

**Note:** To make the notebook run a bit smoother, we implicitly define the dimension through the length of metric $M$, which we will discuss in a bit.

Feel free to adjust this parameter to your needs. Just be cautious: the dimensionality of the Clifford algebra (which we will define later) grows exponentially as $2^d$, so large $d$ will get computationally infeasible.

Next, we sample $N$ random vectors $u$ and $v$, and the categorical variables $y$.

In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F

In [3]:
metric = [1, 1]
d = len(metric)
N = 1024
x = torch.randn(N, 2, d)
y = torch.randint(1, 3, (N,))

In [4]:
x[42]

tensor([[ 0.8873, -1.1112],
        [ 0.8208, -0.8280]])

In [5]:
x[42][:,0]

tensor([0.8873, 0.8208])

In [6]:
y

tensor([2, 1, 1,  ..., 2, 2, 1])

## The Metric

Now, to compute the target values using our function $f(u, v, y)$, we have to specify how we compute the vector norm and inner product.

A well-defined inner product uses a *metric*, which defines *how distances are calculated* in a certain setting.
Typically, that is in the *Euclidean* setting, this metric is positive definite, e.g., $M:=\mathrm{diag\,}[1, 1]$ in the two-dimensional case.

A quadratic form here using matrix notation would look like

$$v \mapsto v ^\top M v,$$

or an inner product

$$(u, v) \mapsto u ^\top M v.$$

This means that it acts as the identity. Hence, it is usually not really taken into account by explicitly writing it down.

However, in more exotic settings such as the *Minkowski space* used in special relativity, we can have $M:=\mathrm{diag\,} [-1, 1, 1, 1]$. Computing inner products in such a space requires carefully taking into account this metric.

In the following, feel free to adjust the metric to your needs, its length should, however, match $d$.

In [7]:
assert len(metric) == d, f"The dimensionality d should match the metric length."

f = torch.zeros(N)
metric_t = torch.tensor(metric, dtype=torch.float)
f[y == 1] = torch.cos(torch.einsum('bi, i, bi->b', x[y==1][:, 0], metric_t, x[y==1][:, 0]).abs().sqrt())
f[y == 2] = 1/10 * torch.einsum('bi, i, bi->b', x[y==2][:, 0], metric_t, x[y==2][:, 1]) ** 3

Then we compute the function $f(u, v, y)$ for all our datapoints. Computing the vector norm and inner products, we take into account the metric, which can easily be done using `torch.einsum`.

In [8]:
f[y == 1]

tensor([ 4.5600e-01,  3.5077e-01,  7.6015e-01,  9.9527e-01,  7.8616e-01,
        -8.4514e-02,  9.3260e-01, -2.9277e-01, -8.1253e-01,  1.4767e-01,
         5.8321e-01,  4.2188e-01,  7.1053e-01,  1.0768e-01,  1.8702e-02,
         8.9654e-01, -5.6695e-01,  5.5888e-01,  4.4886e-01,  9.6224e-01,
         9.2613e-01,  6.9317e-01,  1.5226e-01,  5.3694e-01,  3.9556e-01,
        -9.3179e-01,  9.2151e-01,  4.8318e-01,  7.4320e-01,  3.7025e-01,
        -5.5383e-01,  6.9792e-01, -5.3469e-01, -3.7326e-01,  1.6219e-01,
         5.4283e-01,  4.4422e-01,  7.1341e-01, -5.5932e-01,  7.9261e-01,
         4.5433e-01,  5.8228e-01,  8.4501e-01,  5.2722e-01, -6.0635e-02,
        -3.1374e-01,  3.2908e-01, -9.9911e-01, -9.7567e-02, -1.0300e-02,
         3.8935e-01,  6.8634e-01,  6.4484e-01,  3.1521e-01, -2.3551e-02,
         5.4783e-01,  7.1550e-01,  2.2239e-01,  6.6890e-01,  8.9749e-01,
        -5.1104e-01,  5.7127e-02,  5.5373e-01, -9.9739e-01, -5.4467e-02,
         9.1742e-01,  6.2396e-01, -1.0957e-01,  4.4

In [9]:
v = x[y==1][:,0]

for b in range(v.shape[0]):
    res = 0
    for i in range(v.shape[1]):
        res += v[b,i] * metric_t[i] * v[b,i]

    print(torch.cos(res.abs().sqrt()))

tensor(0.4560)
tensor(0.3508)
tensor(0.7601)
tensor(0.9953)
tensor(0.7862)
tensor(-0.0845)
tensor(0.9326)
tensor(-0.2928)
tensor(-0.8125)
tensor(0.1477)
tensor(0.5832)
tensor(0.4219)
tensor(0.7105)
tensor(0.1077)
tensor(0.0187)
tensor(0.8965)
tensor(-0.5669)
tensor(0.5589)
tensor(0.4489)
tensor(0.9622)
tensor(0.9261)
tensor(0.6932)
tensor(0.1523)
tensor(0.5369)
tensor(0.3956)
tensor(-0.9318)
tensor(0.9215)
tensor(0.4832)
tensor(0.7432)
tensor(0.3703)
tensor(-0.5538)
tensor(0.6979)
tensor(-0.5347)
tensor(-0.3733)
tensor(0.1622)
tensor(0.5428)
tensor(0.4442)
tensor(0.7134)
tensor(-0.5593)
tensor(0.7926)
tensor(0.4543)
tensor(0.5823)
tensor(0.8450)
tensor(0.5272)
tensor(-0.0606)
tensor(-0.3137)
tensor(0.3291)
tensor(-0.9991)
tensor(-0.0976)
tensor(-0.0103)
tensor(0.3893)
tensor(0.6863)
tensor(0.6448)
tensor(0.3152)
tensor(-0.0236)
tensor(0.5478)
tensor(0.7155)
tensor(0.2224)
tensor(0.6689)
tensor(0.8975)
tensor(-0.5110)
tensor(0.0571)
tensor(0.5537)
tensor(-0.9974)
tensor(-0.0545)
tensor(

# Cliffor Algebra Operations

In [None]:
class CliffordAlgebra(nn.Module):
    def __init__(self, metric):
        super().__init__()

        self.register_buffer("metric", torch.as_tensor(metric))
        self.num_bases = len(metric)
        self.bbo = ShortLexBasisBladeOrder(self.num_bases)
        self.dim = len(self.metric)
        self.n_blades = len(self.bbo.grades)
        cayley = (
            construct_gmt(
                self.bbo.index_to_bitmap, self.bbo.bitmap_to_index, self.metric
            )
            .to_dense()
            .to(torch.get_default_dtype())
        )
        self.grades = self.bbo.grades.unique()
        self.register_buffer(
            "subspaces",
            torch.tensor(tuple(math.comb(self.dim, g) for g in self.grades)),
        )
        self.n_subspaces = len(self.grades)
        self.grade_to_slice = self._grade_to_slice(self.subspaces)
        self.grade_to_index = [
            torch.tensor(range(*s.indices(s.stop))) for s in self.grade_to_slice
        ]

        self.register_buffer(
            "bbo_grades", self.bbo.grades.to(torch.get_default_dtype())
        )
        self.register_buffer("even_grades", self.bbo_grades % 2 == 0)
        self.register_buffer("odd_grades", ~self.even_grades)
        self.register_buffer("cayley", cayley)

    def geometric_product(self, a, b, blades=None):
        cayley = self.cayley

        if blades is not None:
            blades_l, blades_o, blades_r = blades
            assert isinstance(blades_l, torch.Tensor)
            assert isinstance(blades_o, torch.Tensor)
            assert isinstance(blades_r, torch.Tensor)
            cayley = cayley[blades_l[:, None, None], blades_o[:, None], blades_r]

        return torch.einsum("...i,ijk,...k->...j", a, cayley, b)

    def _grade_to_slice(self, subspaces):
        grade_to_slice = list()
        subspaces = torch.as_tensor(subspaces)
        for grade in self.grades:
            index_start = subspaces[:grade].sum()
            index_end = index_start + math.comb(self.dim, grade)
            grade_to_slice.append(slice(index_start, index_end))
        return grade_to_slice

    @functools.cached_property
    def _alpha_signs(self):
        return torch.pow(-1, self.bbo_grades)

    @functools.cached_property
    def _beta_signs(self):
        return torch.pow(-1, self.bbo_grades * (self.bbo_grades - 1) // 2)

    @functools.cached_property
    def _gamma_signs(self):
        return torch.pow(-1, self.bbo_grades * (self.bbo_grades + 1) // 2)

    def alpha(self, mv, blades=None):
        signs = self._alpha_signs
        if blades is not None:
            signs = signs[blades]
        return signs * mv.clone()

    def beta(self, mv, blades=None):
        signs = self._beta_signs
        if blades is not None:
            signs = signs[blades]
        return signs * mv.clone()

    def gamma(self, mv, blades=None):
        signs = self._gamma_signs
        if blades is not None:
            signs = signs[blades]
        return signs * mv.clone()

    def zeta(self, mv):
        return mv[..., :1]

    def embed(self, tensor: torch.Tensor, tensor_index: torch.Tensor) -> torch.Tensor:
        mv = torch.zeros(
            *tensor.shape[:-1], 2**self.dim, device=tensor.device, dtype=tensor.dtype
        )
        mv[..., tensor_index] = tensor
        return mv

    def embed_grade(self, tensor: torch.Tensor, grade: int) -> torch.Tensor:
        mv = torch.zeros(*tensor.shape[:-1], 2**self.dim, device=tensor.device)
        s = self.grade_to_slice[grade]
        mv[..., s] = tensor
        return mv

    def get(self, mv: torch.Tensor, blade_index: tuple[int]) -> torch.Tensor:
        blade_index = tuple(blade_index)
        return mv[..., blade_index]

    def get_grade(self, mv: torch.Tensor, grade: int) -> torch.Tensor:
        s = self.grade_to_slice[grade]
        return mv[..., s]

    def b(self, x, y, blades=None):
        if blades is not None:
            assert len(blades) == 2
            beta_blades = blades[0]
            blades = (
                blades[0],
                torch.tensor([0]),
                blades[1],
            )
        else:
            blades = torch.tensor(range(self.n_blades))
            blades = (
                blades,
                torch.tensor([0]),
                blades,
            )
            beta_blades = None

        return self.geometric_product(
            self.beta(x, blades=beta_blades),
            y,
            blades=blades,
        )

    def q(self, mv, blades=None):
        if blades is not None:
            blades = (blades, blades)
        return self.b(mv, mv, blades=blades)

    def _smooth_abs_sqrt(self, input, eps=1e-16):
        return (input**2 + eps) ** 0.25

    def norm(self, mv, blades=None):
        return self._smooth_abs_sqrt(self.q(mv, blades=blades))

    def norms(self, mv, grades=None):
        if grades is None:
            grades = self.grades
        return [
            self.norm(self.get_grade(mv, grade), blades=self.grade_to_index[grade])
            for grade in grades
        ]

    def qs(self, mv, grades=None):
        if grades is None:
            grades = self.grades
        return [
            self.q(self.get_grade(mv, grade), blades=self.grade_to_index[grade])
            for grade in grades
        ]

    def sandwich(self, u, v, w):
        return self.geometric_product(self.geometric_product(u, v), w)

    def output_blades(self, blades_left, blades_right):
        blades = []
        for blade_left in blades_left:
            for blade_right in blades_right:
                bitmap_left = self.bbo.index_to_bitmap[blade_left]
                bitmap_right = self.bbo.index_to_bitmap[blade_right]
                bitmap_out, _ = gmt_element(bitmap_left, bitmap_right, self.metric)
                index_out = self.bbo.bitmap_to_index[bitmap_out]
                blades.append(index_out)

        return torch.tensor(blades)

    def random(self, n=None):
        if n is None:
            n = 1
        return torch.randn(n, self.n_blades)

    def random_vector(self, n=None):
        if n is None:
            n = 1
        vector_indices = self.bbo_grades == 1
        v = torch.zeros(n, self.n_blades, device=self.cayley.device)
        v[:, vector_indices] = torch.randn(
            n, vector_indices.sum(), device=self.cayley.device
        )
        return v

    def parity(self, mv):
        is_odd = torch.all(mv[..., self.even_grades] == 0)
        is_even = torch.all(mv[..., self.odd_grades] == 0)

        if is_odd ^ is_even:  # exclusive or (xor)
            return is_odd
        else:
            raise ValueError("This is not a homogeneous element.")

    def eta(self, w):
        return (-1) ** self.parity(w)

    def alpha_w(self, w, mv):
        return self.even_grades * mv + self.eta(w) * self.odd_grades * mv

    def inverse(self, mv, blades=None):
        mv_ = self.beta(mv, blades=blades)
        return mv_ / self.q(mv)

    def rho(self, w, mv):
        """Applies the versor w action to mv."""
        return self.sandwich(w, self.alpha_w(w, mv), self.inverse(w))

    def reduce_geometric_product(self, inputs):
        return functools.reduce(self.geometric_product, inputs)

    def versor(self, order=None, normalized=True):
        if order is None:
            order = self.dim if self.dim % 2 == 0 else self.dim - 1
        vectors = self.random_vector(order)
        versor = self.reduce_geometric_product(vectors[:, None])
        if normalized:
            versor = versor / self.norm(versor)[..., :1]
        return versor

    def rotor(self):
        return self.versor()

    @functools.cached_property
    def geometric_product_paths(self):
        gp_paths = torch.zeros((self.dim + 1, self.dim + 1, self.dim + 1), dtype=bool)

        for i in range(self.dim + 1):
            for j in range(self.dim + 1):
                for k in range(self.dim + 1):
                    s_i = self.grade_to_slice[i]
                    s_j = self.grade_to_slice[j]
                    s_k = self.grade_to_slice[k]

                    m = self.cayley[s_i, s_j, s_k]
                    gp_paths[i, j, k] = (m != 0).any()

        return gp_paths

# Custom CG Layers

In [None]:
#@title Linear Layer

class MVLinear(nn.Module):
    def __init__(
        self,
        algebra,
        in_features,
        out_features,
        subspaces=True,
        bias=True,
    ):
        super().__init__()

        self.algebra = algebra
        self.in_features = in_features
        self.out_features = out_features
        self.subspaces = subspaces

        if subspaces:
            self.weight = nn.Parameter(
                torch.empty(out_features, in_features, algebra.n_subspaces)
            )
            self._forward = self._forward_subspaces
        else:
            self.weight = nn.Parameter(torch.empty(out_features, in_features))

        if bias:
            self.bias = nn.Parameter(torch.empty(1, out_features, 1))
            self.b_dims = (0,)
        else:
            self.register_parameter("bias", None)
            self.b_dims = ()

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.normal_(self.weight, std=1 / math.sqrt(self.in_features))

        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def _forward(self, input):
        return torch.einsum("bm...i, nm->bn...i", input, self.weight)

    def _forward_subspaces(self, input):
        weight = self.weight.repeat_interleave(self.algebra.subspaces, dim=-1)
        return torch.einsum("bm...i, nmi->bn...i", input, weight)

    def forward(self, input):
        result = self._forward(input)

        if self.bias is not None:
            bias = self.algebra.embed(self.bias, self.b_dims)
            result += unsqueeze_like(bias, result, dim=2)
        return result

# Model
Now that we created the inputs and targets for our model to learn, we specify the model. To do so, we first import the necessary objects, and create the `CliffordAlgebra` object which takes as input the metric.

In [None]:
from models.modules.linear import MVLinear
from models.modules.gp import SteerableGeometricProductLayer
from models.modules.mvlayernorm import MVLayerNorm
from models.modules.mvsilu import MVSiLU

Let's create a Clifford Group Equivariant model block.

In [None]:
class CGEBlock(nn.Module):
    def __init__(self, algebra, in_features, out_features):
        super().__init__()

        self.layers = nn.Sequential(
            MVLinear(algebra, in_features, out_features),
            MVSiLU(algebra, out_features),
            SteerableGeometricProductLayer(algebra, out_features),
            MVLayerNorm(algebra, out_features)
        )

    def forward(self, input):
        # [batch_size, in_features, 2**d] -> [batch_size, out_features, 2**d]
        return self.layers(input)

This is a feedforward transformation, going from `in_features` to `out_features` (similar to the usual neural network setting) using a few modules.

First, we apply a linear transformation (`MVLinear`) followed by a nonlinear `MVSiLU` layer, a multivector analogue of the `SiLU` or *swish* activation. Other equivariant nonlinearities are also possible, but have not been implemented yet.

Then we apply a `SteerableGeometricProductLayer`, which computes geometric products between our input data in a manner similar to *self-attention* (see the implementation of the block for more details).

Finally, we apply a multivector analogue of LayerNorm: `MVLayerNorm`.

In [None]:
class CGEMLP(nn.Module):
    def __init__(self, algebra, in_features, hidden_features, out_features, n_layers=2):
        super().__init__()

        layers = []
        for i in range(n_layers - 1):
            layers.append(
                CGEBlock(algebra, in_features, hidden_features)
            )
            in_features = hidden_features

        layers.append(
            CGEBlock(algebra, hidden_features, out_features)
        )
        self.layers = nn.Sequential(*layers)

    def forward(self, input):
        return self.layers(input)

We simply chain the `CGEBlocks` to get an expressive architecture.

Finally, we specify our final model class, which we call `InvariantCGENN`. Since we are approximating an $O(n)$-*invariant* function, we can do some invariant post-processing using a regular `MLP`.

This is shown in the `forward` function. We first process the input through the `CGEMLP`, which will output *multivector*-valued hidden states. The *grade-0* (scalar) subspace of a multivector is always *invariant*. Hence, we index the last dimension of the hidden states at 0 and let an `MLP` process this. The final output remains invariant.

In [None]:
class InvariantCGENN(nn.Module):

    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.cgemlp = CGEMLP(ca, in_features, hidden_features, hidden_features)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        )

    def forward(self, input):
        h = self.cgemlp(input)
        # Index the hidden states at 0 to get the invariants, and let a regular MLP do the final processing.
        return self.mlp(h[..., 0])

# Data Embedding

Now we have to prepare the data to be processed by the network.

To do so, we first create the `CliffordAlgebra` object, providing us with the necessary Clifford algebra operations, such as the *geometric product*.



In [None]:
from algebra.cliffordalgebra import CliffordAlgebra
ca = CliffordAlgebra(metric)
ca

CliffordAlgebra()

Then, recall that the two vector features $u, v \in \mathbb{R}^n$. The *first* Clifford subspace (or grade 1) is the vector subspace. We also write $\mathrm{Cl}^{(1)}(V, q) = \mathbb{R}^n$.

Therefore:

In [None]:
x_cl = ca.embed_grade(x, 1)
print(x_cl.shape)
print(x_cl[0])

torch.Size([1024, 2, 4])
tensor([[ 0.0000, -1.6215,  0.4577,  0.0000],
        [ 0.0000, -1.1135,  0.6075,  0.0000]])


Note that the final shape increased to $2^d$. That's because we added the other Clifford subspaces (they are set to zero initially).


Next, we *one-hot* encode the categorical variable $y$.

In [None]:
y_oh = F.one_hot(y - 1, 2)
y_oh

tensor([[0, 1],
        [0, 1],
        [0, 1],
        ...,
        [0, 1],
        [1, 0],
        [1, 0]])

This categorical variable can be regarded as a scalar (invariant). Hence, we embed it in the zero grade, i.e., $\mathrm{Cl}^{(0)}(V, q)$.

In [None]:
y_cl = ca.embed_grade(y_oh[..., None], 0)
print(y_cl.shape)
print(y_cl[0])

torch.Size([1024, 2, 4])
tensor([[0., 0., 0., 0.],
        [1., 0., 0., 0.]])


Note that only the scalar part has a 0 or 1 entry.

We now concatenate to get our model input.

In [None]:
input_cl = torch.cat([x_cl, y_cl], dim=1)

# Training

Now we can train our model using regular backpropagation. We specify the input features, hidden features nad output features and use the Adam optimizer.

In [None]:
model = InvariantCGENN(4, 32, 1)

print(f"The model has {sum(p.numel() for p in model.parameters())} parameters.\n")
adam = optim.Adam(model.parameters())

for i in range(256):

    output = model(input_cl)
    loss = F.mse_loss(output.squeeze(-1), f)

    adam.zero_grad()
    loss.backward()
    adam.step()

    if i % 4 == 0:
        print(f"Step: {i}. Loss: {loss.item():.2f}")

The model has 19297 parameters.

Step: 0. Loss: 1.09
Step: 4. Loss: 1.01
Step: 8. Loss: 0.94
Step: 12. Loss: 0.87
Step: 16. Loss: 0.80
Step: 20. Loss: 0.74
Step: 24. Loss: 0.67
Step: 28. Loss: 0.60
Step: 32. Loss: 0.53
Step: 36. Loss: 0.45
Step: 40. Loss: 0.37
Step: 44. Loss: 0.30
Step: 48. Loss: 0.24
Step: 52. Loss: 0.19
Step: 56. Loss: 0.14
Step: 60. Loss: 0.10
Step: 64. Loss: 0.07
Step: 68. Loss: 0.04
Step: 72. Loss: 0.03
Step: 76. Loss: 0.02
Step: 80. Loss: 0.01
Step: 84. Loss: 0.01
Step: 88. Loss: 0.01
Step: 92. Loss: 0.01
Step: 96. Loss: 0.01
Step: 100. Loss: 0.00
Step: 104. Loss: 0.00
Step: 108. Loss: 0.00
Step: 112. Loss: 0.00
Step: 116. Loss: 0.00
Step: 120. Loss: 0.00
Step: 124. Loss: 0.00
Step: 128. Loss: 0.00
Step: 132. Loss: 0.00
Step: 136. Loss: 0.00
Step: 140. Loss: 0.00
Step: 144. Loss: 0.00
Step: 148. Loss: 0.00
Step: 152. Loss: 0.00
Step: 156. Loss: 0.00
Step: 160. Loss: 0.00
Step: 164. Loss: 0.00
Step: 168. Loss: 0.00
Step: 172. Loss: 0.00
Step: 176. Loss: 0.00
Step:

# Equivariance Assessment

Let's get a rotated version of our inputs! We can use the Clifford algebra to sample a random rotation. This is done by sampling a *versor*. A $k$-versor parameterizes an orthogonal transformation. For $k=1$ we have a reflection, for $k=2$ a rotation, for $k=3$ a rotoreflection, and so on.

To do a good equivariance assessment, we need to take care of numerical precision issues which are solved by going to `torch.float64`.



In [None]:
torch.set_default_dtype(torch.float64)
ca = ca.double()
model = model.double()
input_cl = input_cl.double()

In [None]:
w = ca.versor(2)  # A rotation can be obtained by composing two reflections, i.e., a 2-versor.

We apply the group action $\rho(w): \mathrm{Cl}(V, q) \to \mathrm{Cl}(V, q)$ to the input.

In [None]:
input_cl_w = ca.rho(w, input_cl)

We first start with the hidden output states of the Clifford MLP, which are not invariant, but *equivariant*. That is, they should *corotate* with our input.



In [None]:
h_cemlp = model.cgemlp(input_cl)
h_w_cemlp = model.cgemlp(input_cl_w)

In [None]:
assert torch.allclose(ca.rho(w, h_cemlp), h_w_cemlp)

The equivariance checks out! Note that these hidden representations are *multivectors*. I.e., they have the following shapes.

In [None]:
print(h_cemlp.shape)  # [batch size, hidden_features, 2 ** d]
print(h_cemlp[0, 0])

torch.Size([1024, 32, 4])
tensor([ 1.6597,  0.1700, -0.0447, -0.0433], grad_fn=<SelectBackward0>)


Now we investigate if the final prediction really is invariant.

In [None]:
assert torch.allclose(model(input_cl), model(input_cl_w))


😍

# Closing Remarks
**Try running the notebook again with a different metric, inducing a different dimensionality**. Some common ones are

`[1, 1`]: 2D Euclidean geometry.

`[1, 1, 1]` 3D Euclidean geometry.

`[1, 1, 1, 1]` 3D Elliptic geometry.

`[1, 1, 1, -1]` 3D hyperbolic geometry.

`[1, 1, 1, -1]` 4D flat spacetime.

Feel free to use these layers in your specialized architecture. In our experiments, we use them as message and update networks in equivariant *graph neural networks*. Check out the main repo for more details!


# Links
📜 [ArXiV](https://arxiv.org/abs/2305.11141)

🖥️ [Github](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks)

🤓 [Blog Posts](https://davidruhe.github.io/)



---
*Thanks to Chase van de Geijn and Tin Hadži Veljković for providing feedback to this notebook.*
