## GConv Module Tutorial

The `gconv` module is a Pytorch extension that implements regular group convolution. The provided group convolution layers are as straightforward to use as regular Pytorch convolution layers, requiring no expert knowledge to be used effectively. At the same time, `gconv` offers a flexible framefork for working with group convolutions that is fully custamizable. Both 2D and 3D inputs are supported, as well as discrete group convolutions and approximating continuous groups.

This tutorial demonstrates how to get started with the `gconv` module, and how to module can be used to implement custom group convolutions.


### Getting Started

The `gconv` modules are as straightforward to use as any regular Pytorch convolution module. The only difference is the output consisting of both the feature maps, as well as the group elements on which they are defined. See the example below:

In [1]:
import torch                                                                        
import gconv.nn as gnn                                                              

# input batch of 3d data with 3 channels
x1 = torch.randn(1, 3, 28, 28, 28)

# the lifting layer is required to lift R3 input to the group
lifting_layer = gnn.GLiftingConvSE3(in_channels=3, out_channels=16, kernel_size=5)
gconv_layer = gnn.GSeparableConvSE3(in_channels=16, out_channels=32, kernel_size=5)

# global avg pooling to produce invariant features after the group convolutions
pool = gnn.GAvgGlobalPool()

# gconv modules return the feature maps and the group elements on which they are defined
x2, H1 = lifting_layer(x1)
x3, H2 = gconv_layer(x2, H1)

y = pool(x3, H2)

print(y.shape)

torch.Size([1, 32, 1])




## Implementing a custom group convolution

This section explains how to implement a custom group convolutions. Group convolutions consists of (1) a `GroupKernel` module that manages the weight and performs all group related actions and (2) a `GroupConv` module that samples the kernel and performs the convolution. In this example, we will create a lifting and a separable group convolution module for 2D rotations, i.e., the SE(2) group. We start with implementing the lifting and separable kernels.

### Implementing the SE(2) kernels

All that is required for implementing a custom kernel is the logic that deals with the group elements. For this, the following callable objects should be implemented for the given group `H`:

* `det_H(H: Tensor) -> Tensor`: accepts a tensor of group elements and returns their determinants.
* `inverse_H(H: Tensor) -> Tensor`: accepts a tensor of group elements and returns their inverses.
* `left_apply_H_to_H(H1: Tensor, H2: Tensor) -> Tensor`: accepts tensors of group elements `H1` of shape `(N, ...)` and `H2` of shape `(M, ...)` and returns the pairwise left group action resulting in a tensor of shape `(N, M, ...)`.
* `left_apply_H_to_Rn(H: Tensor, grid: Tensor) -> Tensor`: accepts a tensor of group elements `H` of shape `(N, *dims)` and a tensor of Rn vectors of shape `(..., Rn)` and calculates the pairwise product between H and grid, resulting in a tensor of shape `(N, ..., Rn)`.
* `grid_sample(grid: Tensor, signal: Tensor, signal_grid: Tensor) -> Tensor`: Given a tensor `signal` of shape `(N, S)` and a corresponding tensor `signal_grid` of group elements of shape `(N, *dims)`, samples the signal for given `grid` tensor of group elements of shape `(M. *dims)`. The returned signal is a tensor of shape `(M, S)`.

We will implement these methods for the SO(2) group below. For this, we need to chose a representation for the group elements. In the case of SO(2), a simple representation is simply the angle of the rotation. Hence, our SO(2) elements are in the range of [0, $2\pi$).

Before we implement the above methods, we first introduce two methods that allow us to sample SO2 elements: `uniform_grid_so2(n: int) -> Tensor` which given an integer `n` generates a uniform grid on so2 of `n` elements, and `random_grid_so2(n: int) -> Tensor` which generates an `n` size grid of randomly sampled SO2 elements. We also implement a function that samples a uniform grid on R2.

In [14]:
import math
from torch import Tensor

def uniform_grid_so2(n: int, device: str | None = None) -> Tensor:
    return torch.linspace(0, 2 * math.pi, n + 1, device=device)[:-1].view(-1, 1)

def random_grid_so2(n: int, device: str | None = None) -> Tensor:
    return 2 * math.pi * torch.rand(n, device=device).view(-1, 1)

def create_grid_R2(n: int, device: str | None = None) -> Tensor:
    x = torch.linspace(-1, 1, n, device=device)
    X, Y = torch.meshgrid((x, x), indexing="xy")
    
    return torch.stack((Y, X), dim=-1)

grid = random_grid_so2(8)

print(grid)

grid_R2 = create_grid_R2(5)

print(grid_R2.shape)

tensor([[5.7662],
        [4.7929],
        [1.7150],
        [1.1299],
        [3.2560],
        [0.8922],
        [4.4672],
        [1.4345]])
torch.Size([5, 5, 2])


#### Implementing the group determinant

