# 3D groups

Here we derive the groups used for 3D group convolutions.

For more explanation I recommend:

- The [paper](https://arxiv.org/abs/1804.04656) on 3D group convolutions by Marysia Winkels and Taco Cohen.
- The [wikipedia page](https://en.wikipedia.org/wiki/Octahedral_symmetry) on these groups.
- My explanation notebook of group convolutions on the [github page](https://www.github.com/apjansen/geqco).

However the approach here is a lot simpler.

# The groups in terms of flips and rotations

We only need to consider the point group, its interaction with the 3D translations is as before.

Following the realization in the 2D case that all the group actions on a tensor correspond to either flips of an axis or permutations of axes, we can very easily derive the symmetry group of a cube, and its subgroups.

## The largest group Oh

A 3D tensor has 3 axes, so we can enumerate the transformations we can do on it that preserves its shape, assuming first that it is a cube, i.e. the 3 axes are the same length:

- flips: for each axis either flip it or not, for 2^3 = 8 combinations
- permutations: any permutation of the 3 axes, for 3! = 6 combinations

for a total of 8 * 6 = 48 transformations, which is indeed the order of the group Oh.

Note that flips in different axes mutually commute, but they do not commute with permutations. So to be concrete, we'll agree to always perform flips before permutations.

### Group composition

Of course the fact that the numbers match does not prove that the groups are the same (isomorphic), and we still need to derive the multiplication table.
Denoting a generic flip by $F$ and a generic permutation by $P$, according to the claims above we can write any element as $P \odot F$.

To compute the multiplication table we need to write

$(P_a \odot F_a) \odot (P_b \odot F_b)$ in the form  $P_c \odot F_c$

$F$ can be expressed as a binary tuple of length 3, a one indicating the axis is flipped. We can then let a permutation $P$ act on a flip in the natural way: permuting this tuple. With some thought it becomes clear that

$F \odot P = P \odot P^{-1}(F)$

i.e. first permuting and then applying the flip on the result is the same as first applying the inversely permuted flip, and then permuting. In both cases the axes where the tuple $F$ has a 1 are flipped, and the permutation is the same.

Using this we can write:

$(P_a \odot F_a) \odot (P_b \odot F_b) = (P_a \odot P_b) \odot (P_b^{-1}(F_a) \odot F_b)$,

where the composition of two flips simply adds the two tuples mod 2.

A permutation can also be expressed as a tuple, itself a permutation of (0, 1, 2) in this case, where the value $v_i$ of component $i$ indicates which of the axes in the input ends up at axis $i$ in the output, as it is used in tensorflow.
Again after some thought it becomes clear that:

$P_a \odot P_b = P_a(P_b)$,

i.e. first permuting with $P_b$ and then with $P_a$ on the result is the same as permuting once, with the permutation $P_a(P_b)$, i.e. $P_b$ itself permuted as a tuple by $P_a$.

So summarizing:

$(P_a \odot F_a) \odot (P_b \odot F_b) = P_a(P_b) \odot (P_b^{-1}(F_a)+ F_b)$

which suffices to compute the full multiplication table.

## O: the subgroup of Oh that preserves orientation

A flip reverses the parity or chirality of the tensor, giving its mirror image. 
But parity is a binary concept. A second flip, even in a different direction, restores the original parity.

This doesn't take into consideration the permutations though. Permutations can also change parity. In particular a _transposition_, which switches two axes and leaves all others invariant, reverses parity. Recall that in the 2D case we saw that a rotation can be written as a transposition and a flip, which is an even number of parity reversing operations, thus indeed preserves parity.

Any permutation can be written as a product of transpositions, and clearly from the above follows that permutations that consist of an even number of transpositions preserve parity, while those consisting of an odd number reverse it.

So for an element of the full group to preserve orientation, we need that the number of transpositions plus the number of flips is even.
This halves the size of the group to 24, as choosing e.g. a permutation and two components of the flip tuple, the last one is fixed by this requirement.

## D4h: treating the third dimension differently

If the third dimension is special, we cannot permute it with the other two.
So we are left with only 2! = 2 permutations, for a total of 16 elements.
Note that we can still do all flips.

## D4: as D4h but also preserving orientation

This combines the previous restrictions, an even number of flips and permutations only among the first two axes, for a total of 8 elements.

# Implementation

## Oh

We need to know the inverse of each permutation, but with only 3 elements this is easy to compute by hand. In fact the identity and all permutations consisting of a single transposition are their own inverses, and the remaining two are each other's inverses.

In [60]:
permutation_tuples = [(0, 1, 2), (1, 2, 0), (2, 0, 1), (0, 2, 1), (2, 1, 0), (1, 0, 2)]
permutation_invers = [(0, 1, 2), (2, 0, 1), (1, 2, 0), (0, 2, 1), (2, 1, 0), (1, 0, 2)]
p_inv_dict = {p: p_inv for p, p_inv in zip(permutation_tuples, permutation_invers)}

Similarly to what we did for the wallpaper group p4m, we'll construct a class of group elements that implements group composition using the rule we derived above, and use that to compute the multiplication table.

In [61]:
import tensorflow as tf

class Oh:
  def __init__(self, 
               flips: tuple = (0, 0, 0),
               permutation: tuple = (0, 1, 2)):
    self.flips = flips
    self.permutation = permutation

  def __call__(self, x):
    x = tf.reverse(x, axis=[axis for axis, flip in enumerate(self.flips) if flip])
    x = tf.transpose(x, self.permutation)
    return x

  def is_identity(self):
    return self.flips == (0, 0, 0) and self.permutation == (0, 1, 2)

  def __eq__(self, other):
    return self.flips == other.flips and self.permutation == other.permutation

  def __mul__(a, b):
    """Group composition a o b."""
    aflip_permuted = a.transpose_tuple(a.flips, p_inv_dict[b.permutation])
    flips = tuple((a + b) % 2 for a, b in zip(aflip_permuted, b.flips))
    permutation = a.transpose_tuple(b.permutation, a.permutation)
    return Oh(flips=flips, permutation=permutation)

  @staticmethod
  def transpose_tuple(tup, perm):
    return tuple(tup[p] for p in perm)

  def __repr__(self):
    if self.is_identity():
      return 'e'
    string = ''.join([c for c, f in zip(['H', 'W', 'D'], self.flips) if f])
    return str(self.permutation) + string

  def __hash__(self):
    return hash((self.flips, self.permutation))

In [62]:
import itertools

all_permutations = [Oh(permutation=permutation) for permutation in permutation_tuples]
flips_tuples = [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 1)]
all_flips = [Oh(flips=flips) for flips in flips_tuples]
elements = [perm * flip for perm, flip in itertools.product(all_permutations, all_flips)]
element_numbers = {el: nr for nr, el in enumerate(elements)}

In [63]:
group_multiplication = [[row * col for col in elements] for row in elements]
group_permutation = [[element_numbers[row * col] for col in elements] for row in elements]

### Check compatibility composition and action

In [64]:
test = tf.random.normal((28, 28, 28))
maxdiffs = []
for r in range(48):
  for c in range(48):
    g_r = elements[r]
    g_c = elements[c]
    g_roc = group_multiplication[r][c]
    act_c = g_c(test)
    act_r_act_c = g_r(act_c)
    act_roc = g_roc(test)
    maxdiff = tf.reduce_max(tf.abs(act_r_act_c - act_roc)).numpy()
    maxdiffs.append(maxdiff)
print(f"The maximal absolute difference between g(h(x)) and (g o h)(x) is: {max(maxdiffs)}.")

The maximal absolute difference between g(h(x)) and (g o h)(x) is: 0.0.


In [56]:
[i for i, e in enumerate(elements) if (e.permutation == (0, 2, 1) and e.flips == (0, 0, 0))]

[24]

In [59]:
element_numbers[elements[24] * elements[4]], element_numbers[elements[32] * elements[4]]

(28, 36)

In [48]:
for row in group_multiplication:
  print(row)

[e, (0, 1, 2)WD, (0, 1, 2)HD, (0, 1, 2)HW, (0, 1, 2)D, (0, 1, 2)W, (0, 1, 2)H, (0, 1, 2)HWD, (1, 2, 0), (1, 2, 0)WD, (1, 2, 0)HD, (1, 2, 0)HW, (1, 2, 0)D, (1, 2, 0)W, (1, 2, 0)H, (1, 2, 0)HWD, (2, 0, 1), (2, 0, 1)WD, (2, 0, 1)HD, (2, 0, 1)HW, (2, 0, 1)D, (2, 0, 1)W, (2, 0, 1)H, (2, 0, 1)HWD, (0, 2, 1), (0, 2, 1)WD, (0, 2, 1)HD, (0, 2, 1)HW, (0, 2, 1)D, (0, 2, 1)W, (0, 2, 1)H, (0, 2, 1)HWD, (2, 1, 0), (2, 1, 0)WD, (2, 1, 0)HD, (2, 1, 0)HW, (2, 1, 0)D, (2, 1, 0)W, (2, 1, 0)H, (2, 1, 0)HWD, (1, 0, 2), (1, 0, 2)WD, (1, 0, 2)HD, (1, 0, 2)HW, (1, 0, 2)D, (1, 0, 2)W, (1, 0, 2)H, (1, 0, 2)HWD]
[(0, 1, 2)WD, e, (0, 1, 2)HW, (0, 1, 2)HD, (0, 1, 2)W, (0, 1, 2)D, (0, 1, 2)HWD, (0, 1, 2)H, (1, 2, 0)HD, (1, 2, 0)HW, (1, 2, 0), (1, 2, 0)WD, (1, 2, 0)H, (1, 2, 0)HWD, (1, 2, 0)D, (1, 2, 0)W, (2, 0, 1)HW, (2, 0, 1)HD, (2, 0, 1)WD, (2, 0, 1), (2, 0, 1)HWD, (2, 0, 1)H, (2, 0, 1)W, (2, 0, 1)D, (0, 2, 1)WD, (0, 2, 1), (0, 2, 1)HW, (0, 2, 1)HD, (0, 2, 1)W, (0, 2, 1)D, (0, 2, 1)HWD, (0, 2, 1)H, (2, 1, 0)HW, (

### The table

In [49]:
for row in group_permutation:
  print(row)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
[1, 0, 3, 2, 5, 4, 7, 6, 10, 11, 8, 9, 14, 15, 12, 13, 19, 18, 17, 16, 23, 22, 21, 20, 25, 24, 27, 26, 29, 28, 31, 30, 35, 34, 33, 32, 39, 38, 37, 36, 42, 43, 40, 41, 46, 47, 44, 45]
[2, 3, 0, 1, 6, 7, 4, 5, 11, 10, 9, 8, 15, 14, 13, 12, 17, 16, 19, 18, 21, 20, 23, 22, 27, 26, 25, 24, 31, 30, 29, 28, 34, 35, 32, 33, 38, 39, 36, 37, 41, 40, 43, 42, 45, 44, 47, 46]
[3, 2, 1, 0, 7, 6, 5, 4, 9, 8, 11, 10, 13, 12, 15, 14, 18, 19, 16, 17, 22, 23, 20, 21, 26, 27, 24, 25, 30, 31, 28, 29, 33, 32, 35, 34, 37, 36, 39, 38, 43, 42, 41, 40, 47, 46, 45, 44]
[4, 5, 6, 7, 0, 1, 2, 3, 14, 15, 12, 13, 10, 11, 8, 9, 21, 20, 23, 22, 17, 16, 19, 18, 29, 28, 31, 30, 25, 24, 27, 26, 38, 39, 36, 37, 34, 35, 32, 33, 44, 45, 46, 47, 40, 41, 42, 43]
[5, 4, 7, 6, 1, 0, 3, 2, 12, 13, 14, 15, 8, 9, 10, 11, 22, 23, 20, 21, 18, 19, 16, 17

Inverses can be obtained simply from the group composition

In [65]:
inverses = [[col for col in range(len(group_permutation)) if group_permutation[row][col] == 0][0] for row in range(len(group_permutation))]
print(inverses)

[0, 1, 2, 3, 4, 5, 6, 7, 16, 19, 17, 18, 21, 22, 20, 23, 8, 10, 11, 9, 14, 12, 13, 15, 24, 25, 27, 26, 29, 28, 30, 31, 32, 35, 34, 33, 38, 37, 36, 39, 40, 42, 41, 43, 44, 46, 45, 47]


## Subgroups:

For the subgroups, it suffices to know what elements in the bigger group lie in the subgroup. Then to compute the multiplication table of the subgroup, we can just take those rows and columns of the bigger group's table.

### O

In [66]:
permutation_tuples = [(0, 1, 2), (1, 2, 0), (2, 0, 1), (0, 2, 1), (2, 1, 0), (1, 0, 2)]
flips_tuples = [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 1)]

In the permutation and flip tuples above, repeated here for convenience, I've intentionlly put the even ones in both lists in the first half, and the odds in the second half.

This is easy to check for the flips: just check that the number of ones is even.
For the permutations, note that if one axis is left invariant while the other two are not, that is a transposition, which is odd.

So the parities decompose as:

In [67]:
even_perms, odd_perms = all_permutations[:3], all_permutations[3:]
even_flips, odd_flips = all_flips[:4], all_flips[4:]

And they combine into an orientation preserving transformation if either both are even or both are odd:

In [68]:
elements_o = [perm * flip for perm, flip in itertools.product(even_perms, even_flips)]
elements_o += [perm * flip for perm, flip in itertools.product(odd_perms, odd_flips)]

Now that we have the elements, we can reconstruct the multiplication table as before, or we can find what element numbers these correspond to in the bigger group, and extract those rows and columns.

In [69]:
element_numbers_o = [nr for nr, el in enumerate(elements) if (el in elements_o)]
print(element_numbers_o)

[0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47]


### D4h

This is the subgroup that treats the third axis differently, so we simply restrict to those 2 permutations that do that

In [13]:
perms_d4h = [all_permutations[i] for i in (0, -1)]
elements_d4h = [perm * flip for perm, flip in itertools.product(perms_d4h, all_flips)]
print(len(elements_d4h))
print(elements_d4h)

16
[e, (0, 1, 2)WD, (0, 1, 2)HD, (0, 1, 2)HW, (0, 1, 2)D, (0, 1, 2)W, (0, 1, 2)H, (0, 1, 2)HWD, (1, 0, 2), (1, 0, 2)WD, (1, 0, 2)HD, (1, 0, 2)HW, (1, 0, 2)D, (1, 0, 2)W, (1, 0, 2)H, (1, 0, 2)HWD]


In [14]:
element_numbers_d4h = [nr for nr, el in enumerate(elements) if (el in elements_d4h)]
print(len(element_numbers_d4h))
print(element_numbers_d4h)

16
[0, 1, 2, 3, 4, 5, 6, 7, 40, 41, 42, 43, 44, 45, 46, 47]


### D4

Now we have a single even and a single odd permutation left, and we need to combine those with the even and odd flips respectively

In [15]:
elements_d4 = [perms_d4h[0] * flip for flip in even_flips]
elements_d4 += [perms_d4h[1] * flip for flip in odd_flips]
element_numbers_d4 = [nr for nr, el in enumerate(elements) if (el in elements_d4)]
print(len(element_numbers_d4))
print(element_numbers_d4)

8
[0, 1, 2, 3, 44, 45, 46, 47]
