# E(n)-Equivariant Steerable CNNs  -  Equivariant MLPs


In [1]:
import sys
sys.path.append('../')

import torch
import numpy as np

from escnn import gspaces
from escnn import nn
from escnn import group

The **escnn** library also supports MLPs equivariant to compact groups, which can be seen as a special case for $n=0$.
This is done by replacing the convolution layers (e.g. [R3Conv](https://quva-lab.github.io/escnn/api/escnn.nn.html#r3conv)) with the [Linear](https://quva-lab.github.io/escnn/api/escnn.nn.html#linear) layer and by choosing the [no_base_space](https://quva-lab.github.io/escnn/api/escnn.gspaces.html#group-action-trivial-on-single-point) `GSpace` (e.g., instead of [rot3dOnR3](https://quva-lab.github.io/escnn/api/escnn.gspaces.html#escnn.gspaces.rot3dOnR3)). 

All other modules can be used in a similar way, e.g. batch-norm and non-linearities.


Here, we provide an example with `G=SO(3)` and one with `G=O(2)`.

In [2]:
class SO3MLP(nn.EquivariantModule):
    
    def __init__(self, n_classes=10):
        
        super(SO3MLP, self).__init__()
        
        # the model is equivariant to the group SO(3)
        self.G = group.so3_group()
        
        # since we are building an MLP, there is no base-space
        self.gspace = gspaces.no_base_space(self.G)
        
        # the input contains the coordinates of a point in the 3D space
        self.in_type = self.gspace.type(self.G.standard_representation())
        
        # Layer 1
        # We will use the representation of SO(3) acting on signals over a sphere, bandlimited to frequency 1
        # To apply a point-wise non-linearity (e.g. ELU), we need to sample the spherical signals over a finite number of points.
        # Note that this makes the equivariance only approximate.
        # The representation of SO(3) on spherical signals is technically a quotient representation,
        # identified by the subgroup of planar rotations, which has id=(False, -1) in our library
        
        # N.B.: the first this model is instantiated, the library computes numerically the spherical grids, which can take some time
        # These grids are then cached on disk, so future calls should be considerably faster.
        
        activation1 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=3, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=1).irreps, # include all frequencies up to L=1
            grid=self.G.sphere_grid(type='thomson', N=16), # build a discretization of the sphere containing 16 equally distributed points            
            inplace=True
        )
        
        # map with an equivariant Linear layer to the input expected by the activation function, apply batchnorm and finally the activation
        self.block1 = nn.SequentialModule(
            nn.Linear(self.in_type, activation1.in_type),
            nn.IIDBatchNorm1d(activation1.in_type),
            activation1,
        )
        
        # Repeat a similar process for a few layers
        
        # 8 spherical signals, bandlimited up to frequency 3
        activation2 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=8, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=3).irreps, # include all frequencies up to L=3
            grid=self.G.sphere_grid(type='thomson', N=40), # build a discretization of the sphere containing 40 equally distributed points            
            inplace=True
        )
        self.block2 = nn.SequentialModule(
            nn.Linear(self.block1.out_type, activation2.in_type),
            nn.IIDBatchNorm1d(activation2.in_type),
            activation2,
        )
        
        # 8 spherical signals, bandlimited up to frequency 3
        activation3 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=8, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=3).irreps, # include all frequencies up to L=3
            grid=self.G.sphere_grid(type='thomson', N=40), # build a discretization of the sphere containing 40 equally distributed points            
            inplace=True
        )
        self.block3 = nn.SequentialModule(
            nn.Linear(self.block2.out_type, activation3.in_type),
            nn.IIDBatchNorm1d(activation3.in_type),
            activation3,
        )
        
        # 5 spherical signals, bandlimited up to frequency 2
        activation4 = nn.QuotientFourierELU(
            self.gspace,
            subgroup_id=(False, -1),
            channels=5, # specify the number of spherical signals in the output features
            irreps=self.G.bl_sphere_representation(L=2).irreps, # include all frequencies up to L=2
            grid=self.G.sphere_grid(type='thomson', N=25), # build a discretization of the sphere containing 25 equally distributed points            
            inplace=True
        )
        self.block4 = nn.SequentialModule(
            nn.Linear(self.block3.out_type, activation4.in_type),
            nn.IIDBatchNorm1d(activation4.in_type),
            activation4,
        )
        
        # Final linear layer mapping to the output features
        # the output is a 5-dimensional vector transforming according to the Wigner-D matrix of frequency 2
        self.out_type = self.gspace.type(self.G.irrep(2))
        self.block5 = nn.Linear(self.block4.out_type, self.out_type)
    
    def forward(self, x: nn.GeometricTensor):
        
        # check the input has the right type
        assert x.type == self.in_type
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
     
        return x
    
    def evaluate_output_shape(self, input_shape: tuple):
        shape = list(input_shape)
        assert len(shape) ==2, shape
        assert shape[1] == self.in_type.size, shape
        shape[1] = self.out_type.size
        return shape