The determinant of any rotation is simply 1, which is the default for the `GroupKernel`, so we do not need to implement it. Here, we will do so anyway for demonstration purposes.

In [15]:
def det_so2(H: Tensor) -> Tensor:
    return torch.ones_like(H)

print(det_so2(grid))

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])


#### Implementing the group inverse

The inverse of a rotation $\theta$ is simply a rotation by $-\theta$.

In [16]:
def inverse_so2(H: Tensor) -> Tensor:
    return -H

print(inverse_so2(grid))

tensor([[-5.7662],
        [-4.7929],
        [-1.7150],
        [-1.1299],
        [-3.2560],
        [-0.8922],
        [-4.4672],
        [-1.4345]])


#### Implementing the left group action

If we rotate by $\theta_1$ to which we then left-apply $\theta_2$, we obtain the new rotation $(\theta_1 + \theta_2) \mod 2\pi$. The modulo operation is required to keep the group elements in the defined interval of [0, $2\pi$).

In [17]:
def left_apply_to_so2(H1: Tensor, H2: Tensor) -> Tensor:
    # broadcast to apply every element in H1 to every element in H2
    return (H1[:, None] + H2) % (2 * math.pi)

print(left_apply_to_so2(grid, grid).shape)

torch.Size([8, 8, 1])


#### Implementing the left group action on R2.

The spatial grid on which the weights are defined are of shape `(W, H, 2)` in the 2D case, where `W` and `H` are the width and height of the kernels, respectively. Each element of the spatial grid denotes the x and position. 2D vectors can simply be rotated by a rotation matrix.

In [18]:
def left_apply_to_R2(H: Tensor, grid: Tensor) -> Tensor:
    matrices = H.new_empty(H.shape[0], 2, 2)

    cos_H = torch.cos(H).flatten(-2)
    sin_H = torch.sin(H).flatten(-2)

    matrices[..., 0, 0] = cos_H
    matrices[..., 0, 1] = -sin_H
    matrices[..., 1, 0] = sin_H
    matrices[..., 1, 1] = cos_H

    # broadcast to apply every matrix in matrices to every 2d vector in grid
    return (matrices[:, None, None] @ grid[..., None]).flatten(-2)

print(left_apply_to_R2(grid, grid_R2).shape)

torch.Size([8, 5, 5, 2])


#### Implementing grid sampling for SO2

The group kernels are represented by a discretization of the group. In the case for SO2, this could be some number of rotations with corresponding weights.

To approximate the full continuous SO2 group, we are required to sample the weights of any SO2 element. Therefore, we use the discretization as a basis of a continuous signal defined over SO2 through interpolation.

Given a uniform discrete SO2 grid, we can view this grid as points on the unit circle. Hence, we can find the signal of any SO2 element by finding the two neighbouring points and linearly interpolating between them. This can simply be done by transforming our group elements to carthesian coordinates (x, y vectors) that represent points on the unit sphere. The distance between points is then simply the angle between the vectors. This representation has the benefit of dealing with the periodicity of the SO2 manifold.

In [19]:
def grid_sample_so2(grid: Tensor, signal: Tensor, signal_grid: Tensor) -> Tensor:
    # transform grid to carthesian coordinates.
    grid_cart = grid.new_empty(grid.shape[0], 2)
    grid_cart[..., 0] = torch.cos(grid).flatten(-2)
    grid_cart[..., 1] = torch.sin(grid).flatten(-2)

    # transform signal_grid to carthesian coordinates.
    signal_grid_cart = signal_grid.new_empty(signal_grid.shape[0], 2)
    signal_grid_cart[..., 0] = torch.cos(signal_grid).flatten(-2)
    signal_grid_cart[..., 1] = torch.sin(signal_grid).flatten(-2)

    # we calculate the distance of all points in grid to
    # all points in the signal grid to find the two neighbours,
    # i.e., the two closest points. In the case vectors on the unit
    # sphere, this is simply the dot product between them
    distances = (grid_cart[:, None] * signal_grid_cart).sum(-1)

    # obtain the two neighbours and the distances to them
    dists, neighbours = torch.topk(distances, 2, largest=False)

    # to obtain the interpolation coefficients, we need to normalize
    # the distances, such that the distances between two neighbours equals
    # 1. Given a uniform grid of N rotations, we divide by (2pi / N).
    coeffs = dists / ((2 * math.pi) / signal_grid.shape[0])

    # now we can perform the linear interpolation and return the new signal
    return (coeffs[..., None] * signal[neighbours]).sum(-2)

# we define a signal on uniform 4 element grid
signal = torch.randn(4, 64)
signal_grid = uniform_grid_so2(4)

print(grid_sample_so2(grid, signal, signal_grid).shape)


torch.Size([8, 64])


#### Implementing the lifting kernel

