In [40]:
import sys
sys.path.append('..')

import torch
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from equitorch.math._o3 import spherical_harmonics

In [41]:
N = 2000
L = 3
r = torch.randn(N, 3)
r = r / r.norm(dim=-1, keepdim=True)
Ys = spherical_harmonics(r, L=(0,L))

In [42]:
def add_scatter(fig, r, func, row,col):
    l = func.abs()
    fig.add_trace(
        go.Scatter3d(
            x=r[:, 0]*l,
            y=r[:, 1]*l,
            z=r[:, 2]*l,
            mode='markers',
            marker=dict(
                size=3,
                color=func,
                # colorscale='Viridis',
                colorscale='RdBu_r',
                colorbar=dict(thickness=20, len=0.5),
                showscale=False
            ),
            showlegend=False
        ),
        row=row, col=col
    )
    fig.update_scenes(
        xaxis_visible=True,
        yaxis_visible=True,
        zaxis_visible=True,
        xaxis_title="x",
        yaxis_title="y",
        zaxis_title="z",
        xaxis={'showticklabels':False},
        yaxis={'showticklabels':False},
        zaxis={'showticklabels':False},
        row=row, col=col,
        camera={'eye':{'x':1.8, 'y':1.8, 'z':1.8}},
    )


In [45]:
fig = make_subplots(
    rows=4, cols=7,
    specs=[[{'type': 'scene'} for j in range(7)] for i in range(4)],
    # subplot_titles=[f'Y_{i}' for i in range(16)],
    vertical_spacing=0,
    horizontal_spacing=0
)

for l in range(0,L+1):
    for m in range(-l,l+1):
        add_scatter(fig, r, Ys[:,l**2+l+m], l+1, L+m+1)

# Update layout
fig.update_layout(
    height=600,
    width=1400,
    title_text="Spherical Harmonics Visualization",
    # scene=dict(
    #     xaxis_title="X",
    #     yaxis_title="Y",
    #     zaxis_title="Z",
    #     aspectmode="cube"
    # ),
    margin={'b':50, 't':50, 'l':50, 'r':50}
)

# Update scene properties for all subplots


fig

ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed