In [19]:
import e3nn
from e3nn import io, o3
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
import plotly  # Interactive data visualization
import plotly.graph_objects as go  # Interactive plots

true_geometry = torch.tensor([
    [1, 0, 0],
    [-0.5, np.sqrt(3)/2, 0],
    [-0.5, -np.sqrt(3)/2, 0]
], dtype=torch.double)

def visualize(signal):
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            yaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            zaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            bgcolor='rgba(255,255,255,255)',
            aspectmode='cube',
            camera=dict(
                eye=dict(x=0.5, y=0.5, z=0.5)
            )
        ),
        plot_bgcolor='rgba(255,255,255,255)',
        paper_bgcolor='rgba(255,255,255,255)',
        margin=dict(l=0, r=0, t=0, b=0)
    )

    fig = go.Figure(data=[go.Surface(**sph.plotly_surface(signal)[0])], layout=layout)
    fig.show()


In [20]:
lmax=4
sph = io.SphericalTensor(lmax, p_val=1, p_arg=-1)

powerspectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

bispectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

powerspectrum_lambda = lambda x: powerspectrum(x, x)
bispectrum_lambda = lambda x: bispectrum(x, x, x)

In [104]:
true_sig = sph.with_peaks_at(true_geometry)
# true_sig = sph.sum_of_diracs(true_geometry, torch.norm(true_geometry, dim=1))
true_spectrum = bispectrum_lambda(true_sig)

In [87]:
true_geometry 

tensor([[ 1.0000,  0.0000,  0.0000],
        [-0.5000,  0.8660,  0.0000],
        [-0.5000, -0.8660,  0.0000]])

In [71]:
visualize(true_sig)


In [123]:
torch.manual_seed(0)
guess = torch.tensor([[1, 0.05 ,0], [1, 0, 0], [1, -0.01, 0]], dtype=torch.double)
guess

tensor([[ 1.0000,  0.0010,  0.0000],
        [ 1.0000,  0.0000,  0.0000],
        [ 1.0000, -0.0100,  0.0000]])

In [127]:
visualize(sph.with_peaks_at(true_geometry, values=guess.norm(2, -1)).detach())

In [121]:
def invert_bispectrum(target_bispectrum, guess=None, max_iter=2000):
    if guess is None:
        guess = torch.randn(3, 3)
    guess.requires_grad = True
    opt = torch.optim.Adam([guess], lr=1e-1)
    # loss_fn = torch.nn.MSELoss()
    loss_fn = torch.nn.L1Loss()
    for i in range(max_iter):
        # cur_sig = sph.sum_of_diracs(guess, values=guess.norm(2, -1))
        cur_sig = sph.with_peaks_at(guess, values=guess.norm(2, -1))
        # power = powerspectrum(cur_sig)
        cur_bis = bispectrum_lambda(cur_sig)
        loss = loss_fn(target_bispectrum, cur_bis) # + loss_fn(true_power, power)
        if i % 100 == 0:
            print(f"Step {i}, Loss: {loss}, sig: {cur_sig}")
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(loss)
    return cur_sig


In [122]:
pred_sig = invert_bispectrum(true_spectrum, guess=guess)

Step 0, Loss: 0.058731854292557985, sig: tensor([ 1.4180e-01,  2.4560e-01,  3.6186e-14,  0.0000e+00,  0.0000e+00,
         8.0914e-14, -1.5853e-01,  0.0000e+00, -2.7459e-01, -2.9659e-01,
         0.0000e+00, -2.2974e-01, -8.2912e-14,  0.0000e+00, -1.0704e-13,
         0.0000e+00,  0.0000e+00, -1.3110e-13,  0.0000e+00, -1.4865e-13,
         1.5952e-01,  0.0000e+00,  2.3780e-01,  0.0000e+00,  3.1458e-01],
       grad_fn=<SqueezeBackward4>)
Step 100, Loss: 0.12055516137859981, sig: tensor([ 0.4700,  0.7190, -0.1252,  0.0000,  0.0000, -0.1567, -0.2100,  0.0000,
        -0.7280, -0.7257,  0.0000, -0.1202, -0.0566,  0.0000,  0.1043,  0.0000,
         0.0000,  0.0483,  0.0000, -0.2561,  0.0307,  0.0000,  0.0273,  0.0000,
         0.7200], grad_fn=<SqueezeBackward4>)
Step 200, Loss: 0.11846001798123343, sig: tensor([ 0.4679,  0.7148, -0.1231,  0.0000,  0.0000, -0.1511, -0.2057,  0.0000,
        -0.7228, -0.7197,  0.0000, -0.1131, -0.0639,  0.0000,  0.0962,  0.0000,
         0.0000,  0.0379,  0

In [109]:
visualize(pred_sig.detach())