Now we have all the necessary components to start building the kernels. We first create a SE2 lifting kernel module. For this, we create a new class that inhererits from `gconv.nn.kernels.GLiftingKernel`. For interpolating spatial kernels, we can simply utilize the `grid_sample` function `torch.nn.functional`.

In [20]:
from gconv.nn.kernels import GLiftingKernel
from torch.nn import functional as F

class GLiftingKernelSE2(GLiftingKernel):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            group_kernel_size: int,
            groups: int = 1,
        ) -> None:

        # we initiaize the kernel grids used for sampling here.
        grid_H = uniform_grid_so2(group_kernel_size)
        grid_Rn = create_grid_R2(kernel_size)

        # We can also pass any kwargs to the sample function.
        # In our case, we use F.grid_sample, which we can give an 
        # interpolation mode, for which we use "bilinear" and a
        # padidng mode, for which we use "border".
        sample_Rn_kwargs = {"mode": "bilinear", "padding_mode": "border"}
        
        super().__init__(
            in_channels,
            out_channels,
            (kernel_size, kernel_size), # kernel sizes are tupels
            (group_kernel_size,), # same for group kernel size
            grid_H,
            grid_Rn,
            groups,
            det_H=det_so2,
            inverse_H=inverse_so2,
            left_apply_to_Rn=left_apply_to_R2,
            sample_Rn=F.grid_sample,
            sample_Rn_kwargs=sample_Rn_kwargs,
        )

in_channels = 3
out_channels = 16
kernel_size = 5
group_kernel_size = 8

lifting_kernel = GLiftingKernelSE2(in_channels, out_channels, kernel_size, group_kernel_size)
weight = lifting_kernel(grid)

# should be (out_channels, group_kernel_size, in_channels, kernel_size, kernel_size)
print(weight.shape)

torch.Size([16, 8, 3, 5, 5])


#### Implementing the Separable Kernel

Next we implement the separable SE2 kernel, which inherits from `gconv.nn.kernels.GSeparableKernel`. We will see that the kernel initialization is the same as the lifting kernel, apart from also initializing the group on group actions and sampling.

In [26]:
from gconv.nn.kernels import GSeparableKernel

class GSeparableKernelSE2(GSeparableKernel):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            group_kernel_size: int,
            groups: int = 1,
        ) -> None:

        # we initiaize the kernel grids used for sampling here.
        grid_H = uniform_grid_so2(group_kernel_size)
        grid_Rn = create_grid_R2(kernel_size)

        # we again use F.grid_sample. Our sampling of SO2 does not require
        # any extra kwargs
        sample_Rn_kwargs = {"mode": "bilinear", "padding_mode": "border"}
        
        super().__init__(
            in_channels,
            out_channels,
            (kernel_size, kernel_size), # kernel sizes are tupels
            (group_kernel_size,), # same for group kernel size
            grid_H,
            grid_Rn,
            groups,
            det_H=det_so2,
            inverse_H=inverse_so2,
            left_apply_to_H=left_apply_to_so2,
            left_apply_to_Rn=left_apply_to_R2,
            sample_H=grid_sample_so2,
            sample_Rn=F.grid_sample,
            sample_Rn_kwargs=sample_Rn_kwargs,
        )

in_channels = 3
out_channels = 16
kernel_size = 5
group_kernel_size = 8

separable_kernel = GSeparableKernelSE2(in_channels, out_channels, kernel_size, group_kernel_size)

grid_in = grid
grid_out = grid

# separate weights for the subrgoup (H) and spatial (Rn) parts
weight_H, weight_Rn = separable_kernel(grid, grid)

# should be (out_channels, len(grid_in), in_channels, len(grid_out), 1, 1)
print(weight_H.shape)

# should be (out_channels, len(grid_out), 1, kernel_size, kernel_size)
print(weight_Rn.shape)

torch.Size([16, 8, 3, 8, 1, 1])
torch.Size([16, 8, 1, 5, 5])


### Implementing the lifting and separable convolution modules

Now that we have implemented the lifting and separable group kernels, we can initialize the group convolution modules. We again start with the lifting convolution.

#### Implementing the lifting convolution module

Since we are working with SE(2), our SE(2) lifting convolution module inherits from `gconv.nn.GLiftingConv2d`.

In [27]:
from gconv.nn import GLiftingConv2d

class GLiftingConvSE2(GLiftingConv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            group_kernel_size: int,
            groups: int = 1,
            stride: int = 1,
            padding: int | str = 0,
            dilation: int = 1,
            padding_mode: str = "zeros",
            bias: bool = False) -> None:
        
        # all we need to do is intialize the kernel and pass it
        # to the super call
        kernel = GLiftingKernelSE2(
            in_channels,
            out_channels,
            kernel_size,
            group_kernel_size,
            groups=groups
        )

        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            group_kernel_size,
            kernel,
            groups,
            stride,
            padding,
            dilation,
            padding_mode,
            bias
        )

ImportError: attempted relative import with no known parent package