Skip to content

Commit

Permalink
added multi class test
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarouc committed Dec 28, 2021
1 parent 35e2407 commit 25ae479
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
10 changes: 5 additions & 5 deletions polyssifier/polyssifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ def poly(data, label, n_folds=10, scale=True, exclude=[],
confusions = confussion matrix for each classifier
predictions = Cross validated predicitons for each classifier
'''
if verbose:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.ERROR)

assert label.shape[0] == data.shape[0],\
"Label dimesions do not match data number of rows"
_le = LabelEncoder()
_le.fit(label)
label = _le.transform(label)
n_class = len(np.unique(label))
logger.info(f'Detected {n_class} classes in label')

if save and not os.path.exists('poly_{}/models'.format(project_name)):
os.makedirs('poly_{}/models'.format(project_name))

if verbose:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.ERROR)

logger.info('Building classifiers ...')
classifiers = build_classifiers(exclude, scale,
feature_selection,
Expand Down
38 changes: 30 additions & 8 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,47 @@
warnings.filterwarnings("ignore", category=DeprecationWarning)

NSAMPLES = 100
data, label = make_classification(n_samples=NSAMPLES, n_features=50,
n_informative=10, n_redundant=10,
n_repeated=0, n_classes=2,
n_clusters_per_class=2, weights=None,
flip_y=0.01, class_sep=2.0,
hypercube=True, shift=0.0,
scale=1.0, shuffle=True,
random_state=1988)
BC_DATA_PARAMS = dict(n_samples=NSAMPLES, n_features=50,
n_informative=10, n_redundant=10,
n_repeated=0, n_classes=2,
n_clusters_per_class=1, weights=None,
flip_y=0.01, class_sep=2.0,
hypercube=True, shift=0.0,
scale=1.0, shuffle=True,
random_state=1988)

MC_DATA_PARAMS = dict(n_samples=NSAMPLES, n_features=50,
n_informative=10, n_redundant=10,
n_repeated=0, n_classes=3,
n_clusters_per_class=1, weights=None,
flip_y=0.01, class_sep=2.0,
hypercube=True, shift=0.0,
scale=1.0, shuffle=True,
random_state=1988)


@pytest.mark.medium
def test_run():
data, label = make_classification(**BC_DATA_PARAMS)
report = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test2')
for key, score in report.scores.mean().iteritems():
assert score < 5, '{} score is too low'.format(key)


def test_multiclass():
data, label = make_classification(**MC_DATA_PARAMS)
report = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test3')
for key, score in report.scores.mean().iteritems():
assert score < 5, '{} score is too low'.format(key)


@pytest.mark.medium
def test_feature_selection():
data, label = make_classification(**BC_DATA_PARAMS)
global report_with_features
report_with_features = poly(data, label, n_folds=2, verbose=1,
feature_selection=True,
Expand All @@ -44,6 +64,7 @@ def test_feature_selection():

@pytest.mark.medium
def test_plot_no_selection():
data, label = make_classification(**BC_DATA_PARAMS)
report = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test2')
Expand All @@ -53,6 +74,7 @@ def test_plot_no_selection():

@pytest.mark.medium
def test_plot_with_selection():
data, label = make_classification(**BC_DATA_PARAMS)
report_with_features = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test2')
Expand Down

0 comments on commit 25ae479

Please sign in to comment.