In [138]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torchvision import datasets, transforms

from algebra.cliffordalgebra import CliffordAlgebra
from models.modules.gp import SteerableGeometricProductLayer
from models.modules.linear import MVLinear
from models.modules.mvlayernorm import MVLayerNorm
from models.modules.mvsilu import MVSiLU

mnist = datasets.MNIST('~/datasets/', transform=transforms.ToTensor(), download=True)
image = mnist[0][0]

In [139]:
ca = CliffordAlgebra((1., 1.))

In [140]:
size = 28

In [141]:
x = torch.stack(torch.meshgrid(torch.linspace(-1, 1, size), torch.linspace(-1, 1, size), indexing='xy'), -1).reshape(-1, 2)

In [146]:
z1 = torch.tensor([0., 1.])
z2 = torch.tensor([1., 0.])

Z = torch.stack([z1, z2])


In [147]:
x_cl = ca.embed_grade(x, 1)
z1_cl = ca.embed_grade(z1, 1)
z2_cl = ca.embed_grade(z2, 1)

Z_cl = ca.embed_grade(Z, 1)

In [154]:
E = ca.geometric_product(
    x_cl[:, None].repeat(1, len(Z_cl), 1),
    Z_cl[None].repeat(len(x_cl), 1, 1)
)[..., 0]


In [156]:
emb_z1 = ca.geometric_product(x_cl, z1_cl)[..., 0]
emb_z1

tensor([-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -0.9259, -0.9259, -0.9259, -0.9259,
        -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259,
        -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259,
        -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259, -0.9259,
        -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519,
        -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519,
        -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519, -0.8519,
        -0.8519, -0.8519, -0.8519, -0.8519, -0.7778, -0.7778, -0.7778, -0.7778,
        -0.7778, -0.7778, -0.7778, -0.7778, -0.7778, -0.7778, -0.7778, -0.7778,
        -0.7778, -0.7778, -0.7778, -0.77

In [157]:
emb_z2 = ca.geometric_product(x_cl, z2_cl)[..., 0]
emb_z2

tensor([-1.0000, -0.9259, -0.8519, -0.7778, -0.7037, -0.6296, -0.5556, -0.4815,
        -0.4074, -0.3333, -0.2593, -0.1852, -0.1111, -0.0370,  0.0370,  0.1111,
         0.1852,  0.2593,  0.3333,  0.4074,  0.4815,  0.5556,  0.6296,  0.7037,
         0.7778,  0.8519,  0.9259,  1.0000, -1.0000, -0.9259, -0.8519, -0.7778,
        -0.7037, -0.6296, -0.5556, -0.4815, -0.4074, -0.3333, -0.2593, -0.1852,
        -0.1111, -0.0370,  0.0370,  0.1111,  0.1852,  0.2593,  0.3333,  0.4074,
         0.4815,  0.5556,  0.6296,  0.7037,  0.7778,  0.8519,  0.9259,  1.0000,
        -1.0000, -0.9259, -0.8519, -0.7778, -0.7037, -0.6296, -0.5556, -0.4815,
        -0.4074, -0.3333, -0.2593, -0.1852, -0.1111, -0.0370,  0.0370,  0.1111,
         0.1852,  0.2593,  0.3333,  0.4074,  0.4815,  0.5556,  0.6296,  0.7037,
         0.7778,  0.8519,  0.9259,  1.0000, -1.0000, -0.9259, -0.8519, -0.7778,
        -0.7037, -0.6296, -0.5556, -0.4815, -0.4074, -0.3333, -0.2593, -0.1852,
        -0.1111, -0.0370,  0.0370,  0.11

In [158]:
emb = torch.stack([emb_z1, emb_z2], -1)



tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]])

In [129]:
flat_tensor = emb.clone().float()  # Ensure it's in a dtype that supports decimal points
for i in range(emb.size(1)):
    flat_tensor[:, i] += (i * 0.1)

# Sort the flattened tensor, then use the indices to sort the original tensor
_, sorted_indices = flat_tensor.sum(dim=1).sort()
sorted_tensor = emb[sorted_indices]


In [130]:
assert not torch.all(sorted_tensor[1:] == sorted_tensor[:-1])

In [131]:
sorted_tensor

tensor([[-1.0000, -1.0000],
        [-1.0000, -0.9259],
        [-0.9259, -1.0000],
        ...,
        [ 0.9259,  1.0000],
        [ 1.0000,  0.9259],
        [ 1.0000,  1.0000]])