Skip to content

Commit

Permalink
RF, ENH & DOC: Changes to handle residual computation, and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
swaroopgj committed Feb 10, 2017
1 parent 254d44c commit 07a9e67
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 29 deletions.
76 changes: 47 additions & 29 deletions mvpa2/algorithms/hyperalignment.py
Expand Up @@ -154,7 +154,7 @@ class Hyperalignment(ClassWithCollections):
reference. If `None`, then the dataset with the maximum
number of features is used.""")

nproc = Parameter(4, constraints=EnsureInt() & EnsureRange(min=1),
nproc = Parameter(1, constraints=EnsureInt() & EnsureRange(min=1),
doc="Number of processes to use to parallelize the last step of"
"alignment. Requires `joblib` package.")

Expand Down Expand Up @@ -491,41 +491,42 @@ def _level3(self, datasets):

# key different from level-2; the common space is uniform
#temp_commonspace = commonspace

residuals = [None]*len(datasets)
if self.ca['residual_errors'].enabled:
#residuals = np.zeros((1, len(datasets)))
residuals = np.zeros(len(datasets))
self.ca.residual_errors = Dataset(samples=residuals[None, :])
# Checking for joblib, if not, set nproc to 1
if params.nproc > 1:
from mvpa2.base import externals, warning
if not externals.exists('joblib'):
warning("Setting nproc > 1 requires joblib package, which"
"does not seem to exist. Setting nproc to 1")
params.nproc = 1

# start from original input datasets again
if params.nproc == 1:
residuals = []
for i, (m, ds_new) in enumerate(zip(mappers, datasets)):
if __debug__:
debug('HPAL_', "Level 3: ds #%i" % i)
m = get_trained_mapper(ds_new, self.commonspace, m, residuals[i])
'''
# retrain mapper on final common space
ds_new.sa[m.get_space()] = self.commonspace
m.train(ds_new)
# remove common space attribute again to save on memory
del ds_new.sa[m.get_space()]
if residuals is not None:
# obtain final projection
data_mapped = m.forward(ds_new.samples)
residuals[0, i] = np.linalg.norm(data_mapped - self.commonspace)
'''
m, residual = get_trained_mapper(ds_new, self.commonspace, m,
self.ca['residual_errors'].enabled)
if self.ca['residual_errors'].enabled:
residuals.append(residual)
else:
verbose_level_parallel = 50 \
if (__debug__ and 'GCTHR' in debug.active) else 0
if __debug__:
debug('HPAL_', "Level 3: Using joblib with %d jobs" % params.nproc)
verbose_level_parallel = 20 \
if (__debug__ and 'HPAL' in debug.active) else 0
from joblib import Parallel, delayed
mappers = Parallel(n_jobs=params.nproc,
pre_dispatch=params.nproc,
res = Parallel(n_jobs=params.nproc,
pre_dispatch=params.nproc, backend='threading',
verbose=verbose_level_parallel)(
delayed(get_trained_mapper)
(ds, self.commonspace, mapper, residual)
for ds, mapper, residual in zip(datasets, mappers, residuals))
(ds, self.commonspace, mapper, self.ca['residual_errors'].enabled)
for ds, mapper in zip(datasets, mappers))
mappers = [m for m, r in res]
if self.ca['residual_errors'].enabled:
residuals = [r for m, r in res]

if self.ca['residual_errors'].enabled:
self.ca.residual_errors = Dataset(samples=np.array(residuals)[None, :])

return mappers

Expand All @@ -542,12 +543,29 @@ def _map_and_mean(self, datasets, mappers):
dss_mean = params.combiner2(data_mapped)
return dss_mean

def get_trained_mapper(ds, commonspace, mapper, residual):

def get_trained_mapper(ds, commonspace, mapper, compute_residual=False):
"""
Trains a given mapper using dataset and commonspace and computes residuals if
necessary.
Parameters
----------
ds: Input dataset.
commonspace: Target commonspace
mapper: Mapper to train. Typically ProcrusteanMapper.
compute_residual: Whether to compute residuals or not. Default is False,
and returns None.
"""
ds.sa[mapper.get_space()] = commonspace
mapper.train(ds)
# XXX Is this required?
del ds.sa[mapper.get_space()]
if residual is not None:
residual = None
if compute_residual:
data_mapped = mapper.forward(ds.samples)
residual = np.linalg.norm(data_mapped - commonspace)
return mapper
return mapper, residual
18 changes: 18 additions & 0 deletions mvpa2/tests/test_hyperalignment.py
Expand Up @@ -236,6 +236,24 @@ def test_hpal_svd_combo(self):
"SVD. Got correlations %s."
% sv_corrs_orig)

def test_hpal_joblib(self):
skip_if_no_external('joblib')
# get seed dataset
ds4l = datasets['uni4large']
dss_rotated = [random_affine_transformation(ds4l, scale_fac=100, shift_fac=10)
for i in range(4)]
ha = Hyperalignment(nproc=1, enable_ca=['residual_errors'])
ha.train(dss_rotated[:2])
mappers = ha(dss_rotated)
ha_proc = Hyperalignment(nproc=2, enable_ca=['residual_errors'])
ha_proc.train(dss_rotated[:2])
mappers_nproc = ha_proc(dss_rotated)
self.assertTrue(
np.all([np.array_equal(m.proj, mp.proj)
for m, mp in zip(mappers, mappers_nproc)]),
msg="Mappers differ when using nproc>1.")
assert_array_equal(ha.ca.residual_errors.samples, ha_proc.ca.residual_errors.samples)

def test_hypal_michael_caused_problem(self):
from mvpa2.misc import data_generators
from mvpa2.mappers.zscore import zscore
Expand Down

0 comments on commit 07a9e67

Please sign in to comment.