Skip to content

Commit

Permalink
imporve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrebarachant committed May 18, 2015
1 parent dbdd199 commit cb70bcf
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyriemann/utils/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def mean_identity(covmats):
:returns: the identity matrix of size Nchannels
"""
C = eye(covmats.shape[1])
C = numpy.eye(covmats.shape[1])
return C

def mean_covariance(covmats,metric='riemann',*args):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def test_MDM_predict():
mdm.fit(covset,labels)
mdm.predict(covset)

def test_MDM_fit_predict():
"""Test Fit & predict of MDM"""
covset = generate_cov(100,3)
labels = np.array([0,1]).repeat(50)
mdm = MDM(metric='riemann')
mdm.fit_predict(covset,labels)

def test_MDM_transform():
"""Test transform of MDM"""
covset = generate_cov(100,3)
Expand Down
40 changes: 39 additions & 1 deletion tests/test_clustering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from pyriemann.clustering import Kmeans
from pyriemann.clustering import Kmeans,KmeansPerClassTransform

def generate_cov(Nt,Ne):
"""Generate a set of cavariances matrices for test purpose"""
Expand All @@ -18,6 +18,25 @@ def test_Kmeans_fit():
covset = generate_cov(20,3)
km = Kmeans(2)
km.fit(covset)

def test_Kmeans_fit_with_init():
"""Test Fit of Kmeans wit matric initialization"""
covset = generate_cov(20,3)
km = Kmeans(2,init=covset[0:2])
km.fit(covset)

def test_Kmeans_fit_with_y():
"""Test Fit of Kmeans with a given y"""
covset = generate_cov(20,3)
labels = np.array([0,1]).repeat(10)
km = Kmeans(2)
km.fit(covset,y=labels)

def test_Kmeans_fit_parallel():
"""Test Fit of Kmeans using paralell"""
covset = generate_cov(20,3)
km = Kmeans(2,n_jobs=2)
km.fit(covset)

def test_Kmeans_predict():
"""Test prediction of Kmeans"""
Expand All @@ -31,4 +50,23 @@ def test_Kmeans_transform():
covset = generate_cov(20,3)
km = Kmeans(2)
km.fit(covset)
km.transform(covset)

def test_KmeansPCT_init():
"""Test init of Kmeans PCT"""
km = KmeansPerClassTransform(2)

def test_KmeansPCT_fit():
"""Test Fit of Kmeans PCT"""
covset = generate_cov(20,3)
labels = np.array([0,1]).repeat(10)
km = KmeansPerClassTransform(2)
km.fit(covset,labels)

def test_KmeansPCT_transform():
"""Test Transform of Kmeans PCT"""
covset = generate_cov(20,3)
labels = np.array([0,1]).repeat(10)
km = KmeansPerClassTransform(2)
km.fit(covset,labels)
km.transform(covset)
7 changes: 7 additions & 0 deletions tests/test_tangentspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def test_TangentSpace_transform():
ts.fit(covset)
ts.transform(covset)

def test_TangentSpace_transform_with_ts_update():
"""Test transform of Tangent Space with TSupdate"""
covset = generate_cov(100,3)
ts = TangentSpace(metric='riemann',tsupdate=True)
ts.fit(covset)
ts.transform(covset)

def test_TangentSpace_inversetransform():
"""Test inverse transform of Tangent Space"""
covset = generate_cov(100,3)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils_mean.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numpy.testing import assert_array_almost_equal,assert_array_equal
import numpy as np

from pyriemann.utils.mean import (mean_riemann,mean_euclid,mean_logeuclid,mean_logdet,mean_ale,mean_covariance)
from pyriemann.utils.mean import (mean_riemann,mean_euclid,mean_logeuclid,mean_logdet,mean_ale,mean_identity,mean_covariance)

def generate_cov(Nt,Ne):
"""Generate a set of cavariances matrices for test purpose"""
Expand Down Expand Up @@ -36,7 +36,7 @@ def test_identity_mean():
"""Test the logdet mean"""
covmats = generate_cov(100,3)
C = mean_identity(covmats)
assert_array_equal(C,eye(3))
assert_array_equal(C,np.eye(3))

def test_logdet_mean():
"""Test the logdet mean"""
Expand Down

0 comments on commit cb70bcf

Please sign in to comment.