Skip to content

Commit

Permalink
TST more tests for field type classifier training
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Dec 17, 2015
1 parent 7174f7f commit 606f254
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
8 changes: 6 additions & 2 deletions formasaurus/fieldtype_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
""" Default scorer for grid search. We're optimizing for F1. """


_PRECISE_C1_C2 = 0.1655, 0.0236 # values found by randomized search
_REALISTIC_C1_C2 = 0.247, 0.032 # values found by randomized search


def train(annotations,
use_precise_formtypes=True,
optimize_hyperparameters_iters=0,
Expand All @@ -67,15 +71,15 @@ def log(msg):
else:
form_types = np.asarray([a.type for a in annotations])
# c1, c2 = 0.0223, 0.0033 # values found by randomized search
c1, c2 = 0.1655, 0.0236 # values found by randomized search
c1, c2 = _PRECISE_C1_C2
else:
log("Computing realistic form types")
form_types = get_realistic_form_labels(
annotations=annotations,
n_folds=10,
full_type_names=full_form_type_names
)
c1, c2 = 0.247, 0.032 # values found by randomized search
c1, c2 = _REALISTIC_C1_C2

log("Extracting features")
X, y = get_Xy(
Expand Down
47 changes: 46 additions & 1 deletion tests/test_fieldtype_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division

import numpy as np
from sklearn.metrics import accuracy_score
from sklearn_crfsuite.metrics import flat_accuracy_score

from formasaurus.fieldtype_model import (
get_realistic_form_labels
get_realistic_form_labels,
train,
_PRECISE_C1_C2,
_REALISTIC_C1_C2,
get_Xy,
)


Expand All @@ -12,3 +20,40 @@ def test_get_realistic_formtypes(storage):
y_pred = get_realistic_form_labels(annotations, n_folds=5)
score = accuracy_score(y_true, y_pred)
assert 0.7 < score < 0.98


def test_training(storage, capsys):

annotations = list(a for a in storage.iter_annotations(
simplify_form_types=True,
simplify_field_types=True,
) if a.fields_annotated)[:300]

crf = train(
annotations=annotations,
use_precise_formtypes=False,
optimize_hyperparameters_iters=10,
full_form_type_names=False,
full_field_type_names=False
)

out, err = capsys.readouterr()

assert 'Training on 300 forms' in out
assert 'realistic form types' in out
assert 'Best hyperparameters' in out

assert 0.0 < crf.c1 < 1.5
assert 0.0 < crf.c2 < 0.9
assert crf.c1, crf.c2 != _REALISTIC_C1_C2
assert crf.c1, crf.c2 != _PRECISE_C1_C2

form_types = np.asarray([a.type for a in annotations])
X, y = get_Xy(annotations, form_types, full_type_names=False)
y_pred = crf.predict(X)
score = flat_accuracy_score(y, y_pred)
assert 0.9 < score < 1.0 # overfitting FTW!

field_schema = storage.get_field_schema()
short_names = set(field_schema.types_inv.keys())
assert set(crf.classes_).issubset(short_names)

0 comments on commit 606f254

Please sign in to comment.