Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide numpy-like constructors #23

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 48 additions & 38 deletions datarray/datarray.py
Expand Up @@ -425,6 +425,34 @@ def _names_to_numbers(axes, ax_ids):
return proc_ids


def _init_axes (dest, source, labels):
# XXX if an entry of labels is a tuple, it is interpreted
# as a (label, ticks) tuple
if labels is None:
if hasattr(source,'axes'):
_set_axes(dest, source.axes)
return
labels = []
elif len(labels) > dest.ndim:
raise NamedAxisError('labels list should have length <= array ndim')

labels = list(labels) + [None]*(dest.ndim - len(labels))
axes = []
for i, label_spec in enumerate(labels):
if type(label_spec) == type(()):
if len(label_spec) != 2:
raise ValueError(
'if the label specification is a tuple, it must be ' \
'of the form (label, ticks)'
)
label, ticks = label_spec
else:
label = label_spec
ticks = None
axes.append(Axis(label, i, dest, ticks=ticks))

_set_axes(dest, axes)
_validate_axes(axes)

def _validate_axes(axes):
"""
Expand Down Expand Up @@ -565,51 +593,33 @@ def runs_op(*args, **kwargs):
runs_op.func_name = opname
runs_op.func_doc = super_op.__doc__
return runs_op



class DataArray(np.ndarray):
def array (data, labels=None, dtype=None, copy=True, order=None, ndmin=True):
# XXX accepting argument 'subok' does not make sense in here
res = np.array(data, dtype=dtype, copy=False, order=order,
ndmin=ndmin).view(type=DataArray)

_init_axes(res, data, labels)

if copy:
# TODO: still raises ValueError when 'resize' is called
new_res = res.copy()
del res
return new_res

return res

class DataArray(np.ndarray):
# XXX- we need to figure out where in the numpy C code .T is defined!
@property
def T(self):
return self.transpose()

def __new__(cls, data, labels=None, dtype=None, copy=False):
# XXX if an entry of labels is a tuple, it is interpreted
# as a (label, ticks) tuple
# Ensure the output is an array of the proper type
arr = np.array(data, dtype=dtype, copy=copy).view(cls)
if labels is None:
if hasattr(data,'axes'):
_set_axes(arr, data.axes)
return arr
labels = []
elif len(labels) > arr.ndim:
raise NamedAxisError('labels list should have length <= array ndim')

labels = list(labels) + [None]*(arr.ndim - len(labels))
axes = []
for i, label_spec in enumerate(labels):
if type(label_spec) == type(()):
if len(label_spec) != 2:
raise ValueError(
'if the label specification is a tuple, it must be ' \
'of the form (label, ticks)'
)
label, ticks = label_spec
else:
label = label_spec
ticks = None
axes.append(Axis(label, i, arr, ticks=ticks))

_set_axes(arr, axes)

# validate the axes
_validate_axes(axes)


return arr
def __new__(cls, shape, labels=None, dtype=float, buffer=None, offset=0, strides=None, order=None):
res = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=buffer,
offset=offset, strides=strides, order=order)
_init_axes(res, None, labels)
return res

@property
def aix(self):
Expand Down