In [None]:
import re
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import sys
import os
sys.path.append(os.path.join('..', '..'))
from vowpal_wabbit import VowpalWabbitClassifier

In [None]:
quora = pd.read_csv('../data/.input/train.csv')

In [None]:
quora_train, quora_test = train_test_split(quora, random_state=42)
quora_train_data = quora_train.question_text
quora_train_labels = quora_train.target * 2.0 - 1.0
quora_test_data = quora_test.question_text
quora_test_labels = quora_test.target * 2.0 - 1.0

In [None]:
def to_vw_format(document, label=None):
    return str(label or '') + ' |text ' + ' '.join(re.findall('\w{3,}', document.lower())) + '\n'

In [None]:
!!mkdir .input
with open('.input/train.vw', 'w', encoding='utf-8') as vw_train_data:
    for text, target in zip(quora_train_data, quora_train_labels):
        vw_train_data.write(to_vw_format(text, target))
with open('.input/test.vw', 'w', encoding='utf-8') as vw_test_data:
    for text in quora_test_data:
        vw_test_data.write(to_vw_format(text))

In [None]:
fit_params = {
        '--loss_function': 'logistic',
        '-b': 27,
}
vw = VowpalWabbitClassifier(working_dir = '.input', debug = True, fit_params = fit_params)
vw.fit('.input/train.vw')

In [None]:
quora_test_pred = vw.predict('.input/test.vw')

In [None]:
quora_test_prediction = vw.predict_proba_

auc = roc_auc_score(quora_test_labels, quora_test_prediction)
curve = roc_curve(quora_test_labels, quora_test_prediction)


plt.plot(curve[0], curve[1]);
plt.plot([0,1], [0,1])
plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('test AUC = %f' % (auc)); plt.axis([-0.05,1.05,-0.05,1.05]);

In [None]:
print('accuracy', accuracy_score(quora_test_labels, quora_test_pred))
print('precision', precision_score(quora_test_labels, quora_test_pred))
print('recall', recall_score(quora_test_labels, quora_test_pred))
print('f1', f1_score(quora_test_labels, quora_test_pred))
confusion_matrix(quora_test_labels, quora_test_pred)