# Group Equivariant Downsampling Example

This notebook demonstrates usage of our novel subgroup downsampling layer with equivariant anti-aliasing

In [6]:
from gsampling.layers.downsampling import SubgroupDownsample
import torch

## 1. Layer Configuration
 
Configure the group structure and downsampling parameters:

In [None]:
# Fundamental group parameters
input_group = 'dihedral'          # Symmetry group type (dihedral/Cn/symmetric)
sub_group = 'dihedral'            # Subgroup type to downsample to
input_group_order = 12            # Order of rotation elements for dihedral group
sub_sampling_factor = 2           # Factor to reduce group size by

# Feature configuration
number_of_features = 10           # Channel dimension (equivariant features)
generator = 'r-s'                 # Cayley graph generator (r=rotation, s=reflection)

# Hardware configuration
device = 'cuda:0'                 # Device for computation
dtype = torch.float32             # Data type precision

d_layer = SubgroupDownsample(
    group_type="dihedral",         # Parent group type
    order=12,                     # Order of dihedral group (following escnn library convention)
    sub_group_type="dihedral",    # Subgroup type to downsample to
    subsampling_factor=2,         # Factor to reduce group by
    num_features=10,              # Number of input feature channels
    generator="r-s",              # Generators for Cayley graph construction              # Computation device
    dtype=torch.float32,          # Data type
    apply_antialiasing=True,      # Enable equivariant anti-aliasing
    anti_aliasing_kwargs={        # Anti-aliasing optimization params
        "iterations": 100,        # Number of optimization steps
        "smoothness_loss_weight": 0.1, # Smoothness strength in optimization
    },
)
d_layer = d_layer.to(device=device)

  "sampling_matrix", torch.tensor(sampling_matrix).clone()
  "up_sampling_matrix", torch.tensor(up_sampling_matrix).clone()


Initializing anti-aliasing layer
Equi Constraint:  True Equi Correction:  True
===Using Linear Optimization====
Initial guess M: (288,)
Linear Constraint Matrix: (144, 288)
*** starting optimization ***
Optimization terminated successfully    (Exit mode 0)
            Current function value: -2210726690.67739
            Iterations: 14
            Function evaluations: 4047
            Gradient evaluations: 14
*** optimization done ***
Optimal objective value: -2210726690.67739
 Final Loss Reconstruction : 1.0926533015113843e-06
 Final Equivarinace loss : 2.44948974278317


In [15]:
# make fake input tensor of size (batch_size, number_of_features * group_size, height, width)
batch_size = 1
group_size = input_group_order * 2 # dihedral group has 2*order elements following convention of escnn
input_tensor = torch.randn(
    batch_size,
    number_of_features * group_size,
    32,
    32,
    device=device,
    dtype=dtype,
)

out, canonicalization_element = d_layer(input_tensor)
print("Input shape:", input_tensor.shape)
print("Output shape:", out.shape)

Input shape: torch.Size([1, 240, 32, 32])
Output shape: torch.Size([1, 120, 32, 32])


### Key Parameter Explanations:

- **group_type**: Type of symmetry group (dihedral = rotation + reflections)
- **order**: For dihedral groups, specifies number of rotational symmetries (N rotations + N reflections)
- **subsampling_factor**: Integer reduction factor for group size (must divide group order)
- **generator**: Cayley graph construction:
  - 'r-s' = rotation + reflection generators
  - 'r' = rotation-only generation
  - Custom generator tuples supported
- **anti_aliasing_kwargs**: Controls optimization of equivariant low-pass filter

## 3. Input Tensor Construction

In [16]:
batch_size = 1
group_size = input_group_order * 2  # Dihedral group size = 2*order (rotations + reflections)
height = width = 32                # Spatial dimensions

# Create random input tensor
input_tensor = torch.randn(
    batch_size,
    number_of_features * group_size,  # Channels last convention: [features × group_size]
    height,
    width,
    device=device,
    dtype=dtype,
)

print("Input shape:", input_tensor.shape)
print("Interpretation:")
print(f"- {number_of_features} feature channels")
print(f"- {group_size} group elements (dihedral D_{input_group_order})")
print(f"- {height}x{width} spatial resolution")


Input shape: torch.Size([1, 240, 32, 32])
Interpretation:
- 10 feature channels
- 24 group elements (dihedral D_12)
- 32x32 spatial resolution


In [17]:
out, canonicalization_element = d_layer(input_tensor)

## 5. Output Analysis

In [18]:
print("\nOutput shape:", out.shape)
print("Interpretation:")
print(f"- Same {number_of_features} feature channels")
print(f"- Reduced group size: {out.shape[1]//number_of_features} elements")
print(f"- Maintained spatial resolution: {out.shape[-2]}x{out.shape[-1]}")

print("\nCanonicalization Element:", canonicalization_element)
print("Purpose: Contains subgroup coset information for feature alignment. By default, canonicalization is set of false, so this output is not used.")



Output shape: torch.Size([1, 120, 32, 32])
Interpretation:
- Same 10 feature channels
- Reduced group size: 12 elements
- Maintained spatial resolution: 32x32

Canonicalization Element: [(-1, -1)]
Purpose: Contains subgroup coset information for feature alignment. By default, canonicalization is set of false, so this output is not used.


## Critical Shape Transformation

