In [11]:
%load_ext autoreload
%autoreload 2

from jax import config
config.update("jax_enable_x64", True)

from data_generation import *
from moments import *
from subspace_recovery import *
from projection import *
from loss_functions import *
from optim import *
from tracing_recovery import *

import plotly.graph_objects as go

seed = 0
key, subkey = random.split(random.PRNGKey(seed))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Generating the data

We want to generate a curve with $M$ segments in $\mathbb{R}^d$ for $d>M$. We generate our curve by prescribing $M$ segment lengths `seg_lens` ahead of time, starting at the origin, and iteratively sampling a direction on the unit sphere to walk in for a distance of `seg_lens[i]`. The segment lengths are chosen uniformly from $[1, 2]$. We center the curve to have mean zero.

In [12]:
d = 12 # ambient dimension
M = 8 # number of segments
sigma2 = 0.25 # noise level

seg_lens = random.uniform(subkey, shape = (M,), minval = 0.8, maxval = 1.2) 

C = gen_curve_sphere_sampling(seed, M, d, seg_lens)
C = C - compute_m1(C)

chunk_size = int(1e5)

The "real-world" use case for this problem is to estimate the underlying curve from a high noise point cloud around the data by estimating the moments of the underlying curve from the moments of the point cloud. The moments of the underlying curve can be arbitrarily well-approximated given a sufficiently large number of data points, but the moment approximation is memory-intensive. We can either assume we have access to the true moments of the underlying curves by setting `use_true_moments` to `True`, or estimate the moments by setting it equal to false.

In [13]:
use_true_moments = False
N = int(1e7)
N_mini = int(1e4)

if use_true_moments:
  m1 = compute_m1(C)
  m2 = compute_m2(C)
  m3 = compute_m3(C)

