In [1]:
import sys
sys.path.append('..')

import torch
import e3nn

from equitorch.nn._tensor_product import WeightedTensorProduct, TensorProduct, TensorDot
from equitorch.nn._linear import SO3Linear
from equitorch.math._o3 import spherical_harmonics
from equitorch.utils._indices import num_order_between
from equitorch.utils._geometries import rot_on
from equitorch.math._o3 import wigner_D



In [2]:
import torch_geometric
from equitorch.utils._clebsch_gordan import dense_CG
from equitorch.utils._indices import degree_order_to_index, list_degrees, order_ptr


def WeightedEinsum(in_channels, out_channels, l_min, l_max, l1_min, l1_max, l2_min, l2_max, device='cpu'):
    def expand_weight(weight, weight_to, Ml1l2s, l_ind):
        weight_to[:,*Ml1l2s,...] = weight[:,l_ind,...]
        return weight_to

    def generate_lMl1l2s(ls,l_min=0):
        '''
        returns:  [ (l,M,l1,l2), 
                    l_ind) ]
        '''
        ret = sorted([(l,degree_order_to_index(l,m,l_min),l1,l2, l_ind) 
                    for l_ind, (l, l1, l2) in enumerate(ls) 
                    for m in range(-l,l+1)])
        ls_lMl1l2 = torch.tensor([[t[0], t[2], t[3]] for t in ret]).T
        Ms = torch.tensor([t[1] for t in ret])
        l_ind = torch.tensor([t[4] for t in ret])
        return ls_lMl1l2, Ms, l_ind 
    ls = list_degrees((l_min, l_max), (l1_min, l1_max), (l2_min, l2_max))
    ls_lMl1l2, Ms, l_ind = generate_lMl1l2s(ls, l_min)
    l1 = ls_lMl1l2[1]-l1_min
    l2 = ls_lMl1l2[2]-l2_min
    CG = dense_CG((l_min, l_max), (l1_min, l1_max), (l2_min, l2_max)).to(device)
    l1_ptr = order_ptr((l1_min, l1_max)).to(device)
    l2_ptr = order_ptr((l2_min, l2_max)).to(device)
    shape_W = (num_order_between(l_min, l_max), l1_max-l1_min+1, l2_max-l2_min+1)
    l_ind = l_ind.to(device)
    Ms = Ms.to(device)
    l1 = l1.to(device)
    l2 = l2.to(device)
    def tp(X:torch.Tensor, Y:torch.Tensor, W:torch.Tensor, cw: bool, cn: bool):
        '''
        X: B * M1 * C
        Y: B * M2 * C'
        W: B * Ls * C * C' * D for fc
           B * Ls * C = C' * D for cwc
           B * Ls * C * C' for pw
           B * Ls * C for cw
        '''
        non_contract = torch.einsum('PQR,bQc,bRC->bPQRcC', CG, X, Y)
        inter = torch_geometric.utils.segment(non_contract, l2_ptr.reshape(1,1,1,-1))
        inter = torch_geometric.utils.segment(inter, l1_ptr.reshape(1,1,-1)) # bPqrcC
        weight = torch.zeros(W.shape[0],*shape_W,*(W.shape[2:]),device=device) # bPqrcCD
        weight = expand_weight(W, weight, (Ms,l1,l2), l_ind) # bPqr...
        if cw and cn:
            return torch.einsum('bPqrCC, bPqrCD -> bPD', inter, weight)
        elif cw and not cn:
            return torch.einsum('bPqrCC, bPqrC -> bPC', inter, weight)
        elif not cw and cn:
            return torch.einsum('bPqrcC, bPqrcCD -> bPD', inter, weight)
        else:
            return torch.einsum('bPqrcC, bPqrcC -> bPcC', inter, weight)
  
    return tp

