## Recovery of spherical harmonic signal generated from points using power spectrum and bispectrum

In [1]:
import os
import sys
from typing import Optional, Union, Callable, Tuple

import chex
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import plotly
import plotly.graph_objects as go
import plotly.subplots as sp

# Add src directory to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)

from utils.plotters import visualize_geometry, colorplot, visualize_signal
from utils.geometries import trigonal_plane, tetrahedron, octahedron
import spectra

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

In [2]:
boron = np.array([
    [ 0.00000000,  0.00000000,  0.00000000],
    [ 1.39454902,  0.00000000,  0.57793632],
    [-1.23980586,  0.63847348,  0.57793632],
    [-0.29503787, -1.36298188,  0.57793632],
    [-0.91652028,  1.05107447, -0.57793632],
    [ 0.33360178, -1.35405938, -0.57793632],
    [ 1.22118593,  0.67340320, -0.57793632],
    [ 0.00000000,  0.00000000,  1.76852686],
])

visualize_geometry(boron, lmax=6)

In [13]:
def trispectrum(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
    """
    Computes the trispectrum of an array of irreducible representations.

    Parameters:
        x (e3nn.IrrepsArray): Input array of irreps.

    Returns:
        e3nn.IrrepsArray: The trispectrum of the input array.
    """
    rtp = e3nn.reduced_symmetric_tensor_product_basis(x.irreps, 4, keep_ir=['0o'])
    return e3nn.IrrepsArray(rtp.irreps, jnp.einsum("i,j,k,l,ijklz->z", x.array, x.array, x.array, x.array, rtp.array)).array

lmax = 4
geometry = boron
power_spectrum = spectra.Spectra(lmax=lmax, order=1)
true_signal = power_spectrum.compute_geometry_sh_signal(geometry)
tri = trispectrum(true_signal)

In [12]:
print("0o and 0e trispectrum: ", tri.shape)

0o and 0e trispectrum:  (65,)


In [14]:
print("0o trispectrum: ", tri.shape)

0o trispectrum:  (14,)


In [3]:
lmax = 6
power_spectrum = spectra.Spectra(lmax=lmax, order=1)
bispectrum = spectra.Spectra(lmax=lmax, order=2)

geometry = boron
true_signal = power_spectrum.compute_geometry_sh_signal(geometry)
true_power_spectrum = power_spectrum.compute_geometry_spectra(geometry)
true_bispectrum = bispectrum.compute_geometry_spectra(geometry)

In [4]:
def invert_to_signal(true_spectrum, spectrum_type, seed=0):
    """
    Inverts either the power spectrum or bispectrum to a signal.
    Returns the final signal and a list of intermediate signals when loss decreases significantly.
    """
    def loss(
        params: optax.Params, true_spectrum: chex.Array
    ) -> chex.Array:
        """Computes the loss corresponding to the current parameters."""
        # Compute signal.
        pred_sig = params["signal"]

        if spectrum_type == "power":
            pred_spectrum = power_spectrum.compute_sh_signal_spectra(pred_sig)
        elif spectrum_type == "bispectrum":
            pred_spectrum = bispectrum.compute_sh_signal_spectra(pred_sig)

        # Compute loss.
        loss_value = (
            optax.l2_loss(true_spectrum, pred_spectrum).mean()
        )
        return loss_value

    def fit(
        params: optax.Params,
        optimizer: optax.GradientTransformation,
        true_spectrum: chex.Array,
        max_iter: int = 10000,
    ) -> tuple:
        opt_state = optimizer.init(params)
        
        # Store intermediate signals
        intermediate_signals = []
        # Add initial signal (0th iteration)
        intermediate_signals.append(params["signal"])
        
        # Track minimum loss
        min_loss = float('inf')

        @jax.jit
        def step(params, opt_state, true_spectrum):
            loss_value, grads = jax.value_and_grad(loss)(
                params, true_spectrum
            )
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss_value

        for iter in range(max_iter):
            params, opt_state, loss_value = step(
                params, opt_state, true_spectrum
            )
            
            # Save signal when loss decreases by a factor of x
            if loss_value < min_loss / 2:
                intermediate_signals.append(params["signal"])
                min_loss = loss_value
                
            if iter % 100 == 0:
                print(f"step {iter}, loss: {loss_value}")

        # Return final parameters and all the intermediate signals
        return params, intermediate_signals
    
    rng = jax.random.PRNGKey(seed)
    random_signal = e3nn.IrrepsArray("1x0e+1x1o+1x2e+1x3o+1x4e", jax.random.normal(rng, (25,)))
    init_params = {"signal": random_signal}
    optimizer = optax.adam(learning_rate=1e-2)
    
    # Modified to get both final params and intermediate signals
    final_params, intermediate_signals = fit(init_params, optimizer, true_spectrum)
    
    # Return both the final signal and the list of intermediate signals
    return final_params["signal"], intermediate_signals

In [5]:
def signal_from_params(params: optax.Params) -> e3nn.IrrepsArray:
    """Creates signal from current parameters."""
    other_vec = params["points"]
    diracs = e3nn.s2_dirac(
        other_vec, lmax=lmax, p_val=1, p_arg=-1
    )
    weights = jnp.linalg.norm(other_vec, axis=-1)
    return e3nn.sum(diracs * weights[:, None])


def invert_to_points(true_spectrum, spectrum_type, seed=0):
    """
    Inverts either the power spectrum or bispectrum to a signal.
    Returns the final signal and a list of intermediate signals when loss decreases significantly.
    """
    def loss(
        params: optax.Params, true_spectrum: chex.Array
    ) -> chex.Array:
        """Computes the loss corresponding to the current parameters."""
        # Compute signal.
        pred_sig = signal_from_params(params)

        if spectrum_type == "power":
            pred_spectrum = power_spectrum.compute_sh_signal_spectra(pred_sig)
        elif spectrum_type == "bispectrum":
            pred_spectrum = bispectrum.compute_sh_signal_spectra(pred_sig)

        # Compute loss.
        loss_value = (
            optax.l2_loss(true_spectrum, pred_spectrum).mean()
        )
        return loss_value

    def fit(
        params: optax.Params,
        optimizer: optax.GradientTransformation,
        true_spectrum: chex.Array,
        max_iter: int = 10000,
    ):
        opt_state = optimizer.init(params)
        
        # Store intermediate signals
        intermediate_signals = []
        # Add initial signal (0th iteration)
        if spectrum_type == "power":
            intermediate_signals.append(power_spectrum.compute_geometry_sh_signal(params["points"]))
        elif spectrum_type == "bispectrum":
            intermediate_signals.append(bispectrum.compute_geometry_sh_signal(params["points"]))
        
        # Track minimum loss
        min_loss = float('inf')

        @jax.jit
        def step(params, opt_state, true_spectrum):
            loss_value, grads = jax.value_and_grad(loss)(
                params, true_spectrum
            )
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss_value

        for iter in range(max_iter):
            params, opt_state, loss_value = step(
                params, opt_state, true_spectrum
            )
            
            # Save signal when loss decreases by a factor of x
            if loss_value < min_loss / 3:
                if spectrum_type == "power":
                    intermediate_signals.append(power_spectrum.compute_geometry_sh_signal(params["points"]))
                elif spectrum_type == "bispectrum":
                    intermediate_signals.append(bispectrum.compute_geometry_sh_signal(params["points"]))
                min_loss = loss_value
                
            if iter % 100 == 0:
                print(f"step {iter}, loss: {loss_value}")

        # Return final parameters and all the intermediate signals
        return params, intermediate_signals
    
    rng = jax.random.PRNGKey(seed)
    points = jax.random.normal(rng, (8, 3))
    init_params = {"points": points}
    optimizer = optax.adam(learning_rate=1e-2)
    
    # Modified to get both final params and intermediate signals
    final_params, intermediate_signals = fit(init_params, optimizer, true_spectrum)
    
    # Return both the final points and the list of intermediate signals
    return final_params["points"], intermediate_signals

In [6]:
# # Power spectrum to SHs
# _, intermediate_signals = invert_to_signal(true_power_spectrum, "power")
# intermediate_spectra = [power_spectrum.compute_sh_signal_spectra(sig) for sig in intermediate_signals]

# # Bispectrum to SHs
# _, intermediate_signals = invert_to_signal(true_bispectrum, "bispectrum")
# intermediate_spectra = [bispectrum.compute_sh_signal_spectra(sig) for sig in intermediate_signals]

# # Power spectrum to points
# _, intermediate_signals = invert_to_points(true_power_spectrum, "power")
# intermediate_spectra = [power_spectrum.compute_sh_signal_spectra(sig) for sig in intermediate_signals]

# Bispectrum to points
_, intermediate_signals = invert_to_points(true_bispectrum, "bispectrum", seed=1)
intermediate_spectra = [bispectrum.compute_sh_signal_spectra(sig) for sig in intermediate_signals]

step 0, loss: 0.07154971499239167
step 100, loss: 0.001614779514435465
step 200, loss: 0.0007541017568652969
step 300, loss: 0.0003409785062249554
step 400, loss: 0.0002283182324989398
step 500, loss: 0.00018289583291971454
step 600, loss: 0.00015594821017581676
step 700, loss: 0.00013816465440284387
step 800, loss: 0.00012534907922371593
step 900, loss: 0.00011539396825450149
step 1000, loss: 0.00010716447042627728
step 1100, loss: 9.998323664468475e-05
step 1200, loss: 9.340327277910043e-05
step 1300, loss: 8.710776804810716e-05
step 1400, loss: 8.08619861524624e-05
step 1500, loss: 7.448842928826067e-05
step 1600, loss: 6.785737172069849e-05
step 1700, loss: 6.089646754292268e-05
step 1800, loss: 5.3635766190899264e-05
step 1900, loss: 4.630723143402689e-05
step 2000, loss: 3.941127746063213e-05
step 2100, loss: 3.3454073913528315e-05
step 2200, loss: 2.8516742925809314e-05
step 2300, loss: 2.4286259584053594e-05
step 2400, loss: 2.0393721978463036e-05
step 2500, loss: 1.66782397334

In [7]:
import plotly.graph_objects as go
import plotly.subplots as sp
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

# Modified colorplot function to return array data instead of matplotlib image
def colorplot_data(arr: jnp.ndarray):
    """Return data for spectra plotting compatible with plotly"""
    # Pad array with zeros to make length multiple of 5
    pad_length = (5 - (arr.size % 5)) % 5
    padded_arr = jnp.pad(arr, (0, pad_length))
    
    # Reshape into array with 5 columns
    num_rows = padded_arr.size // 5
    reshaped_arr = padded_arr.reshape(num_rows, 5)
    
    # Compute vmax for color scaling
    vmax = jnp.maximum(jnp.abs(jnp.min(reshaped_arr)), jnp.max(reshaped_arr))
    
    # Convert JAX arrays to standard Python/NumPy types for plotly compatibility
    return np.array(reshaped_arr), float(vmax)

# Create 2x2 subplot figure
# TRANSPOSED: Now rows are signals/spectra, columns are true/predicted
fig = sp.make_subplots(
    rows=2, 
    cols=2,
    subplot_titles=['True Signal', 'Predicted Signal', 'True Spectra', 'Predicted Spectra'],
    specs=[[{'type': 'scene'}, {'type': 'scene'}], 
           [{'type': 'heatmap'}, {'type': 'heatmap'}]],
    vertical_spacing=0.05,
    horizontal_spacing=0.05,
    shared_yaxes=True  # This can help remove extra space on the right
)

# Add true signal visualization - now in row 1, col 1  
true_sig = power_spectrum.compute_geometry_sh_signal(trigonal_plane)
true_fig = visualize_signal(true_sig)
for trace in true_fig.data:
    fig.add_trace(trace, row=1, col=1)

# Add predicted signal visualization - now in row 1, col 2
pred_fig = visualize_signal(pred_sig)
for trace in pred_fig.data:
    fig.add_trace(trace, row=1, col=2)

# Add true spectra visualization using plotly heatmap - now in row 2, col 1
if hasattr(true_power_spectrum, 'numpy'):
    true_power_spectrum_np = np.array(true_power_spectrum)
else:
    true_power_spectrum_np = true_power_spectrum
true_spectra_arr, true_vmax = colorplot_data(true_power_spectrum_np)
fig.add_trace(
    go.Heatmap(
        z=true_spectra_arr,
        colorscale='PuOr',
        zmid=0,
        zmin=-true_vmax,
        zmax=true_vmax,
        showscale=False,  # Explicitly disable colorbar
        colorbar=dict(len=0, thickness=0)  # Set colorbar length and thickness to 0
    ),
    row=2, col=1
)

# Add predicted spectra visualization using plotly heatmap - now in row 2, col 2
pred_spectra_data = power_spectrum.compute_sh_signal_spectra(pred_sig)
if hasattr(pred_spectra_data, 'numpy'):
    pred_spectra_data = np.array(pred_spectra_data)
pred_spectra_arr, pred_vmax = colorplot_data(pred_spectra_data)
fig.add_trace(
    go.Heatmap(
        z=pred_spectra_arr,
        colorscale='PuOr',
        zmid=0,
        zmin=-pred_vmax,
        zmax=pred_vmax,
        showscale=False,  # Explicitly disable colorbar
        colorbar=dict(len=0, thickness=0)  # Set colorbar length and thickness to 0
    ),
    row=2, col=2
)

# Update layout
fig.update_layout(
    height=800,
    width=800,
    showlegend=False,
    coloraxis_showscale=False,  # Globally disable colorscales
    margin=dict(r=10, l=10, t=50, b=10),  # Reduce right margin to remove space for colorbar
    plot_bgcolor='rgba(255,255,255,255)',
    paper_bgcolor='rgba(255,255,255,255)'
)

# Completely remove any possible colorbar
for annotation in fig['layout']['annotations']:
    annotation['xanchor'] = 'center'

# Force update all traces to ensure no colorbar
for i in range(len(fig.data)):
    if hasattr(fig.data[i], 'showscale'):
        fig.data[i].showscale = False
        if hasattr(fig.data[i], 'colorbar'):
            fig.data[i].colorbar = dict(thickness=0, len=0, showticklabels=False)


# Update 3D scene properties for both subplots in first row
for col in [1, 2]:
    fig.update_scenes(
        dict(
            xaxis=dict(
                title='',
                showticklabels=False,
                showgrid=False,
                zeroline=False,
                backgroundcolor='rgba(255,255,255,255)',
                range=[-2.5, 2.5]
            ),
            yaxis=dict(
                title='',
                showticklabels=False,
                showgrid=False,
                zeroline=False,
                backgroundcolor='rgba(255,255,255,255)',
                range=[-2.5, 2.5]
            ),
            zaxis=dict(
                title='',
                showticklabels=False,
                showgrid=False,
                zeroline=False,
                backgroundcolor='rgba(255,255,255,255)',
                range=[-2.5, 2.5]
            ),
            bgcolor='rgba(255,255,255,255)',
            aspectmode='cube',
            camera=dict(
                eye=dict(x=0, y=0, z=0.5)
            )
        ),
        row=1,
        col=col
    )

# Update heatmap properties for the second row subplots
for col in [1, 2]:
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=2, col=col)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=2, col=col)

