Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d10ed51
Add curved curved_quiver
cvanelteren Oct 6, 2025
5b0a0eb
add unittests
cvanelteren Oct 6, 2025
e4b1023
update docstrings
cvanelteren Oct 6, 2025
ce7bb37
black formatting
cvanelteren Oct 6, 2025
14b7ce6
rm dup import
cvanelteren Oct 6, 2025
7b1dc19
Update ultraplot/axes/plot.py
cvanelteren Oct 6, 2025
578dd2a
Update ultraplot/axes/plot.py
cvanelteren Oct 6, 2025
a3290c1
Update ultraplot/axes/plot.py
cvanelteren Oct 6, 2025
28b20a1
Update ultraplot/tests/test_plot.py
cvanelteren Oct 6, 2025
e27e9e3
mv import up
cvanelteren Oct 6, 2025
c057d92
Merge branch 'main' into feat-curved-quiever
cvanelteren Oct 6, 2025
e1bbca8
Merge branch 'main' into feat-curved-quiever
cvanelteren Oct 7, 2025
8ab593b
refactor and move down
cvanelteren Oct 7, 2025
1ee8348
black formatting
cvanelteren Oct 7, 2025
6af7555
update tests with new api
cvanelteren Oct 7, 2025
a01893c
add one test as image comp
cvanelteren Oct 7, 2025
1088dd1
Update ultraplot/axes/plot.py
cvanelteren Oct 7, 2025
31a1dbc
Update ultraplot/axes/plot.py
cvanelteren Oct 7, 2025
c3dc242
Update ultraplot/axes/plot.py
cvanelteren Oct 7, 2025
525f56f
Update ultraplot/axes/plot.py
cvanelteren Oct 7, 2025
fd8eb87
Update ultraplot/tests/test_plot.py
cvanelteren Oct 7, 2025
56534d9
black formatting
cvanelteren Oct 7, 2025
4cc11d7
inline the termination to make it slightly more compact
cvanelteren Oct 8, 2025
1c84d1e
Merge branch 'main' into feat-curved-quiever
cvanelteren Oct 8, 2025
8df15b8
mv curved quiver plot to 'plot_types' and update tests'
cvanelteren Oct 8, 2025
9b89228
mv parameters to rcsetup
cvanelteren Oct 8, 2025
a746ef8
add type hinting
cvanelteren Oct 8, 2025
4e3927c
rm dup docstring
cvanelteren Oct 8, 2025
fd0d1e8
rm unused imports
cvanelteren Oct 8, 2025
dba943b
Update ultraplot/axes/plot_types/curved_quiver.py
cvanelteren Oct 13, 2025
f6c9817
Apply suggestion from @beckermr
cvanelteren Oct 13, 2025
bb2f3e0
Apply suggestion from @beckermr
cvanelteren Oct 13, 2025
c15ec2c
Apply suggestion from @beckermr
cvanelteren Oct 13, 2025
60cf7bf
rename private classes
cvanelteren Oct 13, 2025
f841337
more renaming
cvanelteren Oct 13, 2025
36c00d8
more renaming
cvanelteren Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 269 additions & 15 deletions ultraplot/axes/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from typing import Any, Union, Iterable, Optional

from typing import Any, Union
from collections.abc import Callable
from collections.abc import Iterable

Expand All @@ -33,6 +32,7 @@
import matplotlib as mpl
from packaging import version
import numpy as np
from typing import Optional, Union, Any
import numpy.ma as ma

from .. import colors as pcolors
Expand Down Expand Up @@ -170,7 +170,45 @@
docstring._snippet_manager["plot.args_1d_shared"] = _args_1d_shared_docstring
docstring._snippet_manager["plot.args_2d_shared"] = _args_2d_shared_docstring

_curved_quiver_docstring = """
Draws curved vector field arrows (streamlines with arrows) for 2D vector fields.

Parameters
----------
x, y : 1D or 2D arrays
Grid coordinates.
u, v : 2D arrays
Vector components.
color : color or 2D array, optional
Streamline color.
density : float or (float, float), optional
Controls the closeness of streamlines.
grains : int or (int, int), optional
Number of seed points in x and y.
linewidth : float or 2D array, optional
Width of streamlines.
cmap, norm : optional
Colormap and normalization for array colors.
arrowsize : float, optional
Arrow size scaling.
arrowstyle : str, optional
Arrow style specification.
transform : optional
Matplotlib transform.
zorder : float, optional
Z-order for lines/arrows.
start_points : (N, 2) array, optional
Starting points for streamlines.

Returns
-------
CurvedQuiverSet
Container with attributes:
- lines: LineCollection of streamlines
- arrows: PatchCollection of arrows
"""

docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring
# Auto colorbar and legend docstring
_guide_docstring = """
colorbar : bool, int, or str, optional
Expand Down Expand Up @@ -1499,22 +1537,237 @@ class PlotAxes(base.Axes):
Implements all plotting overrides.
"""

def __init__(self, *args, **kwargs):
@docstring._snippet_manager
def curved_quiver(
self,
x: np.ndarray,
y: np.ndarray,
u: np.ndarray,
v: np.ndarray,
linewidth: Optional[float] = None,
color: Optional[Union[str, Any]] = None,
cmap: Optional[Any] = None,
norm: Optional[Any] = None,
arrowsize: Optional[float] = None,
arrowstyle: Optional[str] = None,
transform: Optional[Any] = None,
zorder: Optional[int] = None,
start_points: Optional[np.ndarray] = None,
scale: Optional[float] = None,
grains: Optional[int] = None,
density: Optional[int] = None,
arrow_at_end: Optional[bool] = None,
):
"""
Parameters
----------
*args, **kwargs
Passed to `ultraplot.axes.Axes`.
%(plot.curved_quiver)s

Notes
-----
The implementation of this function is based on the `dfm_tools` repository.
Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py
"""
from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet

# Parse inputs
arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"])
arrowstyle = _not_none(arrowstyle, rc["curved_quiver.arrowstyle"])
zorder = _not_none(zorder, mlines.Line2D.zorder)
transform = _not_none(transform, self.transData)
color = _not_none(color, self._get_lines.get_next_color())
linewidth = _not_none(linewidth, rc["lines.linewidth"])
scale = _not_none(scale, rc["curved_quiver.scale"])
grains = _not_none(grains, rc["curved_quiver.grains"])
density = _not_none(density, rc["curved_quiver.density"])
arrows_at_end = _not_none(arrow_at_end, rc["curved_quiver.arrows_at_end"])

solver = CurvedQuiverSolver(x, y, density)
if zorder is None:
zorder = mlines.Line2D.zorder

line_kw = {}
arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize)

use_multicolor_lines = isinstance(color, np.ndarray)
if use_multicolor_lines:
if color.shape != solver.grid.shape:
raise ValueError(
"If 'color' is given, must have the shape of 'Grid(x,y)'"
)
line_colors = []
color = np.ma.masked_invalid(color)
else:
line_kw["color"] = color
arrow_kw["color"] = color

