In [None]:
import healpy as hp, numpy as np, matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D       # noqa
from matplotlib.colors import TwoSlopeNorm, LightSource

import healpy as hp, numpy as np, matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D          # noqa
from matplotlib import colors

# ─── parameters you can tune ──────────────────────────────────────────
nside      = 32
lmax       = 22
sigma_rms  = 0.01             # 1 % radius RMS (single-step)
snapshots  = [0, 5, 25, 50, 100]
low_band   = (0, 4)           # first 5 degrees
high_band  = (lmax-5, lmax-1) # last 5 degrees
seed       = 0
cm         = plt.get_cmap("coolwarm")
ghost_alpha = 0.15            # opacity of reference sphere
pt_size     = 0.4             # trisurf shading dot-size

np.random.seed(seed)
npix   = hp.nside2npix(nside)
verts  = np.array(hp.pix2vec(nside, np.arange(npix))).T  # (N,3)

# ─── helper functions ─────────────────────────────────────────────────
def rand_alm_map(ℓ_min, ℓ_max):
    alm = np.zeros(hp.Alm.getsize(lmax), np.complex128)
    for ℓ in range(ℓ_min, ℓ_max+1):
        for m in range(ℓ+1):
            idx       = hp.Alm.getidx(lmax, ℓ, m)
            real_part = np.random.normal(scale=sigma_rms)
            imag_part = 0.0 if m == 0 else np.random.normal(scale=sigma_rms)
            alm[idx]  = real_part + 1j*imag_part
    return hp.alm2map(alm, nside, verbose=False)

# local vertex noise
vertex_noise = sigma_rms * np.random.randn(npix)

# low-ℓ and high-ℓ noise maps, scaled to equal RMS
low_map  = rand_alm_map(*low_band)
high_map = rand_alm_map(*high_band)
for m in [low_map, high_map]:
    m *= sigma_rms / np.sqrt(np.mean(m**2))

# trajectories (cumulative)
traj_local = [np.ones(npix)]
traj_high  = [np.ones(npix)]
traj_low   = [np.ones(npix)]
for _ in range(max(snapshots)):
    traj_local.append(traj_local[-1] + vertex_noise)
    traj_high .append(traj_high [-1] + high_map)
    traj_low  .append(traj_low  [-1] + low_map)

scenarios = [("Local noise", traj_local),
             ("High-ℓ noise", traj_high),
             ("Low-ℓ noise",  traj_low)]

# ghost sphere vertices (coarse icosahedral) for reference outline
th_g, ph_g = np.meshgrid(np.linspace(0, np.pi, 30),
                         np.linspace(0, 2*np.pi, 40))
Xg = np.sin(th_g) * np.cos(ph_g)
Yg = np.sin(th_g) * np.sin(ph_g)
Zg = np.cos(th_g)

# ─── plotting ────────────────────────────────────────────────────────
fig = plt.figure(figsize=(len(snapshots)*3.2, len(scenarios)*3.2),
                 constrained_layout=True)

norm = colors.TwoSlopeNorm(vcenter=0, vmin=-sigma_rms*max(snapshots),
                           vmax= sigma_rms*max(snapshots))

for c, step in enumerate(snapshots):
    for r, (label, traj) in enumerate(scenarios):
        ax  = fig.add_subplot(len(scenarios), len(snapshots),
                              r*len(snapshots)+c+1, projection="3d")

        r_field = traj[step]
        Δr      = r_field - 1.0
        pts     = verts * r_field[:, None]

        ax.scatter(*pts.T,
                   c=cm(norm(Δr)), s=0.8, depthshade=False)

        # ghost sphere
        ax.plot_wireframe(Xg, Yg, Zg, color="k", lw=0.2, alpha=ghost_alpha)

        ax.view_init(elev=20, azim=35)
        if c == 0:
            ax.set_ylabel(label, labelpad=10)
        ax.set_title(f"step {step}", fontsize=9, pad=2)
        ax.set_axis_off()
        ax.set_box_aspect([1, 1, 1])


# ---------- regular lat–lon mesh for smooth surface -------------------------
res   = 200
theta = np.linspace(0, np.pi, res)
phi   = np.linspace(0, 2*np.pi, 2*res)
TH, PH = np.meshgrid(theta, phi, indexing="ij")
X0, Y0, Z0 = (np.sin(TH)*np.cos(PH),
              np.sin(TH)*np.sin(PH),
              np.cos(TH))

# ---------- helpers ---------------------------------------------------------
ls = LightSource(azdeg=120, altdeg=25)

def add_axes(ax, length=1.25, lw=0.8):
    z = np.linspace(-length, length, 30)
    for (dx,dy,dz) in [(1,0,0),(0,1,0),(0,0,1)]:
        ax.plot([-dx*length, dx*length],
                [-dy*length, dy*length],
                [-dz*length, dz*length],
                color='k', lw=lw, alpha=0.25, zorder=-1)

scenarios = [("Local noise", traj_local),
             ("High-ℓ noise", traj_high),
             ("Low-ℓ noise",  traj_low)]

n_row, n_col = len(scenarios), len(snapshots)

# ---------- bigger figure ---------------------------------------------------
fig_w = n_col * 4.5       # inches
fig_h = n_row * 4.5
fig   = plt.figure(figsize=(fig_w, fig_h), constrained_layout=True)

norm  = TwoSlopeNorm(vcenter=0,
                     vmin=-sigma_rms*max(snapshots),
                     vmax= sigma_rms*max(snapshots))

for c, step in enumerate(snapshots):
    for r, (label, traj) in enumerate(scenarios):
        Δr_pix = traj[step] - 1.0
        pix_val = hp.get_interp_val(Δr_pix, TH.ravel(), PH.ravel(),
                                    lonlat=False).reshape(TH.shape)
        R = 1 + pix_val
        X, Y, Z = R*X0, R*Y0, R*Z0

        rgb = ls.shade_rgb(plt.cm.coolwarm(norm(pix_val)), Z0)

        ax = fig.add_subplot(n_row, n_col, r*n_col + c + 1, projection='3d')
        ax.plot_surface(X, Y, Z,
                        facecolors=rgb, rstride=2, cstride=2,
                        linewidth=0, antialiased=False)
        add_axes(ax)
        ax.set_axis_off()
        ax.set_box_aspect([1,1,1])
        ax.view_init(elev=20, azim=35)

        if c == 0:                                   # only first column
            ax.set_ylabel(label, fontsize=12, labelpad=18)

# ---------- one set of step labels UNDER the bottom row ---------------------
y_lab = -0.01                                          # figure-coords
for c, step in enumerate(snapshots):
    x = (c + 0.5)/n_col
    fig.text(x, y_lab, f"Step {step}", ha='center', va='top',
             fontsize=24, fontweight='bold')
plt.figure(fig.number)                       # make sure fig is current
plt.savefig("sh_example_rendering.png", dpi=300, bbox_inches="tight")

plt.show()
