From 4f96b60e35db97d018be504a21d5de5fd11b6988 Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Thu, 26 Feb 2026 13:21:14 +1100 Subject: [PATCH 1/6] fix: adding downhole plots for drillhole/dbs --- examples/plot_downhole.py | 69 ++++++ loopresources/drillhole/drillhole.py | 186 ++++++++++++++ loopresources/drillhole/drillhole_database.py | 228 ++++++++++++++++++ 3 files changed, 483 insertions(+) create mode 100644 examples/plot_downhole.py diff --git a/examples/plot_downhole.py b/examples/plot_downhole.py new file mode 100644 index 0000000..54dac18 --- /dev/null +++ b/examples/plot_downhole.py @@ -0,0 +1,69 @@ +""" +Downhole plotting +================ + +Example showing downhole line, categorical, and image plots. +""" + +import matplotlib.pyplot as plt +import pandas as pd + +from loopresources.drillhole.dhconfig import DhConfig +from loopresources.drillhole.drillhole_database import DrillholeDatabase + +collar = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH002", "DH003"], + DhConfig.x: [100.0, 200.0, 300.0], + DhConfig.y: [1000.0, 2000.0, 3000.0], + DhConfig.z: [50.0, 60.0, 70.0], + DhConfig.total_depth: [150.0, 200.0, 180.0], + } +) + +survey = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"], + DhConfig.depth: [0.0, 100.0, 0.0, 120.0, 0.0], + DhConfig.azimuth: [0.0, 0.0, 45.0, 45.0, 90.0], + DhConfig.dip: [-90.0, -90.0, -85.0, -80.0, -90.0], + } +) + +db = DrillholeDatabase(collar, survey) + +lithology = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"], + DhConfig.sample_from: [0.0, 50.0, 0.0, 80.0, 0.0], + DhConfig.sample_to: [50.0, 150.0, 80.0, 200.0, 180.0], + "LITHO": ["Granite", "Schist", "Sandstone", "Shale", "Limestone"], + } +) + +assays = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"], + DhConfig.sample_from: [0.0, 75.0, 0.0, 100.0, 0.0], + DhConfig.sample_to: [75.0, 150.0, 100.0, 200.0, 180.0], + "AU_ppm": [0.1, 2.5, 0.05, 1.2, 0.4], + } +) + +db.add_interval_table("lithology", lithology) +db.add_interval_table("assays", assays) + +# Line plot (numeric values) +db.plot_downhole("assays", "AU_ppm", kind="line", layout="grid", ncols=2) +plt.tight_layout() +plt.show() + +# Categorical plot +db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2) +plt.tight_layout() +plt.show() + +# Image plot (numeric heatmap) +db.plot_downhole("assays", "AU_ppm", kind="image", step=2.0, layout="grid", ncols=2) +plt.tight_layout() +plt.show() diff --git a/loopresources/drillhole/drillhole.py b/loopresources/drillhole/drillhole.py index 53846b4..90d29f2 100644 --- a/loopresources/drillhole/drillhole.py +++ b/loopresources/drillhole/drillhole.py @@ -387,6 +387,192 @@ def trace(self, step: float = 1.0) -> DrillHoleTrace: """ return DrillHoleTrace(self, interval=step) + def _downhole_depth_grid( + self, step: float, max_depth: Optional[float] = None + ) -> np.ndarray: + if step <= 0: + raise ValueError("step must be > 0") + if max_depth is None: + max_depth = float(self.collar[DhConfig.total_depth].values[0]) + if max_depth <= 0: + return np.array([], dtype=float) + return np.arange(0.0, max_depth + step, step) + + def _sample_downhole_values( + self, + table_name: str, + column: str, + step: float, + kind: str, + depth_grid: Optional[np.ndarray] = None, + ): + if table_name in self.database.intervals: + table = self[table_name] + if table.empty: + return np.array([], dtype=float), np.array([], dtype=float) + if column not in table.columns: + raise KeyError(f"Column '{column}' not found in interval table '{table_name}'") + grid = depth_grid if depth_grid is not None else self._downhole_depth_grid(step) + from .resample import resample_interval + + sampled = resample_interval( + pd.DataFrame({DhConfig.depth: grid}), table, [column], method="direct" + ) + return sampled[DhConfig.depth].to_numpy(), sampled[column].to_numpy() + + if table_name in self.database.points: + table = self[table_name] + if table.empty: + return np.array([], dtype=float), np.array([], dtype=float) + if column not in table.columns: + raise KeyError(f"Column '{column}' not found in point table '{table_name}'") + if kind == "line": + return table[DhConfig.depth].to_numpy(), table[column].to_numpy() + + grid = depth_grid if depth_grid is not None else self._downhole_depth_grid(step) + values = np.array([None] * len(grid), dtype=object) + depths = table[DhConfig.depth].to_numpy() + for depth, value in zip(depths, table[column].to_numpy()): + if step <= 0: + continue + idx = int(np.round(depth / step)) + if 0 <= idx < len(values): + values[idx] = value + return grid, values + + raise KeyError(f"Table '{table_name}' not found in intervals or points") + + def plot_downhole( + self, + table_name: str, + column: str, + kind: str = "line", + step: float = 1.0, + ax=None, + cmap: str = "tab20", + show_legend: bool = True, + **kwargs, + ): + """Plot a downhole variable as a line or categorical image. + + Parameters + ---------- + table_name : str + Interval or point table name. + column : str + Column to plot. + kind : {"line", "categorical", "image"} + Plot style. Use "categorical" for discrete values and "image" for numeric heatmaps. + step : float, default 1.0 + Sampling step (meters) for interval or categorical plots. + ax : matplotlib.axes.Axes, optional + Axes to plot on. + cmap : str, default "tab20" + Colormap name for categorical plots. + show_legend : bool, default True + Whether to show a legend. + **kwargs + Passed through to matplotlib plot functions. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + + kind = kind.lower() + if kind not in {"line", "categorical", "image"}: + raise ValueError("kind must be 'line', 'categorical', or 'image'") + + if ax is None: + _, ax = plt.subplots(figsize=(4, 8)) + + depths, values = self._sample_downhole_values(table_name, column, step, kind) + if len(depths) == 0: + return ax + + if kind == "line": + series = pd.to_numeric(pd.Series(values), errors="coerce") + mask = ~np.isnan(series.to_numpy()) + if not mask.any(): + return ax + ax.plot(series[mask], np.asarray(depths)[mask], label=self.hole_id, **kwargs) + ax.set_xlabel(column) + ax.set_ylabel("Depth") + ax.set_title(f"{self.hole_id} {column}") + ax.invert_yaxis() + if show_legend: + ax.legend() + return ax + + if kind == "image": + series = pd.to_numeric(pd.Series(values), errors="coerce") + if series.isna().all(): + return ax + data = np.ma.masked_invalid(series.to_numpy())[:, None] + max_depth = float(np.nanmax(depths)) if len(depths) else 0.0 + im = ax.imshow( + data, + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap, + ) + ax.set_xticks([0.5]) + ax.set_xticklabels([self.hole_id]) + ax.set_xlabel("Hole") + ax.set_ylabel("Depth") + ax.set_title(f"{column}") + if show_legend: + ax.figure.colorbar(im, ax=ax, label=column) + return ax + + depth_values = np.asarray(values, dtype=object) + category_values = pd.Series(depth_values) + categories = [c for c in category_values.unique() if pd.notna(c)] + if not categories: + return ax + category_to_code = {cat: idx for idx, cat in enumerate(categories)} + + codes = np.full(len(depth_values), -1.0) + for idx, value in enumerate(depth_values): + if pd.notna(value): + codes[idx] = category_to_code[value] + + masked = np.ma.masked_where(codes < 0, codes) + cmap_obj = plt.get_cmap(cmap, len(categories)) + try: + cmap_obj = cmap_obj.copy() + except Exception: + pass + try: + cmap_obj.set_bad(color="lightgray") + except Exception: + pass + + max_depth = float(np.nanmax(depths)) if len(depths) else 0.0 + ax.imshow( + masked[:, None], + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap_obj, + vmin=0, + vmax=max(0, len(categories) - 1), + ) + ax.set_xticks([0.5]) + ax.set_xticklabels([self.hole_id]) + ax.set_xlabel("Hole") + ax.set_ylabel("Depth") + ax.set_title(f"{column}") + + if show_legend: + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + ax.legend(handles=handles, title=column, bbox_to_anchor=(1.02, 1), loc="upper left") + return ax + def find_implicit_function_intersection( self, function: Callable[[ArrayLike], ArrayLike], step: float = 1.0, intersection_value : float = 0.0 ) -> pd.DataFrame: diff --git a/loopresources/drillhole/drillhole_database.py b/loopresources/drillhole/drillhole_database.py index ffc5693..226b010 100644 --- a/loopresources/drillhole/drillhole_database.py +++ b/loopresources/drillhole/drillhole_database.py @@ -142,6 +142,234 @@ def plot_collars(self, ax=None, **kwargs): ax.text(x, y, hole_id, fontsize=9, ha="right", va="bottom") return ax + def plot_downhole( + self, + table_name: str, + column: str, + holes: Optional[List[str]] = None, + kind: str = "line", + step: float = 1.0, + ax=None, + layout: str = "grid", + ncols: int = 3, + cmap: str = "tab20", + show_legend: bool = True, + **kwargs, + ): + """Plot a downhole variable for one or more drillholes. + + Parameters + ---------- + table_name : str + Interval or point table name. + column : str + Column to plot. + holes : list of str, optional + Hole IDs to include. Defaults to all holes in the collar table. + kind : {"line", "categorical", "image"} + Plot style. Use "categorical" for discrete values and "image" for numeric heatmaps. + step : float, default 1.0 + Sampling step (meters) for interval or categorical plots. + ax : matplotlib.axes.Axes, optional + Axes to plot on. + layout : {"grid", "column"}, default "grid" + Layout for multiple holes when ax is None. + ncols : int, default 3 + Number of columns when layout is "grid". + cmap : str, default "tab20" + Colormap name for categorical plots. + show_legend : bool, default True + Whether to show a legend. + **kwargs + Passed through to matplotlib plot functions. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + + kind = kind.lower() + if kind not in {"line", "categorical", "image"}: + raise ValueError("kind must be 'line', 'categorical', or 'image'") + if holes is None: + holes = list(self.collar[DhConfig.holeid].unique()) + else: + holes = list(holes) + + valid_holes = [ + hole_id + for hole_id in holes + if hole_id in set(self.collar[DhConfig.holeid].unique()) + ] + if not valid_holes: + return ax + + layout = layout.lower() + if layout not in {"grid", "column"}: + raise ValueError("layout must be 'grid' or 'column'") + + axes = None + if ax is None: + if len(valid_holes) == 1: + fig, axes = plt.subplots(figsize=(6, 4)) + ax = axes + elif layout == "column": + fig, axes = plt.subplots( + nrows=len(valid_holes), ncols=1, figsize=(6, 3 * len(valid_holes)), sharex=True + ) + ax = axes + else: + ncols = max(1, int(ncols)) + nrows = int(np.ceil(len(valid_holes) / ncols)) + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=(4 * ncols, 3 * nrows), + sharex=True, + sharey=True, + ) + ax = axes + elif isinstance(ax, (list, tuple, np.ndarray)): + ax = np.array(ax) + if ax.size < len(valid_holes): + raise ValueError("ax must have at least one axis per hole") + axes = ax + else: + if len(valid_holes) > 1: + raise ValueError( + "Multiple holes require one axis per hole. Pass ax as a list/array or omit ax." + ) + + axes_array = np.array(ax).ravel() if isinstance(ax, np.ndarray) else np.array([ax]) + axes_list = axes_array[: len(valid_holes)] + if axes is not None and isinstance(axes, np.ndarray) and axes_array.size > len(valid_holes): + for extra_ax in axes_array[len(valid_holes) :]: + extra_ax.set_visible(False) + + if kind == "line": + for hole_id, hole_ax in zip(valid_holes, axes_list): + hole = DrillHole(self, hole_id) + depths, values = hole._sample_downhole_values( + table_name, column, step, kind + ) + if len(depths) == 0: + continue + series = pd.to_numeric(pd.Series(values), errors="coerce") + mask = ~np.isnan(series.to_numpy()) + if not mask.any(): + continue + hole_ax.plot(series[mask], np.asarray(depths)[mask], label=hole_id, **kwargs) + hole_ax.set_xlabel(column) + hole_ax.set_ylabel("Depth") + hole_ax.set_title(hole_id) + hole_ax.invert_yaxis() + if show_legend: + hole_ax.legend() + return axes if axes is not None else ax + + if kind == "image": + if step <= 0: + raise ValueError("step must be > 0") + for hole_id, hole_ax in zip(valid_holes, axes_list): + hole = DrillHole(self, hole_id) + max_depth = float(hole.collar[DhConfig.total_depth].values[0]) + if max_depth <= 0: + continue + depth_grid = np.arange(0.0, max_depth + step, step) + _, values = hole._sample_downhole_values( + table_name, column, step, kind, depth_grid=depth_grid + ) + if len(values) == 0: + continue + series = pd.to_numeric(pd.Series(values), errors="coerce") + if series.isna().all(): + continue + data = np.ma.masked_invalid(series.to_numpy())[:, None] + im = hole_ax.imshow( + data, + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap, + ) + hole_ax.set_xticks([0.5]) + hole_ax.set_xticklabels([hole_id]) + hole_ax.set_xlabel("Hole") + hole_ax.set_ylabel("Depth") + hole_ax.set_title(hole_id) + if show_legend: + hole_ax.figure.colorbar(im, ax=hole_ax, label=column) + return axes if axes is not None else ax + + if step <= 0: + raise ValueError("step must be > 0") + + sampled_values = [] + all_values = [] + for hole_id in valid_holes: + hole = DrillHole(self, hole_id) + max_depth = float(hole.collar[DhConfig.total_depth].values[0]) + if max_depth <= 0: + sampled_values.append((hole_id, max_depth, np.array([]))) + continue + depth_grid = np.arange(0.0, max_depth + step, step) + _, values = hole._sample_downhole_values( + table_name, column, step, kind, depth_grid=depth_grid + ) + sampled_values.append((hole_id, max_depth, values)) + if len(values) > 0: + all_values.extend([v for v in values if pd.notna(v)]) + + categories = [c for c in pd.Series(all_values).unique() if pd.notna(c)] + if not categories: + return axes if axes is not None else ax + category_to_code = {cat: idx for idx, cat in enumerate(categories)} + + cmap_obj = plt.get_cmap(cmap, len(categories)) + try: + cmap_obj = cmap_obj.copy() + except Exception: + pass + try: + cmap_obj.set_bad(color="lightgray") + except Exception: + pass + + for (hole_id, max_depth, values), hole_ax in zip(sampled_values, axes_list): + if len(values) == 0: + continue + + codes = np.full(len(values), -1.0) + for row_idx, value in enumerate(values): + if pd.notna(value): + codes[row_idx] = category_to_code[value] + + masked = np.ma.masked_where(codes < 0, codes) + hole_ax.imshow( + masked[:, None], + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap_obj, + vmin=0, + vmax=max(0, len(categories) - 1), + ) + hole_ax.set_xticks([0.5]) + hole_ax.set_xticklabels([hole_id]) + hole_ax.set_xlabel("Hole") + hole_ax.set_ylabel("Depth") + hole_ax.set_title(hole_id) + + if show_legend: + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + hole_ax.legend( + handles=handles, title=column, bbox_to_anchor=(1.02, 1), loc="upper left" + ) + return axes if axes is not None else ax + def get_collar_for_hole(self, hole_id: str) -> pd.DataFrame: """Get collar data for a specific hole. From 1dbb223fb2b1103f7ff30c2e27a815fc3c58a37f Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Thu, 26 Feb 2026 13:48:53 +1100 Subject: [PATCH 2/6] fix: enhance downhole plotting with additional categorical legend options --- examples/plot_downhole.py | 18 ++++- loopresources/drillhole/drillhole_database.py | 66 ++++++++++++++++--- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/examples/plot_downhole.py b/examples/plot_downhole.py index 54dac18..0c5a6f9 100644 --- a/examples/plot_downhole.py +++ b/examples/plot_downhole.py @@ -58,11 +58,27 @@ plt.tight_layout() plt.show() -# Categorical plot +# Categorical plot with shared legend on the right db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2) plt.tight_layout() plt.show() +# Categorical plot with legend at bottom +db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2, legend_loc="bottom") +plt.tight_layout() +plt.show() + +# Categorical plot without legend +db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2, show_legend=False) +plt.tight_layout() +plt.show() + +# Create standalone legend +categories = lithology["LITHO"].unique() +DrillholeDatabase.create_categorical_legend(categories, cmap="tab20", title="Lithology") +plt.tight_layout() +plt.show() + # Image plot (numeric heatmap) db.plot_downhole("assays", "AU_ppm", kind="image", step=2.0, layout="grid", ncols=2) plt.tight_layout() diff --git a/loopresources/drillhole/drillhole_database.py b/loopresources/drillhole/drillhole_database.py index 226b010..d4fbec2 100644 --- a/loopresources/drillhole/drillhole_database.py +++ b/loopresources/drillhole/drillhole_database.py @@ -154,6 +154,7 @@ def plot_downhole( ncols: int = 3, cmap: str = "tab20", show_legend: bool = True, + legend_loc: str = "right", **kwargs, ): """Plot a downhole variable for one or more drillholes. @@ -180,6 +181,8 @@ def plot_downhole( Colormap name for categorical plots. show_legend : bool, default True Whether to show a legend. + legend_loc : {"right", "bottom", "none"}, default "right" + Location for the shared categorical legend. Use "none" for no legend. **kwargs Passed through to matplotlib plot functions. """ @@ -360,16 +363,63 @@ def plot_downhole( hole_ax.set_ylabel("Depth") hole_ax.set_title(hole_id) - if show_legend: - handles = [ - mpatches.Patch(color=cmap_obj(i), label=str(cat)) - for i, cat in enumerate(categories) - ] - hole_ax.legend( - handles=handles, title=column, bbox_to_anchor=(1.02, 1), loc="upper left" - ) + if show_legend and legend_loc != "none": + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + if axes is not None and hasattr(axes_list[0], 'figure'): + fig = axes_list[0].figure + if legend_loc == "right": + fig.legend(handles=handles, title=column, loc="center left", bbox_to_anchor=(1.0, 0.5)) + elif legend_loc == "bottom": + fig.legend(handles=handles, title=column, loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=min(len(categories), 5)) return axes if axes is not None else ax + @staticmethod + def create_categorical_legend( + categories: List[str], + cmap: str = "tab20", + title: str = "Categories", + ax=None, + **kwargs, + ): + """Create a standalone categorical legend. + + Parameters + ---------- + categories : list of str + List of category names. + cmap : str, default "tab20" + Colormap name. + title : str, default "Categories" + Legend title. + ax : matplotlib.axes.Axes, optional + Axes to add legend to. If None, creates a new figure. + **kwargs + Additional keyword arguments passed to ax.legend(). + + Returns + ------- + matplotlib.legend.Legend + The legend object. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + + cmap_obj = plt.get_cmap(cmap, len(categories)) + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + + if ax is None: + fig, ax = plt.subplots(figsize=(3, max(2, len(categories) * 0.3))) + ax.axis("off") + + legend = ax.legend(handles=handles, title=title, loc="center", **kwargs) + return legend + def get_collar_for_hole(self, hole_id: str) -> pd.DataFrame: """Get collar data for a specific hole. From 7c20765efe4868d5e159eae7eae51b395b55c19f Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Thu, 26 Feb 2026 14:24:39 +1100 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- loopresources/drillhole/drillhole_database.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/loopresources/drillhole/drillhole_database.py b/loopresources/drillhole/drillhole_database.py index d4fbec2..3b8f066 100644 --- a/loopresources/drillhole/drillhole_database.py +++ b/loopresources/drillhole/drillhole_database.py @@ -197,10 +197,11 @@ def plot_downhole( else: holes = list(holes) + collar_holes = set(self.collar[DhConfig.holeid].unique()) valid_holes = [ hole_id for hole_id in holes - if hole_id in set(self.collar[DhConfig.holeid].unique()) + if hole_id in collar_holes ] if not valid_holes: return ax From ded74b4e2be0ca14385459ff272fc981c606d3bc Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Thu, 26 Feb 2026 14:24:51 +1100 Subject: [PATCH 4/6] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- loopresources/drillhole/drillhole_database.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/loopresources/drillhole/drillhole_database.py b/loopresources/drillhole/drillhole_database.py index 3b8f066..44ada74 100644 --- a/loopresources/drillhole/drillhole_database.py +++ b/loopresources/drillhole/drillhole_database.py @@ -198,13 +198,14 @@ def plot_downhole( holes = list(holes) collar_holes = set(self.collar[DhConfig.holeid].unique()) - valid_holes = [ - hole_id - for hole_id in holes - if hole_id in collar_holes - ] + collar_holes = set(self.collar[DhConfig.holeid].unique()) + valid_holes = [hole_id for hole_id in holes if hole_id in collar_holes] if not valid_holes: - return ax + unknown_holes = sorted({hole_id for hole_id in holes if hole_id not in collar_holes}) + raise ValueError( + "None of the requested holes are present in the collar table. " + f"Unknown hole IDs: {unknown_holes}" + ) layout = layout.lower() if layout not in {"grid", "column"}: From 86782fc85b86d22697dcf2a28f7baf3ffe65c68c Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Thu, 26 Feb 2026 14:56:14 +1100 Subject: [PATCH 5/6] tests: add dh test --- tests/test_drillhole.py | 129 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/tests/test_drillhole.py b/tests/test_drillhole.py index e69de29..84f8c4c 100644 --- a/tests/test_drillhole.py +++ b/tests/test_drillhole.py @@ -0,0 +1,129 @@ +"""Tests for DrillHole sampling and plotting helpers.""" + +import numpy as np +import pandas as pd +import pytest + +import matplotlib + +matplotlib.use("Agg") + +from loopresources.drillhole.drillhole_database import DrillholeDatabase +from loopresources.drillhole.dhconfig import DhConfig + + +@pytest.fixture +def sample_db(): + collar = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH002"], + DhConfig.x: [100.0, 200.0], + DhConfig.y: [1000.0, 2000.0], + DhConfig.z: [50.0, 60.0], + DhConfig.total_depth: [10.0, 8.0], + } + ) + + survey = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH002"], + DhConfig.depth: [0.0, 0.0], + DhConfig.azimuth: [0.0, 0.0], + DhConfig.dip: [90.0, 90.0], + } + ) + + intervals = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001"], + DhConfig.sample_from: [0.0, 5.0], + DhConfig.sample_to: [5.0, 10.0], + "LITHO": ["A", "B"], + "GRADE": [1.0, 2.0], + } + ) + + points = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH001"], + DhConfig.depth: [1.0, 4.0, 9.0], + "AU_PPM": [100.0, 200.0, 300.0], + } + ) + + db = DrillholeDatabase(collar, survey) + db.add_interval_table("geology", intervals) + db.add_point_table("assay", points) + return db + + +def test_downhole_depth_grid(sample_db): + hole = sample_db["DH001"] + grid = hole._downhole_depth_grid(step=2.5) + np.testing.assert_allclose(grid, np.array([0.0, 2.5, 5.0, 7.5, 10.0])) + + empty = hole._downhole_depth_grid(step=1.0, max_depth=0.0) + assert empty.size == 0 + + with pytest.raises(ValueError, match="step must be > 0"): + hole._downhole_depth_grid(step=0.0) + + +def test_sample_downhole_values_interval_categorical(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("geology", "LITHO", 2.5, "categorical") + np.testing.assert_allclose(depths, np.array([0.0, 2.5, 5.0, 7.5, 10.0])) + assert values.tolist() == ["A", "A", "A", "B", "B"] + + +def test_sample_downhole_values_interval_numeric(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("geology", "GRADE", 2.5, "image") + np.testing.assert_allclose(depths, np.array([0.0, 2.5, 5.0, 7.5, 10.0])) + np.testing.assert_allclose(values.astype(float), np.array([1.0, 1.0, 1.0, 2.0, 2.0])) + + +def test_sample_downhole_values_point_line(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("assay", "AU_PPM", 1.0, "line") + np.testing.assert_allclose(depths, np.array([1.0, 4.0, 9.0])) + np.testing.assert_allclose(values, np.array([100.0, 200.0, 300.0])) + + +def test_sample_downhole_values_point_categorical_grid(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("assay", "AU_PPM", 2.0, "categorical") + np.testing.assert_allclose(depths, np.array([0.0, 2.0, 4.0, 6.0, 8.0, 10.0])) + assert values.tolist() == [100.0, None, 200.0, None, 300.0, None] + + +def test_sample_downhole_values_missing_table_column(sample_db): + hole = sample_db["DH001"] + with pytest.raises(KeyError, match="Table 'missing' not found"): + hole._sample_downhole_values("missing", "LITHO", 1.0, "line") + + with pytest.raises(KeyError, match="Column 'NOPE' not found in interval table"): + hole._sample_downhole_values("geology", "NOPE", 1.0, "line") + + with pytest.raises(KeyError, match="Column 'NOPE' not found in point table"): + hole._sample_downhole_values("assay", "NOPE", 1.0, "line") + + +def test_plot_downhole_multi_hole_layout_errors(sample_db): + with pytest.raises(ValueError, match="ax must have at least one axis per hole"): + sample_db.plot_downhole("geology", "LITHO", holes=["DH001", "DH002"], ax=[]) + + with pytest.raises(ValueError, match="Unknown hole IDs"): + sample_db.plot_downhole("geology", "LITHO", holes=["DH999"]) + + with pytest.raises(ValueError, match="layout must be 'grid' or 'column'"): + sample_db.plot_downhole("geology", "LITHO", layout="row") + + with pytest.raises(ValueError, match="kind must be 'line', 'categorical', or 'image'"): + sample_db.plot_downhole("geology", "LITHO", kind="scatter") + + +def test_plot_downhole_multi_hole_grid_returns_axes(sample_db): + axes = sample_db.plot_downhole("geology", "LITHO", holes=["DH001", "DH002"], layout="grid") + assert isinstance(axes, np.ndarray) + assert axes.size >= 2 From 368bfcb7b9794a42b6e6712c225222063de6ad86 Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Thu, 26 Feb 2026 14:58:08 +1100 Subject: [PATCH 6/6] fix: updating legend location --- loopresources/drillhole/drillhole_database.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/loopresources/drillhole/drillhole_database.py b/loopresources/drillhole/drillhole_database.py index 44ada74..11c4a40 100644 --- a/loopresources/drillhole/drillhole_database.py +++ b/loopresources/drillhole/drillhole_database.py @@ -364,7 +364,9 @@ def plot_downhole( hole_ax.set_xlabel("Hole") hole_ax.set_ylabel("Depth") hole_ax.set_title(hole_id) - + legend_loc = legend_loc.lower() + if legend_loc not in {"right", "bottom", "none"}: + raise ValueError("legend_loc must be 'right', 'bottom', or 'none'") if show_legend and legend_loc != "none": handles = [ mpatches.Patch(color=cmap_obj(i), label=str(cat))