In [1]:
from escnn import nn, gspaces
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
import torch
import numpy as np
import escnn.nn as nn

@torch.inference_mode()
def rel_err(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    x = x.reshape(y.shape[0],x.shape[0]//y.shape[0], -1)
    
    diff  = torch.linalg.vector_norm(x - y, dim=1)
    denom = torch.maximum(torch.linalg.vector_norm(x, dim=1),
                          torch.linalg.vector_norm(y, dim=1))
    denom = torch.clamp(denom, min=1e-12)
    return diff / denom

@torch.inference_mode()
def check_equivariance_batch(x: torch.Tensor, model, num_samples: int = 16, chunk: int = 0):
    """
    Vectorized equivariance test on the *equivariant* feature map returned by model.forward_features.
    Returns: thetas (np.ndarray), errors_per_theta (np.ndarray)
    """
    device = next(model.parameters()).device
    x = x.to(device, non_blocking=True)

    r2_act = getattr(model, "r2_act")
    thetas = np.linspace(0.0, 2*np.pi, num_samples, endpoint=False)
    elems  = [r2_act.fibergroup.element(float(t)) for t in thetas]

    # Reference features
    y_ref = model.forward_features(x)  # GeometricTensor

    # Build transformed inputs (GeoTensor -> transform)
    x_geo = nn.GeometricTensor(x, model.input_type)
    x_list = [x_geo.transform(g).tensor for g in elems]
    xb = nn.GeometricTensor(torch.cat(x_list, dim=0), model.input_type)

    y_rot_tensor = model.forward_features(xb)



    B = x.shape[0]
    errs = rel_err(y_rot_tensor, y_ref).view(num_samples, B).mean(dim=1)
    return thetas, errs.detach().cpu().numpy()

@torch.inference_mode()
def logits_invariance_error(model, x, angles=(0, 90, 180, 270)):
    """
    Relative invariance error on logits after the invariant head.
    """
    from torchvision.transforms.functional import rotate, InterpolationMode
    model.eval()
    device = next(model.parameters()).device
    x = x.to(device, non_blocking=True)

    base = model(x)  # (B, C)
    errs = {}
    for a in angles:
        xr = rotate(x, a, interpolation=InterpolationMode.BILINEAR)
        yr = model(xr)
        errs[a] = rel_err(base, yr).mean().item()
    return errs


In [3]:
x = torch.randn(32,512)

In [4]:
import torch, numpy as np, matplotlib.pyplot as plt
from escnn import gspaces, nn

# --- group & test sampler ---
r2_act = gspaces.rot2dOnR2(maximum_frequency=1)
num_samples = 16
thetas = np.linspace(0, 2*np.pi, num_samples, endpoint=True)
elements = [r2_act.fibergroup.element(theta) for theta in thetas]

g = gspaces.rot2dOnR2(N=-1)                   # SO(2)
G = g.fibergroup
ft_in = nn.FieldType(g, [g.trivial_repr])     # scalar input
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- model builders (same depth: conv -> act/norm -> bn -> conv) ---
def build_norm_normbn(C=8, irreps=2):
    irreps = [g.irrep(i) for i in range(irreps)]
    ft = nn.FieldType(g, irreps*C)      # pure nontrivial => NormBN valid
    return nn.SequentialModule(
        nn.R2Conv(ft_in, ft, 3, padding=1, bias=True),
        # nn.IIDBatchNorm2d(ft, affine=True),
        # nn.NormNonLinearity(ft),
        # nn.R2Conv(ft, ft, 3, padding=1, bias=False),
    )

def build_gated_gnormbn(C=8):
    feats = [g.irrep(0), g.irrep(1)]*C
    gates = [g.trivial_repr]*len(feats)
    ft_full = nn.FieldType(g, gates + feats)       # gates FIRST
    ft_feat = nn.FieldType(g, feats)  # for batchnorm
    len(ft_feat)
    return nn.SequentialModule(
        nn.R2Conv(ft_in, ft_full, 3, padding=1, bias=True),
        nn.GatedNonLinearity1(ft_full, drop_gates=True),
        # nn.FieldNorm(ft_feat, affine=True),
        # nn.R2Conv(ft_feat, ft_feat, 3, padding=1, bias=False),
    )

def build_tensorproduct_11_to_2(C=8):
    ft1 = nn.FieldType(g, [g.irrep(1)]*C)          # uniform in
    ft2 = nn.FieldType(g, [g.irrep(2)]*C)          # valid: 1⊗1 -> 2
    return nn.SequentialModule(
        nn.R2Conv(ft_in, ft1, 3, padding=1, bias=True),
        nn.TensorProductModule(ft1, ft2, initialize=True),
        # nn.FieldNorm(ft2, affine=True),
        # nn.R2Conv(ft2, ft1, 3, padding=1, bias=False),  # keep width comparable
    )

b

models = {
    "Conv 2 irreps":        build_norm_normbn(8),
    "Conv 3 irreps":        build_norm_normbn(8, irreps=3),
    "Conv 4 irreps":        build_norm_normbn(8, irreps=4),
    "Conv 2 irreps, N=32":        build_norm_normbn(32, irreps=2),
    "Conv 3 irreps, N=32":        build_norm_normbn(32, irreps=3),
    "Conv 4 irreps, N=32":        build_norm_normbn(32, irreps=4),

}

# --- run equivariance tests ---
x = torch.randn(1, 1, 256, 256)
x = ft_in(x).to(device)

plt.figure(figsize=(6,4))
for name, model in models.items():
    model = model.to(device)
    thetas_out, errors = check_equivariance_batch(x, model, group=r2_act, num_samples=num_samples)
    plt.hlines(np.mean(errors), xmin=0, xmax=2*np.pi, linestyles='dashed')
    plt.plot(thetas_out, errors, marker="o", ms=3, label=name)
plt.xlabel("rotation angle [rad]"); plt.ylabel("equivariance error")
plt.title("Equivariance error vs. rotation (SO(2))")
plt.legend(); plt.grid(True); plt.show()


NameError: name 'b' is not defined

In [None]:
x = torch.randn(8, 3, 128, 128)
r2_act = gspaces.rot2dOnR2()
ft_in = nn.FieldType(r2_act, [r2_act.irrep(0), r2_act.irrep(1)])
x = nn.GeometricTensor(x, type = ft_in)
print(x.type)
G = r2_act.fibergroup
act = nn.FourierPointwise(r2_act, channels=2, irreps=G.bl_irreps(2), N=16)
ft = act.out_type
feat_type_out = act.out_type

# 2) Convolution to the activation's expected input type
conv = nn.R2Conv(ft_in, feat_type_out, kernel_size=3, padding=3, bias=True)

# 3) Build scalar/vector subset FieldTypes from feat_type_out
reps = feat_type_out.representations
scalar_field_ids = [i for i, r in enumerate(reps) if r.size == 1]   # m = 0
vector_field_ids = [i for i, r in enumerate(reps) if r.size > 1]    # m >= 1

[SO(2)_on_R2[(None, -1)]: {irrep_0 (x1), irrep_1 (x1)}(3)]


In [None]:
import torch
from escnn import gspaces, nn

# Fake input: 3 channels total = 1 (m=0) + 2 (m=1) -> matches [irrep(0), irrep(1)]
x = torch.randn(8, 3, 128, 128)

r2_act = gspaces.rot2dOnR2()             # continuous SO(2) on R^2
ft_in  = nn.FieldType(r2_act, [r2_act.irrep(0), r2_act.irrep(1)])  # dims 1 + 2 = 3
x      = nn.GeometricTensor(x, ft_in)

print("x.in_type:", x.type)

G   = r2_act.fibergroup
act = nn.FourierPointwise(
    r2_act,
    channels=2,                      # multiplicity per irrep returned below
    irreps=G.bl_irreps(2),           # {m=0 (1D), m=1 (2D), m=2 (2D)} in real form
    N=16,                            # angular sampling for the Fourier op
    function='p_relu'                # be explicit; pointwise ReLU in Fourier coords
)

# IMPORTANT: convolve INTO the activation's expected INPUT type
feat_type_in  = act.in_type
feat_type_out = act.out_type

# Map from your ft_in (3ch) -> act.in_type
conv = nn.R2Conv(ft_in, feat_type_in, kernel_size=3, padding=1, bias=True)

# Forward
x = conv(x)
x = act(x)
print("after act type:", x.type)     # should match feat_type_out

# If you really want to "disentangle" into block-diagonal irrep basis,
# build the DisentangleModule with the CURRENT type of x, not ft_in.
dis = nn.DisentangleModule(x.type)
x_dis = dis(x)
print("after disentangle type:", x_dis.type)

x.in_type: [SO(2)_on_R2[(None, -1)]: {irrep_0 (x1), irrep_1 (x1)}(3)]
after act type: [SO(2)_on_R2[(None, -1)]: {regular_[(0,)|(1,)|(2,)] (x2)}(10)]
after disentangle type: [SO(2)_on_R2[(None, -1)]: {regular_[(0,)|(1,)|(2,)]_0 (x1), regular_[(0,)|(1,)|(2,)]_1 (x1), regular_[(0,)|(1,)|(2,)]_2 (x1), regular_[(0,)|(1,)|(2,)]_0 (x1), regular_[(0,)|(1,)|(2,)]_1 (x1), regular_[(0,)|(1,)|(2,)]_2 (x1)}(10)]


In [None]:
ft_in.representations[1].sum_of_squares_constituents

2

In [None]:
SO2 = gspaces.rot2dOnR2(maximum_frequency=8)
O2 = gspaces.flipRot2dOnR2(maximum_frequency=2)

G = SO2.fibergroup
irreps = []
for irr in G.irreps():
    print(irr)
    if irr.name == O2.trivial_repr.name:
        continue
    mult = int(irr.size // irr.sum_of_squares_constituents)  # 1 for 1D, 2 for 2D
    irreps.extend([irr] * mult)
irreps

SO(2)|[irrep_0]:1
SO(2)|[irrep_1]:2


[SO(2)|[irrep_0]:1, SO(2)|[irrep_1]:2]

In [None]:
from escnn import nn, gspaces


GASAA = gspaces.rot2dOnR2(maximum_frequency=2)
print(GASAA.irreps)
print(GASAA.fibergroup.irreps())

[SO(2)|[irrep_0]:1, SO(2)|[irrep_1]:2, SO(2)|[irrep_2]:2, SO(2)|[irrep_3]:2, SO(2)|[irrep_4]:2]
[SO(2)|[irrep_0]:1, SO(2)|[irrep_1]:2, SO(2)|[irrep_2]:2, SO(2)|[irrep_3]:2, SO(2)|[irrep_4]:2]


In [18]:
L = 4
SO2 = gspaces.rot2dOnR2(maximum_frequency=L)
O2 = gspaces.flipRot2dOnR2(maximum_frequency=2)
s = 3
G = SO2.fibergroup
irreps = []
for irr in SO2.irreps:
    print(irr)
    mult = int(irr.size // irr.sum_of_squares_constituents)  # 1 for 1D, 2 for 2D
    irreps.extend([irr] * mult)
r = nn.FieldType(SO2, irreps*L)
tmp_cl = nn.R2Conv(r, r, s,
                padding=1)
tmp_cl.basisexpansion.dimension()

SO(2)|[irrep_0]:1
SO(2)|[irrep_1]:2
SO(2)|[irrep_2]:2


384

In [19]:
r

[SO(2)_on_R2[(None, -1)]: {irrep_0 (x1), irrep_1 (x1), irrep_2 (x1), irrep_0 (x1), irrep_1 (x1), irrep_2 (x1), irrep_0 (x1), irrep_1 (x1), irrep_2 (x1), irrep_0 (x1), irrep_1 (x1), irrep_2 (x1)}(20)]

In [None]:
tmp_cl.basisexpansion.dimension()

14

[SO(2)_on_R2[(None, -1)]: {regular_[(0,)|(1,)|(2,)] (x1)}(5)]

In [None]:
# new file: SO2_Nets/adaptive_fourier.py
import torch
from escnn import nn

class SamplingBranch(torch.nn.Module):
    def __init__(self, r2_act, in_type: nn.FieldType, N: int, hidden_ch: int = 16):
        super().__init__()
        self.N = N
        # small equivariant conv stack -> outputs N scalar trivial fields that we interpret as angles on S1
        self.net = nn.SequentialModule(
            nn.R2Conv(in_type, nn.FieldType(r2_act, hidden_ch * [in_type.fibergroup.trivial_repr]), kernel_size=3, padding=1, bias=False),
            nn.IIDBatchNorm2d(nn.FieldType(r2_act, hidden_ch * [in_type.fibergroup.trivial_repr])),
            nn.ReLU(inplace=True),
            nn.R2Conv(nn.FieldType(r2_act, hidden_ch * [in_type.fibergroup.trivial_repr]),
                      nn.FieldType(r2_act, self.N * [in_type.fibergroup.trivial_repr]), kernel_size=1, padding=0, bias=True),
        )
        self.r2_act = r2_act
        self.G = r2_act.fibergroup  # SO(2)

    def forward(self, feat: nn.GeometricTensor, rep_rho):
        # angles in [-pi, pi]
        angles = torch.pi * torch.tanh(self.net(feat).tensor)  # [B, N, H, W]
        # build A rows from angles and representation columns (quotient or regular)
        # rep_rho expects a callable: g -> matrix R^{F}
        A_rows = []
        for k in range(self.N):
            theta_k = angles[:, k:k+1, ...]  # [B,1,H,W]
            gk = self.G.element(theta_k)     # broadcast element
            # evaluate ρ(gk) δ̂  -> shape [B, F, H, W]
            Ak = rep_rho(gk)                 # your helper that returns vectorized ρ(g)δ̂
            A_rows.append(Ak)
        A = torch.stack(A_rows, dim=1)  # [B, N, F, H, W]
        return A

class AdaptiveFourierPointwise(torch.nn.Module):
    def __init__(self, r2_act, in_type: nn.FieldType, channels: int, irreps, function: str, N: int):
        super().__init__()
        self.N = N
        self.function = function
        self.r2_act = r2_act
        # create a helper to map features to Fourier coeffs and back per ESCNN conventions
        self.ft = nn.FourierTransform(in_type.gspace, irreps)
        self.channels = channels

    def forward(self, x: nn.GeometricTensor, A: torch.Tensor):
        # x.tensor shape [B, Cin, H, W]; interpret as stacked bandlimited coeffs f̂ over channels/spatial
        fhat = self.ft.forward(x)            # [B, C, F, H, W]
        # Af̂: [B,N,F,H,W] x [B,C,F,H,W] -> [B,C,N,H,W]
        y = torch.einsum('bnfhw,bcfhw->bcnhw', A, fhat)
        # pointwise nonlinearity (ReLU/ELU) along N
        if self.function.endswith('relu'):
            y = torch.relu(y)
        elif self.function.endswith('elu'):
            y = torch.nn.functional.elu(y)
        # (1/N)Aᵀ y: [B,C,F,H,W]
        fhat_new = (1.0 / self.N) * torch.einsum('bnfhw,bcnhw->bcfhw', A, y)
        # back to spatial field type
        x_new = self.ft.inverse(fhat_new)
        return x_new


In [None]:
layers = []
irreps = [r2_act.irrep(0), r2_act.irrep(1), r2_act.irrep(2)]  # exclude trivial
channels = 8
N = 16
kernel_size = 3
cur_type = nn.FieldType(r2_act, [r2_act.trivial_repr] * channels)  # start from scalars
pad = kernel_size // 2
non_linearity = 'p_relu'
sampler = SamplingBranch(r2_act, cur_type, N=N)
for _ in range(2):
    # conv to feature_type before activation (as in your fixed variant)
    feature_repr = irreps * channels
    feature_type = nn.FieldType(r2_act, feature_repr)
    layers.append(nn.R2Conv(cur_type, feature_type, kernel_size=kernel_size, padding=pad, bias=False))

    # build A and apply adaptive Fourier pointwise
    act = nn.(r2_act, feature_type, channels=channels, irreps=irreps, function=non_linearity, N=N)
    layers.append(nn.EquivariantModuleWrapper(feature_type, act, sampler))  # small wrapper that calls sampler then act

    layers.append(nn.IIDBatchNorm2d(act.out_type))
    cur_type = act.out_type

nn.SequentialModule(*layers), cur_type

AttributeError: 'SO2' object has no attribute 'trivial_repr'