In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path

import scanpy as sc
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
from scipy.spatial import ConvexHull
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.neighbors import NearestNeighbors

from matplotlib import patches as mpatches

In [None]:
# helper function for plotting spatial obs field (pseudotime)
def plot_spatial_obs(
    adata,
    obs_key: str,
    x_key: str = "x_um_dbscan",
    y_key: str = "y_um_dbscan",
    cmap: str = "viridis",
):
    obs = adata.obs
    mask = obs[x_key].notna() & obs[y_key].notna() & obs[obs_key].notna()
    x = obs.loc[mask, x_key].values
    y = obs.loc[mask, y_key].values
    vals = obs.loc[mask, obs_key].astype(float).values

    fig = plt.figure(figsize=(6, 6))

    sc = plt.scatter(x, y, c=vals, cmap=cmap,
                     s=5, alpha=0.8, vmin=0, vmax=1)

    plt.xlabel(x_key); plt.ylabel(y_key); plt.gca().set_aspect("equal")
    plt.title(f"{obs_key} (spatial)")
    plt.colorbar(sc, shrink=0.75, label=obs_key)
    plt.tight_layout()
    plt.show(); plt.close(fig)

In [None]:
# plot styling
plt.rcParams.update({
    'axes.facecolor'  : 'black',
    'figure.facecolor': 'black',
    'axes.edgecolor'  : 'white',
    'xtick.color'     : 'white',
    'ytick.color'     : 'white',
    'text.color'      : 'white',
    'axes.labelcolor' : 'white',
})

In [None]:
h5ad_file = '/path/to/h5ad'
h5ad_path = os.path.join(base_path,h5ad_file)

adata_full = sc.read_h5ad(h5ad_path)

In [None]:
# keep only spatially mapped cells
mapped_mask = ~adata_full.obs['x_um_dbscan'].isna()
adata = adata_full[mapped_mask].copy()

sc.pl.umap(adata, color=['dpt_pseudotime', 'cell_type_fine'], cmap = 'plasma',save='4dpi_pseudotime.png')

In [None]:
plot_spatial_obs(adata, obs_key='dpt_pseudotime', cmap='plasma')

In [None]:
# basic outlier filtering on dpt_speudotime monotonicity and average neighbor distance
coords = adata.obs[['x_um_dbscan', 'y_um_dbscan']].values
z      = adata.obs['dpt_pseudotime'].values

# create k-nn graph on physical space
k = 5
nn = NearestNeighbors(n_neighbors=k + 1).fit(coords)
distances, idx = nn.kneighbors(coords)

In [None]:
# percentile cutoffs
neighbor_dpt_percentile = 95
neighbor_dist_percentile = 98

# pseudotime diff
neighbor_mean_z = z[idx[:, 1:]].mean(1)
pseudotime_diff = np.abs(z - neighbor_mean_z)
pseudotime_diff_cutoff = np.percentile(pseudotime_diff, neighbor_dpt_percentile)
z_inlier = pseudotime_diff <= pseudotime_diff_cutoff

# neighbor distances
mean_neighbor_dist = distances[:, 1:].mean(1)
neighbor_dist_cutoff = np.percentile(mean_neighbor_dist, neighbor_dist_percentile)
dist_inlier = mean_neighbor_dist <= neighbor_dist_cutoff

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

# panel 1: pseudotime diff
counts, bins, patches = axes[0].hist(
    pseudotime_diff, bins=100, edgecolor="black", log=True, color="tab:blue"
)
for b, p in zip(bins[:-1], patches):
    if b >= pseudotime_diff_cutoff:
        p.set_facecolor("tab:orange")
axes[0].axvline(
    pseudotime_diff_cutoff, color="white", linestyle="--", linewidth=2.5,
    label=f"{neighbor_dpt_percentile}th percentile"
)
axes[0].set_xlabel("Pseudotime difference")
axes[0].set_ylabel("Count (log scale)")
axes[0].set_title("Distribution of pseudotime differences")
axes[0].legend()

# panel 2: mean neighbor distances
counts, bins, patches = axes[1].hist(
    mean_neighbor_dist, bins=100, edgecolor="black", log=True, color="tab:blue"
)
for b, p in zip(bins[:-1], patches):
    if b >= neighbor_dist_cutoff:
        p.set_facecolor("tab:orange")
axes[1].axvline(
    neighbor_dist_cutoff, color="white", linestyle="--", linewidth=2.5,
    label=f"{neighbor_dist_percentile}th percentile"
)
axes[1].set_xlabel("Mean neighbor distance")
axes[1].set_title("Distribution of mean neighbor distances")
axes[1].legend()

# finalize and save
plt.tight_layout()
plt.savefig("./pseudotime_and_neighbor_histograms.png", dpi=300)
plt.show()

In [None]:
inlier_mask = z_inlier & dist_inlier

# optional sample subset (single sample keeps same mask logic)
sub_mask = inlier_mask

#sub_mask = inlier_mask
coords_filtered = coords[sub_mask]
z_filtered      = z[sub_mask]

