Skip to content

Commit

Permalink
Field: add 'dim' property that's identical to model.field_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
MuellerSeb committed Feb 5, 2021
1 parent f305995 commit 9b55002
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 24 deletions.
12 changes: 8 additions & 4 deletions gstools/field/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,18 @@ def pre_pos(self, pos, mesh_type="unstructured"):
"""
# save mesh-type
self.mesh_type = mesh_type
dim = self.model.field_dim
# save pos tuple
if mesh_type != "unstructured":
pos, shape = format_struct_pos_dim(pos, dim)
pos, shape = format_struct_pos_dim(pos, self.dim)
self.pos = pos
pos = gen_mesh(pos)
else:
pos = np.array(pos, dtype=np.double).reshape(dim, -1)
pos = np.array(pos, dtype=np.double).reshape(self.dim, -1)
self.pos = pos
shape = np.shape(pos[0])
# prepend dimension if we have a vector field
if self.value_type == "vector":
shape = (self.model.dim,) + shape
shape = (self.dim,) + shape
if self.model.latlon:
raise ValueError("Field: Vector fields not allowed for latlon")
# return isometrized pos tuple and resulting field shape
Expand Down Expand Up @@ -352,6 +351,11 @@ def value_type(self, value_type):
raise ValueError("Field: value type not in {}".format(VALUE_TYPES))
self._value_type = value_type

@property
def dim(self):
""":class:`int`: Dimension of the field."""
return self.model.field_dim

@property
def name(self):
""":class:`str`: The name of the class."""
Expand Down
6 changes: 3 additions & 3 deletions gstools/field/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def plot_field(fld, field="field", fig=None, ax=None): # pragma: no cover
"""
plot_field = getattr(fld, field)
assert not (fld.pos is None or plot_field is None)
if fld.model.field_dim == 1:
if fld.dim == 1:
ax = _plot_1d(fld.pos, plot_field, fig, ax)
elif fld.model.field_dim == 2:
elif fld.dim == 2:
ax = _plot_2d(
fld.pos, plot_field, fld.mesh_type, fig, ax, fld.model.latlon
)
elif fld.model.field_dim == 3:
elif fld.dim == 3:
ax = _plot_3d(fld.pos, plot_field, fld.mesh_type, fig, ax)
else:
raise ValueError("Field.plot: only possible for dim=1,2,3!")
Expand Down
12 changes: 6 additions & 6 deletions gstools/field/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def mesh_call(
pass

if isinstance(direction, str) and direction == "all":
select = list(range(f_cls.model.field_dim))
select = list(range(f_cls.dim))
elif isinstance(direction, str):
select = _get_select(direction)[: f_cls.model.field_dim]
select = _get_select(direction)[: f_cls.dim]
else:
select = direction[: f_cls.model.field_dim]
if len(select) < f_cls.model.field_dim:
select = direction[: f_cls.dim]
if len(select) < f_cls.dim:
raise ValueError(
"Field.mesh: need at least {} direction(s), got '{}'".format(
f_cls.model.field_dim, direction
f_cls.dim, direction
)
)
# convert pyvista mesh
Expand Down Expand Up @@ -190,7 +190,7 @@ def mesh_call(
offset = []
length = []
mesh_dim = mesh.points.shape[1]
if mesh_dim < f_cls.model.field_dim:
if mesh_dim < f_cls.dim:
raise ValueError("Field.mesh: mesh dimension too low!")
pnts = np.empty((0, mesh_dim), dtype=np.double)
for cell in mesh.cells:
Expand Down
21 changes: 10 additions & 11 deletions gstools/krige/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,16 @@ def _get_dists(self, pos1, pos2=None, pos2_slice=(0, None)):
def get_mean(self, post_process=True):
"""Calculate the estimated mean of the detrended field.
Parameters
----------
post_process : :class:`bool`, optional
Whether to apply field-mean and normalizer.
Default: `True`
Returns
-------
mean : :class:`float` or :any:`None`
Mean of the Kriging System.
post_process : :class:`bool`, optional
Whether to apply field-mean and normalizer.
Default: `True`
Notes
-----
Expand Down Expand Up @@ -491,7 +494,7 @@ def set_condition(
raise ValueError("Krige.set_condition: missing cond_pos/cond_val.")
# correctly format cond_pos and cond_val
self._cond_pos, self._cond_val = set_condition(
cond_pos, cond_val, self.model.field_dim
cond_pos, cond_val, self.dim
)
if fit_normalizer: # fit normalizer to detrended data
self.normalizer.fit(self.cond_val - self.cond_trend)
Expand Down Expand Up @@ -544,7 +547,7 @@ def set_drift_functions(self, drift_functions=None):
self._drift_functions = []
elif isinstance(drift_functions, (str, int)):
self._drift_functions = get_drift_functions(
self.model.field_dim, drift_functions
self.dim, drift_functions
)
else:
if isinstance(drift_functions, collections.abc.Iterator):
Expand Down Expand Up @@ -623,16 +626,12 @@ def cond_ext_drift(self):
@property
def cond_mean(self):
""":class:`numpy.ndarray`: Trend at the conditions."""
return eval_func(
self.mean, self.cond_pos, self.model.field_dim, broadcast=True
)
return eval_func(self.mean, self.cond_pos, self.dim, broadcast=True)

@property
def cond_trend(self):
""":class:`numpy.ndarray`: Trend at the conditions."""
return eval_func(
self.trend, self.cond_pos, self.model.field_dim, broadcast=True
)
return eval_func(self.trend, self.cond_pos, self.dim, broadcast=True)

@property
def unbiased(self):
Expand Down

0 comments on commit 9b55002

Please sign in to comment.