fig.show()

NameError: name 'pred_sig' is not defined

In [None]:

import plotly.graph_objects as go
import plotly.subplots as sp
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

# Modified colorplot function to return array data instead of matplotlib image
def colorplot_data(arr: jnp.ndarray):
    """Return data for spectra plotting compatible with plotly"""
    # Pad array with zeros to make length multiple of 5
    pad_length = (5 - (arr.size % 5)) % 5
    padded_arr = jnp.pad(arr, (0, pad_length))
    
    # Reshape into array with 5 columns
    num_rows = padded_arr.size // 5
    reshaped_arr = padded_arr.reshape(num_rows, 5)
    
    # Compute vmax for color scaling
    vmax = jnp.maximum(jnp.abs(jnp.min(reshaped_arr)), jnp.max(reshaped_arr))
    
    # Convert JAX arrays to standard Python/NumPy types for plotly compatibility
    return np.array(reshaped_arr), float(vmax)


# Alternative approach for animations with Surface plots
def create_animation_v2(pred_signals, pred_spectra, true_sig, true_spectrum, spectra_type):
    """
    Alternative implementation for creating animations with Surface plots.
    This method recreates the entire figure for each frame rather than
    trying to update individual traces.
    
    Parameters are the same as create_animation().
    """
    # Pre-compute spectra data
    # True spectra
    if hasattr(true_spectrum, 'numpy'):
        true_spectrum_np = np.array(true_spectrum)
    else:
        true_spectrum_np = true_spectrum
    true_spectra_arr, true_vmax = colorplot_data(true_spectrum_np)
    
    # Find global max values for consistent scaling across frames
    pred_spectra_data = []
    
    for ps in pred_spectra:
        if hasattr(ps, 'numpy'):
            ps_np = np.array(ps)
        else:
            ps_np = ps
        spectra_arr, _ = colorplot_data(ps_np)
        pred_spectra_data.append(spectra_arr)
    
    # Create the initial figure
    fig = sp.make_subplots(
        rows=2, 
        cols=2,
        subplot_titles=['True Signal', 'Predicted Signal', f'True {spectra_type}', f'Predicted {spectra_type}'],
        specs=[[{'type': 'scene'}, {'type': 'scene'}], 
               [{'type': 'heatmap'}, {'type': 'heatmap'}]],
        vertical_spacing=0.05,
        horizontal_spacing=0.05,
        shared_yaxes=True
    )
    
    # True signal visualization (never changes)
    true_fig = visualize_signal(true_sig)
    for trace in true_fig.data:
        fig.add_trace(trace, row=1, col=1)
    
    # Add initial predicted signal visualization
    pred_fig_0 = visualize_signal(pred_signals[0])
    for trace in pred_fig_0.data:
        fig.add_trace(trace, row=1, col=2)
    
    # Add true spectra heatmap (never changes)
    fig.add_trace(
        go.Heatmap(
            z=true_spectra_arr,
            colorscale='PuOr',
            zmid=0,
            zmin=-true_vmax,
            zmax=true_vmax,
            showscale=False,
            colorbar=dict(len=0, thickness=0)
        ),
        row=2, col=1
    )
    
    # Add initial predicted spectra heatmap
    fig.add_trace(
        go.Heatmap(
            z=pred_spectra_data[0],
            colorscale='PuOr',
            zmid=0,
            zmin=-true_vmax,
            zmax=true_vmax,
            showscale=False,
            colorbar=dict(len=0, thickness=0)
        ),
        row=2, col=2
    )
    
    # Create frames
    frames = []
    
    for i in range(len(pred_signals)):
        # Create a new figure for this frame
        frame_fig = sp.make_subplots(
            rows=2, 
            cols=2,
            subplot_titles=['True Signal', 'Predicted Signal', f'True {spectra_type}', f'Predicted {spectra_type}'],
            specs=[[{'type': 'scene'}, {'type': 'scene'}], 
                   [{'type': 'heatmap'}, {'type': 'heatmap'}]],
            vertical_spacing=0.05,
            horizontal_spacing=0.05,
            shared_yaxes=True
        )
        
        # Add true signal (static)
        for trace in true_fig.data:
            frame_fig.add_trace(trace, row=1, col=1)
        
        # Add predicted signal for this frame
        pred_fig_i = visualize_signal(pred_signals[i])
        for trace in pred_fig_i.data:
            frame_fig.add_trace(trace, row=1, col=2)
        
        # Add true spectra (static)
        frame_fig.add_trace(
            go.Heatmap(
                z=true_spectra_arr,
                colorscale='PuOr',
                zmid=0,
                zmin=-true_vmax,
                zmax=true_vmax,
                showscale=False,
                colorbar=dict(len=0, thickness=0)
            ),
            row=2, col=1
        )
        
        # Add predicted spectra for this frame
        frame_fig.add_trace(
            go.Heatmap(
                z=pred_spectra_data[i],
                colorscale='PuOr',
                zmid=0,
                zmin=-true_vmax,
                zmax=true_vmax,
                showscale=False,
                colorbar=dict(len=0, thickness=0)
            ),
            row=2, col=2
        )
        
        # Update 3D scene properties for both subplots in first row
        for col in [1, 2]:
            frame_fig.update_scenes(
                dict(
                    xaxis=dict(
                        title='',
                        showticklabels=False,
                        showgrid=False,
                        zeroline=False,
                        backgroundcolor='rgba(255,255,255,255)',
                        range=[-2.5, 2.5]
                    ),
                    yaxis=dict(
                        title='',
                        showticklabels=False,
                        showgrid=False,
                        zeroline=False,
                        backgroundcolor='rgba(255,255,255,255)',
                        range=[-2.5, 2.5]
                    ),
                    zaxis=dict(
                        title='',
                        showticklabels=False,
                        showgrid=False,
                        zeroline=False,
                        backgroundcolor='rgba(255,255,255,255)',
                        range=[-2.5, 2.5]
                    ),
                    bgcolor='rgba(255,255,255,255)',
                    aspectmode='cube',
                    camera=dict(
                        eye=dict(x=0, y=0, z=0.5)
                    )
                ),
                row=1,
                col=col
            )
        
        # Update heatmap properties for the second row subplots
        for col in [1, 2]:
            frame_fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=2, col=col)
            frame_fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=2, col=col)
        
        # Create frame
        frame = go.Frame(
            name=f"frame_{i}",
            data=frame_fig.data
        )
        frames.append(frame)
    
    # Add frames to the figure
    fig.frames = frames
    
    # Update layout
    fig.update_layout(
        updatemenus=[
            {
                "type": "buttons",
                "buttons": [
                    {
                        "label": "Play",
                        "method": "animate",
                        "args": [
                            None, 
                            {
                                "frame": {"duration": 100, "redraw": True},
                                "fromcurrent": True,
                                "transition": {"duration": 0}
                            }
                        ]
                    },
                    {
                        "label": "Pause",
                        "method": "animate",
                        "args": [
                            [None], 
                            {
                                "frame": {"duration": 0, "redraw": False},
                                "mode": "immediate",
                                "transition": {"duration": 0}
                            }
                        ]
                    }
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 10},
                "showactive": False,
                "type": "buttons",
                "x": 0.1,
                "y": 0,
                "xanchor": "right",
                "yanchor": "top"
            }
        ],
        sliders=[{
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "prefix": "Iteration: ",
                "visible": True,
                "xanchor": "right"
            },
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": [
                {
                    "args": [
                        [f"frame_{i}"],
                        {
                            "frame": {"duration": 100, "redraw": True},
                            "mode": "immediate",
                            "transition": {"duration": 0}
                        }
                    ],
                    "label": str(i),
                    "method": "animate"
                }
                for i in range(len(frames))
            ]
        }]
    )
    
    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        showlegend=False,
        coloraxis_showscale=False,
        margin=dict(r=10, l=10, t=50, b=80),  # Increased bottom margin for slider
        plot_bgcolor='rgba(255,255,255,255)',
        paper_bgcolor='rgba(255,255,255,255)'
    )
    
    # Update 3D scene properties for both subplots in first row
    for col in [1, 2]:
        fig.update_scenes(
            dict(
                xaxis=dict(
                    title='',
                    showticklabels=False,
                    showgrid=False,
                    zeroline=False,
                    backgroundcolor='rgba(255,255,255,255)',
                    range=[-2.5, 2.5]
                ),
                yaxis=dict(
                    title='',
                    showticklabels=False,
                    showgrid=False,
                    zeroline=False,
                    backgroundcolor='rgba(255,255,255,255)',
                    range=[-2.5, 2.5]
                ),
                zaxis=dict(
                    title='',
                    showticklabels=False,
                    showgrid=False,
                    zeroline=False,
                    backgroundcolor='rgba(255,255,255,255)',
                    range=[-2.5, 2.5]
                ),
                bgcolor='rgba(255,255,255,255)',
                aspectmode='cube',
                camera=dict(
                    eye=dict(x=0, y=0, z=0.5)
                )
            ),
            row=1,
            col=col
        )
    
    # Update heatmap properties for the second row subplots
    for col in [1, 2]:
        fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=2, col=col)
        fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=2, col=col)
    
    # Ensure no colorbar
    for i in range(len(fig.data)):
        if hasattr(fig.data[i], 'showscale'):
            fig.data[i].showscale = False
            if hasattr(fig.data[i], 'colorbar'):
                fig.data[i].colorbar = dict(thickness=0, len=0, showticklabels=False)
    
    return fig



# Create the animation
animation_fig = create_animation_v2(
    intermediate_signals,
    intermediate_spectra,
    true_signal,
    # true_power_spectrum,
    true_bispectrum,
    # "Power Spectrum"
    "Bispectrum"
)

# Show the animation
animation_fig.write_html("signal_evolution.html")
animation_fig.show()