Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 119 additions & 2 deletions src/lineagetree/_core/_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps
from typing import TYPE_CHECKING
import numpy as np
from scipy.ndimage import gaussian_filter1d

if TYPE_CHECKING:
from ..lineage_tree import LineageTree
Expand Down Expand Up @@ -67,7 +68,8 @@ def add_chain(
raise Warning("The node already has a predecessor.")
if lT._time[node] - length < lT.t_b:
raise Warning(
"A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
"A node cannot created outside the lower bound of the dataset."
"(It is possible to change it by lT.t_b = int(...))"
)
for _ in range(int(length)):
old_node = node
Expand Down Expand Up @@ -347,7 +349,6 @@ def stabilise_positions(lT: LineageTree) -> dict[int, np.ndarray]:
nodes_i.append(c)
nodes_j.append(next_n)

# nodes_i, nodes_j = zip(*[(c, tuple(ci for ci in lT.successor[c] if ci in ok_cells)) for c in lT.time_nodes[t1] if c in lT.successor and c in ok_cells])
pos_i = np.array([lT.pos[c] for c in nodes_i])
pos_j = np.array(
[np.mean([lT.pos[ci] for ci in c], axis=0) for c in nodes_j]
Expand All @@ -368,3 +369,119 @@ def stabilise_positions(lT: LineageTree) -> dict[int, np.ndarray]:
lT.pos = new_pos

return new_pos


def anchored_gaussian_smooth(data, sigma=1.5, anchor_strength=3.0):
"""
Apply Gaussian smoothing to a 1D sequence while anchoring the endpoints
and suppressing drift near the boundaries.

This function performs standard Gaussian smoothing and then blends the
smoothed result with the original data using a position-dependent weight.
The weights enforce exact anchoring at the first and last elements and
progressively relax toward the center of the array.

Parameters
----------
data : array_like
Input 1D sequence of numeric values.
sigma : float, optional
Standard deviation of the Gaussian kernel used for smoothing.
Higher values produce stronger smoothing. Default is 1.5.
anchor_strength : float, optional
Controls how strongly the endpoints influence nearby values.
Smaller values result in tighter anchoring (less smoothing near edges),
while larger values allow more smoothing across the entire array.
Default is 3.0.

Returns
-------
numpy.ndarray
Smoothed array of the same shape as the input, with the first and last
elements exactly equal to the original values.

Notes
-----
- The method applies a Gaussian filter followed by a spatially varying
convex combination:
result[i] = alpha[i] * data[i] + (1 - alpha[i]) * smoothed[i]
where alpha[i] decays exponentially with distance from the nearest
endpoint.
- This introduces soft boundary conditions (anchoring), breaking the
shift-invariance of standard convolution-based smoothing.
- The endpoints are strictly preserved (Dirichlet boundary condition),
and nearby points are partially constrained depending on their distance
to the boundaries.

Examples
--------
>>> anchored_gaussian_smooth([10, 12, 15, 20, 18, 16, 14], sigma=1.5)
array([...])
"""
data = np.asarray(data, dtype=float)

# Standard Gaussian smoothing
smoothed = gaussian_filter1d(data, sigma=sigma, mode="nearest")

n = data.size
i = np.arange(n)

# Distance to nearest endpoint
dist = np.minimum(i, n - 1 - i)

# Exponential decay: strong anchoring near edges
alpha = np.exp(-dist / anchor_strength)

# Ensure exact anchoring at endpoints
alpha[0] = 1.0
alpha[-1] = 1.0

# Blend original and smoothed
return alpha * data + (1 - alpha) * smoothed


@modifier
def smooth_trajectories(lT: LineageTree, sigma=1.0, ancor_strength=3):
"""
Smooth 3D trajectories of all chains in a lineage tree using anchored
Gaussian filtering.

For each chain in the lineage tree, the x-, y-, and z-coordinates are
independently smoothed using a Gaussian filter with soft endpoint
constraints. The first and last positions of each chain are preserved
exactly, while nearby points are partially constrained to reduce drift.

Parameters
----------
lT : LineageTree
sigma : float, default=1.0
Standard deviation of the Gaussian kernel used for smoothing each
coordinate independently. Higher values produce smoother trajectories.
Default is 1.0.
ancor_strength : float, default=3
Controls the strength of endpoint anchoring. Smaller values enforce
stronger constraints near the start and end of each chain (less drift),
while larger values allow more global smoothing. Default is 3.

Returns
-------
dict
Dictionary mapping each node in the lineage tree to its smoothed
3D position.
"""
new_pos = {}
for chain in lT.all_chains:
X, Y, Z = np.array([lT.pos[c] for c in chain]).T
X_new = anchored_gaussian_smooth(
X, sigma=sigma, anchor_strength=ancor_strength
)
Y_new = anchored_gaussian_smooth(
Y, sigma=sigma, anchor_strength=ancor_strength
)
Z_new = anchored_gaussian_smooth(
Z, sigma=sigma, anchor_strength=ancor_strength
)
new_pos.update(zip(chain, np.transpose([X_new, Y_new, Z_new])))
lT.old_pos = lT.pos
lT.pos = new_pos
return lT.pos
2 changes: 2 additions & 0 deletions src/lineagetree/_mixins/modifier_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
modifier,
remove_nodes,
stabilise_positions,
smooth_trajectories,
)

from ._methodize import AutoMethodizeMeta
Expand All @@ -21,3 +22,4 @@ class ModifierMixin(metaclass=AutoMethodizeMeta):
modifier = modifier
remove_nodes = remove_nodes
stabilise_positions = stabilise_positions
smooth_trajectories = smooth_trajectories
15 changes: 13 additions & 2 deletions tests/test_lineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,19 @@ def test_plot_chain_hist():
assert sum(p.get_height() for p in ax1.patches) == (
len(lt.all_chains) - len(lt.leaves.union(lt.roots))
)


def test_stabilise_positions():
lT1.stabilise_positions()
new_pos = lT1.stabilise_positions()
assert np.isclose(
new_pos[148361], np.array([1019.66762163, 400.25591182, 287.54520521])
).all()
lT1.pos = lT1.old_pos


def test_smoothing():
new_pos = lt.smooth_trajectories()
assert np.isclose(
lT1.pos[148361], np.array([1019.66762163, 400.25591182, 287.54520521])
new_pos[1552], np.array([462.15385069, 907.17562352, 419.54303692])
).all()
lt.pos = lt.old_pos
Loading