Skip to content

Commit

Permalink
TST speedup tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Mar 3, 2016
1 parent c92f4d9 commit 35e8e00
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
6 changes: 4 additions & 2 deletions formasaurus/fieldtype_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
def train(annotations,
use_precise_form_types=True,
optimize_hyperparameters_iters=0,
optimize_hyperparameters_folds=5,
optimize_hyperparameters_jobs=-1,
full_form_type_names=False,
full_field_type_names=True,
verbose=True):
Expand Down Expand Up @@ -106,9 +108,9 @@ def log(msg):
}

rs = RandomizedSearchCV(crf, params_space,
cv=get_annotation_folds(annotations, 5),
cv=get_annotation_folds(annotations, optimize_hyperparameters_folds),
verbose=verbose,
n_jobs=-1,
n_jobs=optimize_hyperparameters_jobs,
n_iter=optimize_hyperparameters_iters,
iid=False,
scoring=scorer
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_check_data():


def test_evaluate():
out = subprocess.check_output('formasaurus evaluate all --cv 3', shell=True)
out = subprocess.check_output('formasaurus evaluate all --cv 2', shell=True)
m = re.search(b"(\d+.\d+)% forms are classified correctly", out)
assert m
assert float(m.group(1)) > 80
Expand Down
11 changes: 7 additions & 4 deletions tests/test_fieldtype_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division
import itertools

import numpy as np
from sklearn_crfsuite.metrics import flat_accuracy_score
Expand All @@ -13,16 +14,18 @@


def test_training(storage, capsys):

annotations = list(a for a in storage.iter_annotations(
annotations = (a for a in storage.iter_annotations(
simplify_form_types=True,
simplify_field_types=True,
) if a.fields_annotated)[:300]
) if a.fields_annotated)
annotations = list(itertools.islice(annotations, 0, 300))

crf = train(
annotations=annotations,
use_precise_form_types=False,
optimize_hyperparameters_iters=10,
optimize_hyperparameters_iters=2,
optimize_hyperparameters_folds=2,
optimize_hyperparameters_jobs=-1,
full_form_type_names=False,
full_field_type_names=False
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_formtype_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division
import itertools

from sklearn.metrics import accuracy_score

from formasaurus.formtype_model import get_realistic_form_labels


def test_get_realistic_formtypes(storage):
annotations = list(storage.iter_annotations())
annotations = list(itertools.islice(storage.iter_annotations(), 0, 300))
y_true = [a.type_full for a in annotations]
y_pred = get_realistic_form_labels(annotations, n_folds=5)
y_pred = get_realistic_form_labels(annotations, n_folds=3)
score = accuracy_score(y_true, y_pred)
assert 0.7 < score < 0.98


0 comments on commit 35e8e00

Please sign in to comment.