In [None]:
#!/usr/bin/env python
"""
Visualise the synthetic TreeTestData manifold.

◻ scatter coloured by time-point
◻ quiver of the velocity vectors (sub-sampled)
◻ a few reference paths returned by .get_paths()
"""
import numpy as np
import matplotlib.pyplot as plt

# ---------------------------------------------------------------------
# 1.  Load the dataset -------------------------------------------------
from dataset import TreeTestData            # adjust the import path if needed
data   = TreeTestData()

pts    = data.get_data()                    # (N,2)
times  = data.get_times()                   # (N,)  values 0,1
unique = np.unique(times)

# ---------------------------------------------------------------------
# 2.  Figure layout ----------------------------------------------------
fig, axes = plt.subplots(1, 3, figsize=(12, 4), constrained_layout=True)

# ------ (a) scatter per time-point -----------------------------------
cmap   = {0: "#183be0",        # blue   ─ previous time
          1: "#ea7c28"}        # orange ─ next   time
for tp in unique:
    msk = times == tp
    axes[0].scatter(pts[msk,0], pts[msk,1],
                    c=cmap[int(tp)], s=4, alpha=0.25, label=f"t={tp}")
axes[0].set_title("Point cloud by time-point")
axes[0].legend(markerscale=3, frameon=False)
axes[0].set_aspect("equal"); axes[0].axis("off")

# ------ (b) velocity field -------------------------------------------
if data.has_velocity():
    vel = data.get_velocity()
    idx = np.random.choice(data.get_ncells(), 700, replace=False)   # thin arrows
    axes[1].quiver(pts[idx,0], pts[idx,1], vel[idx,0], vel[idx,1],
                   angles='xy', scale_units='xy', scale=1.5,
                   width=0.002, alpha=0.6, color="#444444")
    axes[1].set_title("Velocity vectors (sub-sampled)")
    axes[1].set_aspect("equal"); axes[1].axis("off")

# ------ (c) reference paths ------------------------------------------
paths = data.get_paths(n=1200, n_steps=3)   # shape (N,3,2)
for p in paths[:400]:
    axes[2].plot(p[:,0], p[:,1], alpha=0.25, linewidth=0.7,
                 color="#555555")
axes[2].set_title("Ground-truth paths")
axes[2].set_aspect("equal"); axes[2].axis("off")

plt.show()


: 

: 