Skip to content

Commit

Permalink
Merge pull request #110 from nno/master
Browse files Browse the repository at this point in the history
NIML hstack support + refactoring of Kohonen SOM + flat_surf plots + kitchen sink ;)
  • Loading branch information
yarikoptic committed Jun 4, 2013
2 parents 90355b4 + 0219ea6 commit bc04787
Show file tree
Hide file tree
Showing 24 changed files with 1,788 additions and 129 deletions.
144 changes: 144 additions & 0 deletions mvpa2/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,150 @@ def _expand_attribute(attr, length, attr_name):
# make sequence of identical value matching the desired length
return np.repeat(attr, length)

def stack_by_unique_sample_attribute(dataset, sa_label):
"""Performs hstack based on unique values in sa_label
Parameters
----------
dataset: Dataset
input dataset.
sa_label: str
sample attribute label according which samples in dataset
are stacked.
Returns
-------
stacked_dataset: Dataset
A dataset where matching featurs are joined (hstacked).
If the number of matching features differes for values in sa_label
and exception is raised.
"""

unq, masks = _get_unique_attribute_masks(dataset.sa[sa_label].value)

ds = []
for i, mask in enumerate(masks):
d = dataset[mask, :]
d.fa[sa_label] = [unq[i]] * d.nfeatures
ds.append(d)

stacked_ds = hstack(ds, True)
stacked_ds.sa.pop(sa_label)

return stacked_ds


def stack_by_unique_feature_attribute(dataset, fa_label):
"""Performs vstack based on unique values in fa_label
Parameters
----------
dataset: Dataset
input dataset.
fa_label: str
feature attribute label according which samples in dataset
are stacked.
Returns
stacked_dataset: Dataset
A dataset where matching samples are joined. This dataset has
a sample attribute fa_label added and the feature attribute
fa_label removed.
If the number of matching features differes for values in sa_label
and exception is raised.
"""

unq, masks = _get_unique_attribute_masks(dataset.fa[fa_label].value)

ds = []
for i, mask in enumerate(masks):
d = dataset[:, mask]
d.sa[fa_label] = [unq[i]] * d.nsamples
ds.append(d)

stacked_ds = vstack(ds, True)
stacked_ds.fa.pop(fa_label)

return stacked_ds


def _get_unique_attribute_masks(xs, raise_unequal_count=True):
'''Helper function to get masks for each unique value'''
unq = np.unique(xs)
masks = [x == xs for x in unq]

if raise_unequal_count:
hs = [np.sum(mask) for mask in masks]

for i, h in enumerate(hs):
if i == 0:
h0 = h
elif h != h0:
raise ValueError('Value mismatch between input 0 and %d:'
' %s != %s' % (i, h, h0))
return unq, masks

def split_by_sample_attribute(ds, sa_label, raise_unequal_count=True):
'''Splits a dataset based on unique values of a sample attribute
Parameters
----------
d: Dataset
input dataset
sa_label: str or list of str
sample attribute label(s) on which the split is based
Returns
-------
ds: list of Dataset
List with n datasets, if d.sa[sa_label] has n unique values
'''
if type(sa_label) in (list, tuple):
label0 = sa_label[0]
sas = split_by_sample_attribute(ds, label0, raise_unequal_count)
if len(sa_label) == 1:
return sas
else:
return sum([split_by_sample_attribute(sa, sa_label[1:],
raise_unequal_count)
for sa in sas], [])

_, masks = _get_unique_attribute_masks(ds.sa[sa_label].value,
raise_unequal_count=raise_unequal_count)

return [ds[mask, :].copy(deep=False) for mask in masks]


def split_by_feature_attribute(ds, fa_label, raise_unequal_count=True):
'''Splits a dataset based on unique values of a feature attribute
Parameters
----------
d: Dataset
input dataset
sa_label: str or list of str
sample attribute label(s) on which the split is based
Returns
-------
ds: list of Dataset
List with n datasets, if d.fa[fa_label] has n unique values
'''
if type(fa_label) in (list, tuple):
label0 = fa_label[0]
fas = split_by_feature_attribute(ds, label0, raise_unequal_count)
if len(fa_label) == 1:
return fas
else:
return sum([split_by_feature_attribute(fa, fa_label[1:],
raise_unequal_count)
for fa in fas], [])

_, masks = _get_unique_attribute_masks(ds.fa[fa_label].value,
raise_unequal_count=raise_unequal_count)

return [ds[:, mask].copy(deep=False) for mask in masks]



class DatasetError(Exception):
Expand Down
42 changes: 41 additions & 1 deletion mvpa2/base/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
__docformat__ = 'restructuredtext'

import time
from mvpa2.base.node import Node
from mvpa2.base.node import Node, ChainNode
from mvpa2.base.state import ConditionalAttribute
from mvpa2.base.types import is_datasetlike
from mvpa2.base.dochelpers import _repr_attrs
Expand Down Expand Up @@ -247,3 +247,43 @@ def __call__(self, ds):
force_train = property(fget=lambda x:x.__force_train,
doc="Whether the Learner enforces training upon every"
"called.")


class ChainLearner(Learner, ChainNode):
'''Combines different learners into one in a chained fashion'''
def __init__(self, learners, auto_train=False,
force_train=False, **kwargs):
'''Initializes with measures
Parameters
----------
learners: list or tuple
a list of Learner instances
'''
Learner.__init__(self, auto_train=auto_train,
force_train=force_train, **kwargs)
self._learners = learners

is_trained = property(fget=lambda x:all(y.is_trained
for y in x._learners),
fset=lambda x:map(x.set_trained
for y in x._learners),
doc="Whether the Learner is currently trained.")