In [3]:
def init(N, L, L1, L2, channel1, channel2, channel):
    X = torch.randn(N, num_order_between(*L1), channel1)
    Y = torch.randn(N, num_order_between(*L2), channel2)
    wtp_cw = WeightedTensorProduct(L1, L2, L, channel1, channel1, connected=False, channel_wise=True, external_weight=True)
    wtp_pw = WeightedTensorProduct(L1, L2, L, channel1, channel2, connected=False, channel_wise=False, external_weight=True)
    wtp_cwc = WeightedTensorProduct(L1, L2, L, channel1, channel1, channel, connected=True, channel_wise=True, external_weight=True)
    wtp_fc = WeightedTensorProduct(L1, L2, L, channel1, channel2, channel, connected=True, channel_wise=False, external_weight=True)
    tp_cw = TensorProduct(L1, L2, L, True)
    tp_pw = TensorProduct(L1, L2, L, False)
    lin_cw = SO3Linear(L1, L2, L, channel, channel, True, True)
    lin_fc = SO3Linear(L1, L2, L, channel, channel1, True, False)
    weight_fc = torch.randn(N, wtp_cw.num_weights, channel1, channel2, channel)
    ref = WeightedEinsum(0, 0, *L, *L1, *L2)
    a,b,c = e3nn.o3.rand_angles(N)
    D1 = wigner_D(L1, a, b, c)
    D2 = wigner_D(L2, a, b, c)
    D = wigner_D(L, a, b, c)

    return (
        X, Y,
        wtp_cw, wtp_pw, wtp_cwc, wtp_fc,
        tp_cw, tp_pw, lin_cw, lin_fc, ref,
        weight_fc, 
        D1, D2, D
    ) 

In [4]:
N = 2
L = (0, 2)
L1 = (0, 2)
L2 = (0, 2)
channel1 = 3
channel2 = 2
channel = 5

# f=False
f=True
if f:
    L = (1, 4)
    L1 = (2, 6)
    L2 = (3, 5)
    channel1 = 3
    channel2 = 2
    channel = 5
(
    X, Y,
    wtp_cw, wtp_pw, wtp_cwc, wtp_fc,
    tp_cw, tp_pw, lin_cw, lin_fc, ref,
    weight_fc, # N, wtp_cw.num_weights, channel1, channel2, channel
    D1, D2, D
)  = init(N, L, L1, L2, channel1, channel2, channel)

In [5]:
# for cw
# channel_wise: C1=C2=C3, W[C] # for no-weight-tp
# Z_{c} = X_c \otimes_{W_{c}} Y_c 
Z_lin00 = lin_cw.forward(X[:,:,0:1], Y[:,:,0], weight_fc[:,:,0:1,0,0])
Z_lin11 = lin_cw.forward(X[:,:,1:2], Y[:,:,1], weight_fc[:,:,1:2,0,0])
Z_lin = torch.cat([Z_lin00, Z_lin11],dim=-1)
Z_tp = wtp_cw.forward(X[:,:,:2], Y, weight_fc[:,:,:2,0,0])
Z_ref = ref(X[:,:,:2], Y, weight_fc[:,:,:2,0,0], True, False)

print(wtp_cw)

print(Z_lin00.shape, Z_lin11.shape, Z_tp.shape, Z_ref.shape)
print((Z_lin-Z_tp).abs().max())
print((Z_tp-Z_ref).abs().max())
print('--'*10)

Z_tpR = rot_on(D.transpose(-1,-2),wtp_cw.forward(rot_on(D1,X[:,:,:2]), rot_on(D2,Y), weight_fc[:,:,:2,0,0]))
print((Z_tp-Z_tpR).abs().max())

WeightedTensorProduct(
  L_in1=(2, 6), L_in2=(3, 5), L_out=(1, 4), 
  tp_type=channel_wise, in_channels=3, external_weight=True
)
torch.Size([2, 24, 1]) torch.Size([2, 24, 1]) torch.Size([2, 24, 2]) torch.Size([2, 24, 2])
tensor(4.7684e-07)
tensor(1.1921e-06)
--------------------
tensor(2.5511e-05)


In [6]:
# for pw
# pair_wise: C1,C2 -> C1C2, W[C1, C2] # for no-weight-tp
# Z_{c1,c2} = X_c1 \otimes_{W_{c1,c2}} Y_c2 
Z_lin0_30 = lin_cw.forward(X, Y[:,:,0], weight_fc[:,:,:,0,0])
Z_lin0_31 = lin_cw.forward(X, Y[:,:,1], weight_fc[:,:,:,1,0])
Z_lin = torch.stack([Z_lin0_30, Z_lin0_31],dim=-1).flatten(-2,-1)
Z_tp = wtp_pw.forward(X, Y, weight_fc[:,:,:,:,0]).flatten(-2,-1)

print(wtp_pw)

