diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 9cbc8905..49b0687c 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -12,7 +12,6 @@ from typing import Any, Union, Iterable, Optional -from typing import Any, Union from collections.abc import Callable from collections.abc import Iterable @@ -33,6 +32,7 @@ import matplotlib as mpl from packaging import version import numpy as np +from typing import Optional, Union, Any import numpy.ma as ma from .. import colors as pcolors @@ -170,7 +170,45 @@ docstring._snippet_manager["plot.args_1d_shared"] = _args_1d_shared_docstring docstring._snippet_manager["plot.args_2d_shared"] = _args_2d_shared_docstring +_curved_quiver_docstring = """ +Draws curved vector field arrows (streamlines with arrows) for 2D vector fields. +Parameters +---------- +x, y : 1D or 2D arrays + Grid coordinates. +u, v : 2D arrays + Vector components. +color : color or 2D array, optional + Streamline color. +density : float or (float, float), optional + Controls the closeness of streamlines. +grains : int or (int, int), optional + Number of seed points in x and y. +linewidth : float or 2D array, optional + Width of streamlines. +cmap, norm : optional + Colormap and normalization for array colors. +arrowsize : float, optional + Arrow size scaling. +arrowstyle : str, optional + Arrow style specification. +transform : optional + Matplotlib transform. +zorder : float, optional + Z-order for lines/arrows. +start_points : (N, 2) array, optional + Starting points for streamlines. + +Returns +------- +CurvedQuiverSet + Container with attributes: + - lines: LineCollection of streamlines + - arrows: PatchCollection of arrows +""" + +docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring # Auto colorbar and legend docstring _guide_docstring = """ colorbar : bool, int, or str, optional @@ -1499,22 +1537,237 @@ class PlotAxes(base.Axes): Implements all plotting overrides. """ - def __init__(self, *args, **kwargs): + @docstring._snippet_manager + def curved_quiver( + self, + x: np.ndarray, + y: np.ndarray, + u: np.ndarray, + v: np.ndarray, + linewidth: Optional[float] = None, + color: Optional[Union[str, Any]] = None, + cmap: Optional[Any] = None, + norm: Optional[Any] = None, + arrowsize: Optional[float] = None, + arrowstyle: Optional[str] = None, + transform: Optional[Any] = None, + zorder: Optional[int] = None, + start_points: Optional[np.ndarray] = None, + scale: Optional[float] = None, + grains: Optional[int] = None, + density: Optional[int] = None, + arrow_at_end: Optional[bool] = None, + ): """ - Parameters - ---------- - *args, **kwargs - Passed to `ultraplot.axes.Axes`. + %(plot.curved_quiver)s + + Notes + ----- + The implementation of this function is based on the `dfm_tools` repository. + Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py + """ + from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet + + # Parse inputs + arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) + arrowstyle = _not_none(arrowstyle, rc["curved_quiver.arrowstyle"]) + zorder = _not_none(zorder, mlines.Line2D.zorder) + transform = _not_none(transform, self.transData) + color = _not_none(color, self._get_lines.get_next_color()) + linewidth = _not_none(linewidth, rc["lines.linewidth"]) + scale = _not_none(scale, rc["curved_quiver.scale"]) + grains = _not_none(grains, rc["curved_quiver.grains"]) + density = _not_none(density, rc["curved_quiver.density"]) + arrows_at_end = _not_none(arrow_at_end, rc["curved_quiver.arrows_at_end"]) + + solver = CurvedQuiverSolver(x, y, density) + if zorder is None: + zorder = mlines.Line2D.zorder + + line_kw = {} + arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize) + + use_multicolor_lines = isinstance(color, np.ndarray) + if use_multicolor_lines: + if color.shape != solver.grid.shape: + raise ValueError( + "If 'color' is given, must have the shape of 'Grid(x,y)'" + ) + line_colors = [] + color = np.ma.masked_invalid(color) + else: + line_kw["color"] = color + arrow_kw["color"] = color - See also - -------- - matplotlib.axes.Axes - ultraplot.axes.Axes - ultraplot.axes.CartesianAxes - ultraplot.axes.PolarAxes - ultraplot.axes.GeoAxes - """ - super().__init__(*args, **kwargs) + if isinstance(linewidth, np.ndarray): + if linewidth.shape != solver.grid.shape: + raise ValueError( + "If 'linewidth' is given, must have the shape of 'Grid(x,y)'" + ) + line_kw["linewidth"] = [] + else: + line_kw["linewidth"] = linewidth + arrow_kw["linewidth"] = linewidth + + line_kw["zorder"] = zorder + arrow_kw["zorder"] = zorder + + ## Sanity checks. + if u.shape != solver.grid.shape or v.shape != solver.grid.shape: + raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") + + u = np.ma.masked_invalid(u) + v = np.ma.masked_invalid(v) + magnitude = np.sqrt(u**2 + v**2) + magnitude /= np.max(magnitude) + + resolution = scale / grains + minlength = 0.9 * resolution + + integrate = solver.get_integrator(u, v, minlength, resolution, magnitude) + trajectories = [] + edges = [] + + if start_points is None: + start_points = solver.gen_starting_points(x, y, grains) + + sp2 = np.asanyarray(start_points, dtype=float).copy() + + # Check if start_points are outside the data boundaries + for xs, ys in sp2: + if not ( + solver.grid.x_origin <= xs <= solver.grid.x_origin + solver.grid.width + and solver.grid.y_origin + <= ys + <= solver.grid.y_origin + solver.grid.height + ): + raise ValueError( + "Starting point ({}, {}) outside of data " + "boundaries".format(xs, ys) + ) + + if use_multicolor_lines: + if norm is None: + norm = mcolors.Normalize(color.min(), color.max()) + if cmap is None: + cmap = constructor.Colormap(rc["image.cmap"]) + else: + cmap = mcm.get_cmap(cmap) + + # Convert start_points from data to array coords + # Shift the seed points from the bottom left of the data so that + # data2grid works properly. + sp2[:, 0] -= solver.grid.x_origin + sp2[:, 1] -= solver.grid.y_origin + + for xs, ys in sp2: + xg, yg = solver.domain_map.data2grid(xs, ys) + t = integrate(xg, yg) + if t is not None: + trajectories.append(t[0]) + edges.append(t[1]) + streamlines = [] + arrows = [] + for t, edge in zip(trajectories, edges): + tgx = np.array(t[0]) + tgy = np.array(t[1]) + + # Rescale from grid-coordinates to data-coordinates. + tx, ty = solver.domain_map.grid2data(*np.array(t)) + tx += solver.grid.x_origin + ty += solver.grid.y_origin + + points = np.transpose([tx, ty]).reshape(-1, 1, 2) + streamlines.extend(np.hstack([points[:-1], points[1:]])) + + if len(tx) < 2: + continue + + # Add arrows + s = np.cumsum(np.sqrt(np.diff(tx) ** 2 + np.diff(ty) ** 2)) + if arrow_at_end: + if len(tx) < 2: + continue + + arrow_tail = (tx[-1], ty[-1]) + + # Extrapolate to find arrow head + xg, yg = solver.domain_map.data2grid( + tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin + ) + + ui = solver.interpgrid(u, xg, yg) + vi = solver.interpgrid(v, xg, yg) + + norm_v = np.sqrt(ui**2 + vi**2) + if norm_v > 0: + ui /= norm_v + vi /= norm_v + + if len(s) > 0: + # use average segment length + arrow_length = arrowsize * (s[-1] / len(s)) + else: + # fallback for very short streamlines + arrow_length = ( + arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy]) + ) + + arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length) + n = len(s) - 1 if len(s) > 0 else 0 + else: + n = np.searchsorted(s, s[-1] / 2.0) + arrow_tail = (tx[n], ty[n]) + arrow_head = (np.mean(tx[n : n + 2]), np.mean(ty[n : n + 2])) + + if isinstance(linewidth, np.ndarray): + line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1] + line_kw["linewidth"].extend(line_widths) + arrow_kw["linewidth"] = line_widths[n] + + if use_multicolor_lines: + color_values = solver.interpgrid(color, tgx, tgy)[:-1] + line_colors.append(color_values) + arrow_kw["color"] = cmap(norm(color_values[n])) + + if not edge: + p = mpatches.FancyArrowPatch( + arrow_tail, arrow_head, transform=transform, **arrow_kw + ) + else: + continue + + ds = np.sqrt( + (arrow_tail[0] - arrow_head[0]) ** 2 + + (arrow_tail[1] - arrow_head[1]) ** 2 + ) + if ds < 1e-15: + continue # remove vanishingly short arrows that cause Patch to fail + + self.add_patch(p) + arrows.append(p) + + lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw) + lc.sticky_edges.x[:] = [ + solver.grid.x_origin, + solver.grid.x_origin + solver.grid.width, + ] + lc.sticky_edges.y[:] = [ + solver.grid.y_origin, + solver.grid.y_origin + solver.grid.height, + ] + + if use_multicolor_lines: + lc.set_array(np.ma.hstack(line_colors)) + lc.set_cmap(cmap) + lc.set_norm(norm) + + self.add_collection(lc) + self.autoscale_view() + + ac = mcollections.PatchCollection(arrows) + stream_container = CurvedQuiverSet(lc, ac) + return stream_container def _call_native(self, name, *args, **kwargs): """ @@ -5359,6 +5612,7 @@ def tripcolor(self, *args, **kwargs): # Update kwargs and handle cmap kw.update(_pop_props(kw, "collection")) + center_levels = kw.pop("center_levels", None) kw = self._parse_cmap( triangulation.x, triangulation.y, z, center_levels=center_levels, **kw diff --git a/ultraplot/axes/plot_types/__init__.py b/ultraplot/axes/plot_types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py new file mode 100644 index 00000000..5489a0a9 --- /dev/null +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -0,0 +1,413 @@ +# The following helper classes and functions for curved_quiver are based on the +# work in the `dfm_tools` repository. +# Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py +# Special thanks to @veenstrajelmer for the initial implementation + +__all__ = [ + "CurvedQuiverSolver", + "CurvedQuiverSet", +] + +from typing import Callable +from dataclasses import dataclass +from matplotlib.streamplot import StreamplotSet +import numpy as np + + +@dataclass +class CurvedQuiverSet(StreamplotSet): + lines: object + arrows: object + + +class _DomainMap(object): + """Map representing different coordinate systems. + + Coordinate definitions: + * axes-coordinates goes from 0 to 1 in the domain. + * data-coordinates are specified by the input x-y coordinates. + * grid-coordinates goes from 0 to N and 0 to M for an N x M grid, + where N and M match the shape of the input data. + * mask-coordinates goes from 0 to N and 0 to M for an N x M mask, + where N and M are user-specified to control the density of + streamlines. + + This class also has methods for adding trajectories to the + StreamMask. Before adding a trajectory, run `start_trajectory` to + keep track of regions crossed by a given trajectory. Later, if you + decide the trajectory is bad (e.g., if the trajectory is very + short) just call `undo_trajectory`. + """ + + def __init__(self, grid: "Grid", mask: "StreamMask") -> None: + self.grid = grid + self.mask = mask + + # Constants for conversion between grid- and mask-coordinates + self.x_grid2mask = (mask.nx - 1) / grid.nx + self.y_grid2mask = (mask.ny - 1) / grid.ny + self.x_mask2grid = 1.0 / self.x_grid2mask + self.y_mask2grid = 1.0 / self.y_grid2mask + + self.x_data2grid = 1.0 / grid.dx + self.y_data2grid = 1.0 / grid.dy + + def grid2mask(self, xi: float, yi: float) -> tuple[int, int]: + """Return nearest space in mask-coords from given grid-coords.""" + return ( + int((xi * self.x_grid2mask) + 0.5), + int((yi * self.y_grid2mask) + 0.5), + ) + + def mask2grid(self, xm: int, ym: int) -> tuple[float, float]: + return xm * self.x_mask2grid, ym * self.y_mask2grid + + def data2grid(self, xd: float, yd: float) -> tuple[float, float]: + return xd * self.x_data2grid, yd * self.y_data2grid + + def grid2data(self, xg: float, yg: float) -> tuple[float, float]: + return xg / self.x_data2grid, yg / self.y_data2grid + + def start_trajectory(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._start_trajectory(xm, ym) + + def reset_start_point(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._current_xy = (xm, ym) + + def update_trajectory(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._update_trajectory(xm, ym) + + def undo_trajectory(self) -> None: + self.mask._undo_trajectory() + + +class _CurvedQuiverGrid(object): + """Grid of data.""" + + def __init__(self, x: np.ndarray, y: np.ndarray) -> None: + if x.ndim == 1: + pass + elif x.ndim == 2: + x_row = x[0, :] + if not np.allclose(x_row, x): + raise ValueError("The rows of 'x' must be equal") + x = x_row + else: + raise ValueError("'x' can have at maximum 2 dimensions") + + if y.ndim == 1: + pass + elif y.ndim == 2: + y_col = y[:, 0] + if not np.allclose(y_col, y.T): + raise ValueError("The columns of 'y' must be equal") + y = y_col + else: + raise ValueError("'y' can have at maximum 2 dimensions") + + self.nx = len(x) + self.ny = len(y) + self.dx = x[1] - x[0] + self.dy = y[1] - y[0] + self.x_origin = x[0] + self.y_origin = y[0] + self.width = x[-1] - x[0] + self.height = y[-1] - y[0] + + @property + def shape(self) -> tuple[int, int]: + return self.ny, self.nx + + def within_grid(self, xi: float, yi: float) -> bool: + """Return True if point is a valid index of grid.""" + # Note that xi/yi can be floats; so, for example, we can't simply check + # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` + return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 + + +class _StreamMask(object): + """Mask to keep track of discrete regions crossed by streamlines. + + The resolution of this grid determines the approximate spacing + between trajectories. Streamlines are only allowed to pass through + zeroed cells: When a streamline enters a cell, that cell is set to + 1, and no new streamlines are allowed to enter. + """ + + def __init__(self, density: float | int): + if np.isscalar(density): + if density <= 0: + raise ValueError("If a scalar, 'density' must be positive") + self.nx = self.ny = int(30 * density) + else: + if len(density) != 2: + raise ValueError("'density' can have at maximum 2 dimensions") + self.nx = int(30 * density[0]) + self.ny = int(30 * density[1]) + + self._mask = np.zeros((self.ny, self.nx)) + self.shape = self._mask.shape + self._current_xy = None + + def __getitem__(self, *args): + return self._mask.__getitem__(*args) + + def _start_trajectory(self, xm: int, ym: int): + """Start recording streamline trajectory""" + self._traj = [] + self._update_trajectory(xm, ym) + + def _undo_trajectory(self): + """Remove current trajectory from mask""" + for t in self._traj: + self._mask.__setitem__(t, 0) + + def _update_trajectory(self, xm: int, ym: int) -> None: + """Update current trajectory position in mask. + + If the new position has already been filled, raise + `InvalidIndexError`. + """ + + self._traj.append((ym, xm)) + self._mask[ym, xm] = 1 + self._current_xy = (xm, ym) + + +class _CurvedQuiverTerminateTrajectory(Exception): + pass + + +class CurvedQuiverSolver: + + def __init__( + self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] + ) -> None: + self.grid = _CurvedQuiverGrid(x, y) + self.mask = _StreamMask(density) + self.domain_map = _DomainMap(self.grid, self.mask) + + def get_integrator( + self, + u: np.ndarray, + v: np.ndarray, + minlength: float, + resolution: float, + magnitude: np.ndarray, + ) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: + # rescale velocity onto grid-coordinates for integrations. + u, v = self.domain_map.data2grid(u, v) + + # speed (path length) will be in axes-coordinates + u_ax = u / self.domain_map.grid.nx + v_ax = v / self.domain_map.grid.ny + speed = np.ma.sqrt(u_ax**2 + v_ax**2) + + def forward_time(xi: float, yi: float) -> tuple[float, float]: + ds_dt = self.interpgrid(speed, xi, yi) + if ds_dt == 0: + raise _CurvedQuiverTerminateTrajectory() + dt_ds = 1.0 / ds_dt + ui = self.interpgrid(u, xi, yi) + vi = self.interpgrid(v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def integrate( + x0: float, y0: float + ) -> tuple[tuple[list[float], list[float], bool]] | None: + """Return x, y grid-coordinates of trajectory based on starting point. + + Integrate both forward and backward in time from starting point + in grid coordinates. Integration is terminated when a trajectory + reaches a domain boundary or when it crosses into an already + occupied cell in the StreamMask. The resulting trajectory is + None if it is shorter than `minlength`. + """ + stotal, x_traj, y_traj = 0.0, [], [] + self.domain_map.start_trajectory(x0, y0) + self.domain_map.reset_start_point(x0, y0) + stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12( + x0, y0, forward_time, resolution, magnitude + ) + + if len(x_traj) > 1: + return (x_traj, y_traj), hit_edge + else: + # reject short trajectories + self.domain_map.undo_trajectory() + return None + + return integrate + + def integrate_rk12( + self, + x0: float, + y0: float, + f: Callable[[float, float], tuple[float, float]], + resolution: float, + magnitude: np.ndarray, + ) -> tuple[float, list[float], list[float], list[float], bool]: + """2nd-order Runge-Kutta algorithm with adaptive step size. + + This method is also referred to as the improved Euler's method, or + Heun's method. This method is favored over higher-order methods + because: + + 1. To get decent looking trajectories and to sample every mask cell + on the trajectory we need a small timestep, so a lower order + solver doesn't hurt us unless the data is *very* high + resolution. In fact, for cases where the user inputs data + smaller or of similar grid size to the mask grid, the higher + order corrections are negligible because of the very fast linear + interpolation used in `interpgrid`. + + 2. For high resolution input data (i.e. beyond the mask + resolution), we must reduce the timestep. Therefore, an + adaptive timestep is more suited to the problem as this would be + very hard to judge automatically otherwise. + + This integrator is about 1.5 - 2x as fast as both the RK4 and RK45 + solvers in most setups on my machine. I would recommend removing + the other two to keep things simple. + """ + # This error is below that needed to match the RK4 integrator. It + # is set for visual reasons -- too low and corners start + # appearing ugly and jagged. Can be tuned. + maxerror = 0.003 + + # This limit is important (for all integrators) to avoid the + # trajectory skipping some mask cells. We could relax this + # condition if we use the code which is commented out below to + # increment the location gradually. However, due to the efficient + # nature of the interpolation, this doesn't boost speed by much + # for quite a bit of complexity. + maxds = min(1.0 / self.domain_map.mask.nx, 1.0 / self.domain_map.mask.ny, 0.1) + ds = maxds + + stotal = 0 + xi = x0 + yi = y0 + xf_traj = [] + yf_traj = [] + m_total = [] + hit_edge = False + + while self.domain_map.grid.within_grid(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + m_total.append(self.interpgrid(magnitude, xi, yi)) + + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + ds * k1x, yi + ds * k1y) + except IndexError: + # Out of the domain on one of the intermediate integration steps. + # Take an Euler step to the boundary to improve neatness. + ds, xf_traj, yf_traj = self.euler_step(xf_traj, yf_traj, f) + stotal += ds + hit_edge = True + break + except _CurvedQuiverTerminateTrajectory: + break + + dx1 = ds * k1x + dy1 = ds * k1y + dx2 = ds * 0.5 * (k1x + k2x) + dy2 = ds * 0.5 * (k1y + k2y) + + nx, ny = self.domain_map.grid.shape + # Error is normalized to the axes coordinates + error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2) + + # Only save step if within error tolerance + if error < maxerror: + xi += dx2 + yi += dy2 + self.domain_map.update_trajectory(xi, yi) + if not self.domain_map.grid.within_grid(xi, yi): + hit_edge = True + if (stotal + ds) > resolution * np.mean(m_total): + break + stotal += ds + + # recalculate stepsize based on step error + if error == 0: + ds = maxds + else: + ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) + + return stotal, xf_traj, yf_traj, m_total, hit_edge + + def euler_step(self, xf_traj, yf_traj, f): + """Simple Euler integration step that extends streamline to boundary.""" + ny, nx = self.domain_map.grid.shape + xi = xf_traj[-1] + yi = yf_traj[-1] + cx, cy = f(xi, yi) + + if cx == 0: + dsx = np.inf + elif cx < 0: + dsx = xi / -cx + else: + dsx = (nx - 1 - xi) / cx + + if cy == 0: + dsy = np.inf + elif cy < 0: + dsy = yi / -cy + else: + dsy = (ny - 1 - yi) / cy + + ds = min(dsx, dsy) + + xf_traj.append(xi + cx * ds) + yf_traj.append(yi + cy * ds) + + return ds, xf_traj, yf_traj + + def interpgrid(self, a, xi, yi): + """Fast 2D, linear interpolation on an integer grid""" + Ny, Nx = np.shape(a) + + if isinstance(xi, np.ndarray): + x = xi.astype(int) + y = yi.astype(int) + + # Check that xn, yn don't exceed max index + xn = np.clip(x + 1, 0, Nx - 1) + yn = np.clip(y + 1, 0, Ny - 1) + else: + x = int(xi) + y = int(yi) + xn = min(x + 1, Nx - 1) + yn = min(y + 1, Ny - 1) + + a00 = a[y, x] + a01 = a[y, xn] + a10 = a[yn, x] + a11 = a[yn, xn] + + xt = xi - x + yt = yi - y + + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + ai = a0 * (1 - yt) + a1 * yt + + if not isinstance(xi, np.ndarray): + if np.ma.is_masked(ai): + raise _CurvedQuiverTerminateTrajectory + return ai + + def gen_starting_points(self, x, y, grains): + eps = np.finfo(np.float32).eps + tmp_x = np.linspace(x.min() + eps, x.max() - eps, grains) + tmp_y = np.linspace(y.min() + eps, y.max() - eps, grains) + xs = np.tile(tmp_x, grains) + ys = np.repeat(tmp_y, grains) + seed_points = np.array([list(xs), list(ys)]) + return seed_points.T diff --git a/ultraplot/internals/rcsetup.py b/ultraplot/internals/rcsetup.py index 143fa237..bdb9a88d 100644 --- a/ultraplot/internals/rcsetup.py +++ b/ultraplot/internals/rcsetup.py @@ -890,6 +890,37 @@ def copy(self): "interpreted by `~ultraplot.utils.units`. Numeric units are points." ) _rc_ultraplot_table = { + # Curved quiver settings + "curved_quiver.arrowsize": ( + 1.0, + _validate_float, + "Default size scaling for arrows in curved quiver plots.", + ), + "curved_quiver.arrowstyle": ( + "-|>", + _validate_string, + "Default arrow style for curved quiver plots.", + ), + "curved_quiver.scale": ( + 1.0, + _validate_float, + "Default scale factor for curved quiver plots.", + ), + "curved_quiver.grains": ( + 15, + _validate_int, + "Default number of grains (segments) for curved quiver arrows.", + ), + "curved_quiver.density": ( + 10, + _validate_int, + "Default density of arrows for curved quiver plots.", + ), + "curved_quiver.arrows_at_end": ( + True, + _validate_bool, + "Whether to draw arrows at the end of curved quiver lines by default.", + ), # Stylesheet "style": ( None, diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 9f43a144..e1662dd5 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -447,3 +447,143 @@ def test_inhomogeneous_violin(rng): for violin in violins: assert violin.get_paths() # Ensure paths are created return fig + + +@pytest.mark.mpl_image_compare +def test_curved_quiver(rng): + # Create a grid + x = np.linspace(-4, 4, 20) + y = np.linspace(-3, 3, 20) + X, Y = np.meshgrid(x, y) + + # Define a rotational vector field (circular flow) + U = -Y + V = X + speed = np.sqrt(U**2 + V**2) + + # Create a figure and axes + fig, axs = uplt.subplots(ncols=3, sharey=True, figsize=(12, 4)) + + # Left plot: matplotlib's streamplot + axs[0].streamplot(X, Y, U, V, color=speed) + axs[0].set_title("streamplot (native)") + + # Middle plot: quiver + axs[1].quiver(X, Y, U, V, speed) + axs[1].set_title("quiver") + + # Right plot: curved_quiver + m = axs[2].curved_quiver( + X, Y, U, V, color=speed, arrow_at_end=True, scale=2.0, grains=10 + ) + axs[2].set_title("curved_quiver") + fig.colorbar(m.lines, ax=axs[:], label="speed") + return fig + + +def test_validate_vector_shapes_pass(): + """ + Test that vector shapes match the grid shape using CurvedQuiverSolver. + """ + from ultraplot.axes.plot_types.curved_quiver import _CurvedQuiverGrid + + x = np.linspace(0, 1, 3) + y = np.linspace(0, 1, 3) + grid = _CurvedQuiverGrid(x, y) + u = np.ones(grid.shape) + v = np.ones(grid.shape) + assert u.shape == grid.shape + assert v.shape == grid.shape + + +def test_validate_vector_shapes_fail(): + """ + Test that assertion fails when u and v do not match the grid shape using CurvedQuiverSolver. + """ + from ultraplot.axes.plot_types.curved_quiver import ( + CurvedQuiverSolver, + _CurvedQuiverGrid, + ) + + x = np.linspace(0, 1, 3) + y = np.linspace(0, 1, 3) + grid = _CurvedQuiverGrid(x, y) + u = np.ones((2, 2)) + v = np.ones(grid.shape) + with pytest.raises(AssertionError): + assert u.shape == grid.shape + + +def test_normalize_magnitude(): + """ + Test that magnitude normalization returns a normalized array with max value 1.0 and correct shape. + """ + u = np.array([[1, 2], [3, 4]]) + v = np.array([[4, 3], [2, 1]]) + mag = np.sqrt(u**2 + v**2) + mag_norm = mag / np.max(mag) + assert np.allclose(np.max(mag_norm), 1.0) + assert mag_norm.shape == u.shape + + +def test_generate_start_points(): + """ + Test that CurvedQuiverSolver.gen_starting_points returns valid grid coordinates for seed points, + and that grid.within_grid detects points outside the grid boundaries. + """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver + + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + grains = 5 + solver = CurvedQuiverSolver(x, y, density=5) + sp2 = solver.gen_starting_points(x, y, grains) + assert sp2.shape[1] == 2 + # Should detect if outside boundaries + bad_points = np.array([[10, 10]]) + grid = solver.grid + for pt in bad_points: + assert not grid.within_grid(pt[0], pt[1]) + + +def test_calculate_trajectories(): + """ + Test that CurvedQuiverSolver.get_integrator returns callable for each seed point + and returns lists of trajectories and edges of correct length. + """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver + + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + u = np.ones((5, 5)) + v = np.ones((5, 5)) + mag = np.sqrt(u**2 + v**2) + solver = CurvedQuiverSolver(x, y, density=5) + integrator = solver.get_integrator( + u, v, minlength=0.1, resolution=1.0, magnitude=mag + ) + seeds = solver.gen_starting_points(x, y, grains=2) + results = [integrator(pt[0], pt[1]) for pt in seeds] + assert len(results) == seeds.shape[0] + + +@pytest.mark.mpl_image_compare +def test_curved_quiver_multicolor_lines(): + """ + Test that curved_quiver handles color arrays and returns a lines object. + """ + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + X, Y = np.meshgrid(x, y) + U = np.ones_like(X) + V = np.ones_like(Y) + speed = np.sqrt(U**2 + V**2) + + fig, ax = uplt.subplots() + m = ax.curved_quiver(X, Y, U, V, color=speed) + from matplotlib.collections import LineCollection + + assert isinstance(m.lines, LineCollection) + assert m.lines.get_array().size > 0 # we have colors set + assert m.lines.get_cmap() is not None + return fig