Let's build the model

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SO3MLP().to(device)

  return element.as_euler(param)


Let's test the equivariance of the model

In [4]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

model.eval()

B = 10

# generates B random points in 3D and wrap them in a GeometricTensor of the right type
x = model.in_type(torch.randn(B, 3))


print('##########################################################################################')
with torch.no_grad():
    y = model(x.to(device)).to('cpu')
    print("Outputs' magnitudes")
    print(torch.linalg.norm(y.tensor, dim=1).numpy().reshape(-1))
    print('##########################################################################################')
    print("Errors' magnitudes")
    for r in range(8):
        # sample a random rotation
        g = model.G.sample()
        
        x_transformed = g @ x
        x_transformed = x_transformed.to(device)

        y_transformed = model(x_transformed).to('cpu')
        
        # verify that f(g@x) = g@f(x)=g@y
        print(torch.linalg.norm(y_transformed.tensor - (g@y).tensor, dim=1).numpy().reshape(-1))        

print('##########################################################################################')
print()



##########################################################################################
Outputs' magnitudes
[0.1348 0.2192 0.2075 0.1902 0.2082 0.1977 0.1357 0.2424 0.227  0.1865]
##########################################################################################
Errors' magnitudes
[0.0079 0.024  0.0232 0.0187 0.0244 0.0121 0.0131 0.0207 0.0155 0.0113]
[0.0085 0.0183 0.0213 0.0175 0.0297 0.0146 0.0183 0.0167 0.0194 0.017 ]
[0.0131 0.0202 0.0153 0.0148 0.0178 0.0074 0.0086 0.0144 0.0123 0.009 ]
[0.0051 0.0132 0.0223 0.0143 0.0239 0.0216 0.0092 0.0153 0.0233 0.0091]
[0.0128 0.0288 0.0185 0.0136 0.0153 0.0142 0.0135 0.0194 0.0206 0.0076]
[0.0078 0.0197 0.0214 0.0153 0.0198 0.0132 0.0105 0.0145 0.0167 0.0083]
[0.0072 0.0237 0.0202 0.0132 0.0248 0.0109 0.0104 0.0268 0.0167 0.0112]
[0.005  0.0246 0.0243 0.0172 0.0259 0.016  0.0069 0.0187 0.0222 0.011 ]
##########################################################################################