def train(self, ds):
for learner in self._learners:
learner.train(ds)

def untrain(self):
for learner in self._learners:
measure.untrain()

def __call__(self, ds):
'''Calls all learners of this instance'''
learners = self._learners

r = ds
for learner in learners:
r = learner(r)

return r
120 changes: 95 additions & 25 deletions mvpa2/base/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from mvpa2.base.collections import SampleAttributesCollection, \
FeatureAttributesCollection, DatasetAttributesCollection


if __debug__:
from mvpa2.base import debug

Expand Down Expand Up @@ -249,18 +248,14 @@ def __repr__(self, prefixes=[]):
doc="Node to perform post-processing of results")


class ChainNode(Node):
"""Chain of nodes.
class CompoundNode(Node):
"""List of nodes.
This class allows to concatenate a list of nodes into a processing chain.
When called with a dataset, it is sequentially fed through a nodes in the
chain. A ChainNode may also be used as a generator. In this case, all
nodes in the chain are treated as generators too, and the ChainNode
behaves as a single big generator that recursively calls all embedded
generators and yield the results.
A ChainNode behaves similar to a list container: Nodes can be appended,
A CompoundNode behaves similar to a list container: Nodes can be appended,
and the chain can be sliced like a list, etc ...
Subclasses such as ChainNode and CombinedNode implement the _call
method in different ways.
"""
def __init__(self, nodes, **kwargs):
"""
Expand All @@ -272,8 +267,9 @@ def __init__(self, nodes, **kwargs):
if not len(nodes):
raise ValueError("%s needs at least one embedded node."
% self.__class__.__name__)
Node.__init__(self, **kwargs)

self._nodes = nodes
Node.__init__(self, **kwargs)


def __copy__(self):
Expand All @@ -282,18 +278,7 @@ def __copy__(self):


def _call(self, ds):
mp = ds
for i, n in enumerate(self):
if __debug__:
debug('MAP', "%s: input (%s) -> node (%i/%i): '%s'",
(self.__class__.__name__,
hasattr(mp, 'shape') and mp.shape or '???',
i + 1, len(self),
n))
mp = n(mp)
if __debug__:
debug('MAP', "%s: output (%s)", (self.__class__.__name__, mp.shape))
return mp
raise NotImplementedError("This is an abstract class.")


def generate(self, ds, startnode=0):
Expand Down Expand Up @@ -358,7 +343,7 @@ def __getitem__(self, key):


def __repr__(self, prefixes=[]):
return super(ChainNode, self).__repr__(
return super(CompoundNode, self).__repr__(
prefixes=prefixes
+ _repr_attrs(self, ['nodes']))

Expand All @@ -367,3 +352,88 @@ def __str__(self):
return _str(self, '-'.join([str(n) for n in self]))

nodes = property(fget=lambda self:self._nodes)


class ChainNode(CompoundNode):
"""
This class allows to concatenate a list of nodes into a processing chain.
When called with a dataset, it is sequentially fed through nodes in the
chain. A ChainNode may also be used as a generator. In this case, all
nodes in the chain are treated as generators too, and the ChainNode
behaves as a single big generator that recursively calls all embedded
generators and yield the results.
"""
def __init__(self, nodes, **kwargs):
"""
Parameters
----------
nodes: list
Node instances.
"""
CompoundNode.__init__(self, nodes=nodes, **kwargs)

def _call(self, ds):
mp = ds
for i, n in enumerate(self):
if __debug__:
debug('MAP', "%s: input (%s) -> node (%i/%i): '%s'",
(self.__class__.__name__,
hasattr(mp, 'shape') and mp.shape or '???',
i + 1, len(self),
n))
mp = n(mp)
if __debug__:
debug('MAP', "%s: output (%s)", (self.__class__.__name__, mp.shape))
return mp


class CombinedNode(CompoundNode):
"""Node to pass a dataset on to a set of nodes and combine there output.
Output combination or aggregation is currently done by hstacking or
vstacking the resulting datasets.
"""

def __init__(self, nodes, combine_axis, a=None, **kwargs):
"""
Parameters
----------
mappers : list
combine_axis : ['h', 'v']
a: {'unique','drop_nonunique','uniques','all'} or True or False or None (default: None)
Indicates which dataset attributes from datasets are stored
in merged_dataset. If an int k, then the dataset attributes from
datasets[k] are taken. If 'unique' then it is assumed that any
attribute common to more than one dataset in datasets is unique;
if not an exception is raised. If 'drop_nonunique' then as 'unique',
except that exceptions are not raised. If 'uniques' then, for each
attribute, any unique value across the datasets is stored in a tuple
in merged_datasets. If 'all' then each attribute present in any
dataset across datasets is stored as a tuple in merged_datasets;
missing values are replaced by None. If None (the default) then no
attributes are stored in merged_dataset. True is equivalent to
'drop_nonunique'. False is equivalent to None.
"""
CompoundNode.__init__(self, nodes=nodes, **kwargs)
self._combine_axis = combine_axis
self._a = a

def __copy__(self):
return self.__class__([copy.copy(n) for n in self],
copy.copy(self._combine_axis),
copy.copy(self._a))


def _call(self, ds):
out = [node(ds) for node in self]
from mvpa2.datasets import hstack, vstack
stacker = {'h': hstack, 'v': vstack}
stacked = stacker[self._combine_axis](out, self._a)
return stacked

def __repr__(self, prefixes=[]):
return super(CombinedNode, self).__repr__(
prefixes=prefixes
+ _repr_attrs(self, ['combine_axis', 'a']))


0 comments on commit bc04787

Please sign in to comment.