In [2]:
import torch
import torch.nn.functional as F
import numpy as np
import math
torch.manual_seed(0)

<torch._C.Generator at 0x7fe3a3592730>

In [44]:
from escnn.gspaces import GSpace
from escnn.group import Representation
from typing import Literal, Callable, Any
from escnn.nn import FieldType, GeometricTensor
from escnn import nn as escnn_nn


class RBSteerableConv(escnn_nn.EquivariantModule):
    def __init__(self, 
                 gspace: GSpace, 
                 in_fields: tuple[Representation] | list[Representation], 
                 out_fields: tuple[Representation] | list[Representation], 
                 in_dims: int,
                 v_kernel_size: int,
                 h_kernel_size: int,
                 v_stride: int = 1,
                 h_stride: int = 1,
                 h_dilation: int = 1,
                 v_pad_mode: Literal['valid', 'zero'] = 'zero', 
                 h_pad_mode: Literal['valid', 'zero', 'circular', 'reflect', 'replicate'] = 'circular',
                 bias: bool = True,
                 sigma: float | list[float] = None,
                 frequencies_cutoff: float | Callable[[float], int] = None,
                 rings: list[float] = None,
                 maximum_offset: int = None,
                 recompute: bool = False,
                 basis_filter: Callable[[dict], bool] = None,
                 initialize: bool = True,
                 **kwargs
                 ):
        super().__init__()
        
        assert len(in_dims) == 3
        
        v_pad_mode = v_pad_mode.lower()
        h_pad_mode = h_pad_mode.lower()
        assert v_pad_mode.lower() in ['valid', 'zero']
        assert h_pad_mode.lower() in ['valid', 'zero', 'circular', 'reflect', 'replicate']
        
        if h_pad_mode == 'valid':
            h_padding = 0
            h_pad_mode = 'zero'
        else:
            # escnn_nn.R2Conv only allows for the same amount of padding on both sides
            h_padding = [compute_required_same_padding(in_dims[i], h_kernel_size, h_stride, split=True)[1] for i in [0, 1]]
        
        out_height = compute_output_size(in_dims[-1], v_kernel_size, v_stride, dilation=1, pad=v_pad_mode!='valid')
        
        r2_conv_in_type = FieldType(gspace, out_height*v_kernel_size*in_fields) # concatenated neighborhoods
        out_type = FieldType(gspace, out_height*out_fields)

        self.r2_conv = escnn_nn.R2Conv(in_type=r2_conv_in_type, 
                                       out_type=out_type, 
                                       kernel_size=h_kernel_size, 
                                       padding=tuple(h_padding), 
                                       stride=h_stride, 
                                       dilation=h_dilation,
                                       padding_mode=h_pad_mode,
                                       groups=out_height, 
                                       bias=bias,
                                       sigma=sigma,
                                       frequencies_cutoff=frequencies_cutoff,
                                       rings=rings,
                                       maximum_offset=maximum_offset,
                                       recompute=recompute,
                                       basis_filter=basis_filter,
                                       initialize=initialize,
                                       **kwargs)
        
        self.in_fields = in_fields
        self.out_fields = out_fields
        
        self.in_height = in_dims[-1]
        self.out_height = out_height
        
        self.in_type = FieldType(gspace, self.in_height*self.in_fields) # without any neighborhood concatenation
        self.r2_conv_in_type = r2_conv_in_type # with any neighborhood concatenation
        self.out_type = out_type
        
        self.in_dims = in_dims
        # escnn_nn.R2Conv only allows for the same amount of padding on both sides
        self.out_dims = [compute_output_size(in_dims[i], h_kernel_size, h_stride, dilation=h_dilation, 
                                             pad=h_pad_mode!='valid', equal_pad=True) for i in [0, 1]] + [out_height]
        
        self.v_pad = v_pad_mode!='valid'
        self.v_stride = v_stride
        self.v_kernel_size = v_kernel_size
        
        
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        """
        geomTensor of shape [batch, inHeight*sum(inFieldsizes), width, depth]
        -> geomTensor of shape [batch, outHeight*sum(outFieldsizes), width, depth]
        """
        assert input.type == self.in_type
        
        concatenated_neighborhoods = self._concat_vertical_neighborhoods(input)
        return self.r2_conv.forward(concatenated_neighborhoods)
        
        
    def _concat_vertical_neighborhoods(self, geom_tensor: GeometricTensor) -> GeometricTensor:
        """geomTensor of shape [batch, inHeight*sum(fieldsizes), width, depth]
        -> [batch, outHeight*ksize*sum(fieldsizes), width, depth]"""
        tensor = geom_tensor.tensor.reshape(-1, self.in_height, sum(field.size for field in self.in_fields), *self.in_dims[:2])

        if self.v_pad:
            # pad height
            padding = compute_required_same_padding(in_size=self.in_height, kernel_size=self.v_kernel_size, stride=self.v_stride, split=True)
            tensor = F.pad(tensor, (*([0,0]*3), *padding)) # shape:(b,padH,c*t,w,d)
        
        # compute neighborhoods
        tensor = tensor.unfold(dimension=1, size=self.v_kernel_size, step=self.v_stride) # shape:(b,outH,c*t,w,d,ksize)
        
        # concatenate neighboroods
        tensor = tensor.permute(0, 1, 5, 2, 3, 4) # shape:(b,outH,ksize,c*t,w,d)
        tensor = tensor.flatten(start_dim=1, end_dim=3) # shape:(b,outH*ksize*c*t,w,d)
        
        return GeometricTensor(tensor, self.r2_conv_in_type)
    
    
    def evaluate_output_shape(self, input_shape: tuple) -> tuple:
        assert len(input_shape) == 4
        assert input_shape[1] == self.in_type.size
    
        batch_size = input_shape[0]
        
        return (batch_size, self.out_type.size) + tuple(self.in_dims[:2])
    
    
    def train(self, *args, **kwargs):
        return self.r2_conv.train(*args, **kwargs)
    
    
    def check_equivariance(self, atol: float = 1e-7, rtol: float = 1e-5) -> list[tuple[Any, float]]:
        r"""
        
        Method that automatically tests the equivariance of the current module.
        The default implementation of this method relies on :meth:`escnn.nn.GeometricTensor.transform` and uses the
        the group elements in :attr:`~escnn.nn.FieldType.testing_elements`.
        
        This method can be overwritten for custom tests.
        
        Returns:
            a list containing containing for each testing element a pair with that element and the corresponding
            equivariance error
        
        """
    
        x = torch.randn(3, self.in_type.size, *self.in_dims[:2])
        x = GeometricTensor(x, self.in_type)
        
        errors = []
        for el in self.out_type.testing_elements:
            el = self.in_type.gspace.fibergroup.sample()
            print(el)
            
            out1 = self(x).transform(el).tensor.detach().numpy()
            out2 = self(x.transform(el)).tensor.detach().numpy()
        
            errs = out1 - out2
            errs = np.abs(errs).reshape(-1)
            print(el, errs.max(), errs.mean(), errs.var())
        
            assert np.allclose(out1, out2, atol=atol, rtol=rtol), \
                f'The error found during equivariance check with element "{el}" \
                    is too high: max = {errs.max()}, mean = {errs.mean()} var ={errs.var()}'
            
            errors.append((el, errs.mean()))
        
        return errors

        
