diff --git a/examples/plot_downhole.py b/examples/plot_downhole.py new file mode 100644 index 0000000..0c5a6f9 --- /dev/null +++ b/examples/plot_downhole.py @@ -0,0 +1,85 @@ +""" +Downhole plotting +================ + +Example showing downhole line, categorical, and image plots. +""" + +import matplotlib.pyplot as plt +import pandas as pd + +from loopresources.drillhole.dhconfig import DhConfig +from loopresources.drillhole.drillhole_database import DrillholeDatabase + +collar = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH002", "DH003"], + DhConfig.x: [100.0, 200.0, 300.0], + DhConfig.y: [1000.0, 2000.0, 3000.0], + DhConfig.z: [50.0, 60.0, 70.0], + DhConfig.total_depth: [150.0, 200.0, 180.0], + } +) + +survey = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"], + DhConfig.depth: [0.0, 100.0, 0.0, 120.0, 0.0], + DhConfig.azimuth: [0.0, 0.0, 45.0, 45.0, 90.0], + DhConfig.dip: [-90.0, -90.0, -85.0, -80.0, -90.0], + } +) + +db = DrillholeDatabase(collar, survey) + +lithology = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"], + DhConfig.sample_from: [0.0, 50.0, 0.0, 80.0, 0.0], + DhConfig.sample_to: [50.0, 150.0, 80.0, 200.0, 180.0], + "LITHO": ["Granite", "Schist", "Sandstone", "Shale", "Limestone"], + } +) + +assays = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"], + DhConfig.sample_from: [0.0, 75.0, 0.0, 100.0, 0.0], + DhConfig.sample_to: [75.0, 150.0, 100.0, 200.0, 180.0], + "AU_ppm": [0.1, 2.5, 0.05, 1.2, 0.4], + } +) + +db.add_interval_table("lithology", lithology) +db.add_interval_table("assays", assays) + +# Line plot (numeric values) +db.plot_downhole("assays", "AU_ppm", kind="line", layout="grid", ncols=2) +plt.tight_layout() +plt.show() + +# Categorical plot with shared legend on the right +db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2) +plt.tight_layout() +plt.show() + +# Categorical plot with legend at bottom +db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2, legend_loc="bottom") +plt.tight_layout() +plt.show() + +# Categorical plot without legend +db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2, show_legend=False) +plt.tight_layout() +plt.show() + +# Create standalone legend +categories = lithology["LITHO"].unique() +DrillholeDatabase.create_categorical_legend(categories, cmap="tab20", title="Lithology") +plt.tight_layout() +plt.show() + +# Image plot (numeric heatmap) +db.plot_downhole("assays", "AU_ppm", kind="image", step=2.0, layout="grid", ncols=2) +plt.tight_layout() +plt.show() diff --git a/loopresources/drillhole/drillhole.py b/loopresources/drillhole/drillhole.py index 53846b4..90d29f2 100644 --- a/loopresources/drillhole/drillhole.py +++ b/loopresources/drillhole/drillhole.py @@ -387,6 +387,192 @@ def trace(self, step: float = 1.0) -> DrillHoleTrace: """ return DrillHoleTrace(self, interval=step) + def _downhole_depth_grid( + self, step: float, max_depth: Optional[float] = None + ) -> np.ndarray: + if step <= 0: + raise ValueError("step must be > 0") + if max_depth is None: + max_depth = float(self.collar[DhConfig.total_depth].values[0]) + if max_depth <= 0: + return np.array([], dtype=float) + return np.arange(0.0, max_depth + step, step) + + def _sample_downhole_values( + self, + table_name: str, + column: str, + step: float, + kind: str, + depth_grid: Optional[np.ndarray] = None, + ): + if table_name in self.database.intervals: + table = self[table_name] + if table.empty: + return np.array([], dtype=float), np.array([], dtype=float) + if column not in table.columns: + raise KeyError(f"Column '{column}' not found in interval table '{table_name}'") + grid = depth_grid if depth_grid is not None else self._downhole_depth_grid(step) + from .resample import resample_interval + + sampled = resample_interval( + pd.DataFrame({DhConfig.depth: grid}), table, [column], method="direct" + ) + return sampled[DhConfig.depth].to_numpy(), sampled[column].to_numpy() + + if table_name in self.database.points: + table = self[table_name] + if table.empty: + return np.array([], dtype=float), np.array([], dtype=float) + if column not in table.columns: + raise KeyError(f"Column '{column}' not found in point table '{table_name}'") + if kind == "line": + return table[DhConfig.depth].to_numpy(), table[column].to_numpy() + + grid = depth_grid if depth_grid is not None else self._downhole_depth_grid(step) + values = np.array([None] * len(grid), dtype=object) + depths = table[DhConfig.depth].to_numpy() + for depth, value in zip(depths, table[column].to_numpy()): + if step <= 0: + continue + idx = int(np.round(depth / step)) + if 0 <= idx < len(values): + values[idx] = value + return grid, values + + raise KeyError(f"Table '{table_name}' not found in intervals or points") + + def plot_downhole( + self, + table_name: str, + column: str, + kind: str = "line", + step: float = 1.0, + ax=None, + cmap: str = "tab20", + show_legend: bool = True, + **kwargs, + ): + """Plot a downhole variable as a line or categorical image. + + Parameters + ---------- + table_name : str + Interval or point table name. + column : str + Column to plot. + kind : {"line", "categorical", "image"} + Plot style. Use "categorical" for discrete values and "image" for numeric heatmaps. + step : float, default 1.0 + Sampling step (meters) for interval or categorical plots. + ax : matplotlib.axes.Axes, optional + Axes to plot on. + cmap : str, default "tab20" + Colormap name for categorical plots. + show_legend : bool, default True + Whether to show a legend. + **kwargs + Passed through to matplotlib plot functions. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + + kind = kind.lower() + if kind not in {"line", "categorical", "image"}: + raise ValueError("kind must be 'line', 'categorical', or 'image'") + + if ax is None: + _, ax = plt.subplots(figsize=(4, 8)) + + depths, values = self._sample_downhole_values(table_name, column, step, kind) + if len(depths) == 0: + return ax + + if kind == "line": + series = pd.to_numeric(pd.Series(values), errors="coerce") + mask = ~np.isnan(series.to_numpy()) + if not mask.any(): + return ax + ax.plot(series[mask], np.asarray(depths)[mask], label=self.hole_id, **kwargs) + ax.set_xlabel(column) + ax.set_ylabel("Depth") + ax.set_title(f"{self.hole_id} {column}") + ax.invert_yaxis() + if show_legend: + ax.legend() + return ax + + if kind == "image": + series = pd.to_numeric(pd.Series(values), errors="coerce") + if series.isna().all(): + return ax + data = np.ma.masked_invalid(series.to_numpy())[:, None] + max_depth = float(np.nanmax(depths)) if len(depths) else 0.0 + im = ax.imshow( + data, + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap, + ) + ax.set_xticks([0.5]) + ax.set_xticklabels([self.hole_id]) + ax.set_xlabel("Hole") + ax.set_ylabel("Depth") + ax.set_title(f"{column}") + if show_legend: + ax.figure.colorbar(im, ax=ax, label=column) + return ax + + depth_values = np.asarray(values, dtype=object) + category_values = pd.Series(depth_values) + categories = [c for c in category_values.unique() if pd.notna(c)] + if not categories: + return ax + category_to_code = {cat: idx for idx, cat in enumerate(categories)} + + codes = np.full(len(depth_values), -1.0) + for idx, value in enumerate(depth_values): + if pd.notna(value): + codes[idx] = category_to_code[value] + + masked = np.ma.masked_where(codes < 0, codes) + cmap_obj = plt.get_cmap(cmap, len(categories)) + try: + cmap_obj = cmap_obj.copy() + except Exception: + pass + try: + cmap_obj.set_bad(color="lightgray") + except Exception: + pass + + max_depth = float(np.nanmax(depths)) if len(depths) else 0.0 + ax.imshow( + masked[:, None], + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap_obj, + vmin=0, + vmax=max(0, len(categories) - 1), + ) + ax.set_xticks([0.5]) + ax.set_xticklabels([self.hole_id]) + ax.set_xlabel("Hole") + ax.set_ylabel("Depth") + ax.set_title(f"{column}") + + if show_legend: + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + ax.legend(handles=handles, title=column, bbox_to_anchor=(1.02, 1), loc="upper left") + return ax + def find_implicit_function_intersection( self, function: Callable[[ArrayLike], ArrayLike], step: float = 1.0, intersection_value : float = 0.0 ) -> pd.DataFrame: diff --git a/loopresources/drillhole/drillhole_database.py b/loopresources/drillhole/drillhole_database.py index ffc5693..11c4a40 100644 --- a/loopresources/drillhole/drillhole_database.py +++ b/loopresources/drillhole/drillhole_database.py @@ -142,6 +142,288 @@ def plot_collars(self, ax=None, **kwargs): ax.text(x, y, hole_id, fontsize=9, ha="right", va="bottom") return ax + def plot_downhole( + self, + table_name: str, + column: str, + holes: Optional[List[str]] = None, + kind: str = "line", + step: float = 1.0, + ax=None, + layout: str = "grid", + ncols: int = 3, + cmap: str = "tab20", + show_legend: bool = True, + legend_loc: str = "right", + **kwargs, + ): + """Plot a downhole variable for one or more drillholes. + + Parameters + ---------- + table_name : str + Interval or point table name. + column : str + Column to plot. + holes : list of str, optional + Hole IDs to include. Defaults to all holes in the collar table. + kind : {"line", "categorical", "image"} + Plot style. Use "categorical" for discrete values and "image" for numeric heatmaps. + step : float, default 1.0 + Sampling step (meters) for interval or categorical plots. + ax : matplotlib.axes.Axes, optional + Axes to plot on. + layout : {"grid", "column"}, default "grid" + Layout for multiple holes when ax is None. + ncols : int, default 3 + Number of columns when layout is "grid". + cmap : str, default "tab20" + Colormap name for categorical plots. + show_legend : bool, default True + Whether to show a legend. + legend_loc : {"right", "bottom", "none"}, default "right" + Location for the shared categorical legend. Use "none" for no legend. + **kwargs + Passed through to matplotlib plot functions. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + + kind = kind.lower() + if kind not in {"line", "categorical", "image"}: + raise ValueError("kind must be 'line', 'categorical', or 'image'") + if holes is None: + holes = list(self.collar[DhConfig.holeid].unique()) + else: + holes = list(holes) + + collar_holes = set(self.collar[DhConfig.holeid].unique()) + collar_holes = set(self.collar[DhConfig.holeid].unique()) + valid_holes = [hole_id for hole_id in holes if hole_id in collar_holes] + if not valid_holes: + unknown_holes = sorted({hole_id for hole_id in holes if hole_id not in collar_holes}) + raise ValueError( + "None of the requested holes are present in the collar table. " + f"Unknown hole IDs: {unknown_holes}" + ) + + layout = layout.lower() + if layout not in {"grid", "column"}: + raise ValueError("layout must be 'grid' or 'column'") + + axes = None + if ax is None: + if len(valid_holes) == 1: + fig, axes = plt.subplots(figsize=(6, 4)) + ax = axes + elif layout == "column": + fig, axes = plt.subplots( + nrows=len(valid_holes), ncols=1, figsize=(6, 3 * len(valid_holes)), sharex=True + ) + ax = axes + else: + ncols = max(1, int(ncols)) + nrows = int(np.ceil(len(valid_holes) / ncols)) + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=(4 * ncols, 3 * nrows), + sharex=True, + sharey=True, + ) + ax = axes + elif isinstance(ax, (list, tuple, np.ndarray)): + ax = np.array(ax) + if ax.size < len(valid_holes): + raise ValueError("ax must have at least one axis per hole") + axes = ax + else: + if len(valid_holes) > 1: + raise ValueError( + "Multiple holes require one axis per hole. Pass ax as a list/array or omit ax." + ) + + axes_array = np.array(ax).ravel() if isinstance(ax, np.ndarray) else np.array([ax]) + axes_list = axes_array[: len(valid_holes)] + if axes is not None and isinstance(axes, np.ndarray) and axes_array.size > len(valid_holes): + for extra_ax in axes_array[len(valid_holes) :]: + extra_ax.set_visible(False) + + if kind == "line": + for hole_id, hole_ax in zip(valid_holes, axes_list): + hole = DrillHole(self, hole_id) + depths, values = hole._sample_downhole_values( + table_name, column, step, kind + ) + if len(depths) == 0: + continue + series = pd.to_numeric(pd.Series(values), errors="coerce") + mask = ~np.isnan(series.to_numpy()) + if not mask.any(): + continue + hole_ax.plot(series[mask], np.asarray(depths)[mask], label=hole_id, **kwargs) + hole_ax.set_xlabel(column) + hole_ax.set_ylabel("Depth") + hole_ax.set_title(hole_id) + hole_ax.invert_yaxis() + if show_legend: + hole_ax.legend() + return axes if axes is not None else ax + + if kind == "image": + if step <= 0: + raise ValueError("step must be > 0") + for hole_id, hole_ax in zip(valid_holes, axes_list): + hole = DrillHole(self, hole_id) + max_depth = float(hole.collar[DhConfig.total_depth].values[0]) + if max_depth <= 0: + continue + depth_grid = np.arange(0.0, max_depth + step, step) + _, values = hole._sample_downhole_values( + table_name, column, step, kind, depth_grid=depth_grid + ) + if len(values) == 0: + continue + series = pd.to_numeric(pd.Series(values), errors="coerce") + if series.isna().all(): + continue + data = np.ma.masked_invalid(series.to_numpy())[:, None] + im = hole_ax.imshow( + data, + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap, + ) + hole_ax.set_xticks([0.5]) + hole_ax.set_xticklabels([hole_id]) + hole_ax.set_xlabel("Hole") + hole_ax.set_ylabel("Depth") + hole_ax.set_title(hole_id) + if show_legend: + hole_ax.figure.colorbar(im, ax=hole_ax, label=column) + return axes if axes is not None else ax + + if step <= 0: + raise ValueError("step must be > 0") + + sampled_values = [] + all_values = [] + for hole_id in valid_holes: + hole = DrillHole(self, hole_id) + max_depth = float(hole.collar[DhConfig.total_depth].values[0]) + if max_depth <= 0: + sampled_values.append((hole_id, max_depth, np.array([]))) + continue + depth_grid = np.arange(0.0, max_depth + step, step) + _, values = hole._sample_downhole_values( + table_name, column, step, kind, depth_grid=depth_grid + ) + sampled_values.append((hole_id, max_depth, values)) + if len(values) > 0: + all_values.extend([v for v in values if pd.notna(v)]) + + categories = [c for c in pd.Series(all_values).unique() if pd.notna(c)] + if not categories: + return axes if axes is not None else ax + category_to_code = {cat: idx for idx, cat in enumerate(categories)} + + cmap_obj = plt.get_cmap(cmap, len(categories)) + try: + cmap_obj = cmap_obj.copy() + except Exception: + pass + try: + cmap_obj.set_bad(color="lightgray") + except Exception: + pass + + for (hole_id, max_depth, values), hole_ax in zip(sampled_values, axes_list): + if len(values) == 0: + continue + + codes = np.full(len(values), -1.0) + for row_idx, value in enumerate(values): + if pd.notna(value): + codes[row_idx] = category_to_code[value] + + masked = np.ma.masked_where(codes < 0, codes) + hole_ax.imshow( + masked[:, None], + aspect="auto", + interpolation="nearest", + origin="upper", + extent=(0.0, 1.0, max_depth, 0.0), + cmap=cmap_obj, + vmin=0, + vmax=max(0, len(categories) - 1), + ) + hole_ax.set_xticks([0.5]) + hole_ax.set_xticklabels([hole_id]) + hole_ax.set_xlabel("Hole") + hole_ax.set_ylabel("Depth") + hole_ax.set_title(hole_id) + legend_loc = legend_loc.lower() + if legend_loc not in {"right", "bottom", "none"}: + raise ValueError("legend_loc must be 'right', 'bottom', or 'none'") + if show_legend and legend_loc != "none": + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + if axes is not None and hasattr(axes_list[0], 'figure'): + fig = axes_list[0].figure + if legend_loc == "right": + fig.legend(handles=handles, title=column, loc="center left", bbox_to_anchor=(1.0, 0.5)) + elif legend_loc == "bottom": + fig.legend(handles=handles, title=column, loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=min(len(categories), 5)) + return axes if axes is not None else ax + + @staticmethod + def create_categorical_legend( + categories: List[str], + cmap: str = "tab20", + title: str = "Categories", + ax=None, + **kwargs, + ): + """Create a standalone categorical legend. + + Parameters + ---------- + categories : list of str + List of category names. + cmap : str, default "tab20" + Colormap name. + title : str, default "Categories" + Legend title. + ax : matplotlib.axes.Axes, optional + Axes to add legend to. If None, creates a new figure. + **kwargs + Additional keyword arguments passed to ax.legend(). + + Returns + ------- + matplotlib.legend.Legend + The legend object. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + + cmap_obj = plt.get_cmap(cmap, len(categories)) + handles = [ + mpatches.Patch(color=cmap_obj(i), label=str(cat)) + for i, cat in enumerate(categories) + ] + + if ax is None: + fig, ax = plt.subplots(figsize=(3, max(2, len(categories) * 0.3))) + ax.axis("off") + + legend = ax.legend(handles=handles, title=title, loc="center", **kwargs) + return legend + def get_collar_for_hole(self, hole_id: str) -> pd.DataFrame: """Get collar data for a specific hole. diff --git a/tests/test_drillhole.py b/tests/test_drillhole.py index e69de29..84f8c4c 100644 --- a/tests/test_drillhole.py +++ b/tests/test_drillhole.py @@ -0,0 +1,129 @@ +"""Tests for DrillHole sampling and plotting helpers.""" + +import numpy as np +import pandas as pd +import pytest + +import matplotlib + +matplotlib.use("Agg") + +from loopresources.drillhole.drillhole_database import DrillholeDatabase +from loopresources.drillhole.dhconfig import DhConfig + + +@pytest.fixture +def sample_db(): + collar = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH002"], + DhConfig.x: [100.0, 200.0], + DhConfig.y: [1000.0, 2000.0], + DhConfig.z: [50.0, 60.0], + DhConfig.total_depth: [10.0, 8.0], + } + ) + + survey = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH002"], + DhConfig.depth: [0.0, 0.0], + DhConfig.azimuth: [0.0, 0.0], + DhConfig.dip: [90.0, 90.0], + } + ) + + intervals = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001"], + DhConfig.sample_from: [0.0, 5.0], + DhConfig.sample_to: [5.0, 10.0], + "LITHO": ["A", "B"], + "GRADE": [1.0, 2.0], + } + ) + + points = pd.DataFrame( + { + DhConfig.holeid: ["DH001", "DH001", "DH001"], + DhConfig.depth: [1.0, 4.0, 9.0], + "AU_PPM": [100.0, 200.0, 300.0], + } + ) + + db = DrillholeDatabase(collar, survey) + db.add_interval_table("geology", intervals) + db.add_point_table("assay", points) + return db + + +def test_downhole_depth_grid(sample_db): + hole = sample_db["DH001"] + grid = hole._downhole_depth_grid(step=2.5) + np.testing.assert_allclose(grid, np.array([0.0, 2.5, 5.0, 7.5, 10.0])) + + empty = hole._downhole_depth_grid(step=1.0, max_depth=0.0) + assert empty.size == 0 + + with pytest.raises(ValueError, match="step must be > 0"): + hole._downhole_depth_grid(step=0.0) + + +def test_sample_downhole_values_interval_categorical(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("geology", "LITHO", 2.5, "categorical") + np.testing.assert_allclose(depths, np.array([0.0, 2.5, 5.0, 7.5, 10.0])) + assert values.tolist() == ["A", "A", "A", "B", "B"] + + +def test_sample_downhole_values_interval_numeric(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("geology", "GRADE", 2.5, "image") + np.testing.assert_allclose(depths, np.array([0.0, 2.5, 5.0, 7.5, 10.0])) + np.testing.assert_allclose(values.astype(float), np.array([1.0, 1.0, 1.0, 2.0, 2.0])) + + +def test_sample_downhole_values_point_line(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("assay", "AU_PPM", 1.0, "line") + np.testing.assert_allclose(depths, np.array([1.0, 4.0, 9.0])) + np.testing.assert_allclose(values, np.array([100.0, 200.0, 300.0])) + + +def test_sample_downhole_values_point_categorical_grid(sample_db): + hole = sample_db["DH001"] + depths, values = hole._sample_downhole_values("assay", "AU_PPM", 2.0, "categorical") + np.testing.assert_allclose(depths, np.array([0.0, 2.0, 4.0, 6.0, 8.0, 10.0])) + assert values.tolist() == [100.0, None, 200.0, None, 300.0, None] + + +def test_sample_downhole_values_missing_table_column(sample_db): + hole = sample_db["DH001"] + with pytest.raises(KeyError, match="Table 'missing' not found"): + hole._sample_downhole_values("missing", "LITHO", 1.0, "line") + + with pytest.raises(KeyError, match="Column 'NOPE' not found in interval table"): + hole._sample_downhole_values("geology", "NOPE", 1.0, "line") + + with pytest.raises(KeyError, match="Column 'NOPE' not found in point table"): + hole._sample_downhole_values("assay", "NOPE", 1.0, "line") + + +def test_plot_downhole_multi_hole_layout_errors(sample_db): + with pytest.raises(ValueError, match="ax must have at least one axis per hole"): + sample_db.plot_downhole("geology", "LITHO", holes=["DH001", "DH002"], ax=[]) + + with pytest.raises(ValueError, match="Unknown hole IDs"): + sample_db.plot_downhole("geology", "LITHO", holes=["DH999"]) + + with pytest.raises(ValueError, match="layout must be 'grid' or 'column'"): + sample_db.plot_downhole("geology", "LITHO", layout="row") + + with pytest.raises(ValueError, match="kind must be 'line', 'categorical', or 'image'"): + sample_db.plot_downhole("geology", "LITHO", kind="scatter") + + +def test_plot_downhole_multi_hole_grid_returns_axes(sample_db): + axes = sample_db.plot_downhole("geology", "LITHO", holes=["DH001", "DH002"], layout="grid") + assert isinstance(axes, np.ndarray) + assert axes.size >= 2