Md Ashiqur Rahman, Raymond A. Yeh
Department of Computer Science, Purdue University
The proposed layers maintain consistent feature representations across varying input resolutions by dynamically adapting to the spatial scale of the input image. This ensures robustness and stability when processing images of different sizes.
This is the official implementation of "Truly Scale-Equivariant Deep Nets with Fourier Layers", accepted at NeurIPS 2023.
We address the challenge of scale-equivariance in vision models by proposing a novel architecture based on Fourier layers that achieves zero equivariance error. Unlike prior approaches that assume a continuous domain and overlook anti-aliasing, our method formulates downscaling directly in the discrete domain with anti-aliasing built in. Evaluated on the MNIST scale and STL-10, our model demonstrates competitive classification performance while ensuring exact scale invariance.
First, clone the repository and install the package in editable mode:
git clone https://github.com/ashiq24/scale_equivariant_fourier_layer.git
cd scale_equivariant_fourier_layer
pip install -e .
Make sure you have the following libraries installed:
-
PyTorch (any version compatible with your hardware: CUDA or CPU)
-
NumPy
-
SciPy
This example demonstrates how to use our localized spectral convolution and scale-equivariant non-linearity + pooling layer, tailored for image inputs at multiple resolutions. 📐📷
from scale_eq.layers.spectral_conv import SpectralConv2dLocalized
from scale_eq.layers.scalq_eq_nonlin import scaleEqNonlinMaxp
import torch
🌀 1. Localized Spectral Convolution Layer
seq_conv = SpectralConv2dLocalized(
in_channel=3, # Input image has 3 channels (e.g., RGB)
out_channel=32, # Output has 32 channels
global_modes=28, # Captures global patterns (~ half of input size for speed)
local_modes=7 # Captures fine-scale features
)
🔍 Local Model behaves like a traditional filter (e.g., capturing edges or textures).
🌐 Global Model should be close to the input image resolution (128), but can be smaller to reduce computation. In this example, we use 28 to save on compute while still learning from global features.
seq_nonlin = scaleEqNonlinMaxp(
torch.nn.functional.sigmoid, # Non-linearity to apply (can be ReLU, GELU, etc.)
base_res=32, # Minimum resolution (lowest scale)
normalization=None, # Optional: apply normalization
max_res=128, # Maximum resolution (highest scale)
increment=32, # Controls scale skipping (1 = full equivariance, >1 = faster)
channels=32, # Number of feature channels
pool_window=2 # Window size for max pooling
)
🔁 Scale Equivariance: The same pattern at different scales should yield similar outputs — this module enforces that.
⚖️ Trade-off: increment=1 means no scales skipped (full equivariance, but slow). Setting it to a higher value (like 32) means some scales are skipped — it's still effective, but faster.
🏊♂️ Pooling adds robustness and spatial invariance to features.
The scale equivariant convolution and non-linearities work in the complex Fourier domain. So the input and output are both in terms of the Fourier Coefficient. Let's see these in action:
random_input = torch.randn(1, 3, 128, 128).to(torch.float)
with torch.no_grad():
# Step 1: Forward FFT for frequency domain processing
conv_out = seq_conv(torch.fft.fft2(random_input, norm='forward'))
# Step 2: Apply scale-equivariant nonlinearity + pooling
nonlin_out = seq_nonlin(conv_out)
# Step 3: Inverse FFT to bring back to spatial domain
real_output = torch.fft.ifft2(nonlin_out, norm='forward').real
print("Input Shape: ", random_input.shape)
print("Convolved Output Shape: ", real_output.shape)
A notebook containing a demonstration of the scale equivariant layer and its uses in Deep neural networks is available in the notebook demo_and_quickstart.ipynb
. The notebook can also be executed on Google Colab by following the link
To regenerate the experiment of the paper,r please install the dependencies in regen_results_req.txt
Steps:
- Download the project
- Update the
model
flag to select desired model - Update the
project_name
flag to match the Neptune project - Update the
data_path
to the dataset loaction.
Execute the following commands
python3 train_script GPU_ID
train_script: train_1d.py, train_mnist.py, train_stl.py
GPU_ID: int, device id of the GPU to train on