In [20]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_files
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score, classification_report
 
DATA_DIR = "./bbc/"

In [21]:
data = load_files(DATA_DIR, encoding="utf-8", decode_error="replace")
# calculate count of each category
labels, counts = np.unique(data.target, return_counts=True)
# convert data.target_names to np array for fancy indexing
labels_str = np.array(data.target_names)[labels]
print(dict(zip(labels_str, counts)))

{'business': 510, 'entertainment': 386, 'politics': 417, 'sport': 511, 'tech': 401}


In [22]:
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target)
list(t[:80] for t in X_train[:10])

['US budget deficit to reach $368bn\n\nThe US budget deficit is set to hit a worse-t',
 'WMC says Xstrata bid is too low\n\nAustralian mining firm WMC Resources has said i',
 "Roxy Music on Isle of Wight bill\n\nRoxy Music will appear at June's Isle of Wight",
 "Newcastle 2-1 Bolton\n\nKieron Dyer smashed home the winner to end Bolton's 10-gam",
 'Italy 8-38 Wales\n\nWales secured their first away win in the RBS Six Nations for ',
 'Newry to fight cup exit in courts\n\nNewry City are expected to discuss legal aven',
 'Strong demand triggers oil rally\n\nCrude oil prices surged back above the $47 a b',
 'Be careful how you code\n\nA new European directive could put software writers at ',
 'Worldcom director ends evidence\n\nThe former chief financial officer at US teleco',
 "Ministers deny care sums 'wrong'\n\nMinisters have insisted they are committed to "]

In [23]:
vectorizer = TfidfVectorizer(stop_words="english", max_features=1000, decode_error="ignore")
vectorizer.fit(X_train)
X_train_vectorized = vectorizer.transform(X_train)

> Using cross_val_score function, we’ll train the each model two times and record their mean accuracy. We’ll choose the highest performing model and train it and then evaluate it in the test set.

In [29]:
# start with the classic
# with either pure counts or tfidf features
sgd = Pipeline([
        ("count vectorizer", CountVectorizer(stop_words="english", max_features=3000)),
        ("sgd", SGDClassifier(loss="modified_huber"))
    ])
sgd_tfidf = Pipeline([
        ("tfidf_vectorizer", TfidfVectorizer(stop_words="english", max_features=3000)),
        ("sgd", SGDClassifier(loss="modified_huber"))
    ])
 
svc = Pipeline([
        ("count_vectorizer", CountVectorizer(stop_words="english", max_features=3000)),
        ("linear svc", SVC(kernel="linear"))
    ])
svc_tfidf = Pipeline([
        ("tfidf_vectorizer", TfidfVectorizer(stop_words="english", max_features=3000)),
        ("linear svc", SVC(kernel="linear"))
    ])
   
all_models = [
    ("sgd", sgd),
    ("sgd_tfidf", sgd_tfidf),
    ("svc", svc),
    ("svc_tfidf", svc_tfidf),
    ]
 
unsorted_scores = [(name, cross_val_score(model, X_train, y_train, cv=2).mean()) for name, model in all_models]
scores = sorted(unsorted_scores, key=lambda x: -x[1])
print(scores)  

[('svc_tfidf', 0.974220623501199), ('sgd_tfidf', 0.9592326139088729), ('svc', 0.9574340527577938), ('sgd', 0.9544364508393286)]


> Support Vector Machine with tf-idf features scored the highest accuracy of 97%. Lets train it and evaluate it in the test dataset.

In [25]:
model = svc_tfidf
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))

0.9748653500897666
              precision    recall  f1-score   support

           0       0.98      0.95      0.96       131
           1       0.96      0.99      0.97        92
           2       0.96      0.96      0.96       102
           3       0.99      0.99      0.99       137
           4       0.97      0.99      0.98        95

    accuracy                           0.97       557
   macro avg       0.97      0.98      0.97       557
weighted avg       0.98      0.97      0.97       557

