Skip to content

Commit

Permalink
NF: CompoundLearner and subclasses ChainLearner and CombinedLearner
Browse files Browse the repository at this point in the history
  • Loading branch information
nno committed Jul 31, 2013
1 parent b785a13 commit afae14d
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 17 deletions.
75 changes: 58 additions & 17 deletions mvpa2/base/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mvpa2.base.state import ConditionalAttribute
from mvpa2.base.types import is_datasetlike
from mvpa2.base.dochelpers import _repr_attrs
from mvpa2.base.node import CompoundNode, CombinedNode, ChainNode

if __debug__:
from mvpa2.base import debug
Expand Down Expand Up @@ -249,41 +250,81 @@ def __call__(self, ds):
"called.")


class ChainLearner(Learner, ChainNode):
'''Combines different learners into one in a chained fashion'''
class CompoundLearner(Learner, CompoundNode):
def __init__(self, learners, auto_train=False,
force_train=False, **kwargs):
'''Initializes with measures
Parameters
----------
learners: list or tuple
a list of Learner instances
'''
Learner.__init__(self, auto_train=auto_train,
force_train=force_train, **kwargs)
self._learners = learners
CompoundNode.__init__(self, learners, **kwargs)

is_trained = property(fget=lambda x:all(y.is_trained
for y in x._learners),
fset=lambda x:map(x.set_trained
for y in x._learners),
for y in x),
fset=lambda x:map(y._set_trained()
for y in x),
doc="Whether the Learner is currently trained.")

def train(self, ds):
for learner in self._learners:
for learner in self:
learner.train(ds)

def untrain(self):
for learner in self._learners:
measure.untrain()
for learner in self:
learner.untrain()

def __call__(self, ds):
'''Calls all learners of this instance'''
learners = self._learners
def _call(self, ds):
raise NotImplementedError


class ChainLearner(ChainNode, CompoundLearner):
'''Combines different learners into one in a chained fashion'''
def __init__(self, learners, auto_train=False,
force_train=False, **kwargs):
'''Initializes with measures
Parameters
----------
learners: list or tuple
a list of Learner instances
'''
CompoundLearner.__init__(self, learners, auto_train=auto_train,
force_train=force_train, **kwargs)

def _call(self, ds):
return ChainNode._call(self, ds)

class CombinedLearner(CompoundLearner, CombinedNode):
def __init__(self, learners, combine_axis, a=None, **kwargs):
"""
Parameters
----------
learners : list of Learner
combine_axis : ['h', 'v']
a: {'unique','drop_nonunique','uniques','all'} or True or False or None (default: None)
Indicates which dataset attributes from datasets are stored
in merged_dataset. If an int k, then the dataset attributes from
datasets[k] are taken. If 'unique' then it is assumed that any
attribute common to more than one dataset in datasets is unique;
if not an exception is raised. If 'drop_nonunique' then as 'unique',
except that exceptions are not raised. If 'uniques' then, for each
attribute, any unique value across the datasets is stored in a tuple
in merged_datasets. If 'all' then each attribute present in any
dataset across datasets is stored as a tuple in merged_datasets;
missing values are replaced by None. If None (the default) then no
attributes are stored in merged_dataset. True is equivalent to
'drop_nonunique'. False is equivalent to None.
"""
CompoundLearner.__init__(self, learners, **kwargs)
self._combine_axis = combine_axis
self._a = a

def _call(self, ds):
return CombinedNode._call(self, ds)

r = ds
for learner in learners:
r = learner(r)

return r
116 changes: 116 additions & 0 deletions mvpa2/tests/test_compound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
# See COPYING file distributed along with the PyMVPA package for the
# copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Unit tests for PyMVPA sparse multinomial logistic regression classifier"""

import numpy as np

from mvpa2.testing import *
from mvpa2.base.learner import Learner, CompoundLearner, \
ChainLearner, CombinedLearner
from mvpa2.base.node import Node, CompoundNode, \
ChainNode, CombinedNode

from mvpa2.datasets.base import AttrDataset

class FxNode(Node):
def __init__(self, f, space='targets',
pass_attr=None, postproc=None, **kwargs):
super(FxNode, self).__init__(space, pass_attr, postproc, **kwargs)
self.f = f

def _call(self, ds):
cp = ds.copy()
cp.samples = self.f(ds.samples)
return cp

class FxyLearner(Learner):
def __init__(self, f):
super(FxyLearner, self).__init__()
self.f = f
self.x = None

def _train(self, ds):
self.x = ds.samples

def _call(self, ds):
cp = ds.copy()
cp.samples = self.f(self.x)(ds.samples)
return cp


class CompoundTests(unittest.TestCase):
def test_compound_node(self):
data = np.asarray([[1, 2, 3, 4]], dtype=np.float_).T
ds = AttrDataset(data, sa=dict(targets=[0, 0, 1, 1]))

add = lambda x: lambda y: x + y
mul = lambda x: lambda y: x * y

add2 = FxNode(add(2))
mul3 = FxNode(mul(3))

assert_array_equal(add2(ds).samples, data + 2)

add2mul3 = ChainNode([add2, mul3])
assert_array_equal(add2mul3(ds), (data + 2) * 3)

add2_mul3v = CombinedNode([add2, mul3], 'v')
add2_mul3h = CombinedNode([add2, mul3], 'h')
assert_array_equal(add2_mul3v(ds).samples,
np.vstack((data + 2, data * 3)))
assert_array_equal(add2_mul3h(ds).samples,
np.hstack((data + 2, data * 3)))

def test_compound_learner(self):
data = np.asarray([[1, 2, 3, 4]], dtype=np.float_).T
ds = AttrDataset(data, sa=dict(targets=[0, 0, 1, 1]))
train = ds[ds.sa.targets == 0]
test = ds[ds.sa.targets == 1]
dtrain = train.samples
dtest = test.samples

sub = FxyLearner(lambda x: lambda y: x - y)
assert_false(sub.is_trained)
sub.train(train)
assert_array_equal(sub(test).samples, dtrain - dtest)


div = FxyLearner(lambda x: lambda y: x / y)
div.train(train)
assert_array_almost_equal(div(test).samples, dtrain / dtest)
div.untrain()

subdiv = ChainLearner((sub, div))
assert_false(subdiv.is_trained)
subdiv.train(train)
assert_true(subdiv.is_trained)
subdiv.untrain()
assert_raises(RuntimeError, subdiv, test)
subdiv.train(train)

assert_array_almost_equal(subdiv(test).samples, dtrain / (dtrain - dtest))

sub_div = CombinedLearner((sub, div), 'v')
assert_true(sub_div.is_trained)
sub_div.untrain()
subdiv.train(train)
assert_true(sub_div.is_trained)

assert_array_almost_equal(sub_div(test).samples,
np.vstack((dtrain - dtest, dtrain / dtest)))



def suite():
return unittest.makeSuite(SMLRTests)


if __name__ == '__main__':
import runner

0 comments on commit afae14d

Please sign in to comment.