-
Notifications
You must be signed in to change notification settings - Fork 20
Open
Labels
array-typesbugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or request
Description
After upgrading to the latest xarray
version and installing flox
, I find that chunked pint
arrays break with the .resample()
method. I'm posting this here instead of pint_xarray
since it looks like from the traceback this is coming from flox
.
I imagine this has to do with the complexity of working with duck-arrays like pint_xarray
.
Possible related threads:
import xarray as xr
import pint_xarray
import flox
xr.__version__
>>> '2022.9.0'
pint_xarray.__version__
>>> '0.3'
flox.__version__
>>> '0.5.9'
time_ax = xr.cftime_range('2020-06-01 01:00:00', freq='H', periods=3)
ds = xr.DataArray(range(3), dims='time', coords={'time': time_ax})
# Simple case, no dask or pint
ds.resample(time="D").mean()
>>> <xarray.DataArray (time: 1)>
>>> array([1.])
>>> Coordinates:
>>> * time (time) object 2020-06-01 00:00:00
# Dask case
ds_chunked = ds.chunk({'time': 1})
ds_chunked.resample(time="D").mean().compute()
>>> <xarray.DataArray (time: 1)>
>>> array([1.])
>>> Coordinates:
>>> * time (time) object 2020-06-01 00:00:00
# Pint case
ds_pint = ds.pint.quantify('kelvin')
ds_pint.resample(time="D").mean()
>>> <xarray.DataArray (time: 1)>
>>> <Quantity([1.], 'kelvin')>
>>> Coordinates:
>>> * time (time) object 2020-06-01 00:00:00
# Pint with xarray chunk
ds_pint_chunk = ds_pint.chunk({'time': 1})
ds_pint_chunk.resample(time="D").mean().compute()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [25], in <cell line: 3>()
1 # Pint with xarray chunk
2 ds_pint_chunk = ds_pint.chunk({'time': 1})
----> 3 ds_pint_chunk.resample(time="D").mean().compute()
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataarray.py:1083, in DataArray.compute(self, **kwargs)
1064 """Manually trigger loading of this array's data from disk or a
1065 remote source into memory and return a new array. The original is
1066 left unaltered.
(...)
1080 dask.compute
1081 """
1082 new = self.copy(deep=False)
-> 1083 return new.load(**kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataarray.py:1057, in DataArray.load(self, **kwargs)
1039 def load(self: T_DataArray, **kwargs) -> T_DataArray:
1040 """Manually trigger loading of this array's data from disk or a
1041 remote source into memory and return this array.
1042
(...)
1055 dask.compute
1056 """
-> 1057 ds = self._to_temp_dataset().load(**kwargs)
1058 new = self._from_temp_dataset(ds)
1059 self._variable = new._variable
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataset.py:734, in Dataset.load(self, **kwargs)
731 import dask.array as da
733 # evaluate all the dask arrays simultaneously
--> 734 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
736 for k, data in zip(lazy_data, evaluated_data):
737 self.variables[k].data = data
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
597 keys.append(x.__dask_keys__())
598 postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
86 elif isinstance(pool, multiprocessing.pool.Pool):
87 pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
90 pool.submit,
91 pool._max_workers,
92 dsk,
93 keys,
94 cache=cache,
95 get_id=_thread_get_id,
96 pack_exception=pack_exception,
97 **kwargs,
98 )
100 # Cleanup pools associated to dead threads
101 with pools_lock:
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
509 _execute_task(task, data) # Re-execute locally
510 else:
--> 511 raise_exception(exc, tb)
512 res, worker_id = loads(res_info)
513 state["cache"][key] = res
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:319, in reraise(exc, tb)
317 if exc.__traceback__ is not tb:
318 raise exc.with_traceback(tb)
--> 319 raise exc
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
222 try:
223 task, data = loads(task_info)
--> 224 result = _execute_task(task, data)
225 id = get_id()
226 result = dumps((result, id))
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
988 if not len(args) == len(self.inkeys):
989 raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:149, in get(dsk, out, cache)
147 for key in toposort(dsk):
148 task = dsk[key]
--> 149 result = _execute_task(task, cache)
150 cache[key] = result
151 result = _execute_task(out, cache)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/toolz/functoolz.py:487, in Compose.__call__(self, *args, **kwargs)
486 def __call__(self, *args, **kwargs):
--> 487 ret = self.first(*args, **kwargs)
488 for f in self.funcs:
489 ret = f(ret)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/core.py:689, in chunk_reduce(array, by, func, expected_groups, axis, fill_value, dtype, reindex, engine, kwargs, sort)
687 result = reduction(group_idx, array, **kwargs)
688 else:
--> 689 result = generic_aggregate(
690 group_idx, array, axis=-1, engine=engine, func=reduction, **kwargs
691 ).astype(dt, copy=False)
692 if np.any(props.nanmask):
693 # remove NaN group label which should be last
694 result = result[..., :-1]
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/aggregations.py:49, in generic_aggregate(group_idx, array, engine, func, axis, size, fill_value, dtype, **kwargs)
44 else:
45 raise ValueError(
46 f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
47 )
---> 49 return method(
50 group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
51 )
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/aggregate_flox.py:33, in _np_grouped_op(group_idx, array, op, axis, size, fill_value, dtype, out)
26 out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
28 if (len(uniques) == size) and (uniques == np.arange(size)).all():
29 # The previous version of this if condition
30 # ((uniques[1:] - uniques[:-1]) == 1).all():
31 # does not work when group_idx is [1, 2] for e.g.
32 # This happens during binning
---> 33 op.reduceat(array, inv_idx, axis=axis, dtype=dtype, out=out)
34 else:
35 out[..., uniques] = op.reduceat(array, inv_idx, axis=axis, dtype=dtype)
TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(<ufunc 'add'>, 'reduceat', <Quantity([0], 'kelvin')>, array([0]), axis=-1, dtype=dtype('int64'), out=(array([0]),)): 'Quantity', 'ndarray', 'ndarray'
Metadata
Metadata
Assignees
Labels
array-typesbugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or request
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
dcherian commentedon Oct 6, 2022
OK that wont work but it should have not gone down this code path at all in xarray. But it looks like I only tested pure pint arrays not pint + dask:
https://github.com/pydata/xarray/blob/50ea159bfd0872635ebf4281e741f3c87f0bef6b/xarray/core/utils.py#L980
It'd be nice to add full pint support here but it'll be a bit of effort. Are you interested in working on it?
riley-brady commentedon Nov 4, 2022
@dcherian sorry for the delay here. I could work on this effort, but unfortunately only on weekends, so it might be a long process. I would appreciate some guidance if you have some time (either over chat here or a zoom call) on which parts of the code to target, since I haven't worked closely with the package.
My current solution is to
dequantify
, runresample()
or whichever other method this is happening on, and thenquantify
, which isn't ideal but works. The error message is not super clear, so I'm not sure that's a sustainable solution for the community as a whole.dcherian commentedon Nov 5, 2022
Thanks for offering to help @riley-brady
I think this is what we'll have to do since pint's support for ufuncs isn't great apparently.
lets strip array units if any right at the beginning, and reapply it at the end
flox/flox/core.py
Line 1641 in e3ea0e7
by
but I think that's OK for now? Alternatively you could again dequantify and then quantify after compute.getattr(numpy, agg.name)(Quantity([1, 1,], dtype=array.dtype, units=array.units)
So basically run the aggregation on a small problem to determine what the output units are (necessary forany
,all
,var
,arg*
for e.g.), and apply that at the end. This approach won't work for "custom aggregations" but we can deal with that later when we need to.Pint array: strip and reattach appropriate units