In [5]:
class SO3MLPtensor(nn.EquivariantModule):
    
    def __init__(self, n_classes=10):
        
        super(SO3MLPtensor, self).__init__()
        
        # the model is equivariant to the group SO(3)
        self.G = group.so3_group()
        
        # since we are building an MLP, there is no base-space
        self.gspace = gspaces.no_base_space(self.G)
        
        # the input contains the coordinates of a point in the 3D space
        in_repr = self.G.standard_representation()
        self.in_type = self.gspace.type(in_repr)
        
        # Layer 1
        # We will use the representation of SO(3) acting on signals over a sphere, bandlimited to frequency 2
        # We use the tensor-product non-linearity, which is essentially a quadratic function.
        
        ttype = self.gspace.type(self.G.bl_sphere_representation(L=2))
        activation1 = nn.TensorProductModule(self.in_type, ttype)
        
        # First we apply batch-norm and then the non-linearity. 
        # In the next blocks, we will also include a Linear layer.
        self.block1 = nn.SequentialModule(
            nn.IIDBatchNorm1d(activation1.in_type),
            activation1,
        )
        
        # Repeat a similar process for a few layers
        
        # input and output types must have the same number of fields (here, 8)
        # the input one shouldn't have frequencies higher than the output of the previous block
        activation2 = nn.TensorProductModule(
            in_type = self.gspace.type(*[self.G.bl_sphere_representation(L=2)]*8),
            out_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8)    
        )
        self.block2 = nn.SequentialModule(
            nn.Linear(self.block1.out_type, activation2.in_type),
            nn.IIDBatchNorm1d(activation2.in_type),
            activation2,
        )
        
        activation3 = nn.TensorProductModule(
            in_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8),
            out_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8)    
        )
        self.block3 = nn.SequentialModule(
            nn.Linear(self.block2.out_type, activation3.in_type),
            nn.IIDBatchNorm1d(activation3.in_type),
            activation3,
        )
        
        activation4 = nn.TensorProductModule(
            in_type = self.gspace.type(*[self.G.bl_sphere_representation(L=3)]*8),
            out_type = self.gspace.type(*[self.G.irrep(2)]*8)    # the final layer only require frequency 2 features, so there is no point in generating other frequencies
        )
        self.block4 = nn.SequentialModule(
            nn.Linear(self.block3.out_type, activation4.in_type),
            nn.IIDBatchNorm1d(activation4.in_type),
            activation4,
        )
        
        # Final linear layer mapping to the output features
        # the output is a 5-dimensional vector transforming according to the Wigner-D matrix of frequency 2
        self.out_type = self.gspace.type(self.G.irrep(2))
        self.block5 = nn.Linear(self.block4.out_type, self.out_type)
    
    def forward(self, x: nn.GeometricTensor):
        
        # check the input has the right type
        assert x.type == self.in_type
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
     
        return x
    
    def evaluate_output_shape(self, input_shape: tuple):
        shape = list(input_shape)
        assert len(shape) ==2, shape
        assert shape[1] == self.in_type.size, shape
        shape[1] = self.out_type.size
        return shape

Let's build the model

In [6]:
device = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
model = SO3MLPtensor().to(device)

Let's test the equivariance of the model

In [7]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

model.eval()

B = 6

# generates B random points in 3D and wrap them in a GeometricTensor of the right type
x = model.in_type(torch.randn(B, 3))


print('##########################################################################################')
with torch.no_grad():
    y = model(x.to(device)).to('cpu')
    print("Outputs' magnitudes")
    print(torch.linalg.norm(y.tensor, dim=1).numpy().reshape(-1))
    print('##########################################################################################')
    print("Errors' magnitudes")
    for r in range(8):
        # sample a random rotation
        g = model.G.sample()
        
        x_transformed = g @ x
        x_transformed = x_transformed.to(device)

        y_transformed = model(x_transformed).to('cpu')
        
        # verify that f(g@x) = g@f(x)=g@y
        print(torch.linalg.norm(y_transformed.tensor - (g@y).tensor, dim=1).numpy().reshape(-1))        

print('##########################################################################################')
print()



##########################################################################################
Outputs' magnitudes
[133357.12        0.0907      0.      62198.625      80.044       0.    ]
##########################################################################################
Errors' magnitudes
[0.18   0.     0.     0.0688 0.0002 0.    ]
[0.0982 0.     0.     0.0999 0.0002 0.    ]
[0.3416 0.     0.     0.1136 0.0001 0.    ]
[0.1337 0.     0.     0.1165 0.0001 0.    ]
[0.2221 0.     0.     0.11   0.     0.    ]
[0.2336 0.     0.     0.076  0.0001 0.    ]
[0.2423 0.     0.     0.0608 0.0001 0.    ]
[0.1007 0.     0.     0.0967 0.0002 0.    ]
##########################################################################################



