In [1]:
import re
from itertools import chain
from collections import Counter

from scipy import stats as st
import pycrfsuite
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
from sklearn.grid_search import RandomizedSearchCV, GridSearchCV
from sklearn.cross_validation import cross_val_score
from sklearn_crfsuite import CRF

from formasaurus.storage import Storage
from formasaurus.html import get_fields_to_annotate, html_tostring, get_fields_to_annotate, get_text_around_elems
from formasaurus.text import ngrams, normalize, token_ngrams, tokenize, normalize_whitespaces, number_pattern 
from formasaurus import formtype_model, evaluation

In [2]:
storage = Storage('../formasaurus/data/')
index = storage.get_index()
field_schema = storage.get_field_schema()
form_schema = storage.get_form_schema()

Many field types have only a few examples in annotation data, so here we're considering only some of the field types and use more coarse grained classes. 

In [3]:
SUPPORTED = {
    'username',
    'email',
    'email confirmation',
    'password',
    'password confirmation',
    'OpenID',
    'username or email',

    'remember me checkbox',
    'receive emails confirmation',
    'TOS confirmation',
    
    'submit button',
    'cancel button',
    'reset/clear button',    
    
    'search query',
    'search category / refinement',
    
    'first name',
    'last name',
    'full name',
    'gender',
    'website/url',
    'organization name',

    'year',
    'month',
    'day',
    'full date',

    'captcha',
    'honeypot',
    
    'comment title/subject',
    'comment text',
    
    'zip/postal code',
    'city',
    'country',
    'state',
    'time zone',
    'DST',
    
    'phone',
    'phone part',  # not in training data
    'phone part 1',
    'phone part 2',
    'phone part 3',
    
    'other read-only',
    'NOT ANNOTATED',
}

def simplify_type(tp, other_tp='OO'):
    tp = field_schema.types_inv.get(tp, tp)
    if 'phone' in tp:
        tp = 'phone'        
    if 'full' in tp and 'date' in tp:
        tp = 'full date'
    if 'year' in tp:
        tp = 'year'
    if 'month' in tp:
        tp = 'month'
    if 'day' in tp:
        tp = 'day'
    if tp in SUPPORTED:
        return tp.replace(' ', '_').upper()
    return other_tp

In [4]:
def get_fields_annotation(ann):
    return ann.info['visible_html_fields'][ann.index]

def fields_annotation_complete(ann):
    field_ann = get_fields_annotation(ann)
    if not field_ann:
        return False
    return all(v != field_schema.na_value for v in field_ann.values())

def fields_annotation_partial(ann):
    field_ann = get_fields_annotation(ann)
    if not field_ann:
        return False
    values = field_ann.values()
    return any(v == field_schema.na_value for v in values) and not all(v == field_schema.na_value for v in values)

annotations_all = list(a for a in storage.iter_annotations(index) if True)
annotations_complete = [a for a in annotations_all if fields_annotation_complete(a)]
# annotations_formonly = [a for a in annotations_all if not fields_annotation_complete(a)]
len(annotations_complete), len(annotations_all)  #, len(annotations_formonly)

(964, 1016)

The model is two-stage:

1. First, we train Formasaurus form type detector.
2. Second, we use form type detector results to improve quality of field type detection.

We have form types available directly in training data, but in reality form type detecor will make mistakes. It is better for field type detector to account for this and not rely on form types blindly. So it should be trained on input where form type detection quality is roughly the same it'll be in real life. 

To do that we split training data into 3 parts:

1. data for form type detector (25% + ~~partial data without field annotations~~);
2. data for field type detector (50%);
3. testing data (25%).

Then train form type detector on its data, generate form type labels for the rest of the data and train/check field type detector on that.

In [5]:
FORM_TRAIN_SIZE = int(len(annotations_complete) * 0.25)
FIELD_TRAIN_SIZE = int(len(annotations_complete) * 0.5)
TEST_SIZE = len(annotations_complete) - (FORM_TRAIN_SIZE + FIELD_TRAIN_SIZE)

annotations_form = annotations_complete[:FORM_TRAIN_SIZE]  # + annotations_formonly
annotations_field = annotations_complete[FORM_TRAIN_SIZE:FIELD_TRAIN_SIZE + FORM_TRAIN_SIZE]
annotations_test = annotations_complete[FIELD_TRAIN_SIZE + FORM_TRAIN_SIZE:]

len(annotations_form), len(annotations_field), len(annotations_test)

(241, 482, 241)

Train form type detector and check that its quality is not too much worse than quality of detector trained on all data. First, check quality of a full model:

In [6]:
X_ft_train = [a.form for a in annotations_form]
y_ft_train = [a.type for a in annotations_form]

