In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import torch
from torch import Tensor

import torch_geometric
import e3nn

from equitorch.nn._linear import _so2_indices, SO2Linear, SO3Linear, _so3_conv
from equitorch.typing import DegreeRange
from equitorch.utils._clebsch_gordan import blocked_CG
from equitorch.utils._geometries import rot_on
from equitorch.utils._indices import list_degrees

In [2]:
float_type = torch.float32

In [3]:
def so3_weights_to_so2(weight_so3: Tensor, L_in:DegreeRange, L_out:DegreeRange, dim=-3):
    if dim == -2:
        weight_so3 = weight_so3.unsqueeze(dim=-1)
    weight_so2 = weight_so3.clone()
    offset = 0
    CGs = blocked_CG(L_out, L_in, L_in[1]+L_out[1])
    ys = torch.sqrt((2*torch.arange(0, L_in[1]+L_out[1]+1)+1) / 4 / torch.pi)
    for l_in in range(L_in[0], L_in[1]+1):
        for l_out in range(L_out[0], L_out[1]+1):
            l_ = min(l_in, l_out)
            transform = torch.zeros(2*l_+1,2*l_+1)
            for ind_l, l  in enumerate(range(l_in+l_out-2*l_, l_in+l_out+1)):
                for ind_m, m in enumerate(range(-l_,l_+1)):
                    # print(CGs[l_out, l_in, l])
                    # print('l_in', 'l_out', 'l_', 'l', 'm', 'transform')
                    print(l_in, l_out, l_, l, m, transform.shape)
                    transform[ind_m, ind_l] = CGs[l_out, l_in, l][l_out+m, l_in+abs(m), l] * ys[l]
            transform = transform.unsqueeze(0)
            weight_so2[:,offset:offset+2*l_+1,:,:] = (transform @ weight_so3[:,offset:offset+2*l_+1,:,:].flatten(-2,-1)).unflatten(-1, (weight_so3.shape[-2:]))
            offset += 2*l_+1
    return weight_so2 if dim==-2 else weight_so2.squeeze(dim=-1)


In [4]:
def init_test(N, in_channels, out_channels, l_l, l1_l, l2_l, len_W, device='cpu'):
    r1 = torch.randn(N,in_channels,3,dtype=float_type)
    r1 = r1 / r1.norm(dim=-1, keepdim=True)
    r2 = torch.randn(N,3,dtype=float_type)
    r2 = r2 / r2.norm(dim=-1, keepdim=True)
    Y1 = e3nn.o3.spherical_harmonics(l1_l, r1, False)
    Y2 = e3nn.o3.spherical_harmonics(l2_l, r2, False)
    a,b,c = e3nn.o3.rand_angles(N,dtype=float_type)
    D1 = e3nn.math.direct_sum(*(e3nn.o3.wigner_D(l1_,a,b,c) for l1_ in l1_l))
    D2 = e3nn.math.direct_sum(*(e3nn.o3.wigner_D(l2_,a,b,c) for l2_ in l2_l))
    Y1R = D1 @ Y1.transpose(-1,-2)
    Y2R = (D2 @ Y2.unsqueeze(-1)).squeeze(-1)
    D = e3nn.math.direct_sum(*(e3nn.o3.wigner_D(l_,a,b,c) for l_ in l_l)).type(float_type)
    W = torch.randn(N, len_W,in_channels, out_channels, dtype=float_type)
    # W = torch.ones(N, len_W,in_channels, out_channels, dtype=float_type)
    return W.to(device), Y1.transpose(-1,-2).to(device), Y2.to(device), \
        Y1R.to(device), Y2R.to(device), \
        (a.to(device),b.to(device),c.to(device)), \
        D1.to(device), D2.to(device), D.to(device), r2

In [5]:
N = 5
l_min = l1_min = l2_min = 2
l_max = l1_max = l2_max = 2
# l1_min, l1_max = (1,4)
# l2_min, l2_max = (0,11)
# l_min, l_max = (2,7)

in_channels = 2
out_channels = 5


In [6]:
so2 = SO2Linear(in_channels, out_channels, (l1_min, l1_max), (l_min, l_max), True, False)
so3 = SO3Linear(in_channels, out_channels, (l1_min, l1_max), (l2_min, l2_max), (l_min, l_max), True, False)

In [7]:
ls = list_degrees((l_min, l_max), (l1_min, l1_max), (l2_min, l2_max))
W_so3, X, Y, XR, YR, (a,b,c), D1, D2, D, r = init_test(
    N, in_channels, out_channels, 
    list(range(l_min, l_max+1)), list(range(l1_min, l1_max+1)), list(range(l2_min, l2_max+1)),
    len(ls))

W_so2 = so3_weights_to_so2(W_so3, (l1_min, l1_max), (l_min, l_max))

