# Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds

Author: Xinling Yu

## Introduction

In this notebook, we study tensor field neural networks [1], which are locally equivariant to 3D rotations, translations, and permutations of points at every layer. Convolutional neural networks are translation-equivariant, which means that features can be indetified anywhere in an input image. However, in order to recognize features with arbitrary 3D rotations and their orientations, data augmentation is required in traditional convolutional neural networks. To eliminate the need for data augmentation, the authors use continuous equivalent filters constructed from spherical harmonics, which enjoy richer equivariance: the symmetries of 3D Euclidean space.

This work builds upon Harmonic Networks [2], which achieves 2D rotation equivariance using discrete convolutions and filters composed of circular harmonics and SchNet [3], which presents a rotation-invariant network using continuous convolutions. Similar problems of invariance or equivalence under specific input transformations have been dealt with in previous work. G-CNNs [4] guarantee equivariance with respect to finite symmetry groups (unlike the continuous groups in this work). Spherical CNNs [5] use spherical harmonics and Wigner D-matrices but only address spherical signals (2D data on the surface of a sphere). Differ from these works, tensor field neural networks can achieve 3D rotation- and translation- equivarianence.

The rest of the notebook is organized as follows. In Section 2, we present a brief overview of group representations and equivariance in 3D, then we introduce tensor field neural network layers. Next, we implement the tensor field neural network model using Geomstats [6] and Tensorflow, see Section 3 and layer.py. Furthermore, we demonstrate the 3D rotation- and translation-equivariance of the model using a toy 3D Teris shapes dataset, we also present the results using a novel preshape space method implemented in the Geomstats package [6] in another notebook Preshape_space.ipynb. Finally we compare the results from the tensor field networks with the results from preshape space method.

In [1]:
%%html
<center><img src='images/motivation.png', width=800, height=800>

## 2 Background and Model

### 2.1 Group Representations and Equivariance in 3D

A representation $D$ of a group $G$ is a function from $G$ to square matrices such that for all $g, h \in G$,
$$
D(g) D(h)=D(g h)
$$
A function $\mathcal{L}: \mathcal{X} \rightarrow \mathcal{Y}$ (for vector spaces $\mathcal{X}$ and $\mathcal{Y}$ ) is equivariant with respect to a group $G$ and group representations $D^{\mathcal{X}}$ and $D^{\mathcal{Y}}$ if for all $g \in G$,
$$
\mathcal{L} \circ D^{\mathcal{X}}(g)=D^{\mathcal{Y}}(g) \circ \mathcal{L}
$$
Tensor field networks act on points with associated features. A layer $\mathcal{L}$ takes a finite set $S$ of vectors in $\mathbb{R}^{3}$ and a vector in $\mathcal{X}$ at each point in $S$ and outputs a vector in $\mathcal{Y}$ at each point in $S$, where $\mathcal{X}$ and $\mathcal{Y}$ are vector spaces. We write this as
$$
\mathcal{L}\left(\vec{r}_{a}, x_{a}\right)=\left(\vec{r}_{a}, y_{a}\right)
$$
where $\vec{r}_{a} \in \mathbb{R}^{3}$ are the point coordinates and $x_{a} \in \mathcal{X}, y_{a} \in \mathcal{Y}$ are the feature vectors ( $a$ being an indexing scheme over the points in $S$ ). This combination of $\mathbb{R}^{3}$ with another vector space can be written as $\mathbb{R}^{3} \oplus \mathcal{X}$, where $\oplus$ refers to concatenation.
Next we study the equivariance in 3D.
1. Permutation equivariance: $\mathcal{L} \circ \mathcal{P}_{\sigma}=\mathcal{P}_{\sigma} \circ \mathcal{L}$, where $\mathcal{P}_{\sigma}\left(\vec{r}_{a}, x_{a}\right):=\left(\vec{r}_{\sigma(a)}, x_{\sigma(a)}\right)$ and $\sigma$ permutes the points to which the indices refer.
2. Translation equivariance: $\mathcal{L} \circ \mathcal{T}_{\vec{t}}=\mathcal{T}_{\vec{t}} \circ \mathcal{L}$, where $\mathcal{T}_{\vec{t}}\left(\vec{r}_{a}, x_{a}\right):=\left(\vec{r}_{a}+\vec{t}, x_{a}\right)$.
3. Rotation equivariance: Let $D^{\mathcal{X}}$ be a representation of $S O(3)$ on a vector space $\mathcal{X}\left(\right.$ and $D^{\mathcal{Y}}$ on $\mathcal{Y})$. Acting with $g \in S O(3)$ on $\vec{r} \in \mathbb{R}^{3}$ we write as $\mathcal{R}(g) \vec{r}$, and acting on $x \in \mathcal{X}$ gives $D^{\mathcal{X}}(g) x$. Then the condition for rotation equivariance is $\mathcal{L} \circ\left[\mathcal{R}(g) \oplus D^{\mathcal{X}}(g)\right]=\left[\mathcal{R}(g) \oplus D^{\mathcal{Y}}(g)\right] \circ \mathcal{L}$, where $\left[\mathcal{R}(g) \oplus D^{\mathcal{X}}(g)\right]\left(\vec{r}_{a}, x_{a}\right)=\left(\mathcal{R}(g) \vec{r}_{a}, D^{\mathcal{X}}(g) x_{a}\right)$.

