# Spherical Signal Alignment

In [3]:
import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)
import spectra as spectra
from utils.plotters import visualize_signal
from utils.alignment import sample_uniform_quaternion, find_best_random_quaternion, align_signals
from utils.geometries import trigonal_plane
import jax
jax.config.update("jax_enable_x64", True)

lmax = 4

In [4]:
original_signal = spectra.sum_of_diracs(trigonal_plane, lmax=lmax)
visualize_signal(original_signal)

In [5]:
key = jax.random.PRNGKey(0)
ground_truth_quaternion = sample_uniform_quaternion(key)
rotated_signal = original_signal.transform_by_quaternion(ground_truth_quaternion)

In [6]:
initial_quaternion = find_best_random_quaternion(key, original_signal, rotated_signal, num_samples=100)
predicted_quaternion, _ = align_signals(original_signal, rotated_signal, initial_quaternion, learning_rate=0.01)

In [7]:
print(f"Predicted quaternion: {predicted_quaternion}")
print(f"Ground truth quaternion: {ground_truth_quaternion}")

Predicted quaternion: [-0.4427944   0.58459489 -0.3484359  -0.58375882]
Ground truth quaternion: [-0.44279448  0.58459476 -0.34843594 -0.58375886]


In [8]:
visualize_signal(original_signal.transform_by_quaternion(predicted_quaternion))

In [9]:
visualize_signal(rotated_signal)