Z_ref = ref(X, Y, weight_fc[:,:,:,:,0], False, False).flatten(-2,-1)
print(Z_lin0_30.shape, Z_lin0_31.shape, Z_lin.shape, Z_tp.shape, Z_ref.shape)
print((Z_lin-Z_tp).abs().max())
print((Z_tp-Z_ref).abs().max())
print('--'*10)

Z_tpR = rot_on(D.transpose(-1,-2),wtp_pw.forward(rot_on(D1,X), rot_on(D2,Y), weight_fc[:,:,:,:,0]).flatten(-2,-1))
print((Z_tp-Z_tpR).abs().max())

WeightedTensorProduct(
  L_in1=(2, 6), L_in2=(3, 5), L_out=(1, 4), 
  tp_type=pair_wise, in1_channels=3, in2_channels=2, external_weight=True
)
torch.Size([2, 24, 3]) torch.Size([2, 24, 3]) torch.Size([2, 24, 6]) torch.Size([2, 24, 6]) torch.Size([2, 24, 6])
tensor(1.4305e-06)
tensor(4.7684e-07)
--------------------
tensor(3.9577e-05)


In [7]:
# for cwc
# channel_wise_connected: C1=C2 -> C3, W[C,C3]
# Z_c3 = X_c \otimes_{W_{c,c3}} Y_c
Z_lin00 = lin_fc.forward(X[:,:,0:1], Y[:,:,0], weight_fc[:,:,0:1,0,:])
Z_lin11 = lin_fc.forward(X[:,:,1:2], Y[:,:,1], weight_fc[:,:,1:2,0,:])
Z_lin = Z_lin00 + Z_lin11
Z_tp = wtp_cwc.forward(X[:,:,:2], Y, weight_fc[:,:,:2,0,:])
Z_ref = ref(X[:,:,:2], Y, weight_fc[:,:,:2,0,:], True, True)

print(wtp_cwc)

print(Z_lin00.shape, Z_lin11.shape, Z_tp.shape, Z_ref.shape)
print((Z_lin-Z_tp).abs().max())
print((Z_tp-Z_ref).abs().max())
print('--'*10)

Z_tpR = rot_on(D.transpose(-1,-2),wtp_cwc.forward(rot_on(D1,X[:,:,:2]), rot_on(D2,Y), weight_fc[:,:,:2,0,:]))
print((Z_tp-Z_tpR).abs().max())

WeightedTensorProduct(
  L_in1=(2, 6), L_in2=(3, 5), L_out=(1, 4), 
  tp_type=channel_wise_connected, in_channels=3, out_channels=5, external_weight=True
)
torch.Size([2, 24, 5]) torch.Size([2, 24, 5]) torch.Size([2, 24, 5]) torch.Size([2, 24, 5])
tensor(9.5367e-07)
tensor(9.5367e-07)
--------------------
tensor(4.7922e-05)


In [8]:
# for fc
# fully_connected: C1, C2 -> C3, W[C1,C2,C3]
# Z_c3 = X_c1 \otimes_{W_{c1,c2,c3}} Y_c2
Z_lin0_30 = lin_fc.forward(X, Y[:,:,0], weight_fc[:,:,:,0,:])
Z_lin0_31 = lin_fc.forward(X, Y[:,:,1], weight_fc[:,:,:,1,:])
Z_lin = Z_lin0_30 + Z_lin0_31
Z_tp = wtp_fc.forward(X, Y, weight_fc)
Z_ref = ref(X, Y, weight_fc, False, True)

print(wtp_fc)

print(Z_lin0_30.shape, Z_lin0_31.shape, Z_lin.shape, Z_tp.shape, Z_ref.shape)
print((Z_lin-Z_tp).abs().max())
print((Z_tp-Z_ref).abs().max())
print('--'*10)

Z_tpR = rot_on(D.transpose(-1,-2),wtp_fc.forward(rot_on(D1,X), rot_on(D2,Y), weight_fc))
print((Z_tp-Z_tpR).abs().max())

WeightedTensorProduct(
  L_in1=(2, 6), L_in2=(3, 5), L_out=(1, 4), 
  tp_type=fully_connected, in1_channels=3, in2_channels=2, out_channels=5, external_weight=True
)
torch.Size([2, 24, 5]) torch.Size([2, 24, 5]) torch.Size([2, 24, 5]) torch.Size([2, 24, 5]) torch.Size([2, 24, 5])
tensor(1.9073e-06)
tensor(1.9073e-06)
--------------------
tensor(7.5817e-05)


In [9]:
weight_fc = weight_fc.fill_(1)

