-
Notifications
You must be signed in to change notification settings - Fork 15
/
patch.py
297 lines (255 loc) · 9.7 KB
/
patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
A 2D trace object.
"""
from __future__ import annotations
from typing import Callable, Mapping, Optional, Sequence
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from xarray import DataArray
import dascore.proc
from dascore.constants import PatchType
from dascore.core.schema import PatchAttrs
from dascore.io import PatchIO
from dascore.transform import TransformPatchNameSpace
from dascore.utils.coords import Coords, assign_coords
# from dascore.utils.mapping import FrozenDict
from dascore.utils.patch import _AttrsCoordsMixer
from dascore.viz import VizPatchNameSpace
class Patch:
"""
A Class for managing data and metadata.
Parameters
----------
data
An array-like containing data, an xarray DataArray object, or a Patch.
coords
The coordinates, or dimensional labels for the data. These can be
passed in three forms:
{coord_name: data}
{coord_name: ((dimensions,), data)}
{coord_name: (dimensions, data)}
dims
A sequence of dimension strings. The first entry cooresponds to the
first axis of data, the second to the second dimension, and so on.
attrs
Optional attributes (non-coordinate metadata) passed as a dict.
Notes
-----
Unless data is a DataArray or Patch, data, coords, and dims are required.
"""
data: ArrayLike
coords: Mapping[str, ArrayLike]
dims: tuple[str, ...]
attrs: PatchAttrs
def __init__(
self,
data: ArrayLike | DataArray | None = None,
coords: Mapping[str, ArrayLike] | None = None,
dims: Sequence[str] | None = None,
attrs: Optional[Mapping] = None,
):
if isinstance(data, (DataArray, self.__class__)):
dar = data if isinstance(data, DataArray) else data._data_array
self._data_array = dar
return
# Try to generate coords from ranges in attrs
if coords is None and attrs is not None:
coords = PatchAttrs(**dict(attrs)).coords_from_dims()
non_attrs = [x is None for x in [data, coords, dims]]
if any(non_attrs) and not all(non_attrs):
msg = "data, coords, and dims must be defined to init Patch."
raise ValueError(msg)
mixer = _AttrsCoordsMixer(attrs, coords, dims)
attrs, coords = mixer()
# get xarray coords from custom coords object
xr_coords = coords.to_nested_dict()
self._data_array = DataArray(
data=data, dims=dims, coords=xr_coords, attrs=attrs
)
def __eq__(self, other):
"""
Compare one Trace2D to another.
Parameters
----------
other
Returns
-------
"""
return self.equals(other)
def __str__(self):
xarray_str = str(self._data_array)
class_name = self.__class__.__name__
return xarray_str.replace("xarray.DataArray", f"dascore.{class_name}")
__repr__ = __str__
def equals(self, other: PatchType, only_required_attrs=True) -> bool:
"""
Determine if the current trace equals the other trace.
Parameters
----------
other
A Trace2D object
only_required_attrs
If True, only compare required attributes. This helps avoid issues
with comparing histories or custom attrs of patches, for example.
"""
if only_required_attrs:
attrs_to_compare = set(PatchAttrs.get_defaults()) - {"history"}
attrs1 = {x: self.attrs.get(x, None) for x in attrs_to_compare}
attrs2 = {x: other.attrs.get(x, None) for x in attrs_to_compare}
else:
attrs1, attrs2 = dict(self.attrs), dict(other.attrs)
if set(attrs1) != set(attrs2): # attrs don't have same keys; not equal
return False
if attrs1 != attrs2:
# see if some values are NaNs, these should be counted equal
not_equal = {
x
for x in attrs1
if attrs1[x] != attrs2[x]
and not (pd.isnull(attrs1[x]) and pd.isnull(attrs2[x]))
}
if not_equal:
return False
# check coords, names and values
coord1 = {x: self.coords[x] for x in self.coords}
coord2 = {x: other.coords[x] for x in other.coords}
if not set(coord2) == set(coord1):
return False
for name in coord1:
if not np.all(coord1[name] == coord2[name]):
return False
# handle transposed case; patches that are identical but transposed
# should still be equal.
if self.dims != other.dims and set(self.dims) == set(other.dims):
other = other.transpose(*self.dims)
return np.equal(self.data, other.data).all()
def new(
self: PatchType,
data: None | ArrayLike = None,
coords: None | dict[str | Sequence[str], ArrayLike] = None,
dims: None | Sequence[str] = None,
attrs: None | Mapping = None,
) -> PatchType:
"""
Return a copy of the Patch with updated data, coords, dims, or attrs.
Parameters
----------
data
An array-like containing data, an xarray DataArray object, or a Patch.
coords
The coordinates, or dimensional labels for the data. These can be
passed in three forms:
{coord_name: data}
{coord_name: ((dimensions,), data)}
{coord_name: (dimensions, data)}
dims
A sequence of dimension strings. The first entry cooresponds to the
first axis of data, the second to the second dimension, and so on.
attrs
Optional attributes (non-coordinate metadata) passed as a dict.
"""
data = data if data is not None else self.data
attrs = attrs if attrs is not None else self.attrs
if coords is None:
coords = getattr(self.coords, "_coords", self.coords)
dims = self.dims
else:
dims = dims or list(coords)
return self.__class__(data=data, coords=coords, attrs=attrs, dims=dims)
def update_attrs(self: PatchType, **attrs) -> PatchType:
"""
Update attrs and return a new Patch.
Parameters
----------
**attrs
attrs to add/update.
"""
dar = self._data_array
mixer = _AttrsCoordsMixer(dar.attrs, dar.coords, dar.dims)
mixer.update_attrs(**attrs)
attrs, coords = mixer()
return self.__class__(self.data, coords=coords, attrs=attrs, dims=self.dims)
@property
def data(self):
"""Return the data array."""
return self._data_array.data
@property
def coords(self):
"""Return a dict of coordinate data {coord_name: data}"""
return Coords(self._data_array.coords, dims=self.dims).array_dict
@property
def coord_dims(self):
"""Return a dict of coordinate dimensions {coord_name: (**dims)}"""
return Coords(self._data_array.coords, dims=self.dims).dims_dict
@property
def dims(self) -> tuple[str, ...]:
"""Return the dimensions contained in patch."""
return self._data_array.dims
@property
def attrs(self) -> PatchAttrs:
"""Return the attributes of the trace."""
return PatchAttrs(**self._data_array.attrs)
@property
def shape(self) -> tuple[int, ...]:
"""Return the shape of the data array."""
return self._data_array.shape
def to_xarray(self):
"""
Return a data array with patch contents.
"""
# Note this is here in case we decide to remove xarray there will
# still be a way to get a DataArray object with an optional import
return self._data_array
squeeze = dascore.proc.squeeze
rename = dascore.proc.rename
transpose = dascore.proc.transpose
# --- processing funcs
select = dascore.proc.select
decimate = dascore.proc.decimate
detrend = dascore.proc.detrend
pass_filter = dascore.proc.pass_filter
sobel_filter = dascore.proc.sobel_filter
median_filter = dascore.proc.median_filter
aggregate = dascore.proc.aggregate
abs = dascore.proc.abs
resample = dascore.proc.resample
iresample = dascore.proc.iresample
interpolate = dascore.proc.interpolate
normalize = dascore.proc.normalize
standardize = dascore.proc.standardize
taper = dascore.proc.taper
# --- Method Namespaces
# Note: these can't be cached_property (from functools) or references
# to self stick around and keep large arrays in memory.
@property
def viz(self) -> VizPatchNameSpace:
"""The visualization namespace."""
return VizPatchNameSpace(self)
@property
def tran(self) -> TransformPatchNameSpace:
"""The transformation namespace."""
return TransformPatchNameSpace(self)
@property
def io(self) -> PatchIO:
"""Return a patch IO object for saving patches to various formats."""
return PatchIO(self)
def pipe(self, func: Callable[["Patch", ...], "Patch"], *args, **kwargs) -> "Patch":
"""
Pipe the patch to a function.
This is primarily useful for maintaining a chain of patch calls for
a function.
Parameters
----------
func
The function to pipe the patch. It must take a patch instance as
the first argument followed by any number of positional or keyword
arguments, then return a patch.
*args
Positional arguments that get passed to func.
**kwargs
Keyword arguments passed to func.
"""
return func(self, *args, **kwargs)
# Bind assign_coords as method
assign_coords = assign_coords