### 2.2 Tensor Field Neural Network Layers

The input and output $V_{a c m}^{(l)}$ of each layer of a tensor field network is a finite set $S$ of points in $\mathbb{R}^{3}$ and a vector in a representation of $S O(3)$ associated with each point. This object $V_{a c m}^{(l)}$ is implemented as a dictionary with key $l$ (the rotation order) of multidimensional arrays each with shapes $\left[|S|, n_{l}, 2 l+1\right]$ (where $n_{l}$ is the number of channels) corresponding to [point index, channel index, representation index]. The following figure shows two point masses with velocity (purple) and acceleration (orange).

In [1]:
%%html
<center><img src='images/tensor.png', width=800, height=800>

Next we define three tensor field network layers. We refer to [1] to check more details about why they are equivariant.
1. Point convolution: A given input inhabits one representation, a filter inhabits another, and together these produce outputs at possibly many rotation orders. We can put everything together into our pointwise convolution layer definition:
$$
\mathcal{L}_{a c m_{o}}^{\left(l_{o}\right)}\left(\vec{r}_{a}, V_{a c m_{i}}^{\left(l_{i}\right)}\right):=\sum_{m_{f}, m_{i}} C_{\left(l_{f}, m_{f}\right)\left(l_{i}, m_{i}\right)}^{\left(l_{o}, m_{o}\right)} \sum_{b \in S} F_{c m_{f}}^{\left(l_{f}, l_{i}\right)}\left(\vec{r}_{a b}\right) V_{b c m_{i}}^{\left(l_{i}\right)}
$$
where $\vec{r}_{a b}:=\vec{r}_{a}-\vec{r}_{b}$ and the subscripts $i, f$, and $o$ denote the representations of the input, filter, and output, respectively.
2. Self-interaction: Self-interaction layers are analogous to $1 \times 1$ convolutions, and they act like $l=0$ (scalar) filters:
$$
\sum_{c^{\prime}} W_{c c^{\prime}}^{(l)} V_{a c^{\prime} m}^{(l)}
$$
3. Nonlinearity: Nonlinearity layer acts as a scalar transform in the $l$ spaces (that is, along the $m$ dimension). For $l=0$ channels, we can use
$\eta^{(0)}\left(V_{a c}^{(0)}+b_{c}^{(0)}\right) \quad$ and $\quad \eta^{(l)}\left(\|V\|_{a c}^{(l)}+b_{c}^{(l)}\right) V_{a c m}^{(l)} \quad$ where $\quad\|V\|_{a c}^{(l)}:=\sqrt{\sum_{m}\left|V_{a c m}^{(l)}\right|^{2}}$
for some functions $\eta^{(l)}: \mathbb{R} \rightarrow \mathbb{R}$ (which can be different for each $l$ ) and biases $b_{c}^{(l)}$.

In [2]:
%%html
<center><img src='images/network.png', width=800, height=800>