def compute_output_size(in_size: int, kernel_size: int, stride: int, dilation: int, pad: bool, equal_pad: bool = False) -> int:
    padding = 0
    if pad:
        pad_split = compute_required_same_padding(in_size, kernel_size, stride, split=True)
        padding = 2*pad_split[1] if equal_pad else sum(pad_split)

    return ((in_size - dilation*(kernel_size-1) + padding - 1) // stride) + 1


def compute_required_same_padding(in_size: int, kernel_size: int, stride: int, split: bool = False) -> int | tuple[int, int]:
    out_size = math.ceil(in_size/stride)
    padding = max((out_size-1) * stride - in_size + kernel_size, 0)
    
    if split:
        return math.floor(padding/2), math.ceil(padding/2)
    else:
        return padding
        
        
# TODO DataAugmentation (with vector rotation)
# TODO 3D Pooling
# TODO 3D Upsampling

### Data

In [41]:
BATCH_SIZE = 1
WIDTH, DEPTH, HEIGHT = 48, 48, 32
RB_CHANNELS = 4

sim_data = torch.randn(BATCH_SIZE, WIDTH, DEPTH, HEIGHT, RB_CHANNELS)

tensor = sim_data.permute(0, 3, 4, 1, 2).reshape(BATCH_SIZE, HEIGHT*RB_CHANNELS, WIDTH, DEPTH)

### Model definition

In [48]:
from escnn import gspaces
from collections import OrderedDict

class RBModel(escnn_nn.SequentialModule):
    def __init__(self, gspace: GSpace = gspaces.flipRot2dOnR2(N=4),
                   rb_dims: tuple = (48, 48, 32),
                   v_kernel_size: int = 3,
                   h_kernel_size: int = 5,
                   hidden_channels: tuple = 2*(10,10)
                   ):      
    
        rb_fields = [gspace.trivial_repr, gspace.irrep(1, 1), gspace.trivial_repr]
        hidden_field_type = [gspace.regular_repr]
        
        layers = OrderedDict()
        
        # INPUT LAYER
        conv = RBSteerableConv(gspace=gspace, 
                            in_fields=rb_fields, 
                            out_fields=hidden_channels[0]*hidden_field_type, 
                            in_dims=rb_dims,
                            v_kernel_size=v_kernel_size, 
                            h_kernel_size=h_kernel_size)
        layers['Conv1'] = conv
        
        # HIDDEN LAYERS
        for i, channels in enumerate(hidden_channels[1:], start=2):
            conv = RBSteerableConv(gspace=gspace, 
                                in_fields=layers[f'Conv{i-1}'].out_fields, 
                                out_fields=channels*hidden_field_type, 
                                in_dims=layers[f'Conv{i-1}'].out_dims,
                                v_kernel_size=v_kernel_size, 
                                h_kernel_size=h_kernel_size)
            layers[f'Conv{i}'] = conv
        
        # OUTPUT LAYER
        i += 1
        conv = RBSteerableConv(gspace=gspace, 
                            in_fields=layers[f'Conv{i-1}'].out_fields, 
                            out_fields=rb_fields,
                            in_dims=layers[f'Conv{i-1}'].out_dims,
                            v_kernel_size=v_kernel_size, 
                            h_kernel_size=h_kernel_size)
        layers[f'Conv{i}'] = conv
        
        super().__init__(layers)
        
    @property
    def first_layer(self):
        if len(self._modules) == 0:
            return None
        return self._modules[next(iter(self._modules))]
        
        
    def check_equivariance(self, atol: float = 1e-4, rtol: float = 1e-5) -> list[tuple[Any, float]]:
        r"""
        
        Method that automatically tests the equivariance of the current module.
        The default implementation of this method relies on :meth:`escnn.nn.GeometricTensor.transform` and uses the
        the group elements in :attr:`~escnn.nn.FieldType.testing_elements`.
        
        This method can be overwritten for custom tests.
        
        Returns:
            a list containing containing for each testing element a pair with that element and the corresponding
            equivariance error
        
        """
        
        x = torch.randn(3, self.in_type.size, *self.first_layer.in_dims[:2])
        x = GeometricTensor(x, self.in_type)
        
        errors = []
        for el in self.out_type.testing_elements:
            el = self.in_type.gspace.fibergroup.sample()
            print(el)
            
            out1 = self(x).transform(el).tensor.detach().numpy()
            out2 = self(x.transform(el)).tensor.detach().numpy()
        
            errs = out1 - out2
            errs = np.abs(errs).reshape(-1)
            print(el, errs.max(), errs.mean(), errs.var())
        
            assert np.allclose(out1, out2, atol=atol, rtol=rtol), \
                f'The error found during equivariance check with element "{el}" \
                    is too high: max = {errs.max()}, mean = {errs.mean()} var ={errs.var()}'
            
            errors.append((el, errs.mean()))
        
        return errors

model = RBModel()

In [49]:
model.forward(model.in_type(tensor))
model.check_equivariance()

(-, 2[2pi/4])
(-, 2[2pi/4]) 1.5258789e-05 2.061465e-06 2.5653188e-12
(+, 3[2pi/4])
(+, 3[2pi/4]) 1.5258789e-05 1.983555e-06 2.3780927e-12
(+, 2[2pi/4])
(+, 2[2pi/4]) 1.50203705e-05 2.1090534e-06 2.681103e-12
(-, 1[2pi/4])
(-, 1[2pi/4]) 1.335144e-05 1.8674493e-06 2.116885e-12
(-, 3[2pi/4])
(-, 3[2pi/4]) 1.4305115e-05 2.042686e-06 2.52035e-12
(-, 1[2pi/4])
(-, 1[2pi/4]) 1.335144e-05 1.8674493e-06 2.116885e-12
(-, 2[2pi/4])
(-, 2[2pi/4]) 1.5258789e-05 2.061465e-06 2.5653188e-12
(+, 2[2pi/4])
(+, 2[2pi/4]) 1.50203705e-05 2.1090534e-06 2.681103e-12


[((-, 2[2pi/4]), 2.061465e-06),
 ((+, 3[2pi/4]), 1.983555e-06),
 ((+, 2[2pi/4]), 2.1090534e-06),
 ((-, 1[2pi/4]), 1.8674493e-06),
 ((-, 3[2pi/4]), 2.042686e-06),
 ((-, 1[2pi/4]), 1.8674493e-06),
 ((-, 2[2pi/4]), 2.061465e-06),
 ((+, 2[2pi/4]), 2.1090534e-06)]