The layer performs **group dimension reduction** while preserving features:

| Dimension    | Input Size (channels x Group) | Output Size (channels x Group) | Notes                     |
|--------------|------------|-------------|---------------------------|
| Batch        | 1          | 1           | Unchanged                 |
| Features     | 10×24      | 10×12       | 2× group reduction (D12→D6)|
| Spatial      | 32×32      | 32×32       | Spatial resolution maintained |

## Model Architecture Configuration

In [1]:
from gsampling.models.model_handler import get_model

In [4]:
# Configure hierarchical group-equivariant architecture
model = get_model(
    input_channel=3,  # RGB input channels
    num_channels=[32, 64, 128],  # Feature channels per stage
    num_layers=3,     # Number of processing stages
    dwn_group_types=[
        ["dihedral", "dihedral"],  # [input_group, subgroup] for stage 1
        ["dihedral", "dihedral"],  # For stage 2
        ["dihedral", "dihedral"]   # For stage 3
    ],
    subsampling_factors=[2, 1, 1],  # Group reduction factors per stage
    spatial_subsampling_factors=[2, 1, 1],  # Spatial downsampling factors
    num_classes=10,    # STL-10 has 10 classes
    antialiasing_kwargs={
        "iterations": 100,  # Anti-aliasing optimization steps
        "smoothness_loss_weight": 0.5  # Trade-off between equivariance and smoothness
    }
)


Antialiasing True
Initializing anti-aliasing layer
Equi Constraint:  True Equi Correction:  True
===Using Linear Optimization====
Initial guess M: (288,)
Linear Constraint Matrix: (144, 288)
*** starting optimization ***
Optimization terminated successfully    (Exit mode 0)
            Current function value: -2766735471.4094844
            Iterations: 11
            Function evaluations: 3179
            Gradient evaluations: 11
*** optimization done ***
Optimal objective value: -2766735471.4094844
 Final Loss Reconstruction : 1.2018166076012732e-06
 Final Equivarinace loss : 2.4494897427831672


### Architecture Breakdown

| Stage | Channels | Group Subsampling | Spatial Subsampling | Feature Map Size* |
|-------|----------|-------------------|---------------------|-------------------|
| 1     | 32       | D12 → D6 (2x)     | 64x64 → 32x32 (2x)  | 32x32             |
| 2     | 64       | D6 → D6 (1x)      | 32x32 → 32x32 (1x)  | 32x32             |
| 3     | 128      | D6 → D6 (1x)      | 32x32 → 32x32 (1x)  | 32x32             |

# Training with g-CNN

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Configure data augmentation
train_transform = transforms.Compose([
    transforms.Resize(96),  # STL-10 has 96x96 images
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize(96),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load datasets
train_dataset = datasets.STL10(
    root='./data', 
    split='train',
    download=True, 
    transform=train_transform
)

test_dataset = datasets.STL10(
    root='./data',
    split='test',
    download=True,
    transform=test_transform
)

# Create dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# ## Training Loop with Equivariance Monitoring

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)
        
    # Validation
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = outputs.max(1)
            test_correct += predicted.eq(targets).sum().item()
            test_total += targets.size(0)
    
    # Update learning rate
    scheduler.step()
    
    # Print metrics
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {total_loss/len(train_loader):.4f} | Acc: {100.*correct/total:.2f}%")
    print(f"Test Acc: {100.*test_correct/test_total:.2f}%")
    print(f"LR: {scheduler.get_last_lr()[0]:.2e}\n")


Epoch 1/100
Train Loss: 1.9300 | Acc: 28.80%
Test Acc: 37.74%
LR: 3.00e-04

Epoch 2/100
Train Loss: 1.5408 | Acc: 43.44%
Test Acc: 46.25%
LR: 3.00e-04

Epoch 3/100
Train Loss: 1.3102 | Acc: 53.22%
Test Acc: 47.30%
LR: 3.00e-04

Epoch 4/100
Train Loss: 1.0909 | Acc: 63.12%
Test Acc: 49.92%
LR: 3.00e-04

Epoch 5/100
Train Loss: 0.9104 | Acc: 69.92%
Test Acc: 48.92%
LR: 3.00e-04

Epoch 6/100
Train Loss: 0.6951 | Acc: 78.86%
Test Acc: 50.10%
LR: 2.99e-04

Epoch 7/100
Train Loss: 0.4910 | Acc: 86.58%
Test Acc: 49.80%
LR: 2.99e-04

Epoch 8/100
Train Loss: 0.3581 | Acc: 90.90%
Test Acc: 51.36%
LR: 2.99e-04

Epoch 9/100
Train Loss: 0.2138 | Acc: 95.88%
Test Acc: 48.74%
LR: 2.99e-04

Epoch 10/100
Train Loss: 0.1264 | Acc: 98.40%
Test Acc: 50.11%
LR: 2.98e-04

Epoch 11/100
Train Loss: 0.0631 | Acc: 99.76%
Test Acc: 50.50%
LR: 2.98e-04

Epoch 12/100
Train Loss: 0.0306 | Acc: 99.98%
Test Acc: 50.60%
LR: 2.97e-04

Epoch 13/100
Train Loss: 0.0198 | Acc: 99.98%
Test Acc: 50.30%
LR: 2.97e-04

Epoch 14