In [1]:
from os import path

from matplotlib import pyplot as plt
import pandas as pd

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.pipeline import make_pipeline, make_union
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn import metrics

from functions import load_bad_words, build_data_path
from constants import LABEL_COLS

In [4]:
training_data_path = build_data_path('train.csv')

In [5]:
df = pd.read_csv(training_data_path)
X = df['comment_text']
# df['not_toxic'] = df[LABEL_COLS].apply(not_toxic, axis=1)
# LABEL_COLS.append('not_toxic')
y = df[LABEL_COLS]

In [7]:
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.33)

In [8]:
clf = OneVsRestClassifier(SVC(gamma='scale'))

tfidf = TfidfVectorizer(lowercase=True, stop_words='english')

pipeline = make_pipeline(tfidf, clf)

pipeline.fit(X_train, y_train)

y_predictions = pipeline.predict(X_valid)

In [9]:
hamming_loss = metrics.hamming_loss(y_valid, y_predictions, labels=LABEL_COLS)
jaccard = metrics.jaccard_similarity_score(y_valid, y_predictions)
print(f'Hamming loss (lower is better): {hamming_loss}')
print(f'Jaccard similarity (higher is better): {jaccard}')
print()
print(metrics.classification_report(y_valid, y_predictions, target_names=LABEL_COLS))

Hamming loss (lower is better): 0.02594365002500364
Jaccard similarity (higher is better): 0.9210622432379397

               precision    recall  f1-score   support

        toxic       0.91      0.33      0.48      5004
 severe_toxic       0.00      0.00      0.00       521
      obscene       0.91      0.47      0.62      2781
       threat       0.00      0.00      0.00       161
       insult       0.75      0.35      0.48      2535
identity_hate       0.00      0.00      0.00       452

    micro avg       0.87      0.34      0.48     11454
    macro avg       0.43      0.19      0.26     11454
 weighted avg       0.78      0.34      0.47     11454
  samples avg       0.03      0.03      0.03     11454



  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
