In [13]:
import torch
import pyshtools
from scipy.special import sph_harm
from math import sqrt

In [14]:
n_angle = 80
freq_cutoff = 10
f = torch.zeros(n_angle, n_angle)
FT = [torch.zeros(2*l + 1, n_angle * n_angle, dtype=torch.cfloat) for l in range(freq_cutoff + 1)]
index = 0

for theta in range(n_angle):
    for phi in range(n_angle):
        torch.zero_(f)
        f[theta, phi] = 1

        sh_transform = torch.from_numpy(pyshtools.expand.SHExpandDHC(f))
        sh_transform = torch.hstack((torch.fliplr(sh_transform[1])[:, :-1], sh_transform[0]))

        for l in range(freq_cutoff+1):
            FT[l][:, index] = sh_transform[l,((n_angle//2 -1) - l):((n_angle//2 -1) + l + 1)]
        index += 1

\begin{align*}
   \sum_{\theta, \phi} f(\theta, \phi)\text{FT}[\ell](\theta, \phi)
 =& \frac{1}{4\pi}\sum_{\theta, \phi} f(\theta, \phi) Y^{(\ell)}(\theta, \phi)\left(\frac{8\pi}{n_a^2}\sin \theta\sum_{k=0}^{\frac{n_a}{2}-1}\frac{\sin((2k+1)\theta)}{2k+1}\right)\\
 =& \frac{1}{4\pi}\frac{2\pi^2}{n_a^2} \sum_{\theta, \phi} f(\theta, \phi) Y^{(\ell)}(\theta, \phi) \sin \theta\left(\frac{4}{\pi}\sum_{k=0}^{\frac{n_a}{2}-1}\frac{\sin((2k+1)\theta)}{2k+1}\right)\\
 =& \frac{\pi}{2 n_a^2} \sum_{\theta, \phi} f(\theta, \phi) Y^{(\ell)}(\theta, \phi) \sin \theta\left(\frac{4}{\pi}\sum_{k=0}^{\frac{n_a}{2}-1}\frac{\sin((2k+1)\theta)}{2k+1}\right)\\
 =& \frac{\pi}{2n_a^2} \sum_{\theta, \phi} f(\theta, \phi)\text{SHT}[\ell](\theta, \phi)
\end{align*}

In [16]:
theta, phi = torch.meshgrid(torch.pi * (torch.arange(n_angle)) / n_angle, (2 * torch.pi * (torch.arange(n_angle)) / n_angle) , indexing='ij')
factor = (2*torch.arange(n_angle//2) + 1)
quadrature = torch.sin(theta) * (4*torch.sin(factor*theta.unsqueeze(-1)) / factor).sum(dim=-1)/torch.pi
SHT = [torch.stack([torch.from_numpy(sph_harm(m, l, phi.numpy(), theta.numpy())).type(torch.cfloat) * quadrature * sqrt(4*torch.pi) * (-1) ** m # Phase
                    for m in range(-l, l+1)], dim=0).flatten(1)
        for l in range(freq_cutoff + 1)]

for l in range(freq_cutoff):
    print((torch.conj(FT[l])  - SHT[l] * torch.pi/(2*n_angle**2)).abs().max())


tensor(1.6007e-10)
tensor(2.2165e-10)
tensor(3.6814e-10)
tensor(5.7050e-10)
tensor(7.8725e-10)
tensor(1.0090e-09)
tensor(1.2996e-09)
tensor(1.5390e-09)
tensor(1.7812e-09)
tensor(2.0580e-09)


In [4]:
kernel_size = [5,5]
n_radius = 3
n_angle = 100
interpolation_type = 1
freq_cutoff = 4
import torch
from Steerable.nn import get_interpolation_matrix, get_SHT_matrix

R = [(kernel_size[d] - 1) / 2 for d in range(len(kernel_size))]
r = torch.vstack([torch.arange(R[i] / (n_radius+1), R[i], R[i] / (n_radius+1))[:n_radius] for i in range(len(kernel_size))])
tau_r = torch.prod(r, dim=0)**((len(kernel_size)-1)/len(kernel_size))
SHT = get_SHT_matrix(n_angle, freq_cutoff, len(kernel_size)) # Spherical Harmonic Transform Matrix
if len(kernel_size) == 2:
    I = get_interpolation_matrix(kernel_size, n_radius, n_angle, interpolation_type).type(torch.cfloat) # Interpolation Matrix
    Fint = torch.einsum('r, mt, rtxy -> mrxy', tau_r, SHT, I)
elif len(kernel_size) == 3:
    I = get_interpolation_matrix((kernel_size[2], kernel_size[0], kernel_size[1]), n_radius, n_angle, interpolation_type).type(torch.cfloat) # Interpolation Matrix
    I = torch.permute(I, (0,1,3,4,2))
    Fint = [torch.einsum('r, lt, rtxyz -> lrxyz', tau_r, SHT[l], I) for l in range(freq_cutoff+1)] # Fint Matrix

In [9]:
tau_r

tensor([0.0000, 0.5000, 1.0000, 1.5000])

In [10]:
R = torch.prod(torch.tensor([(kernel_size[i] - 1) / 2 for i in range(len(kernel_size))])) ** (1/len(kernel_size))
tau_r = (torch.arange(1,n_radius+1) * R / (n_radius+1)) ** (len(kernel_size)-1)
SHT = get_SHT_matrix(n_angle, freq_cutoff, len(kernel_size)) # Spherical Harmonic Transform Matrix
if len(kernel_size) == 2:
    I = get_interpolation_matrix(kernel_size, n_radius, n_angle, interpolation_type).type(torch.cfloat) # Interpolation Matrix
    Fint1 = torch.einsum('r, mt, rtxy -> mrxy', tau_r, SHT, I)
elif len(kernel_size) == 3:
    I = get_interpolation_matrix((kernel_size[2], kernel_size[0], kernel_size[1]), n_radius, n_angle, interpolation_type).type(torch.cfloat) # Interpolation Matrix
    I = torch.permute(I, (0,1,3,4,2))
    Fint = [torch.einsum('r, lt, rtxyz -> lrxyz', tau_r, SHT[l], I) for l in range(freq_cutoff+1)] # Fint Matrix

In [12]:
(Fint - Fint1).abs().max()

tensor(0.)