Skip to content

Commit

Permalink
Merge pull request dipy#1195 from arokem/pickle-pam-3.6
Browse files Browse the repository at this point in the history
Make PeaksAndMetrics pickle-able
  • Loading branch information
arokem committed Mar 22, 2017
2 parents ad3cf93 + 14d94e8 commit d23f8f1
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 23 deletions.
93 changes: 71 additions & 22 deletions dipy/direction/peaks.py
Expand Up @@ -157,8 +157,68 @@ def peak_directions(odf, sphere, relative_peak_threshold=.5,
return directions, values, indices


def _pam_from_attrs(klass, sphere, peak_indices, peak_values, peak_dirs,
gfa, qa, shm_coeff, B, odf):
"""
Construct a PeaksAndMetrics object (or object of a subclass) from its
attributes.
This is also useful for pickling/unpickling of these objects (see also
:func:`__reduce__` below).
Parameters
----------
klass : class
The class of object to be created.
sphere : `Sphere` class instance.
Sphere for discretization.
peak_indices : ndarray
Indices (in sphere vertices) of the peaks in each voxel.
peak_values : ndarray
The value of the peaks.
peak_dirs : ndarray
The direction of each peak.
gfa : ndarray
The Generalized Fractional Anisotropy in each voxel.
qa : ndarray
Quantitative Anisotropy in each voxel.
shm_coeff : ndarray
The coefficients of the spherical harmonic basis for the ODF in
each voxel.
B : ndarray
The spherical harmonic matrix, for multiplication with the
coefficients.
odf : ndarray
The orientation distribution function on the sphere in each voxel.
Returns
-------
pam : Instance of the class `klass`.
"""
this_pam = klass()
this_pam.sphere = sphere
this_pam.peak_dirs = peak_dirs
this_pam.peak_values = peak_values
this_pam.peak_indices = peak_indices
this_pam.gfa = gfa
this_pam.qa = qa
this_pam.shm_coeff = shm_coeff
this_pam.B = B
this_pam.odf = odf
return this_pam


class PeaksAndMetrics(PeaksAndMetricsDirectionGetter):
pass
def __reduce__(self): return _pam_from_attrs, (self.__class__,
self.sphere,
self.peak_indices,
self.peak_values,
self.peak_dirs,
self.gfa,
self.qa,
self.shm_coeff,
self.B,
self.odf)


def _peaks_from_model_parallel(model, data, sphere, relative_peak_threshold,
Expand Down Expand Up @@ -480,27 +540,16 @@ def peaks_from_model(model, data, sphere, relative_peak_threshold,

qa_array /= global_max

pam = PeaksAndMetrics()
pam.sphere = sphere
pam.peak_dirs = peak_dirs
pam.peak_values = peak_values
pam.peak_indices = peak_indices
pam.gfa = gfa_array
pam.qa = qa_array

if return_sh:
pam.shm_coeff = shm_coeff
pam.B = B
else:
pam.shm_coeff = None
pam.B = None

if return_odf:
pam.odf = odf_array
else:
pam.odf = None

return pam
return _pam_from_attrs(PeaksAndMetrics,
sphere,
peak_indices,
peak_values,
peak_dirs,
gfa_array,
qa_array,
shm_coeff if return_sh else None,
B if return_sh else None,
odf_array if return_odf else None)


def gfa(samples):
Expand Down
26 changes: 26 additions & 0 deletions dipy/direction/tests/test_peaks.py
@@ -1,4 +1,8 @@
import numpy as np

import pickle
from io import BytesIO

from numpy.testing import (assert_array_equal, assert_array_almost_equal,
assert_almost_equal, run_module_suite,
assert_equal, assert_)
Expand Down Expand Up @@ -485,6 +489,28 @@ def test_peaksFromModel():
assert_array_equal(pam.peak_indices[mask, 0], odf_argmax)
assert_array_equal(pam.peak_indices[mask, 1:], -1)

# Test serialization and deserialization:
for normalize_peaks in [True, False]:
for return_odf in [True, False]:
for return_sh in [True, False]:
pam = peaks_from_model(model, data, _sphere, .5, 45,
normalize_peaks=normalize_peaks,
return_odf=return_odf,
return_sh=return_sh)

b = BytesIO()
pickle.dump(pam, b)
b.seek(0)
new_pam = pickle.load(b)
b.close()

for attr in ['peak_dirs', 'peak_values', 'peak_indices',
'gfa', 'qa', 'shm_coeff', 'B', 'odf']:
assert_array_equal(getattr(pam, attr),
getattr(new_pam, attr))
assert_array_equal(pam.sphere.vertices,
new_pam.sphere.vertices)


def test_peaksFromModelParallel():
SNR = 100
Expand Down
1 change: 0 additions & 1 deletion dipy/reconst/peak_direction_getter.pyx
Expand Up @@ -125,4 +125,3 @@ cdef class PeaksAndMetricsDirectionGetter(DirectionGetter):
return 0
else:
return 1

0 comments on commit d23f8f1

Please sign in to comment.