In [50]:
import e3nn
from e3nn import io, o3
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
import plotly  # Interactive data visualization
import plotly.graph_objects as go  # Interactive plots

true_geometry = torch.tensor([
    [1, 0, 0],
    [-0.5, np.sqrt(3)/2, 0],
    [-0.5, -np.sqrt(3)/2, 0]
], dtype=torch.double)

def visualize(signal):
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            yaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            zaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            bgcolor='rgba(255,255,255,255)',
            aspectmode='cube',
            camera=dict(
                eye=dict(x=0.5, y=0.5, z=0.5)
            )
        ),
        plot_bgcolor='rgba(255,255,255,255)',
        paper_bgcolor='rgba(255,255,255,255)',
        margin=dict(l=0, r=0, t=0, b=0)
    )

    fig = go.Figure(data=[go.Surface(**sph.plotly_surface(signal)[0])], layout=layout)
    fig.show()


In [51]:
lmax=4
sph = io.SphericalTensor(lmax, p_val=1, p_arg=-1)

powerspectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

bispectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

powerspectrum_lambda = lambda x: powerspectrum(x, x)
bispectrum_lambda = lambda x: bispectrum(x, x, x)


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.



In [52]:
# true_sig = sph.with_peaks_at(true_geometry)
# true_sig = sph.sum_of_diracs(true_geometry, torch.tensor(true_geometry.shape[0]*[1.0], dtype=torch.double))
true_sig = sph.sum_of_diracs(true_geometry, torch.norm(true_geometry, dim=1))
true_spectrum = bispectrum_lambda(true_sig)

In [53]:
true_geometry 

tensor([[ 1.0000,  0.0000,  0.0000],
        [-0.5000,  0.8660,  0.0000],
        [-0.5000, -0.8660,  0.0000]])

In [54]:
visualize(true_sig)


In [55]:
def invert_bispectrum(target_bispectrum, guess=None, max_iter=2000):
    if guess is None:
        guess = torch.randn(12, 3)
    guess.requires_grad = True
    opt = torch.optim.Adam([guess], lr=1e-2 )
    loss_fn = torch.nn.MSELoss()
    # loss_fn = torch.nn.L1Loss()
    for i in range(max_iter):
        cur_sig = sph.sum_of_diracs(guess, values=guess.norm(2, -1))
        # power = powerspectrum(cur_sig)
        cur_bis = bispectrum_lambda(cur_sig)
        loss = loss_fn(target_bispectrum, cur_bis) # + loss_fn(true_power, power)
        if i % 100 == 0:
            print(f"Step {i}, Loss: {loss}, sig: {cur_sig}")
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(loss)
    return cur_sig


In [56]:
torch.norm(torch.randn(12, 3), dim=1)

tensor([1.4151, 0.6058, 1.3908, 1.8201, 1.0527, 1.7175, 1.6997, 0.5264, 0.5981,
        1.5038, 1.2090, 1.5683])

In [57]:
pred_sig = invert_bispectrum(true_spectrum)

Step 0, Loss: 28.87342306010826, sig: tensor([ 2.3054,  0.6348,  0.5546,  0.0829,  1.3108,  0.9028,  1.0842, -0.5663,
        -1.0734,  0.5844, -0.8871,  1.9402,  0.3120,  0.1273, -0.4906,  0.1619,
        -1.2676, -0.1384,  0.0732,  2.0583,  0.2422,  0.1553, -0.7589,  0.4689,
        -0.5334], grad_fn=<MulBackward0>)
Step 100, Loss: 0.6652664185921195, sig: tensor([ 1.4129,  0.2457,  0.4974,  0.0958,  0.4015,  0.2436,  0.8746, -0.1207,
        -0.4549,  0.5723, -0.3296,  0.6434,  0.0969, -0.0268,  0.2580, -0.0263,
        -0.1869,  0.4115,  0.2214,  0.3836,  0.0027,  0.2686,  0.3805,  0.1078,
         0.1037], grad_fn=<MulBackward0>)
Step 200, Loss: 0.21735700403594885, sig: tensor([ 1.1787,  0.1775,  0.4791,  0.1072,  0.3016,  0.2021,  0.7726, -0.0814,
        -0.3501,  0.5288, -0.2370,  0.5487,  0.0456, -0.0529,  0.2155, -0.0645,
        -0.0964,  0.4437,  0.1253,  0.3249,  0.0660,  0.1519,  0.3069,  0.1513,
         0.0372], grad_fn=<MulBackward0>)
Step 300, Loss: 0.106160686190199

In [58]:
visualize(pred_sig.detach())