X_ft_test = [a.form for a in annotations_field + annotations_test]
y_ft_test = [a.type for a in annotations_field + annotations_test]

ft_model = formtype_model.get_model()

First, check estimated quality of a full model:

In [7]:
evaluation.print_cv_scores(ft_model, X_ft_train + X_ft_test, y_ft_train + y_ft_test, cv=5)

  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


5-fold cross-validation F1: 0.864 (±0.017)  min=0.855  max=0.875


Then train model on a subset of data and check that quality is not **too** much worse:

In [8]:
ft_model.fit(X_ft_train, y_ft_train)
ft_model.score(X_ft_test, y_ft_test)

0.80359612724757956

Ok, now build field type extraction model.

In [10]:
def _get_field_types(ann):
    """ Return a list of field types """
    field_types = {
        k: simplify_type(v) 
        for k, v in get_fields_annotation(ann).items()
    }   
    return [
        field_types[field.name]
        for field in get_fields_to_annotate(ann.form)
    ]


def _elem_features(elem):
    elem_name = normalize(elem.name)
    feat = {
        'name': tokenize(elem_name),
        'name-ngrams-3-5': ngrams(elem_name, 3, 5),
        'value': ngrams(normalize(elem.get('value', '')), 5, 5),
        'placeholder': tokenize(normalize(elem.get('placeholder', ''))),
        'tag': elem.tag,
        'css_class': ngrams(normalize(elem.get('class', '')), 5, 5),
    }
    label = elem.label
    if label is not None:
        label_text = normalize(label.text_content())
        feat['label'] = tokenize(label_text)
        feat['label-ngrams-3-5'] = ngrams(label_text, 3, 5)
        
    if elem.tag == 'input':
        feat['input-type'] = elem.get('type', 'text').lower()
        
    if elem.tag == 'select':
        feat['option-text'] = [normalize(v) for v in elem.xpath('option//text()')]
        feat['option-value'] = [normalize(el.get('value', '')) for el in elem.xpath('option')]
        feat['option-num-pattern'] = list(
            {number_pattern(v) for v in feat['option-text'] + feat['option-value']}
        )
        
    return feat


def get_form_features(form, form_type):
    field_elems = get_fields_to_annotate(form)
    text_before, text_after = get_text_around_elems(form, field_elems)
    res = [_elem_features(elem) for elem in field_elems]
    
    for idx, elem_feat in enumerate(res):
        elem_feat['form-type'] = form_schema.types_inv.get(form_type, form_type)        
        # get text before element
        text = normalize(text_before[field_elems[idx]])
        tokens = tokenize(text)[-6:]
        elem_feat['text-before'] = token_ngrams(tokens, 1, 2)
        
        # get text after element
        text = normalize(text_after[field_elems[idx]])
        tokens = tokenize(text)[:5]
        elem_feat['text-after'] = token_ngrams(tokens, 1, 2)
        
    return res


def get_Xy(ft_model, annotations):
    """ Return training data for field type detection """
    forms = [a.form for a in annotations]
    predicted_form_types = ft_model.predict(forms)
    X = [
        get_form_features(form, form_type) 
        for form, form_type in zip(forms, predicted_form_types)
    ]
    y = [_get_field_types(a) for a in annotations]
    return X, y

#get_Xy(ft_model, annotations_field[:1])
# _ann = annotations_all[0]
# get_form_features(_ann.form, _ann.type)

In [11]:
X_train, y_train = get_Xy(ft_model, annotations_field)
X_test, y_test = get_Xy(ft_model, annotations_test)
X, y = X_train + X_test, y_train + y_test
len(X_train), len(X_test)

(482, 241)

In [114]:
#[(idx, yseq) for idx, yseq in enumerate(y_train) if 'GENDER' in yseq]

In [13]:
#X_train[10]

Find regularization parameters using randomized search:

In [14]:
%%time
params_space = {
    'c1': st.expon(scale=1.0),
    'c2': st.expon(scale=0.5),
#     'max_iterations': range(80, 200)
#     'min_freq': [None, 2],
}

crf = CRF(all_possible_transitions=True, max_iterations=100)
rs = RandomizedSearchCV(crf, params_space, cv=5, verbose=1, n_jobs=-1, n_iter=50, iid=False)
rs.fit(X_train, y_train)

