Skip to content

Commit

Permalink
just noticed sklearn.cross_validation.cross_val_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Dec 18, 2015
1 parent b64e7f3 commit ffe50a6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 21 deletions.
13 changes: 3 additions & 10 deletions formasaurus/fieldtype_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
import numpy as np
from sklearn.grid_search import RandomizedSearchCV
from sklearn.metrics import make_scorer
from sklearn.cross_validation import cross_val_predict
from sklearn_crfsuite import CRF
from sklearn_crfsuite.metrics import flat_f1_score

from formasaurus import formtype_model
from formasaurus.html import get_fields_to_annotate, get_text_around_elems
from formasaurus.text import (normalize, tokenize, ngrams, number_pattern,
token_ngrams)
from formasaurus.utils import select_by_index
from formasaurus.annotation import get_annotation_folds


Expand Down Expand Up @@ -152,15 +152,8 @@ def get_realistic_form_labels(annotations, n_folds=10, model=None,
else:
y = np.asarray([a.type for a in annotations])

y_pred = np.empty(len(annotations), dtype=object)

for idx_train, idx_test in get_annotation_folds(annotations, n_folds):
X_train = select_by_index(X, idx_train)
X_test = select_by_index(X, idx_test)
model.fit(X_train, y[idx_train])
y_pred[idx_test] = model.predict(X_test)

return y_pred
folds = get_annotation_folds(annotations, n_folds)
return cross_val_predict(model, X, y, cv=folds)


def get_form_features(form, form_type, field_elems=None):
Expand Down
11 changes: 0 additions & 11 deletions formasaurus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,6 @@ def inverse_mapping(dct):
return {v:k for k,v in dct.items()}


def select_by_index(arr, index):
"""
Like numpy indexing, but for lists. This is for cases
conversion to numpy array is problematic.
>>> select_by_index(['a', 'b', 'c', 'd'], [0, 3])
['a', 'd']
"""
return [arr[i] for i in index]


def at_root(*args):
""" Return path relative to formasaurus source code """
return os.path.join(os.path.dirname(__file__), *args)
Expand Down

0 comments on commit ffe50a6

Please sign in to comment.