## 3 Implementation

In [2]:
import tensorflow.compat.v1 as tf
import numpy as np
import scipy.linalg
import os
os.environ['GEOMSTATS_BACKEND'] = 'tensorflow'
import geomstats.backend as gs # import the Geomstats tensorflow 2.2 backend
tf.compat.v1.disable_eager_execution()

FLOAT_TYPE = gs.float32
EPSILON = 1e-8


def get_eijk():
    """
    Constant Levi-Civita tensor
    Returns:
        tf.Tensor of shape [3, 3, 3]
    """
    eijk_ = np.zeros((3, 3, 3))
    eijk_[0, 1, 2] = eijk_[1, 2, 0] = eijk_[2, 0, 1] = 1.
    eijk_[0, 2, 1] = eijk_[2, 1, 0] = eijk_[1, 0, 2] = -1.
    return tf.constant(eijk_, dtype=FLOAT_TYPE)


def norm_with_epsilon(input_tensor, axis=None, keep_dims=False):
    """
    Regularized norm
    Args:
        input_tensor: tf.Tensor
    Returns:
        tf.Tensor normed over axis
    """
    return gs.sqrt(gs.maximum(tf.reduce_sum(tf.square(input_tensor), axis=axis, keep_dims=keep_dims), EPSILON))


def ssp(x):
    """
    Shifted soft plus nonlinearity.
    Args:
        x: tf.Tensor
    Returns:
        tf.Tensor of same shape as x 
   """
    return gs.log(0.5 * gs.exp(x) + 0.5)


def rotation_equivariant_nonlinearity(x, nonlin=ssp, biases_initializer=None):
    """
    Rotation equivariant nonlinearity.
    The -1 axis is assumed to be M index (of which there are 2 L + 1 for given L).
    Args:
        x: tf.Tensor with channels as -2 axis and M as -1 axis.
    Returns:
        tf.Tensor of same shape as x with 3d rotation-equivariant nonlinearity applied.
    """
    if biases_initializer is None:
        biases_initializer = tf.constant_initializer(0.)
    shape = x.get_shape().as_list()
    channels = shape[-2]
    representation_index = shape[-1]

    biases = tf.get_variable('biases',
                             [channels],
                             dtype=FLOAT_TYPE,
                             initializer=biases_initializer)

    if representation_index == 1:
        return nonlin(x)
    else:
        norm = norm_with_epsilon(x, axis=-1)
        nonlin_out = nonlin(gs.nn.bias_add(norm, biases))
        factor = gs.divide(nonlin_out, norm)
        # Expand dims for representation index.
        return tf.multiply(x, gs.expand_dims(factor, axis=-1))
    


def difference_matrix(geometry):
    """
    Get relative vector matrix for array of shape [N, 3].
    Args:
        geometry: tf.Tensor with Cartesian coordinates and shape [N, 3]
    Returns:
        Relative vector matrix with shape [N, N, 3]
    """
    # [N, 1, 3]
    ri = gs.expand_dims(geometry, axis=1)
    # [1, N, 3]
    rj = gs.expand_dims(geometry, axis=0)
    # [N, N, 3]
    rij = ri - rj
    return rij


def distance_matrix(geometry):
    """
    Get relative distance matrix for array of shape [N, 3].
    Args:
        geometry: tf.Tensor with Cartesian coordinates and shape [N, 3]
    Returns:
        Relative distance matrix with shape [N, N]
    """
    # [N, N, 3]
    rij = difference_matrix(geometry)
    # [N, N]
    dij = norm_with_epsilon(rij, axis=-1)
    return dij


def random_rotation_matrix(numpy_random_state):
    """
    Generates a random 3D rotation matrix from axis and angle.
    Args:
        numpy_random_state: numpy random state object
    Returns:
        Random rotation matrix.
    """
    rng = numpy_random_state
    axis = rng.randn(3)
    axis /= np.linalg.norm(axis) + EPSILON
    theta = 2 * np.pi * rng.uniform(0.0, 1.0)
    return rotation_matrix(axis, theta)


def rotation_matrix(axis, theta):
    return scipy.linalg.expm(np.cross(np.eye(3), axis * theta))

