Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/plot_downhole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Downhole plotting
================

Example showing downhole line, categorical, and image plots.
"""

import matplotlib.pyplot as plt
import pandas as pd

from loopresources.drillhole.dhconfig import DhConfig
from loopresources.drillhole.drillhole_database import DrillholeDatabase

collar = pd.DataFrame(
{
DhConfig.holeid: ["DH001", "DH002", "DH003"],
DhConfig.x: [100.0, 200.0, 300.0],
DhConfig.y: [1000.0, 2000.0, 3000.0],
DhConfig.z: [50.0, 60.0, 70.0],
DhConfig.total_depth: [150.0, 200.0, 180.0],
}
)

survey = pd.DataFrame(
{
DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"],
DhConfig.depth: [0.0, 100.0, 0.0, 120.0, 0.0],
DhConfig.azimuth: [0.0, 0.0, 45.0, 45.0, 90.0],
DhConfig.dip: [-90.0, -90.0, -85.0, -80.0, -90.0],
}
)

db = DrillholeDatabase(collar, survey)

lithology = pd.DataFrame(
{
DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"],
DhConfig.sample_from: [0.0, 50.0, 0.0, 80.0, 0.0],
DhConfig.sample_to: [50.0, 150.0, 80.0, 200.0, 180.0],
"LITHO": ["Granite", "Schist", "Sandstone", "Shale", "Limestone"],
}
)

assays = pd.DataFrame(
{
DhConfig.holeid: ["DH001", "DH001", "DH002", "DH002", "DH003"],
DhConfig.sample_from: [0.0, 75.0, 0.0, 100.0, 0.0],
DhConfig.sample_to: [75.0, 150.0, 100.0, 200.0, 180.0],
"AU_ppm": [0.1, 2.5, 0.05, 1.2, 0.4],
}
)

db.add_interval_table("lithology", lithology)
db.add_interval_table("assays", assays)

# Line plot (numeric values)
db.plot_downhole("assays", "AU_ppm", kind="line", layout="grid", ncols=2)
plt.tight_layout()
plt.show()

# Categorical plot with shared legend on the right
db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2)
plt.tight_layout()
plt.show()

# Categorical plot with legend at bottom
db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2, legend_loc="bottom")
plt.tight_layout()
plt.show()

# Categorical plot without legend
db.plot_downhole("lithology", "LITHO", kind="categorical", step=2.0, layout="grid", ncols=2, show_legend=False)
plt.tight_layout()
plt.show()

# Create standalone legend
categories = lithology["LITHO"].unique()
DrillholeDatabase.create_categorical_legend(categories, cmap="tab20", title="Lithology")
plt.tight_layout()
plt.show()

# Image plot (numeric heatmap)
db.plot_downhole("assays", "AU_ppm", kind="image", step=2.0, layout="grid", ncols=2)
plt.tight_layout()
plt.show()
186 changes: 186 additions & 0 deletions loopresources/drillhole/drillhole.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,192 @@ def trace(self, step: float = 1.0) -> DrillHoleTrace:
"""
return DrillHoleTrace(self, interval=step)

def _downhole_depth_grid(
self, step: float, max_depth: Optional[float] = None
) -> np.ndarray:
if step <= 0:
raise ValueError("step must be > 0")
if max_depth is None:
max_depth = float(self.collar[DhConfig.total_depth].values[0])
if max_depth <= 0:
return np.array([], dtype=float)
return np.arange(0.0, max_depth + step, step)

def _sample_downhole_values(
self,
table_name: str,
column: str,
step: float,
Comment on lines +401 to +405
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new sampling helpers are core logic that the plotting APIs depend on, but there are no unit tests exercising their outputs for interval vs point tables (and categorical vs numeric data). Consider adding tests that assert the generated depth grid and sampled values (plus error cases like unknown tables / missing columns) to avoid relying on Matplotlib rendering in tests.

Copilot uses AI. Check for mistakes.
kind: str,
depth_grid: Optional[np.ndarray] = None,
):
if table_name in self.database.intervals:
table = self[table_name]
if table.empty:
return np.array([], dtype=float), np.array([], dtype=float)
if column not in table.columns:
raise KeyError(f"Column '{column}' not found in interval table '{table_name}'")
grid = depth_grid if depth_grid is not None else self._downhole_depth_grid(step)
from .resample import resample_interval

sampled = resample_interval(
pd.DataFrame({DhConfig.depth: grid}), table, [column], method="direct"
)
return sampled[DhConfig.depth].to_numpy(), sampled[column].to_numpy()

