New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Basic implementation of cross-validation rsm using learner and CV. #504
Changes from 7 commits
3215b70
f2d6f00
7b20a10
b41e527
f09f01a
94ad406
e8ef537
7231b34
972fa5c
a243ad6
9f76829
f8c7184
d67b9f7
ebced46
abbe6d6
a6b1415
8a955c9
72b3fc6
6a8d2f5
5a6f13e
bd94cb3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,19 +10,63 @@ | |
|
||
__docformat__ = 'restructuredtext' | ||
|
||
from itertools import combinations | ||
from itertools import combinations, product | ||
import numpy as np | ||
from mvpa2.measures.base import Measure | ||
from mvpa2.datasets.base import Dataset | ||
from mvpa2.base import externals | ||
from mvpa2.base.param import Parameter | ||
from mvpa2.base.constraints import EnsureChoice | ||
from mvpa2.mappers.fx import mean_group_sample | ||
|
||
if externals.exists('scipy', raise_=True): | ||
from scipy.spatial.distance import pdist, squareform | ||
from scipy.spatial.distance import pdist, squareform, cdist | ||
from scipy.stats import rankdata, pearsonr | ||
|
||
|
||
class CDist(Measure): | ||
"""Compute dissimiliarity matrix for samples in a dataset | ||
|
||
This `Measure` can be trained on part of the dataset (for example, | ||
a partition) and called on another partition. It can be used in | ||
cross-validation to generate cross-validated RSA. | ||
""" | ||
pairwise_metric = Parameter('correlation', constraints='str', doc=""" | ||
Distance metric to use for calculating pairwise vector distances for | ||
dissimilarity matrix (DSM). See scipy.spatial.distance.pdist for | ||
all possible metrics.""") | ||
|
||
sattr = Parameter(['targets'], doc=""" | ||
List of sample attributes whose unique values will be used to identify the | ||
samples groups. Typically your category labels or targets.""") | ||
|
||
def __init__(self, **kwargs): | ||
Measure.__init__(self, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we be using |
||
self._train_ds = None | ||
|
||
def _prepare_ds(self, ds): | ||
if self.params.sattr is not None: | ||
mgs = mean_group_sample(attrs=self.params.sattr) | ||
ds_ = mgs(ds) | ||
else: | ||
ds_ = ds.copy(deep=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw why bother deep copying the data?? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need. Left over from old implementation. |
||
return ds_ | ||
|
||
def _train(self, ds): | ||
self._train_ds = self._prepare_ds(ds) | ||
self.is_trained = True | ||
|
||
def _call(self, ds): | ||
test_ds = self._prepare_ds(ds) | ||
# Call actual distance metric | ||
distds = cdist(self._train_ds.samples, test_ds, | ||
metric=self.params.pairwise_metric) | ||
# Make target pairs | ||
distds = Dataset(samples=distds.ravel()[None, ], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we are arranging samples as folds to be consistent with cross-validation as used with classifiers, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I decided to return the result as single-feature dataset. This way it will work with I think it is ready to consider for PR. @mvdoc @yarikoptic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK for me about returning samples instead of features. I didn't notice that it was the way Andy returned the pairs in PDist already. We need to fix the test for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I think that failure is an issue to be addressed in @mvdoc Do you think we should just by-pass it for now in tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's see what @yarikoptic says first ;-) |
||
fa={'pairs': list(product(test_ds.UT, test_ds.UT))}) | ||
return distds | ||
|
||
|
||
class PDist(Measure): | ||
"""Compute dissimiliarity matrix for samples in a dataset | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't critical piece missing -- "cross-validated dissimilarity"