In [15]:
import e3nn
from e3nn import io, o3
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
import plotly  # Interactive data visualization
import plotly.graph_objects as go  # Interactive plots

true_geometry = torch.tensor([
    [1, 0, 0],
    [-0.5, np.sqrt(3)/2, 0],
    [-0.5, -np.sqrt(3)/2, 0]
], dtype=torch.double)

def visualize(signal):
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            yaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            zaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            bgcolor='rgba(255,255,255,255)',
            aspectmode='cube',
            camera=dict(
                eye=dict(x=0.5, y=0.5, z=0.5)
            )
        ),
        plot_bgcolor='rgba(255,255,255,255)',
        paper_bgcolor='rgba(255,255,255,255)',
        margin=dict(l=0, r=0, t=0, b=0)
    )

    fig = go.Figure(data=[go.Surface(**sph.plotly_surface(signal)[0])], layout=layout)
    fig.show()


In [16]:
lmax=4
sph = io.SphericalTensor(lmax, p_val=1, p_arg=-1)

powerspectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

bispectrum = o3.ReducedTensorProducts(
    'ijk=jik=ikj', i=sph, 
    filter_ir_out=['0e', '0o'], 
    filter_ir_mid=o3.Irrep.iterator(lmax)
)

powerspectrum_lambda = lambda x: powerspectrum(x, x)
bispectrum_lambda = lambda x: bispectrum(x, x, x)

In [21]:
# true_sig = sph.with_peaks_at(true_geometry)
true_sig = sph.sum_of_diracs(true_geometry, torch.tensor(true_geometry.shape[0]*[1.0], dtype=torch.double))
true_spectrum = bispectrum_lambda(true_sig)

In [22]:
visualize(true_sig)


In [23]:
def invert_bispectrum(target_bispectrum, guess=None, max_iter=1000):
    if guess is None:
        guess = torch.randn(12, 3)
    guess.requires_grad = True
    opt = torch.optim.Adam([guess], lr=1e-2 )
    # loss_fn = torch.nn.MSELoss()
    loss_fn = torch.nn.L1Loss()
    for i in range(max_iter):
        cur_sig = sph.sum_of_diracs(guess, values=guess.norm(2, -1))
        # power = powerspectrum(cur_sig)
        cur_bis = bispectrum_lambda(cur_sig)
        loss = loss_fn(target_bispectrum, cur_bis) # + loss_fn(true_power, power)
        if i % 100 == 0:
            print(f"Step {i}, Loss: {loss}")
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(loss)
    return cur_sig


In [24]:
pred_sig = invert_bispectrum(true_spectrum)

Step 0, Loss: 3.1361235839222985
Step 100, Loss: 0.3530335154371478
Step 200, Loss: 0.09530004658243947
Step 300, Loss: 0.042164767553656425
Step 400, Loss: 0.023375591702183476
Step 500, Loss: 0.011454680000188261
Step 600, Loss: 0.00449174418734937
Step 700, Loss: 0.0005248688312575533
Step 800, Loss: 0.0002588273253885365
Step 900, Loss: 9.928323961616947e-05
tensor(0.0002, grad_fn=<MeanBackward0>)


In [25]:
visualize(pred_sig.detach())