See also
--------
matplotlib.axes.Axes
ultraplot.axes.Axes
ultraplot.axes.CartesianAxes
ultraplot.axes.PolarAxes
ultraplot.axes.GeoAxes
"""
super().__init__(*args, **kwargs)
if isinstance(linewidth, np.ndarray):
if linewidth.shape != solver.grid.shape:
raise ValueError(
"If 'linewidth' is given, must have the shape of 'Grid(x,y)'"
)
line_kw["linewidth"] = []
else:
line_kw["linewidth"] = linewidth
arrow_kw["linewidth"] = linewidth

line_kw["zorder"] = zorder
arrow_kw["zorder"] = zorder

## Sanity checks.
if u.shape != solver.grid.shape or v.shape != solver.grid.shape:
raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'")

u = np.ma.masked_invalid(u)
v = np.ma.masked_invalid(v)
magnitude = np.sqrt(u**2 + v**2)
magnitude /= np.max(magnitude)

resolution = scale / grains
minlength = 0.9 * resolution

integrate = solver.get_integrator(u, v, minlength, resolution, magnitude)
trajectories = []
edges = []

if start_points is None:
start_points = solver.gen_starting_points(x, y, grains)

sp2 = np.asanyarray(start_points, dtype=float).copy()

# Check if start_points are outside the data boundaries
for xs, ys in sp2:
if not (
solver.grid.x_origin <= xs <= solver.grid.x_origin + solver.grid.width
and solver.grid.y_origin
<= ys
<= solver.grid.y_origin + solver.grid.height
):
raise ValueError(
"Starting point ({}, {}) outside of data "
"boundaries".format(xs, ys)
)

if use_multicolor_lines:
if norm is None:
norm = mcolors.Normalize(color.min(), color.max())
if cmap is None:
cmap = constructor.Colormap(rc["image.cmap"])
else:
cmap = mcm.get_cmap(cmap)

# Convert start_points from data to array coords
# Shift the seed points from the bottom left of the data so that
# data2grid works properly.
sp2[:, 0] -= solver.grid.x_origin
sp2[:, 1] -= solver.grid.y_origin

for xs, ys in sp2:
xg, yg = solver.domain_map.data2grid(xs, ys)
t = integrate(xg, yg)
if t is not None:
trajectories.append(t[0])
edges.append(t[1])
streamlines = []
arrows = []
for t, edge in zip(trajectories, edges):
tgx = np.array(t[0])
tgy = np.array(t[1])

# Rescale from grid-coordinates to data-coordinates.
tx, ty = solver.domain_map.grid2data(*np.array(t))
tx += solver.grid.x_origin
ty += solver.grid.y_origin

points = np.transpose([tx, ty]).reshape(-1, 1, 2)
streamlines.extend(np.hstack([points[:-1], points[1:]]))

if len(tx) < 2:
continue

# Add arrows
s = np.cumsum(np.sqrt(np.diff(tx) ** 2 + np.diff(ty) ** 2))
if arrow_at_end:
if len(tx) < 2:
continue

arrow_tail = (tx[-1], ty[-1])

# Extrapolate to find arrow head
xg, yg = solver.domain_map.data2grid(
tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin
)

ui = solver.interpgrid(u, xg, yg)
vi = solver.interpgrid(v, xg, yg)

norm_v = np.sqrt(ui**2 + vi**2)
if norm_v > 0:
ui /= norm_v
vi /= norm_v

if len(s) > 0:
# use average segment length
arrow_length = arrowsize * (s[-1] / len(s))
else:
# fallback for very short streamlines
arrow_length = (
arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy])
)

arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length)
n = len(s) - 1 if len(s) > 0 else 0
else:
n = np.searchsorted(s, s[-1] / 2.0)
arrow_tail = (tx[n], ty[n])
arrow_head = (np.mean(tx[n : n + 2]), np.mean(ty[n : n + 2]))

if isinstance(linewidth, np.ndarray):
line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1]
line_kw["linewidth"].extend(line_widths)
arrow_kw["linewidth"] = line_widths[n]

if use_multicolor_lines:
color_values = solver.interpgrid(color, tgx, tgy)[:-1]
line_colors.append(color_values)
arrow_kw["color"] = cmap(norm(color_values[n]))

if not edge:
p = mpatches.FancyArrowPatch(
arrow_tail, arrow_head, transform=transform, **arrow_kw
)
else:
continue

ds = np.sqrt(
(arrow_tail[0] - arrow_head[0]) ** 2
+ (arrow_tail[1] - arrow_head[1]) ** 2
)
if ds < 1e-15:
continue # remove vanishingly short arrows that cause Patch to fail

self.add_patch(p)
arrows.append(p)

lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw)
lc.sticky_edges.x[:] = [
solver.grid.x_origin,
solver.grid.x_origin + solver.grid.width,
]
lc.sticky_edges.y[:] = [
solver.grid.y_origin,
solver.grid.y_origin + solver.grid.height,
]

if use_multicolor_lines:
lc.set_array(np.ma.hstack(line_colors))
lc.set_cmap(cmap)
lc.set_norm(norm)

self.add_collection(lc)
self.autoscale_view()

ac = mcollections.PatchCollection(arrows)
stream_container = CurvedQuiverSet(lc, ac)
return stream_container

def _call_native(self, name, *args, **kwargs):
"""
Expand Down Expand Up @@ -5359,6 +5612,7 @@ def tripcolor(self, *args, **kwargs):

# Update kwargs and handle cmap
kw.update(_pop_props(kw, "collection"))

center_levels = kw.pop("center_levels", None)
kw = self._parse_cmap(
triangulation.x, triangulation.y, z, center_levels=center_levels, **kw
Expand Down
Empty file.
Loading