In [8]:
class SO2MLP(nn.EquivariantModule):
    
    def __init__(self, n_classes=10):
        
        super(SO2MLP, self).__init__()
        
        # the model is equivariant to the group O(2)
        self.G = group.so2_group()
        
        # since we are building an MLP, there is no base-space
        self.gspace = gspaces.no_base_space(self.G)
        
        # the input contains the coordinates of a point in the 2D space
        self.in_type = self.gspace.type(self.G.standard_representation())
        
        # Layer 1
        # We will use the regular representation of SO(2) acting on signals over SO(2) itself, bandlimited to frequency 1
        # Most of the comments on the previous SO(3) network apply here as well
       
        activation1 = nn.FourierELU(
            self.gspace,
            channels=3, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=1).irreps, # include all frequencies up to L=1
            inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 6 equally distributed points
            type='regular', N=6,   
        )
        
        # map with an equivariant Linear layer to the input expected by the activation function, apply batchnorm and finally the activation
        self.block1 = nn.SequentialModule(
            nn.Linear(self.in_type, activation1.in_type),
            nn.IIDBatchNorm1d(activation1.in_type),
            activation1,
        )
        
        # Repeat a similar process for a few layers
        
        # 8 signals, bandlimited up to frequency 3
        activation2 = nn.FourierELU(
            self.gspace,
            channels=8, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=3).irreps, # include all frequencies up to L=3
            inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 16 equally distributed points
            type='regular', N=16,
        )
        self.block2 = nn.SequentialModule(
            nn.Linear(self.block1.out_type, activation2.in_type),
            nn.IIDBatchNorm1d(activation2.in_type),
            activation2,
        )
        
        # 8 signals, bandlimited up to frequency 3
        activation3 = nn.FourierELU(
            self.gspace,
            channels=8, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=3).irreps, # include all frequencies up to L=3
            inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 16 equally distributed points
            type='regular', N=16,
        )
        self.block3 = nn.SequentialModule(
            nn.Linear(self.block2.out_type, activation3.in_type),
            nn.IIDBatchNorm1d(activation3.in_type),
            activation3,
        )
        
        # 5 signals, bandlimited up to frequency 2
        activation4 = nn.FourierELU(
            self.gspace,
            channels=5, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=2).irreps, # include all frequencies up to L=2
            inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 12 equally distributed points
            type='regular', N=12,
        )
        self.block4 = nn.SequentialModule(
            nn.Linear(self.block3.out_type, activation4.in_type),
            nn.IIDBatchNorm1d(activation4.in_type),
            activation4,
        )
        
        # Final linear layer mapping to the output features
        # the output is a 2-dimensional vector rotating with frequency 2
        self.out_type = self.gspace.type(self.G.irrep(2))
        self.block5 = nn.Linear(self.block4.out_type, self.out_type)
    
    def forward(self, x: nn.GeometricTensor):
        
        # check the input has the right type
        assert x.type == self.in_type
        
        # apply each equivariant block
        
        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
     
        return x
    
    def evaluate_output_shape(self, input_shape: tuple):
        shape = list(input_shape)
        assert len(shape) ==2, shape
        assert shape[1] == self.in_type.size, shape
        shape[1] = self.out_type.size
        return shape

Let's build the model

In [9]:
model = SO2MLP().to(device)

Let's test the equivariance of the model

In [10]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

model.eval()

B = 10

# generates B random points in 2D and wrap them in a GeometricTensor of the right type
x = model.in_type(torch.randn(B, 2))


print('##########################################################################################')
with torch.no_grad():
    y = model(x.to(device)).to('cpu')
    print("Outputs' magnitudes")
    print(torch.linalg.norm(y.tensor, dim=1).numpy().reshape(-1))
    print('##########################################################################################')
    print("Errors' magnitudes")
    for r in range(8):
        # sample a random rotation
        g = model.G.sample()
        
        x_transformed = g @ x
        x_transformed = x_transformed.to(device)

        y_transformed = model(x_transformed).to('cpu')
        
        # verify that f(g@x) = g@f(x)=g@y
        print(torch.linalg.norm(y_transformed.tensor - (g@y).tensor, dim=1).numpy().reshape(-1))
        

print('##########################################################################################')
print()



##########################################################################################
Outputs' magnitudes
[0.115  0.1023 0.014  0.0306 0.0657 0.0292 0.0484 0.028  0.033  0.026 ]
##########################################################################################
Errors' magnitudes
[0.0014 0.0004 0.     0.0001 0.0005 0.0001 0.0002 0.0001 0.0001 0.    ]
[0.0018 0.0014 0.     0.0001 0.0003 0.0001 0.0004 0.0002 0.0001 0.0001]
[0.0027 0.0006 0.0001 0.0002 0.0005 0.0002 0.0004 0.     0.0002 0.0002]
[0.0026 0.0011 0.0001 0.0002 0.0004 0.0002 0.0005 0.0001 0.0002 0.0002]
[0.001  0.001  0.     0.0001 0.0005 0.0001 0.0001 0.     0.0001 0.0001]
[0.0016 0.001  0.0001 0.0001 0.0005 0.0002 0.0002 0.     0.0002 0.0001]
[0.0016 0.0007 0.     0.0001 0.0006 0.     0.0002 0.0002 0.0001 0.    ]
[0.0019 0.0012 0.     0.0001 0.0004 0.0001 0.0004 0.0002 0.0001 0.0001]
##########################################################################################

