In [6]:
from escnn import gspaces, nn
import torch

class Equivariant3DCNN(torch.nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 num_classes: int, 
                 group_type: str = "octa",
                 hidden_channels: int = 16):
        """
        2-layer equivariant CNN for 3D data using escnn.

        Args:
            in_channels (int): Number of input channels.
            num_classes (int): Number of output classes.
            group_type (str): Symmetry group type. One of:
                              ["octa", "full_octa", "ico", "full_ico", "dihedral", "cylindrical"]
            hidden_channels (int): Number of hidden channels after first layer.
        """
        super().__init__()

        # Select symmetry group (gspace) for R^3
        if group_type == "octa":
            self.r3_act = gspaces.octaOnR3()
        elif group_type == "full_octa":
            self.r3_act = gspaces.fullOctaOnR3()
        elif group_type == "ico":
            self.r3_act = gspaces.icoOnR3()
        elif group_type == "full_ico":
            self.r3_act = gspaces.fullIcoOnR3()
        elif group_type == "dihedral":
            self.r3_act = gspaces.dihedralOnR3(4)  
        elif group_type == "cylindrical":
            self.r3_act = gspaces.cylindricalOnR3(4)
        else:
            raise ValueError(f"Unknown group_type: {group_type}")

        # Input type: trivial representation for scalar fields (like pixel/voxel intensities)
        in_type = nn.FieldType(self.r3_act, in_channels * [self.r3_act.trivial_repr])
        
        # First equivariant convolution layer
        hidden_type = nn.FieldType(self.r3_act, hidden_channels * [self.r3_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.R3Conv(in_type, hidden_type, kernel_size=3, padding=1, bias=False),
            nn.ReLU(hidden_type, inplace=True)
        )

        # Second equivariant convolution layer
        out_type = nn.FieldType(self.r3_act, hidden_channels * [self.r3_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R3Conv(hidden_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.ReLU(out_type, inplace=True)
        )

        # Global pooling to get invariants
        self.pool = nn.GroupPooling(out_type)

        # Final classifier: standard torch.nn.Linear
        # After pooling, the type is trivial and can be flattened
        self.fc = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, x):
        # Wrap tensor into escnn GeometricTensor
        x = nn.GeometricTensor(x, self.block1[0].in_type)

        x = self.block1(x)
        x = self.block2(x)

        # Group pooling to invariants
        x = self.pool(x)

        # Now x.tensor has shape (B, C, H, W, D)
        # Apply global spatial pooling
        x = torch.mean(x.tensor, dim=[2, 3, 4])  # (B, C)

        logits = self.fc(x)
        return logits

In [7]:
model = Equivariant3DCNN(in_channels=1, num_classes=10, group_type="cylindrical", hidden_channels=8)
inp = torch.randn(2, 1, 16, 16, 16)  # batch=2, 1 channel, 16³ volume
out = model(inp)
print(out.shape)  # (2, 10)

torch.Size([2, 10])
