In [3]:
%load_ext autoreload
%autoreload 2

import jax
jax.config.update("jax_enable_x64", True)

from tracing_recovery import *
from data_generation import *

seed = 1
key, subkey = random.split(random.PRNGKey(seed))

import plotly.graph_objects as go

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
d = 100 # ambient dimension
n = 1000 # num discretization samples
N = 100000 # size of point cloud
M = 10
seg_len = 1 / M
sigma2 = (1 / M / jnp.sqrt(d) / 2)**2
print(f"sigma2: {sigma2}")

c0 = random.normal(subkey, shape = (d,))
key, subkey = random.split(key)

cs = gen_curve_sphere_sampling(seed, M, c0, seg_len)
cloud = gen_cloud(subkey, cs, sigma2, N, n)

sigma2: 2.5e-05


In [5]:
cs_pred_fwd = trace_cloud(cloud, cs[0], seg_len, M)
cs_pred_rev = trace_cloud(cloud, cs[-1], seg_len, M)[::-1]
cs_pred_avg = (cs_pred_fwd + cs_pred_rev) / 2

In [6]:
# project down to 3 random dims

idxs = jnp.array([10, 11, 12])
cloud_proj = cloud[:, idxs]

cs_pred_fwd_proj = cs_pred_fwd[:, idxs]
cs_pred_rev_proj = cs_pred_rev[:, idxs]
cs_pred_avg_proj = cs_pred_avg[:, idxs]
cs_proj = cs[:, idxs]

In [7]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x = cloud_proj[:, 0], y = cloud_proj[:, 1], z = cloud_proj[:, 2], hoverinfo='skip', name = "cloud", mode = "markers", marker = dict(size=1, opacity = 0.05)))
fig.add_trace(go.Scatter3d(x = cs_proj[:, 0],y = cs_proj[:, 1], z = cs_proj[:, 2], name = "ground truth cs", mode = "lines", line=dict(color = "blue")))
fig.add_trace(go.Scatter3d(x = cs_pred_fwd_proj[:, 0],y = cs_pred_fwd_proj[:, 1], z = cs_pred_fwd_proj[:, 2], name = "cs_pred_fwd", mode = "lines", line=dict(dash = "dash", color = "green")))
fig.add_trace(go.Scatter3d(x = cs_pred_rev_proj[:, 0],y = cs_pred_rev_proj[:, 1], z = cs_pred_rev_proj[:, 2], name = "cs_pred_rev", mode = "lines", line=dict(dash = "dash", color = "purple")))
fig.add_trace(go.Scatter3d(x = cs_pred_avg_proj[:, 0],y = cs_pred_avg_proj[:, 1], z = cs_pred_avg_proj[:, 2], name = "cs_pred_avg", mode = "lines", line=dict(width = 5, dash = "dash", color = "red")))
fig.update_scenes(aspectmode='data')
fig.show()