2 2 2 0 -2 torch.Size([5, 5])
2 2 2 0 -1 torch.Size([5, 5])
2 2 2 0 0 torch.Size([5, 5])
2 2 2 0 1 torch.Size([5, 5])
2 2 2 0 2 torch.Size([5, 5])
2 2 2 1 -2 torch.Size([5, 5])
2 2 2 1 -1 torch.Size([5, 5])
2 2 2 1 0 torch.Size([5, 5])
2 2 2 1 1 torch.Size([5, 5])
2 2 2 1 2 torch.Size([5, 5])
2 2 2 2 -2 torch.Size([5, 5])
2 2 2 2 -1 torch.Size([5, 5])
2 2 2 2 0 torch.Size([5, 5])
2 2 2 2 1 torch.Size([5, 5])
2 2 2 2 2 torch.Size([5, 5])
2 2 2 3 -2 torch.Size([5, 5])
2 2 2 3 -1 torch.Size([5, 5])
2 2 2 3 0 torch.Size([5, 5])
2 2 2 3 1 torch.Size([5, 5])
2 2 2 3 2 torch.Size([5, 5])
2 2 2 4 -2 torch.Size([5, 5])
2 2 2 4 -1 torch.Size([5, 5])
2 2 2 4 0 torch.Size([5, 5])
2 2 2 4 1 torch.Size([5, 5])
2 2 2 4 2 torch.Size([5, 5])


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [5, 5] but got: [5, 1].

In [None]:
from equitorch.utils._geometries import align_to_z_wigner


D_so2_in = align_to_z_wigner(r, (l1_min, l1_max))
D_so2_out = align_to_z_wigner(r, (l_min, l_max))

In [None]:
rot_on(D_so2_out.transpose(-1,-2), so2.forward(rot_on(D_so2_in, X), W_so2))
# rot_on(D, so2.forward(rot_on(D1.transpose(-1,-2), X), W_so2))

tensor([[[ 0.1521, -0.3799, -0.3855, -0.3357, -0.5739],
         [-0.1186,  0.1137, -0.1234, -0.3965,  0.3965],
         [-0.2454, -0.4902, -0.0901, -0.3988,  0.2499],
         ...,
         [ 0.1107, -0.1456,  0.1322,  0.0303, -0.0950],
         [ 0.0135, -0.0315,  0.2652, -0.0560, -0.0646],
         [-0.3217, -0.2805, -0.1048, -0.0110, -0.0058]],

        [[ 0.4241,  0.5858,  0.2130, -0.3527,  0.2422],
         [-0.4065, -0.3411,  0.0597,  0.3384, -0.2942],
         [-0.1002, -0.2662, -0.1680, -0.2319, -0.3407],
         ...,
         [ 0.1807, -0.0892,  0.0884,  0.1008, -0.6223],
         [ 0.2033, -0.0676, -0.1079,  0.0989, -0.2588],
         [ 0.0230, -0.0820, -0.0328,  0.0288, -0.0457]],

        [[ 0.0642,  0.2381, -0.0535, -0.0988,  0.2166],
         [-0.2248,  0.2204,  0.2982,  0.3385,  0.0927],
         [-0.0919, -0.1478, -0.3404, -0.5299,  0.2118],
         ...,
         [-0.1253, -0.1105, -0.3782, -0.0594,  0.0357],
         [ 0.0873,  0.0739,  0.2264, -0.0008, -0.0105],
  

In [None]:
so3.forward(X, Y, W_so3)

tensor([[[-2.3485e-01, -2.5458e-01,  3.7538e-02,  4.3065e-02, -6.6671e-02],
         [ 1.1529e-01,  5.4354e-01, -4.3774e-01,  1.2770e-01, -1.5578e-01],
         [-1.3105e-02, -3.2032e-01, -5.4711e-02, -9.3531e-02, -3.5454e-01],
         ...,
         [-6.0789e-04, -1.3356e-01, -1.8051e-01, -1.0640e-01, -3.6389e-03],
         [-1.7156e-02, -4.9925e-02, -1.5906e-01, -1.4693e-01,  5.4501e-02],
         [-9.4884e-03, -1.1566e-02, -5.6345e-02, -6.1832e-02,  2.8461e-02]],

        [[ 1.3732e-02,  2.3007e-02,  1.1126e-01,  2.1982e-01, -2.0802e-01],
         [ 5.9295e-03, -2.0077e-01,  1.3106e-02,  1.6667e-01,  3.0641e-01],
         [ 2.7446e-02,  2.2059e-02, -2.3188e-01, -3.0910e-01,  1.8966e-01],
         ...,
         [ 3.9788e-02, -1.6774e-01,  1.5215e-01, -1.8787e-01, -3.0338e-02],
         [ 8.9428e-02,  1.3314e-02,  3.3123e-02,  1.7370e-01, -3.8402e-02],
         [-4.5342e-02,  2.2347e-01, -7.5628e-02,  2.0798e-02,  1.9095e-01]],

        [[-1.3852e-01, -1.1446e-01,  4.1548e-01,  2.1462

In [86]:
rot_on(D_so2_out.transpose(-1,-2), so2.forward(rot_on(D_so2_in, X), W_so2)).norm(dim=(-1,-2))

tensor([4.1911, 3.7891, 3.9845, 3.5987, 4.2526])

In [87]:
so3.forward(X, Y, W_so3).norm(dim=(-1,-2))

tensor([2.8862, 2.9930, 2.9723, 2.9839, 2.8070])

In [88]:
rot_on(D_so2_out.transpose(-1,-2), so2.forward(rot_on(D_so2_in, X), W_so2)).norm(dim=(-1,-2)) / so3.forward(X, Y, W_so3).norm(dim=(-1,-2))

tensor([1.4521, 1.2660, 1.3405, 1.2061, 1.5150])