In [1]:
import sys
sys.path.append('..')

import torch
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from equitorch.math._o3 import spherical_harmonics

In [28]:
N = 2000
C = 1
L = 3
r = torch.randn(N, 3, C, dtype=torch.float64)
r = r / r.norm(dim=-2, keepdim=True)
Ys = spherical_harmonics(r, L=(0,L), dim=-2)
inner = torch.einsum('NPC,NQC->PQ', Ys, Ys) * 4 * torch.pi / (Ys.shape[0] * Ys.shape[-1])
print(inner[:5,:5])

tensor([[ 1.0000, -0.0257, -0.0066, -0.0090,  0.0190],
        [-0.0257,  1.0015,  0.0216,  0.0147, -0.0238],
        [-0.0066,  0.0216,  1.0057,  0.0124,  0.0272],
        [-0.0090,  0.0147,  0.0124,  0.9927,  0.0168],
        [ 0.0190, -0.0238,  0.0272,  0.0168,  0.9653]], dtype=torch.float64)


In [17]:
def add_scatter(fig, r, func, row,col):
    l = func.abs()
    fig.add_trace(
        go.Scatter3d(
            x=r[:, 0]*l,
            y=r[:, 1]*l,
            z=r[:, 2]*l,
            text=func,
            mode='markers',
            marker=dict(
                size=3,
                color=func,
                # colorscale='Viridis',
                colorscale='RdBu_r',
                colorbar=dict(thickness=20, len=0.5),
                showscale=False
            ),
            showlegend=False
        ),
        row=row, col=col
    )
    fig.update_scenes(
        xaxis_visible=True,
        yaxis_visible=True,
        zaxis_visible=True,
        xaxis_title="x",
        yaxis_title="y",
        zaxis_title="z",
        xaxis={'showticklabels':False},
        yaxis={'showticklabels':False},
        zaxis={'showticklabels':False},
        row=row, col=col,
        camera={'eye':{'x':1.8, 'y':1.8, 'z':1.8}},
    )


In [18]:
fig = make_subplots(
    rows=4, cols=7,
    specs=[[{'type': 'scene'} for j in range(7)] for i in range(4)],
    # subplot_titles=[f'Y_{i}' for i in range(16)],
    vertical_spacing=0,
    horizontal_spacing=0
)

for l in range(0,L+1):
    for m in range(-l,l+1):
        add_scatter(fig, r[:,:,0], Ys[:,l**2+l+m,0], l+1, L+m+1)

# Update layout
fig.update_layout(
    height=600,
    width=1400,
    title_text="Spherical Harmonics Visualization",
    # scene=dict(
    #     xaxis_title="X",
    #     yaxis_title="Y",
    #     zaxis_title="Z",
    #     aspectmode="cube"
    # ),
    margin={'b':50, 't':50, 'l':50, 'r':50}
)

# Update scene properties for all subplots


# fig
print()




In [37]:
spherical_harmonics(torch.Tensor([0,1,1]), L=(0,L), dim=-1)

tensor([ 0.2821,  0.4886,  0.4886,  0.0000,  0.0000,  1.0925,  0.3154,  0.0000,
        -0.5463, -0.5900,  0.0000,  1.3711, -0.3732,  0.0000, -1.4453, -0.0000])

In [34]:
import math
math.sqrt(15/16/math.pi)

0.5462742152960396