In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import tensorflow as tf
# Make tensorflow not take over the entire GPU memory
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
from tfga import GeometricAlgebra
from tfga.blades import BladeKind
from tfga.layers import GeometricProductConv1D

In [2]:
ga = GeometricAlgebra([0, 1, 1, 1])

batch_size = 2
sequence_length = 8
c_in = 3
c_out = 4
kernel_size = 3

a = ga.from_tensor_with_kind(tf.ones([batch_size, sequence_length, c_in, ga.num_blades]), BladeKind.MV)
k = ga.from_tensor_with_kind(tf.ones([kernel_size, c_in, c_out, ga.num_blades]), BladeKind.MV)

y = ga.geom_conv1d(a, k, 2, "SAME")

print(y.shape)
print(y)

(2, 4, 4, 16)
tf.Tensor(
[[[[  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]]

  [[  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]]

  [[  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36.  36.]
   [  0. -36.   0.  36.   0.  36.  36. -36.  36.   0.  36.  36.  36.
     36.  36

In [3]:
mv_indices = tf.range(ga.num_blades, dtype=tf.int64)

conv_layer = GeometricProductConv1D(
    ga, filters=c_out, kernel_size=kernel_size, stride=2, padding="SAME",
    blade_indices_kernel=tf.range(ga.num_blades, dtype=tf.int64),
    blade_indices_bias=tf.range(ga.num_blades, dtype=tf.int64)
)

y2 = conv_layer(a)
print(y2.shape)
ga.print(y2)
ga.print(y2[0, 0, 0])

(2, 4, 4, 16)
MultiVector[batch_shape=(2, 4, 4)]
MultiVector[-0.86*1 + 0.12*e_0 + -0.86*e_1 + 0.24*e_2 + 0.55*e_3 + -1.85*e_01 + -1.05*e_02 + 2.10*e_03 + 0.24*e_12 + 0.55*e_13 + -1.29*e_23 + 1.53*e_012 + -1.01*e_013 + -1.56*e_023 + -1.29*e_123 + -1.02*e_0123]
