In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


In [97]:
class HarmonicConvLayer(nn.Module):
    """
    A custom layer for a single Harmonic Network convolution.
    It learns filters as a basis of circular harmonics.
    """

    def __init__(self, in_channels, out_channels, k_size, max_order):
        super(HarmonicConvLayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k_size = k_size
        self.rotation_orders = torch.arange(0, max_order+1, dtype=torch.int64)

        # --- TODO 1: Define Learnable Parameters ---
        # The learnable parameters are the radial profiles (R(r)) and phase offsets (beta).
        # We need one set for each filter connecting an input channel to an output channel
        # for each rotation order.

        num_radii = self._get_unique_radii_count(k_size)
        num_orders = len(self.rotation_orders)

        # Shape: (out_channels, in_channels, num_orders, num_radii)
        self.radial_profiles = nn.Parameter(torch.randn(out_channels, in_channels, num_orders, num_radii))

        # Shape: (out_channels, in_channels, num_orders)
        self.phase_offsets = nn.Parameter(torch.zeros(out_channels, in_channels, num_orders))
        # --- TODO 2: Pre-calculate non-learnable filter components ---
        # These are fixed grids that don't change during training.
        # Using register_buffer makes them part of the module's state without being parameters.
        self.register_buffer('r_grid', self._create_r_grid(k_size))
        self.register_buffer('phi_grid', self._create_phi_grid(k_size))
        self.register_buffer('radius_to_index', self._create_radius_to_index_map(k_size))

    def _get_unique_radii_count(self, k_size):
        if k_size % 2 == 0:
            raise ValueError("Kernel size must be odd.")
        center = (k_size - 1) / 2
        x_grid = torch.linspace(-center, center, k_size)
        y_gird = x_grid
        x, y = torch.meshgrid(x_grid, y_gird)
        r = torch.sqrt(x ** 2 + y ** 2)
        return len(np.unique(r))

    def _create_r_grid(self, k_size):
        if k_size % 2 == 0:
            raise ValueError("Kernel size must be odd.")
        center = (k_size - 1) / 2
        x_grid = torch.linspace(-center, center, k_size)
        y_gird = x_grid
        x, y = torch.meshgrid(x_grid, y_gird)
        r = torch.sqrt(x ** 2 + y ** 2)
        return r

    def _create_phi_grid(self, k_size):
        if k_size % 2 == 0:
            raise ValueError("Kernel size must be odd.")
        center = (k_size - 1) / 2
        x_grid = torch.linspace(-center, center, k_size)
        y_gird = x_grid
        x, y = torch.meshgrid(x_grid, y_gird)
        theta = torch.arctan2(y,x)
        return theta

    def _create_radius_to_index_map(self, k_size):
        if k_size % 2 == 0:
            raise ValueError("Kernel size must be odd.")

        r_grid = self._create_r_grid(k_size)
        unique_radii = torch.unique(r_grid)

        radius_to_index = {radius.item(): i for i, radius in enumerate(unique_radii)}
        index_grid = torch.zeros_like(r_grid, dtype=torch.long)
        for radius, index in radius_to_index.items():
            index_grid[r_grid == radius] = index

        return index_grid
    def _assemble_filters(self):
        rotation_orders = self.rotation_orders.reshape(1, 1, -1, 1, 1)

        # Reshape phi_grid to enable proper broadcasting
        phi_grid = self.phi_grid.reshape(1, 1, 1, *self.phi_grid.shape)

        # Reshape phase_offsets for broadcasting
        phase_offsets = self.phase_offsets.reshape(*self.phase_offsets.shape, 1, 1)

        # Get radial profiles with proper indexing
        R = self.radial_profiles[:, :, :, self.radius_to_index]
        # Calculate phase with properly shaped tensors
        phase = rotation_orders * phi_grid + phase_offsets
        print(phase.shape, rotation_orders.shape, phi_grid.shape, phase_offsets.shape)

        # Calculate final filters
        filters = R * torch.exp(1j * phase)
        return filters


In [98]:
harm = HarmonicConvLayer(10, 10, 11, 2)


In [99]:
print([*harm.radial_profiles.shape])
print(harm.rotation_orders)
harm._assemble_filters().shape

[10, 10, 3, 20]
tensor([0, 1, 2])
torch.Size([10, 10, 3, 11, 11]) torch.Size([1, 1, 3, 1, 1]) torch.Size([1, 1, 1, 11, 11]) torch.Size([10, 10, 3, 1, 1])


torch.Size([10, 10, 3, 11, 11])

In [87]:
num_radii = harm._get_unique_radii_count(11)
num_orders = len(harm.rotation_orders)
radial_profiles = nn.Parameter(torch.randn(10, 10, 3, num_radii))
R = radial_profiles[:, :, :, harm.radius_to_index]
print(R.shape)
print(radial_profiles[0,0,0,:])
print(R[0,0,0,:,:][0])

torch.Size([10, 10, 3, 11, 11])
tensor([-1.1947,  1.1851,  2.4057, -0.9259, -0.8873, -0.4223, -0.5089,  1.1204,
        -1.4454,  0.1895,  1.4023,  0.8980, -0.3540,  0.5955, -0.9290,  1.8226,
        -0.3550,  2.1132, -1.0072,  0.4313], grad_fn=<SliceBackward0>)
tensor([ 0.4313, -1.0072,  2.1132,  1.8226, -0.9290,  0.5955, -0.9290,  1.8226,
         2.1132, -1.0072,  0.4313], grad_fn=<SelectBackward0>)


In [102]:
from escnn import gspaces
r2_act = gspaces.rot2dOnR2(-1, maximum_frequency=2)

In [167]:
g = r2_act.fibergroup.sample()
rho = r2_act.irrep(1)

86.07356773460438