In [25]:
import e3nn
from e3nn import io, o3
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
import plotly  # Interactive data visualization
import plotly.graph_objects as go  # Interactive plots

true_geometry = torch.tensor([
    [2, 0, 0],
    [-0.5, np.sqrt(3)/2, 0],
    [-0.5, -np.sqrt(3)/2, 0]
], dtype=torch.double)

def visualize(signal):
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            yaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            zaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            bgcolor='rgba(255,255,255,255)',
            aspectmode='cube',
            camera=dict(
                eye=dict(x=0.5, y=0.5, z=0.5)
            )
        ),
        plot_bgcolor='rgba(255,255,255,255)',
        paper_bgcolor='rgba(255,255,255,255)',
        margin=dict(l=0, r=0, t=0, b=0)
    )

    fig = go.Figure(data=[go.Surface(**sph.plotly_surface(signal)[0])], layout=layout)
    fig.show()


In [2]:
lmax=4
sph = io.SphericalTensor(lmax, p_val=1, p_arg=-1)

powerspectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

bispectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

powerspectrum_lambda = lambda x: powerspectrum(x, x)
bispectrum_lambda = lambda x: bispectrum(x, x, x)



In [27]:
# true_sig = sph.with_peaks_at(true_geometry)
# true_sig = sph.sum_of_diracs(true_geometry, torch.tensor(true_geometry.shape[0]*[1.0], dtype=torch.double))
true_sig = sph.sum_of_diracs(true_geometry, torch.norm(true_geometry, dim=1))
true_spectrum = bispectrum_lambda(true_sig)

In [26]:
true_geometry 

tensor([[ 2.0000,  0.0000,  0.0000],
        [-0.5000,  0.8660,  0.0000],
        [-0.5000, -0.8660,  0.0000]])

In [28]:
visualize(true_sig)


In [37]:
def invert_bispectrum(target_bispectrum, guess=None, max_iter=2000):
    if guess is None:
        guess = torch.randn(100, 3)
    guess.requires_grad = True
    opt = torch.optim.Adam([guess], lr=1e-2 )
    # loss_fn = torch.nn.MSELoss()
    loss_fn = torch.nn.L1Loss()
    for i in range(max_iter):
        cur_sig = sph.sum_of_diracs(guess, values=guess.norm(2, -1))
        # power = powerspectrum(cur_sig)
        cur_bis = bispectrum_lambda(cur_sig)
        loss = loss_fn(target_bispectrum, cur_bis) # + loss_fn(true_power, power)
        if i % 100 == 0:
            print(f"Step {i}, Loss: {loss}")
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(loss)
    return cur_sig


In [38]:
torch.norm(torch.randn(12, 3), dim=1)

tensor([1.5246, 2.3681, 1.3634, 1.4072, 0.8600, 1.5546, 1.7775, 1.6982, 0.9886,
        0.7296, 1.7097, 1.2923])

In [39]:
pred_sig = invert_bispectrum(true_spectrum)

Step 0, Loss: 12065006.966468336
Step 100, Loss: 473723.8677414754
Step 200, Loss: 143194.33725983786
Step 300, Loss: 65804.66876970531
Step 400, Loss: 36736.30597132612
Step 500, Loss: 23128.982236304255
Step 600, Loss: 15869.599339222723
Step 700, Loss: 11519.722910772343
Step 800, Loss: 8681.047810381076
Step 900, Loss: 6762.9395330627485
Step 1000, Loss: 5399.237240159956
Step 1100, Loss: 4407.199788437512
Step 1200, Loss: 3663.1201336221175
Step 1300, Loss: 3091.947372686339
Step 1400, Loss: 2628.8588944545563
Step 1500, Loss: 2251.0023009012284
Step 1600, Loss: 1949.8058710152868
Step 1700, Loss: 1701.4866466552444
Step 1800, Loss: 1491.6115096219262
Step 1900, Loss: 1314.766662867418
tensor(1164.6861, grad_fn=<MseLossBackward0>)


In [40]:
visualize(pred_sig.detach())