Skip to content

Commit

Permalink
Refactor the _load_remote_dataset function to load tiled and non-tile…
Browse files Browse the repository at this point in the history
…d grids in a consistent way (#3120)
  • Loading branch information
seisman committed Apr 22, 2024
1 parent 44f44d3 commit 8b2a74c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
66 changes: 40 additions & 26 deletions pygmt/datasets/load_remote_dataset.py
Expand Up @@ -4,12 +4,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, NamedTuple
from typing import TYPE_CHECKING, ClassVar, Literal, NamedTuple

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import kwargs_to_strings
from pygmt.io import load_dataarray
from pygmt.src import grdcut, which
from pygmt.helpers import build_arg_list, kwargs_to_strings
from pygmt.src import which

if TYPE_CHECKING:
import xarray as xr
Expand Down Expand Up @@ -344,7 +344,7 @@ def _load_remote_dataset(
dataset_prefix: str,
resolution: str,
region: str | list,
registration: str,
registration: Literal["gridline", "pixel", None],
) -> xr.DataArray:
r"""
Load GMT remote datasets.
Expand All @@ -370,54 +370,68 @@ def _load_remote_dataset(
Returns
-------
grid : :class:`xarray.DataArray`
grid
The GMT remote dataset grid.
Note
----
The returned :class:`xarray.DataArray` doesn't support slice operation for tiled
grids.
The registration and coordinate system type of the returned
:class:`xarray.DataArray` grid can be accessed via the GMT accessors (i.e.,
``grid.gmt.registration`` and ``grid.gmt.gtype`` respectively). However, these
properties may be lost after specific grid operations (such as slicing) and will
need to be manually set before passing the grid to any PyGMT data processing or
plotting functions. Refer to :class:`pygmt.GMTDataArrayAccessor` for detailed
explanations and workarounds.
"""
dataset = datasets[dataset_name]

# Check resolution
if resolution not in dataset.resolutions:
raise GMTInvalidInput(
f"Invalid resolution '{resolution}' for {dataset.title} dataset. "
f"Available resolutions are: {', '.join(dataset.resolutions)}."
)
resinfo = dataset.resolutions[resolution]

# check registration
valid_registrations = dataset.resolutions[resolution].registrations
# Check registration
if registration is None:
# use gridline registration unless only pixel registration is available
registration = "gridline" if "gridline" in valid_registrations else "pixel"
# Use gridline registration unless only pixel registration is available
registration = "gridline" if "gridline" in resinfo.registrations else "pixel"
elif registration in ("pixel", "gridline"):
if registration not in valid_registrations:
if registration not in resinfo.registrations:
raise GMTInvalidInput(
f"{registration} registration is not available for the "
f"{resolution} {dataset.title} dataset. Only "
f"{valid_registrations[0]} registration is available."
f"{resinfo.registrations[0]} registration is available."
)
else:
raise GMTInvalidInput(
f"Invalid grid registration: '{registration}', should be either 'pixel', "
"'gridline' or None. Default is None, where a gridline-registered grid is "
"returned unless only the pixel-registered grid is available."
)
reg = f"_{registration[0]}"

# different ways to load tiled and non-tiled grids.
# Known issue: tiled grids don't support slice operation
# See https://github.com/GenericMappingTools/pygmt/issues/524
if region is None:
if dataset.resolutions[resolution].tiled:
raise GMTInvalidInput(
f"'region' is required for {dataset.title} resolution '{resolution}'."
fname = f"@{dataset_prefix}{resolution}_{registration[0]}"
if resinfo.tiled and region is None:
raise GMTInvalidInput(
f"'region' is required for {dataset.title} resolution '{resolution}'."
)

# Currently, only grids are supported. Will support images in the future.
kwdict = {"T": "g", "R": region} # region can be None
with Session() as lib:
with lib.virtualfile_out(kind="grid") as voutgrd:
lib.call_module(
module="read",
args=[fname, voutgrd, *build_arg_list(kwdict)],
)
fname = which(f"@{dataset_prefix}{resolution}{reg}", download="a")
grid = load_dataarray(fname, engine="netcdf4")
else:
grid = grdcut(f"@{dataset_prefix}{resolution}{reg}", region=region)
grid = lib.virtualfile_to_raster(outgrid=None, vfname=voutgrd)

# Full path to the grid if not tiled grids.
source = which(fname, download="a") if not resinfo.tiled else None
# Manually add source to xarray.DataArray encoding to make the GMT accessors work.
if source:
grid.encoding["source"] = source

# Add some metadata to the grid
grid.name = dataset.name
Expand Down
5 changes: 2 additions & 3 deletions pygmt/tests/test_accessor.py
Expand Up @@ -115,9 +115,8 @@ def test_accessor_grid_source_file_not_exist():
# Registration and gtype are correct
assert grid.gmt.registration == 1
assert grid.gmt.gtype == 1
# The source grid file is defined but doesn't exist
assert grid.encoding["source"].endswith(".nc")
assert not Path(grid.encoding["source"]).exists()
# The source grid file is undefined.
assert grid.encoding.get("source") is None

# For a sliced grid, fallback to default registration and gtype,
# because the source grid file doesn't exist.
Expand Down

0 comments on commit 8b2a74c

Please sign in to comment.