else:
  m1_raw, m2_raw, m3_raw, cloud_mini = gen_cloud_moments(subkey, C, sigma2, N, num_chunks = N // chunk_size, N_mini = N_mini)
  m1 = m1_raw
  m2 = m2_raw - sigma2 * jnp.eye(d)
  m3 = m3_raw - 3 * sigma2 * compute_sym_part(jnp.einsum("i, jk -> ijk", m1, jnp.eye(d)))

generating cloud in chunks and computing online moments:   0%|          | 0/100 [00:00<?, ?it/s]

                                                                                                          

# Initial guess with tensor power method

We apply the tensor power method to the third moment (whitening with the second moment) to approximate the subspaces in which the control points of $C$ lie. The tensor power method is only able to recover the subspaces up to permutation, so we need to sort them with our chaining algorithm (Algorithm 1 in the paper). A status of zero indicates that the sorting of subspaces was successful.

In [14]:
subspaces, status = recover_C_subspaces(m3, m2, M, seed = 0)
print(status)
if status != 0:
  print(f"subspace recovery failed with status {status}! results may be suboptimal. see docstring for recover_C_subspaces() for details.\n")

0


To get an initial approximation, we now find the lengths of the control points along the estimated subspaces that minimize the moment errors.

In [15]:
Lhat_init = jnp.ones(shape = (M+1,))
Chat_subspaces = estimate_C_from_subspaces(subspaces, Lhat_init, m1, m2, m3)
print(f"curve loss: {compute_curve_loss(C, Chat_subspaces):.5f}")

                                                                                

curve loss: 0.14358




# Finetuning with projection and direct moment matching.

Before fine-tuning our estimation, we now project the curve to a lower dimensional subspace spanned by top $M$ eigenvectors of the second moment matrix (either of the true curve or the data).

In [16]:
basis = get_basis(m2, M)
Chat_projected = project_to_subspace(Chat_subspaces.T, basis).T

if use_true_moments:
  C_projected = project_to_subspace(C.T, basis).T
  m1_projected = compute_m1(C_projected)
  m2_projected = compute_m2(C_projected)
  m3_projected = compute_m3(C_projected)

else:
  m1_projected, m2_projected, m3_projected = project_moments(m1, m2, m3, M)

Within this subspace, we optimize the curve points themselves (rather than their lengths) to minimize moment losses. If the ground truth curve has uniform lengths, we use the relaxed moments; otherwise, we use the true moments.

In [17]:
phat = compute_seg_lens(Chat_projected)
phat = phat / jnp.sum(phat)

Cp_dict = {"Chat" : Chat_projected, "phat" : phat}

Chat_projected, phat = finetune_C_with_moments(Chat_projected, phat, m3_projected, nit = 2500)
Chat = deproject_from_subspace(Chat_projected.T, basis).T

print(f"curve loss: {compute_curve_loss(C, Chat)}")
print(f"third moment loss: {m3_loss(Chat, m3)}")

                                                                                               

curve loss: 0.03693939711004181
third moment loss: 0.0036396502135324377




# Compute baselines

Match either the third moment only or all three moments from a random initialization in the $M$-dimensional subspace that $C$ lives in, taking the best of 10 random initializations.

In [18]:
best_Chat_m123, best_Chat_m3 = estimate_C_baseline(seed, m1_projected, m2_projected, m3_projected, M, num_trials = 10)
Chat_baseline_m123 = deproject_from_subspace(best_Chat_m123.T, basis).T
Chat_baseline_m3 = deproject_from_subspace(best_Chat_m3.T, basis).T

trialing multiple random initializations for baseline:   0%|          | 0/10 [00:00<?, ?it/s]

                                                                                                      

In [19]:
curve_loss_m123_baseline = compute_curve_loss(C, Chat_baseline_m123)
curve_loss_m3_baseline = compute_curve_loss(C, Chat_baseline_m3)
m3_loss_m123_baseline = m3_loss(Chat_baseline_m123, compute_m3(C))
m3_loss_m3_baseline = m3_loss(Chat_baseline_m3, compute_m3(C))
print(f"curve loss from triple moment baseline: {curve_loss_m123_baseline:.5f}")
print(f"curve loss from third moment baseline: {curve_loss_m3_baseline:.5f}")
print(f"third moment loss from triple moment baseline: {m3_loss_m123_baseline:.5f}")
print(f"third moment loss from third moment baseline: {m3_loss_m3_baseline:.5f}")

curve loss from triple moment baseline: 1.26564
curve loss from third moment baseline: 1.75696
third moment loss from triple moment baseline: 0.00021
third moment loss from third moment baseline: 0.00155


In [20]:
idx0, idx1, idx2 = 0, 1, 2 # choose three random dimensions to project the data down to

fig = go.Figure(layout=dict(width=900, height=600))

fig.add_trace(go.Scatter3d(x = C[:, idx0], y = C[:, idx1], z = C[:, idx2], name = "C true <br>", mode = "lines", line = dict(color = "blue", width = 5)))
fig.add_trace(go.Scatter3d(x = Chat_subspaces[:, idx0], y = Chat_subspaces[:, idx1], z = Chat_subspaces[:, idx2], name = "C predicted <br> after phase 1 <br>", mode = "lines", line = dict(color = "red", dash = "dashdot", width = 5)))
fig.add_trace(go.Scatter3d(x = Chat[:, idx0], y = Chat[:, idx1], z = Chat[:, idx2], name = "C predicted <br> after phase 2 <br>", mode = "lines", line = dict(color = "red", width = 5)))
fig.add_trace(go.Scatter3d(x = Chat_baseline_m3[:, idx0], y = Chat_baseline_m3[:, idx1], z = Chat_baseline_m3[:, idx2], name = "C baseline with third <br> moment matching only<br>", mode = "lines", line = dict(color = "orange", dash = "dot", width = 3)))
fig.add_trace(go.Scatter3d(x = Chat_baseline_m123[:, idx0], y = Chat_baseline_m123[:, idx1], z = Chat_baseline_m123[:, idx2], name = "C baseline with all <br> three moments<br>", mode = "lines", line = dict(color = "orange", width = 3)))
if not use_true_moments:
  fig.add_trace(go.Scatter3d(x = cloud_mini[:, idx0], y = cloud_mini[:, idx1], z = cloud_mini[:, idx2], hoverinfo="skip", name = "point cloud", mode = "markers", marker = dict(color = "blue", size = 1, opacity = 0.1)))
fig.show()