if table_name in self.database.points:
table = self[table_name]
if table.empty:
return np.array([], dtype=float), np.array([], dtype=float)
if column not in table.columns:
raise KeyError(f"Column '{column}' not found in point table '{table_name}'")
if kind == "line":
return table[DhConfig.depth].to_numpy(), table[column].to_numpy()

grid = depth_grid if depth_grid is not None else self._downhole_depth_grid(step)
values = np.array([None] * len(grid), dtype=object)
depths = table[DhConfig.depth].to_numpy()
for depth, value in zip(depths, table[column].to_numpy()):
if step <= 0:
continue
idx = int(np.round(depth / step))
if 0 <= idx < len(values):
values[idx] = value
return grid, values

raise KeyError(f"Table '{table_name}' not found in intervals or points")

def plot_downhole(
self,
table_name: str,
column: str,
kind: str = "line",
step: float = 1.0,
ax=None,
cmap: str = "tab20",
show_legend: bool = True,
**kwargs,
):
"""Plot a downhole variable as a line or categorical image.

Parameters
----------
table_name : str
Interval or point table name.
column : str
Column to plot.
kind : {"line", "categorical", "image"}
Plot style. Use "categorical" for discrete values and "image" for numeric heatmaps.
step : float, default 1.0
Sampling step (meters) for interval or categorical plots.
ax : matplotlib.axes.Axes, optional
Axes to plot on.
cmap : str, default "tab20"
Colormap name for categorical plots.
show_legend : bool, default True
Whether to show a legend.
**kwargs
Passed through to matplotlib plot functions.
"""
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

kind = kind.lower()
if kind not in {"line", "categorical", "image"}:
raise ValueError("kind must be 'line', 'categorical', or 'image'")

if ax is None:
_, ax = plt.subplots(figsize=(4, 8))

depths, values = self._sample_downhole_values(table_name, column, step, kind)
if len(depths) == 0:
return ax

if kind == "line":
series = pd.to_numeric(pd.Series(values), errors="coerce")
mask = ~np.isnan(series.to_numpy())
if not mask.any():
return ax
ax.plot(series[mask], np.asarray(depths)[mask], label=self.hole_id, **kwargs)
ax.set_xlabel(column)
ax.set_ylabel("Depth")
ax.set_title(f"{self.hole_id} {column}")
ax.invert_yaxis()
if show_legend:
ax.legend()
return ax

if kind == "image":
series = pd.to_numeric(pd.Series(values), errors="coerce")
if series.isna().all():
return ax
data = np.ma.masked_invalid(series.to_numpy())[:, None]
max_depth = float(np.nanmax(depths)) if len(depths) else 0.0
im = ax.imshow(
data,
aspect="auto",
interpolation="nearest",
origin="upper",
extent=(0.0, 1.0, max_depth, 0.0),
cmap=cmap,
)
ax.set_xticks([0.5])
ax.set_xticklabels([self.hole_id])
ax.set_xlabel("Hole")
ax.set_ylabel("Depth")
ax.set_title(f"{column}")
if show_legend:
ax.figure.colorbar(im, ax=ax, label=column)
return ax

depth_values = np.asarray(values, dtype=object)
category_values = pd.Series(depth_values)
categories = [c for c in category_values.unique() if pd.notna(c)]
if not categories:
return ax
category_to_code = {cat: idx for idx, cat in enumerate(categories)}

codes = np.full(len(depth_values), -1.0)
for idx, value in enumerate(depth_values):
if pd.notna(value):
codes[idx] = category_to_code[value]

masked = np.ma.masked_where(codes < 0, codes)
cmap_obj = plt.get_cmap(cmap, len(categories))
try:
cmap_obj = cmap_obj.copy()
except Exception:
pass
try:
cmap_obj.set_bad(color="lightgray")
except Exception:
pass

max_depth = float(np.nanmax(depths)) if len(depths) else 0.0
ax.imshow(
masked[:, None],
aspect="auto",
interpolation="nearest",
origin="upper",
extent=(0.0, 1.0, max_depth, 0.0),
cmap=cmap_obj,
vmin=0,
vmax=max(0, len(categories) - 1),
)
ax.set_xticks([0.5])
ax.set_xticklabels([self.hole_id])
ax.set_xlabel("Hole")
ax.set_ylabel("Depth")
ax.set_title(f"{column}")

if show_legend:
handles = [
mpatches.Patch(color=cmap_obj(i), label=str(cat))
for i, cat in enumerate(categories)
]
ax.legend(handles=handles, title=column, bbox_to_anchor=(1.02, 1), loc="upper left")
return ax

def find_implicit_function_intersection(
self, function: Callable[[ArrayLike], ArrayLike], step: float = 1.0, intersection_value : float = 0.0
) -> pd.DataFrame:
Expand Down
Loading