INFO: Using tensorflow backend


## 4 3D Tetris Shape Classification

The training dataset only contains 8 shapes, we evaluate the predition accuracy of tensor field networks on testing data set that is created by rotating and translating the training dataset. We use a 3-module network that includes the following for every module: all possible paths with l = 0 and l = 1 convolutions, concatenation, a self-interaction layer, and a rotation-equivariant nonlinearity. We only use the l = 0 output of the network since the shape classes are invariant under rotation and hence scalars. To get a classification from the l = 0 output of the network, we sum over the feature vectors of all points.

In [3]:
%%html
<center><img src='images/data.png', width=800, height=800>

### 4.1 Tensor Field Networks

In [4]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import random
from math import pi, sqrt
import layers

tetris = [[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
          [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2
          [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
          [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # T
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # zigzag
          [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)]]  # L

dataset = [np.array(points_) for points_ in tetris]
num_classes = len(dataset)

In [5]:
# radial basis functions
rbf_low = 0.0
rbf_high = 3.5
rbf_count = 4
rbf_spacing = (rbf_high - rbf_low) / rbf_count
centers = tf.cast(tf.lin_space(rbf_low, rbf_high, rbf_count), FLOAT_TYPE)

# r : [N, 3]
r = tf.placeholder(FLOAT_TYPE, shape=(4, 3))

# rij : [N, N, 3]
rij = difference_matrix(r)

# dij : [N, N]
dij = distance_matrix(r)

# rbf : [N, N, rbf_count]
gamma = 1. / rbf_spacing
rbf = tf.exp(-gamma * tf.square(tf.expand_dims(dij, axis=-1) - centers))

layer_dims = [1, 4, 4, 4]
num_layers = len(layer_dims) - 1

# embed : [N, layer1_dim, 1]
with tf.variable_scope(None, "embed"):
    embed = layers.self_interaction_layer_without_biases(tf.ones(shape=(4, 1, 1)), layer_dims[0])

input_tensor_list = {0: [embed]}

for layer, layer_dim in enumerate(layer_dims[1:]):
    with tf.variable_scope(None, 'layer' + str(layer), values=[input_tensor_list]):
        input_tensor_list = layers.convolution(input_tensor_list, rbf, rij)
        input_tensor_list = layers.concatenation(input_tensor_list)
        input_tensor_list = layers.self_interaction(input_tensor_list, layer_dim)
        input_tensor_list = layers.nonlinearity(input_tensor_list)

tfn_scalars = input_tensor_list[0][0]
tfn_output_shape = tfn_scalars.get_shape().as_list()
tfn_output = tf.reduce_mean(tf.squeeze(tfn_scalars), axis=0)
fully_connected_layer = tf.get_variable('fully_connected_weights', 
                                        [tfn_output_shape[-2], len(dataset)], dtype=FLOAT_TYPE)
output_biases = tf.get_variable('output_biases', [len(dataset)], dtype=FLOAT_TYPE)

# output : [num_classes]
output = tf.einsum('xy,x->y', fully_connected_layer, tfn_output) + output_biases

tf_label = tf.placeholder(tf.int32)

# truth : [num_classes]
truth = tf.one_hot(tf_label, num_classes)

# loss : []
loss = tf.nn.softmax_cross_entropy_with_logits(labels=truth, logits=output)

optim = tf.train.AdamOptimizer(learning_rate=1.e-3)

train_op = optim.minimize(loss)

Instructions for updating:
keep_dims is deprecated, use keepdims instead


Instructions for updating:
keep_dims is deprecated, use keepdims instead


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.



Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.



In [6]:
%%time

max_epochs = 2001
print_freq = 100

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# training
for epoch in range(max_epochs):    
    loss_sum = 0.
    for label, shape in enumerate(dataset):
        loss_value, _ = sess.run([loss, train_op], feed_dict={r: shape, tf_label: label})
        loss_sum += loss_value
        
    if epoch % print_freq == 0:
        print("Epoch %d: validation loss = %.3f" % (epoch, loss_sum / len(dataset)))

Epoch 0: validation loss = 2.170
Epoch 100: validation loss = 1.254
Epoch 200: validation loss = 0.639
Epoch 300: validation loss = 0.352
Epoch 400: validation loss = 0.070
Epoch 500: validation loss = 0.021
Epoch 600: validation loss = 0.009
Epoch 700: validation loss = 0.005
Epoch 800: validation loss = 0.003
Epoch 900: validation loss = 0.002
Epoch 1000: validation loss = 0.001
Epoch 1100: validation loss = 0.001
Epoch 1200: validation loss = 0.000
Epoch 1300: validation loss = 0.000
Epoch 1400: validation loss = 0.000
Epoch 1500: validation loss = 0.000
Epoch 1600: validation loss = 0.000
Epoch 1700: validation loss = 0.000
Epoch 1800: validation loss = 0.000
Epoch 1900: validation loss = 0.000
Epoch 2000: validation loss = 0.000
CPU times: user 57.9 s, sys: 19.2 s, total: 1min 17s
Wall time: 25.3 s


In [7]:
%%time

rng = np.random.RandomState()
test_set_size = 25 # we have 25 x 8 test shapes
predictions = [list() for i in range(len(dataset))]

correct_predictions = 0
total_predictions = 0
for i in range(test_set_size):
    for label, shape in enumerate(dataset):
        rotation = random_rotation_matrix(rng)
        rotated_shape = np.dot(shape, rotation)
        translation = np.expand_dims(np.random.uniform(low=-3., high=3., size=(3)), axis=0)
        translated_shape = rotated_shape + translation
        output_label = sess.run(gs.argmax(output), 
                                feed_dict={r: translated_shape, tf_label: label})
        total_predictions += 1
        if output_label == label:
            correct_predictions += 1
print('Test accuracy: %f' % (float(correct_predictions) / total_predictions))

Test accuracy: 1.000000
CPU times: user 1min 35s, sys: 8.85 s, total: 1min 44s
Wall time: 34.8 s


### 4.2 Geomstats: Preshape Space Method

Point clouds can also be treated as landmarks, then we can project point cloud to the preshape space and check which shapes it align the best. In the tetris shape classification task, the workflow can be described as follows:
1. Compute the tetris shapes projections on the preshape space.
2. Given an input rotated and translated tetris shape, compute its projection on the preshape space. Then we align the projection to the projected tetris shapes.
3. Find the best alignment.

Check more details in the notebook Preshape_space.ipynb.

From the results from two notebooks, we can see that both tensor field networks and preshape space method achieve $100\%$ accuracy in the 3D tetris shape classification task, which is sufficient to demonstrate that they are 3D rotation- and translation- equivalent. And preshape space method is much faster than the tensor field networks (1.85s vs more than 2mins).

## 5 Conclusion

In this notebook, we have explained the theory of tensor field networks and demonstrated their capabilities of 3D rotation- and translation- equivarianence using a simple classification example. We also propose a simple yet very efficient preshape space method that can be easily implemented by Geomstats. The results show that the preshape space method is much faster than the tensor field networks. However, the 3D tetris shape classification is not a hard task since the input point clouds are just $4 \times 3$ matrices, in the future, it'll be very interesting to see if preshape space method and Geomstats can handle more complex point clouds data.

## 6 References

[1] Thomas, Nathaniel, et al. "Tensor field networks: Rotation-and translation-equivariant neural networks for 3d point clouds." arXiv preprint arXiv:1802.08219 (2018).

[2] Worrall, Daniel E., et al. "Harmonic networks: Deep translation and rotation equivariance." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2017.

[3] Schütt, Kristof T., et al. "Quantum-chemical insights from deep tensor neural networks." Nature communications 8.1 (2017): 1-8.

[4] Cohen, Taco, and Max Welling. "Group equivariant convolutional networks." International conference on machine learning. PMLR, 2016.

[5] Cohen, Taco S., et al. "Spherical cnns." arXiv preprint arXiv:1801.10130 (2018).

[6] Miolane, Nina, et al. "Geomstats: a Python package for Riemannian geometry in machine learning." Journal of Machine Learning Research 21.223 (2020): 1-9.