crf = rs.best_estimator_
print('params:', rs.best_params_)
print('score:', rs.best_score_)
print("model size: ", int(crf.tagger.info().header['size']) // 1000, "K")

Fitting 5 folds for each of 50 candidates, totalling 250 fits


[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:   19.5s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 250 out of 250 | elapsed:  2.2min finished


params: {'c2': 0.023321796892154266, 'c1': 0.567016489417367}
score: 0.783089706131
model size:  299 K
CPU times: user 23.5 s, sys: 1.12 s, total: 24.7 s
Wall time: 2min 17s


In [15]:
crf.score(X_test, y_test)

0.83125000000000004

In [16]:
# %%time
# crf = CRF(c1=0.358, c2=0.161, verbose=False, max_iterations=200)
# crf.fit(X_train, y_train)
# crf.score(X_test, y_test)

In [27]:
# crf.fit(X, y)

In [17]:
scores = cross_val_score(crf, X_train, y_train, cv=3, n_jobs=4)
print("Accuracy: {:0.3f} ± {:0.3f}".format(scores.mean(), 2*scores.std()))
print(scores)

Accuracy: 0.754 ± 0.084
[ 0.71449925  0.81258941  0.73555841]


In [18]:
def flatten(y):
    return list(chain.from_iterable(y))

labels = crf.tagger.labels()
y_pred_flat = flatten(crf.predict(X_test))
y_true_flat = flatten(y_test)
print(classification_report(y_true_flat, y_pred_flat, digits=3, labels=labels, target_names=labels))

                              precision    recall  f1-score   support

                SEARCH_QUERY      0.892     0.987     0.937        75
                          OO      0.549     0.821     0.658        95
                   FULL_NAME      0.792     0.826     0.809        23
                       EMAIL      0.919     0.887     0.903       115
       COMMENT_TITLE/SUBJECT      0.885     0.793     0.836        29
                COMMENT_TEXT      0.913     0.913     0.913        23
               SUBMIT_BUTTON      0.918     1.000     0.957        56
                    PASSWORD      0.987     0.987     0.987        75
        REMEMBER_ME_CHECKBOX      0.929     0.963     0.945        27
       PASSWORD_CONFIRMATION      0.913     0.955     0.933        22
                     COUNTRY      1.000     0.375     0.545         8
                       PHONE      0.611     0.786     0.688        14
                         DAY      1.000     0.800     0.889         5
                   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


    precision    recall  f1-score   support

                    SEARCH_QUERY      0.882     1.000     0.938        45
                   SUBMIT_BUTTON      0.889     1.000     0.941        32
                           EMAIL      0.965     0.786     0.866        70
                        PASSWORD      1.000     1.000     1.000        49
            REMEMBER_ME_CHECKBOX      0.938     1.000     0.968        15
           PASSWORD_CONFIRMATION      1.000     1.000     1.000        15
                       FULL_NAME      0.857     0.600     0.706        10
                              OO      0.467     0.862     0.606        58
                         COUNTRY      1.000     0.375     0.545         8
                           PHONE      0.600     0.429     0.500         7
                             DAY      1.000     0.800     0.889         5
                           MONTH      1.000     0.800     0.889         5
                            YEAR      1.000     0.667     0.800         6
                         CAPTCHA      0.833     0.833     0.833         6
                TOS_CONFIRMATION      1.000     0.800     0.889         5
               USERNAME_OR_EMAIL      0.500     0.333     0.400         3
                        USERNAME      0.742     0.920     0.821        25
                      FIRST_NAME      1.000     0.778     0.875         9
                       LAST_NAME      1.000     0.778     0.875         9
                     WEBSITE/URL      1.000     0.333     0.500         3
           COMMENT_TITLE/SUBJECT      1.000     0.870     0.930        23
                    COMMENT_TEXT      1.000     0.500     0.667         8
                          GENDER      1.000     0.875     0.933        16
               ORGANIZATION_NAME      0.000     0.000     0.000         2
                            CITY      0.800     0.800     0.800         5
                 ZIP/POSTAL_CODE      1.000     0.714     0.833         7
     RECEIVE_EMAILS_CONFIRMATION      1.000     0.500     0.667         6
    SEARCH_CATEGORY_/_REFINEMENT      0.929     0.722     0.813        18
              RESET/CLEAR_BUTTON      0.000     0.000     0.000         0
                 OTHER_READ-ONLY      1.000     0.333     0.500         3
              EMAIL_CONFIRMATION      0.333     1.000     0.500         2
                       TIME_ZONE      0.000     0.000     0.000         0
                        ANTI-BOT      0.000     0.000     0.000         4
                           STATE      1.000     0.333     0.500         6
                       FULL_DATE      0.000     0.000     0.000         2

                     avg / total      0.860     0.823     0.820       487

In [19]:
import pandas as pd
pd.options.display.max_rows = 50
pd.options.display.max_columns = 50
labels_short = [name[:10] for name in labels]

pd.DataFrame(confusion_matrix(y_true_flat, y_pred_flat, labels), index=labels_short, columns=labels_short)

Unnamed: 0,SEARCH_QUE,OO,FULL_NAME,EMAIL,COMMENT_TI,COMMENT_TE,SUBMIT_BUT,PASSWORD,REMEMBER_M,PASSWORD_C,COUNTRY,PHONE,DAY,MONTH,YEAR,CAPTCHA,TOS_CONFIR,GENDER,LAST_NAME,FIRST_NAME,ORGANIZATI,ZIP/POSTAL,CITY,EMAIL_CONF,RECEIVE_EM,USERNAME_O,USERNAME,WEBSITE/UR,STATE,SEARCH_CAT,FULL_DATE,CANCEL_BUT,RESET/CLEA,OTHER_READ,HONEYPOT,TIME_ZONE,DST
SEARCH_QUE,74,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
OO,0,78,1,0,3,2,2,0,2,0,0,2,0,0,0,1,0,0,0,1,0,1,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0
FULL_NAME,0,2,19,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
EMAIL,4,3,0,102,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0
COMMENT_TI,0,6,0,0,23,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
COMMENT_TE,0,2,0,0,0,21,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
SUBMIT_BUT,0,0,0,0,0,0,56,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
PASSWORD,0,0,0,0,0,0,0,74,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
REMEMBER_M,0,0,0,1,0,0,0,0,26,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
PASSWORD_C,0,0,0,0,0,0,0,1,0,21,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [20]:
from collections import Counter
info = crf.tagger.info()

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-30s -> %-30s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(info.transitions).most_common(30))

print("\nTop unlikely transitions:")
print_transitions(Counter(info.transitions).most_common()[-20:])

Top likely transitions:
FIRST_NAME                     -> LAST_NAME                      4.434624
MONTH                          -> DAY                            4.136403
ZIP/POSTAL_CODE                -> CITY                           3.840611
PASSWORD                       -> PASSWORD_CONFIRMATION          3.643369
SEARCH_CATEGORY_/_REFINEMENT   -> SEARCH_CATEGORY_/_REFINEMENT   3.067800
DAY                            -> MONTH                          2.910512
COMMENT_TITLE/SUBJECT          -> COMMENT_TEXT                   2.829119
USERNAME                       -> PASSWORD                       2.809531
COMMENT_TITLE/SUBJECT          -> COMMENT_TITLE/SUBJECT          2.698229
COMMENT_TEXT                   -> HONEYPOT                       2.660318
MONTH                          -> YEAR                           2.275607
YEAR                           -> MONTH                          2.190361
TIME_ZONE                      -> DST                            2.180270
GENDER        

In [26]:
info = crf.tagger.info()

def _filtered_state_features(info, query, k=1):
    return Counter({
        (attr, label): weight
        for ((attr, label), weight) in info.state_features.items()
        if (query in attr or query in label) and k*weight >= 0
    })


def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-30s %s" % (weight, label, attr))    
        

def print_top_positive(info, N, query=''):
    print("\nTop positive:")
    cnt = _filtered_state_features(info, query, 1)
    print_state_features(cnt.most_common(N))
    

def print_top_negative(info, N, query=''):
    print("\nTop negative:")
    cnt = _filtered_state_features(info, query, -1)
    print_state_features(cnt.most_common()[-N:])
    

def print_top(info, N, query=''):
    cnt = _filtered_state_features(info, query, 0)
    print_state_features(cnt.most_common(N))
    

print_top(info, 150, 'HONEY')

# print("\nTop negative:")
# print_top_negative(info, 30, 'input-type')

2.695031 HONEYPOT                       label-ngrams-3-5: fi
1.776379 HONEYPOT                       css_class:honey
0.967930 HONEYPOT                       input-type:text
0.914294 HONEYPOT                       form-type:contact/comment
0.770739 HONEYPOT                       name-ngrams-3-5:oney
0.770739 HONEYPOT                       name-ngrams-3-5:honey
0.770739 HONEYPOT                       name-ngrams-3-5:ney
0.384961 HONEYPOT                       name-ngrams-3-5:one
0.366221 HONEYPOT                       name-ngrams-3-5:hone
0.365820 HONEYPOT                       name-ngrams-3-5:hon
0.292793 HONEYPOT                       text-before:value
0.292759 HONEYPOT                       name-ngrams-3-5:st2
0.292759 HONEYPOT                       name-ngrams-3-5:ottes
0.292759 HONEYPOT                       name-ngrams-3-5:ttes
0.292759 HONEYPOT                       text-before:name bottest1
0.292759 HONEYPOT                       name-ngrams-3-5:bott
0.292759 HONEYPOT            