Skip to content

Commit

Permalink
Merge 045a73d into 6864d76
Browse files Browse the repository at this point in the history
  • Loading branch information
yarikoptic committed May 17, 2019
2 parents 6864d76 + 045a73d commit 3abbcb3
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 4 deletions.
41 changes: 38 additions & 3 deletions mvpa2/clfs/gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from numpy import ones, zeros, sum, abs, isfinite, dot
from mvpa2.base import warning, externals
from mvpa2.clfs.base import Classifier, accepts_dataset_as_samples
from mvpa2.base.learner import DegenerateInputError
from mvpa2.base.types import asobjarray
from mvpa2.base.param import Parameter
from mvpa2.base.state import ConditionalAttribute
Expand Down Expand Up @@ -101,6 +102,12 @@ class GNB(Classifier):
disabled by default since does not impact classification output.
""")

guard_overflows = Parameter(True, constraints='bool',
doc="""Computation of marginals could experience under and overflows
causing NaNs and Infs to emerge. When enabled, GNB will verify
having finite numbers and mitigate the issue while computing
(ATM only in logprob=True mode). """)

def __init__(self, **kwargs):
"""Initialize an GNB classifier.
"""
Expand Down Expand Up @@ -201,6 +208,12 @@ def _train(self, dataset):
else:
variances[non0labels] /= nsamples_per_class[non0labels]

if len(np.unique(means)) <= 1 and len(np.unique(variances)) <= 1:
raise DegenerateInputError(
"All means and variances are identical, cannot train GNB to "
"produce meaningful results"
)

# Precompute and store weighting coefficient for Gaussian
if params.logprob:
# it would be added to exponent
Expand Down Expand Up @@ -230,6 +243,7 @@ def _predict(self, data):
"""Predict the output for the provided data.
"""
params = self.params
guard_overflows = params.guard_overflows
# argument of exponentiation
scaled_distances = \
-0.5 * (((data - self.means[:, np.newaxis, ...])**2) \
Expand All @@ -256,7 +270,8 @@ def _predict(self, data):

# Incorporate class probabilities:
prob_cs_cp = lprob_cs + np.log(self.priors[:, np.newaxis])

if guard_overflows:
assert np.all(np.isfinite(lprob_cs))
else:
# Just a regular Normal distribution with per
# feature/class mean and variances
Expand All @@ -268,24 +283,44 @@ def _predict(self, data):
## First we need to reshape to get class x samples x features
prob_csf = prob_csfs.reshape(
prob_csfs.shape[:2] + (-1,))

## Now -- product across features
prob_cs = prob_csf.prod(axis=2)
if guard_overflows:
assert np.all(np.isfinite(prob_cs)) # use logprob version then
assert np.any(prob_cs)

# Incorporate class probabilities:
prob_cs_cp = prob_cs * self.priors[:, np.newaxis]

assert np.all(np.isfinite(prob_cs_cp)) # before normalize

# Normalize by evidence P(data)
if params.normalize:
if params.logprob:
prob_cs_cp_real = np.exp(prob_cs_cp)
# to avoid overunderflows offset all the values
# (identical to multiplying by a number), and later
# remove (divide by it). Do it per each sample separately
underflow_offset = -np.ceil(np.max(prob_cs_cp, axis=0)) \
if guard_overflows else 0
prob_cs_cp_real = np.exp(prob_cs_cp + underflow_offset)
else:
prob_cs_cp_real = prob_cs_cp

prob_s_cp_marginals = np.sum(prob_cs_cp_real, axis=0)

if guard_overflows:
assert np.all(np.isfinite(prob_cs_cp_real))
assert np.all(np.isfinite(prob_s_cp_marginals)) # no overflows
assert np.any(prob_s_cp_marginals) # .inf down the road

if params.logprob:
prob_cs_cp -= np.log(prob_s_cp_marginals)
prob_cs_cp -= np.log(prob_s_cp_marginals) - underflow_offset
else:
prob_cs_cp /= prob_s_cp_marginals

assert np.all(np.isfinite(prob_cs_cp))

# Take the class with maximal (log)probability
winners = prob_cs_cp.argmax(axis=0)
predictions = [self.ulabels[c] for c in winners]
Expand Down
2 changes: 1 addition & 1 deletion mvpa2/tests/test_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def test_degenerate_usage(self, clf):
try:
try:
clf.train(ds) # should not crash or stall
except (ValueError), e:
except (ValueError, AssertionError) as e:
self.fail("Failed to train on degenerate data. Error was %r" % e)
except DegenerateInputError:
# so it realized that data is degenerate and puked
Expand Down
100 changes: 100 additions & 0 deletions mvpa2/tests/test_gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def test_gnb(self):

ds = datasets['uni2medium']

# Store probabilities for further comparison
probabilities = {}

# Generic silly coverage just to assure that it works in all
# possible scenarios:
bools = (True, False)
Expand Down Expand Up @@ -57,6 +60,7 @@ def test_gnb(self):
v = np.exp(v)
d1 = np.sum(v, axis=1) - 1.0
self.assertTrue(np.max(np.abs(d1)) < 1e-5)
probabilities[repr(gnb_)] = v
# smoke test to see whether invocation of sensitivity analyser blows
# if gnb classifier isn't linear, and to see whether it doesn't blow
# when it is linear.
Expand All @@ -67,7 +71,18 @@ def test_gnb(self):
with self.assertRaises(NotImplementedError):
gnb_.get_sensitivity_analyzer()

# Verify that probabilities are identical when we use logprob or not
assert_array_almost_equal(
probabilities["GNB(space='targets', normalize=True, logprob=False)"],
probabilities["GNB(space='targets', normalize=True)"]
)
assert_array_almost_equal(
probabilities["GNB(space='targets', normalize=True, logprob=False, prior='uniform')"],
probabilities["GNB(space='targets', normalize=True, prior='uniform')"]
)


@reseed_rng()
def test_gnb_sensitivities():
gnb = GNB(common_variance=True)
ds = normal_feature_dataset(perlabel=4,
Expand Down Expand Up @@ -107,6 +122,91 @@ def test_gnb_sensitivities():
assert t1t2sens[i2] > t1t2sens[4]


@reseed_rng()
def test_gnb_overflow():
# https://github.com/PyMVPA/PyMVPA/issues/581
gnb = GNB(enable_ca='estimates',
#logprob=True, # implemented only for True ATM
normalize=True,
# uncomment if interested to trigger:
# guard_overflows=False,
)

# Having lots of features could trigger under/overflows
ds = normal_feature_dataset(perlabel=4,
nlabels=2,
nfeatures=100000,
nchunks=2,
snr=5,
nonbogus_features=[0, 1]
)

ds_train = ds[ds.chunks == ds.UC[0]]
ds_test = ds[ds.chunks == ds.UC[1]]

gnb.train(ds_train)
res = gnb.predict(ds_test)
res_est = gnb.ca.estimates

probs = np.exp(res_est) if gnb.params.logprob else res_est

assert np.all(np.isfinite(res_est))
assert np.all(np.isfinite(probs))
assert_equal(sorted(np.unique(probs)), [0, 1]) # quantized into 0, 1 given this many samples


def _test_gnb_overflow_haxby(): # pragma: no cover
# example from https://github.com/PyMVPA/PyMVPA/issues/581
# a heavier version of the above test
import os
import numpy as np

from mvpa2.datasets.sources.native import load_tutorial_data
from mvpa2.clfs.gnb import GNB
from mvpa2.measures.base import CrossValidation
from mvpa2.generators.partition import HalfPartitioner
from mvpa2.mappers.zscore import zscore
from mvpa2.mappers.detrend import poly_detrend
from mvpa2.datasets.miscfx import remove_invariant_features
from mvpa2.testing.datasets import *

datapath = '/usr/share/data/pymvpa2-tutorial/'
haxby = load_tutorial_data(datapath,
roi='vt',
add_fa={'vt_thr_glm': os.path.join(datapath,
'haxby2001',
'sub001',
'masks',
'orig',
'vt.nii.gz')})
# poly_detrend(haxby, polyord=1, chunks_attr='chunks')
haxby = haxby[np.array([l in ['rest', 'scrambled'] # ''house', 'face']
for l in haxby.targets], dtype='bool')]
#zscore(haxby, chunks_attr='chunks', param_est=('targets', ['rest']),
# dtype='float32')
# haxby = haxby[haxby.sa.targets != 'rest']
haxby = remove_invariant_features(haxby)

clf = GNB(enable_ca='estimates',
logprob=True,
normalize=True)

#clf.train(haxby)
#clf.predict(haxby)
# estimates a bit "overfit" to judge in the train/predict on the same data

cv = CrossValidation(clf,
HalfPartitioner(attr='chunks'),
postproc=None,
enable_ca=['stats'])

cv_results = cv(haxby)
res1_est = clf.ca.estimates
print "Estimates:\n", res1_est
print "Exp(estimates):\n", np.round(np.exp(res1_est), 3)
assert np.all(np.isfinite(res1_est))


def suite(): # pragma: no cover
return unittest.makeSuite(GNBTests)

Expand Down

0 comments on commit 3abbcb3

Please sign in to comment.