Skip to content

Commit

Permalink
Merge pull request #421 from yarikoptic/nf-savez
Browse files Browse the repository at this point in the history
NF: to/from_npz for Datasets
  • Loading branch information
mih committed Jan 27, 2016
2 parents 1e727b9 + 0f6f9ca commit 8ee99be
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
69 changes: 69 additions & 0 deletions mvpa2/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__docformat__ = 'restructuredtext'

from os.path import lexists
import numpy as np
import copy

Expand Down Expand Up @@ -603,6 +604,74 @@ def from_hdf5(cls, source, name=None):
hdf.close()
return res

def to_npz(self, filename, compress=True):
"""Save dataset to a .npz file storing all fa/sa/a which are ndarrays
Parameters
----------
filename : str
compress : bool, optional
If True, savez_compressed is used
"""
savez = np.savez_compressed if compress else np.savez
if not filename.endswith('.npz'):
filename += '.npz'
entries = {'samples': self.samples}
skipped = []
for c in ('a', 'fa', 'sa'):
col = getattr(self, c)
for k in col:
v = col[k].value
e = '%s.%s' % (c, k)
if isinstance(v, np.ndarray):
entries[e] = v
else:
skipped.append(e)
if skipped:
warning("Skipping %s since not ndarrays" % (', '.join(skipped)))
return savez(filename, **entries)

@classmethod
def from_npz(cls, filename):
"""Load dataset from NumPy's .npz file, as e.g. stored by to_npz
File expected to have 'samples' item, which serves as samples, and
other items prefixed with the corresponding collection (e.g. 'sa.' or
'fa.'). All other entries are skipped
Parameters
----------
filename: str
Filename for the .npz file. Can be specified without .npz suffix
"""
# some sugaring
filename_npz = filename + '.npz'
if not lexists(filename) and not filename.endswith('.npz') and lexists(filename_npz):
filename = filename_npz

entries = np.load(filename)

cols = {'a': {}, 'fa': {}, 'sa': {}}
skipped = []
for e, v in entries.items():
if e == 'samples':
samples = v
else:
if '.' not in e:
skipped.append(e)
continue
c, k = e.split('.', 1)
if c not in cols:
skipped.append(e)
continue
cols[c][k] = v
if skipped:
warning("Skipped following items since do not belong to any of "
"known collections: %s" % (", ".join(sorted(skipped))))
return cls(samples, **cols)


# shortcut properties
nsamples = property(fget=len)
nfeatures = property(fget=lambda self: self.shape[1])
Expand Down
26 changes: 26 additions & 0 deletions mvpa2/tests/test_datasetng.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,9 @@ def test_h5py_io(dsfile):

# reload and check for identity
ds2 = Dataset.from_hdf5(dsfile)

assert_datasets_equal(ds, ds2)
# Old tests -- better more than less ;)
assert_array_equal(ds.samples, ds2.samples)
for attr in ds.sa:
assert_array_equal(ds.sa[attr].value, ds2.sa[attr].value)
Expand All @@ -999,6 +1002,29 @@ def test_h5py_io(dsfile):
pass


@nodebug(['ID_IN_REPR', 'MODULE_IN_REPR'])
@with_tempfile(suffix='.npz')
def test_npz_io(dsfile):

# store random dataset to file
ds = datasets['3dlarge'].copy()

ds.a.pop('mapper') # can't be saved
ds.to_npz(dsfile)

# reload and check for identity
ds2 = Dataset.from_npz(dsfile)
assert_datasets_equal(ds, ds2)

assert_array_equal(ds.samples, ds2.samples)

# But if we try to save with mapper -- it just gets ignored (warning is
# issued)
datasets['3dlarge'].to_npz(dsfile)
ds2_ = Dataset.from_npz(dsfile)
assert_datasets_equal(ds2, ds2_)


def test_all_equal():
# all these values are supposed to be different from each other
# but equal to themselves
Expand Down

0 comments on commit 8ee99be

Please sign in to comment.