In [8]:
import torch
import torch.nn as nn
from torchga.torchga import GeometricAlgebra
from torchga.layers import (
    GeometricProductDense,
    TensorToGeometric,
    GeometricProductConv1D,
    GeometricProductElementwise,
    GeometricSandwichProductDense,
    GeometricSandwichProductElementwise
)

In [2]:
batch = 4
modalities = 3
test_input = torch.rand([batch, modalities, 4, 4, 4]) # batch, modalities, 4x4x4
test_input = test_input.flatten(start_dim = 2) # flat embeddings
embedding_length = test_input.shape[-1]
algebra_metric = [1, 1, 1]
algebra = GeometricAlgebra(algebra_metric)
embedding_indices = [1, 2, 3]
to_ga_embeddings = nn.ModuleList()

# prepare embedding sublayers
for index in embedding_indices:
    to_ith_blade = TensorToGeometric(algebra, torch.tensor(index))
    to_ga_embeddings.append(to_ith_blade)
    
print(test_input.shape)

torch.Size([4, 3, 64])


In [3]:
embeddings = [t.squeeze(1) for t in torch.chunk(test_input, chunks=test_input.shape[0], dim=1)]
len(embeddings), embeddings[0].shape

(3, torch.Size([4, 64]))

In [4]:
ga_embedding_vectors = [
                embedding_ith(emb[..., None])
                for embedding_ith, emb in zip(to_ga_embeddings, embeddings)
            ]
ga_embedding_vectors[0].shape, len(ga_embedding_vectors)

(torch.Size([4, 64, 8]), 3)

In [5]:
ga_embedding_vectors[1][3, 0, :]

tensor([0.0000, 0.0000, 0.1478, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])

In [6]:
ga_embedding = torch.stack(ga_embedding_vectors, dim=0)
print(ga_embedding.shape)
ga_embedding = ga_embedding.sum(dim=0)
ga_embedding.shape, ga_embedding[0, 0, :]

torch.Size([3, 4, 64, 8])


(torch.Size([4, 64, 8]),
 tensor([0.0000, 0.1690, 0.0304, 0.0791, 0.0000, 0.0000, 0.0000, 0.0000]))

In [None]:
print(f"embedding inputs are of shape {ga_embedding.shape}\n")
weight_blade_indices = torch.tensor([0, 1, 2, 3])
# GeometricProductDense,
processing_block = GeometricProductDense(
    algebra,
    embedding_length,
    embedding_length,
    blade_indices_kernel=weight_blade_indices,
    activation="tanh",
    use_bias=False,
)
output = processing_block(ga_embedding)
print(
    f"geometric product dense block produces {output.shape}, example\n{output[0,0,:]}"
)
# GeometricProductConv1D,
# this shit is crazy
processing_block = GeometricProductConv1D(
    algebra,
    num_input_filters=embedding_length,
    num_output_filters=embedding_length,
    kernel_size=3,
    stride=3,
    padding=None,
    blade_indices_kernel=weight_blade_indices,
    activation="tanh",
    use_bias=False,
)
output = processing_block(ga_embedding)
print(
    f"geometric product conv1d block produces fucked up {output.shape}, example\n{output[0,0,:]}"
)
# GeometricProductElementwise,
processing_block = GeometricProductElementwise(
    algebra, embedding_length, embedding_length, weight_blade_indices, activation="tanh", use_bias=False
)
output = processing_block(ga_embedding)
print(
    f"geometric product elementwise block produces {output.shape}, example\n{output[0,0,:]}"
)
# GeometricSandwichProductDense
processing_block = GeometricSandwichProductDense(
    algebra, embedding_length, embedding_length, weight_blade_indices, activation="tanh", use_bias=False
)
output = processing_block(ga_embedding)
print(
    f"geometric product sandwich dense block produces {output.shape}, example\n{output[0,0,:]}"
)
# GeometricSandwichProductElementwise
processing_block = GeometricSandwichProductElementwise(
    algebra, embedding_length, embedding_length, weight_blade_indices, activation="tanh", use_bias=False
)
output = processing_block(ga_embedding)
print(
    f"geometric product sandwich dense elementwise block produces {output.shape}, example\n{output[0,0,:]}"
)

embedding inputs are of shape torch.Size([4, 64, 8])

geometric product dense block produces torch.Size([4, 64, 8]), example
tensor([ 0.0335,  0.0837,  0.1045,  0.4523, -0.2594,  0.1548,  0.3612,  0.0000],
       grad_fn=<SliceBackward0>)
geometric product conv1d block produces torch.Size([1, 64, 8]), example
tensor([ 0.0159,  0.0263,  0.1174, -0.1688, -0.0539, -0.1391, -0.1785, -0.1962],
       grad_fn=<SliceBackward0>)
geometric product elementwise block produces torch.Size([4, 64, 8]), example
tensor([-0.0610,  0.0238,  0.0043,  0.0111, -0.0109,  0.0018,  0.0054,  0.0000],
       grad_fn=<SliceBackward0>)
geometric product sandwich dense block produces torch.Size([4, 64, 8]), example
tensor([-9.8562e-02,  2.1833e-02, -4.2696e-03,  9.7732e-03,  2.4745e-09,
        -3.5034e-09,  3.4506e-09, -1.0405e-09], grad_fn=<SliceBackward0>)
geometric product sandwich dense elementwise block produces torch.Size([4, 64, 8]), example
tensor([-5.1293e-04, -7.1052e-04,  1.3378e-04,  1.9419e-04,  4.65