In [10]:
# for cw
# channel_wise: C1=C2=C3, for no-weight-tp
# Z_{c} = X_c \otimes Y_c 
Z_lin00 = lin_cw.forward(X[:,:,0:1], Y[:,:,0], weight_fc[:,:,0:1,0,0])
Z_lin11 = lin_cw.forward(X[:,:,1:2], Y[:,:,1], weight_fc[:,:,1:2,0,0])
Z_lin = torch.cat([Z_lin00, Z_lin11],dim=-1)
Z_tp = tp_cw.forward(X[:,:,:2], Y)
Z_ref = ref(X[:,:,:2], Y, weight_fc[:,:,:2,0,0], True, False)

print(tp_cw)

print(Z_lin00.shape, Z_lin11.shape, Z_tp.shape, Z_ref.shape)
print((Z_lin-Z_tp).abs().max())
print((Z_tp-Z_ref).abs().max())
print('--'*10)

Z_tpR = rot_on(D.transpose(-1,-2),tp_cw.forward(rot_on(D1,X[:,:,:2]), rot_on(D2,Y)))
print((Z_tp-Z_tpR).abs().max())

TensorProduct(L_in1=(2, 6), L_in2=(3, 5), L_out=(1, 4), channel_wise=True)
torch.Size([2, 24, 1]) torch.Size([2, 24, 1]) torch.Size([2, 24, 2]) torch.Size([2, 24, 2])
tensor(4.7684e-07)
tensor(1.4305e-06)
--------------------
tensor(2.4080e-05)


In [11]:
# for pw
# pair_wise: C1,C2 -> C1C2, W[C1, C2] for no-weight-tp
# Z_{c1,c2} = X_c1 \otimes_{W_{c1,c2}} Y_c2 
Z_lin0_30 = lin_cw.forward(X, Y[:,:,0], weight_fc[:,:,:,0,0])
Z_lin0_31 = lin_cw.forward(X, Y[:,:,1], weight_fc[:,:,:,1,0])
Z_lin = torch.stack([Z_lin0_30, Z_lin0_31],dim=-1).flatten(-2,-1)
Z_tp = tp_pw.forward(X, Y).flatten(-2,-1)
Z_ref = ref(X, Y, weight_fc[:,:,:,:,0], False, False).flatten(-2,-1)

print(tp_pw)

print(Z_lin0_30.shape, Z_lin0_31.shape, Z_lin.shape, Z_tp.shape, Z_ref.shape)
print((Z_lin-Z_tp).abs().max())
print((Z_tp-Z_ref).abs().max())
print('--'*10)

Z_tpR = rot_on(D.transpose(-1,-2),tp_pw.forward(rot_on(D1,X), rot_on(D2,Y)).flatten(-2,-1))
print((Z_tp-Z_tpR).abs().max())

TensorProduct(L_in1=(2, 6), L_in2=(3, 5), L_out=(1, 4), channel_wise=False)
torch.Size([2, 24, 3]) torch.Size([2, 24, 3]) torch.Size([2, 24, 6]) torch.Size([2, 24, 6]) torch.Size([2, 24, 6])
tensor(4.7684e-07)
tensor(1.4305e-06)
--------------------
tensor(3.0756e-05)


In [12]:
# TensorDot
N = 5
L = (1, 5)
channel1 = 3
channel2 = 2
channel = 5
(X,Y,_,_,_,_,_,_,_,_,_,_,D1,D2,D) = init(N, L, L, L, channel1, channel2, channel)

td_cw = TensorDot(L, True)
print(td_cw)
d_cw = td_cw.forward(X[:,:,:2], Y)
d_cwR = td_cw.forward(rot_on(D1,X[:,:,:2]), rot_on(D2,Y))
print(d_cw.shape, (d_cw-d_cwR).abs().max())
print('--'*20)

td_pw = TensorDot(L, False)
print(td_pw)
d_pw = td_pw.forward(X, Y)
d_pwR = td_pw.forward(rot_on(D1,X), rot_on(D2,Y))
print(d_pw.shape, (d_pw-d_pwR).abs().max())

TensorDot(L=(1, 5), channel_wise=True)
torch.Size([5, 5, 2]) tensor(7.0572e-05)
----------------------------------------
TensorDot(L=(1, 5), channel_wise=False)
torch.Size([5, 5, 3, 2]) tensor(7.0572e-05)