In [None]:
plot_spatial_obs(adata[sub_mask], obs_key='dpt_pseudotime', cmap='plasma')

In [None]:
# gaussian process fit
kernel = RBF(length_scale=1.0) + WhiteKernel(noise_level=1e-3)
gpr    = GaussianProcessRegressor(kernel=kernel, normalize_y=True)
gpr.fit(coords_filtered, z_filtered)

In [None]:
gpr.kernel_

In [None]:
# make a 200 by 200 grid
grid_res = 200

x_grid = np.linspace(coords_filtered[:,0].min(), coords_filtered[:,0].max(), grid_res)
y_grid = np.linspace(coords_filtered[:,1].min(), coords_filtered[:,1].max(), grid_res)

Xg, Yg = np.meshgrid(x_grid, y_grid)
grid_pts = np.c_[Xg.ravel(), Yg.ravel()]

In [None]:
# get the value of the GPR on grid points
Z_mean = gpr.predict(grid_pts).reshape(grid_res, grid_res)

In [None]:
# visualize mean of the GRP
plt.figure(figsize=(9, 8))

# plot countour (mean of GP)
cf = plt.contourf(Xg, Yg, Z_mean, levels=50, cmap='plasma')

# plot cells with pseudotime values
plt.scatter(coords_filtered[:,0], coords_filtered[:,1], c=z_filtered,
            cmap='plasma', s=10, edgecolor='k', linewidth=0.2, alpha=0.6)

plt.colorbar(cf, label='GP mean pseudotime')
plt.title("GP pseudotime surface")
plt.savefig('./gp_regression.png',dpi=300)
plt.gca().set_aspect('equal'); plt.tight_layout(); plt.show()

In [None]:
# compute grid poitns and gradients
grid_pts = np.c_[Xg.ravel(), Yg.ravel()]
Zg = gpr.predict(grid_pts).reshape(Xg.shape)
dy, dx = Yg[1,0] - Yg[0,0], Xg[0,1] - Xg[0,0]
dZ_dy, dZ_dx = np.gradient(Zg, dy, dx)
DX_grid, DY_grid = dZ_dx, dZ_dy

# convex hull mask
poly = coords_filtered[ConvexHull(coords_filtered).vertices]
inside = Path(poly, closed=True).contains_points(grid_pts).reshape(Xg.shape)

Zg_mask  = np.where(inside, Zg,      np.nan)
DX_mask  = np.where(inside, DX_grid, np.nan)
DY_mask  = np.where(inside, DY_grid, np.nan)

# blur helper
def blur_nan(A, sigma):
    m = np.isfinite(A).astype(float)
    Af = np.where(np.isfinite(A), A, 0.0)
    num = gaussian_filter(Af, sigma=sigma)
    den = gaussian_filter(m,  sigma=sigma)
    out = num / den
    out[den == 0] = np.nan
    return out, den

# pick a center for drawing blur ellipses (centroid of coords)
xc, yc = coords_filtered.mean(axis=0)

In [None]:
# plot over choices of blur width
sigmas = [2.5, 5, 10]
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, sigma in zip(axes, sigmas):
    # blur fields
    Z_blur,  Z_support  = blur_nan(Zg_mask,  sigma)
    DX_blur, DX_support = blur_nan(DX_mask,  sigma)
    DY_blur, DY_support = blur_nan(DY_mask,  sigma)

    # re-apply hull mask + support threshold
    support_thresh = 0.5
    support_ok = (Z_support >= support_thresh) & (DX_support >= support_thresh) & (DY_support >= support_thresh)
    final_mask = inside & support_ok
    Z_blur[~final_mask]  = np.nan
    DX_blur[~final_mask] = np.nan
    DY_blur[~final_mask] = np.nan

    # plot surface + streamlines
    im = ax.pcolormesh(Xg, Yg, np.ma.masked_invalid(Z_blur),
                       cmap="plasma", shading="auto")
    ax.streamplot(Xg, Yg, DX_blur, DY_blur,
                  color="k", density=2, linewidth=1.2, arrowsize=1.6)

    # draw a red ellipse for 1sigma extent
    sigma_x, sigma_y = sigma * dx, sigma * dy
    e = mpatches.Ellipse((xc, yc), width=2*sigma_x, height=2*sigma_y,
                         fill=False, color='red', linestyle="--", linewidth=1.5)
    ax.add_patch(e)

    # annotate sigma and full width at half max
    ax.text(0.02, 0.98,
            f"σ = {sigma:g} px\nFWHM ≈ {2.355*sigma:.2f} px",
            transform=ax.transAxes, va="top", ha="left",
            bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", pad=4),
            fontsize=9)

    ax.set_aspect("equal")
    ax.set_title(f"σ={sigma} blur")

# colorbar
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
fig.colorbar(im, cax=cbar_ax, label="GP pseudotime")

plt.subplots_adjust(right=0.9, wspace=0.3)  # leave space for colorbar
plt.show()