Get newsgroup data

In [5]:
from sklearn.datasets import fetch_20newsgroups
categories = ['comp.graphics', 'sci.space']
data_train = fetch_20newsgroups(subset='train', categories=categories, random_state=42)
data_test = fetch_20newsgroups(subset='test', categories=categories, random_state=42)

Setup clean text function

In [6]:
def is_letter_only(word):
    for char in word:
        if not char.isalpha():
            return False
    return True

from nltk.corpus import names
all_names = set(names.words())
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
def clean_text(data):
    data_cleaned = []
    for doc in data:
        doc = doc.lower()
        doc_cleaned = ' '.join(lemmatizer.lemmatize(word) for word in doc.split() if is_letter_only(word) and word not in all_names)
        data_cleaned.append(doc_cleaned)
    return data_cleaned

Clean data

In [7]:
cleaned_train = clean_text(data_train.data)
label_train = data_train.target
cleaned_test = clean_text(data_test.data)
label_test = data_test.target
print(len(label_train), len(label_test))

1177 783


Verify balance

In [8]:
from collections import Counter
print(Counter(label_train))
print(Counter(label_test))

Counter({1: 593, 0: 584})
Counter({1: 394, 0: 389})


Setup TfidfVectorizer

In [9]:
from sklearn.feature_extraction.text import TfidfVectorizer
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=None)
term_docs_train = tfidf_vectorizer.fit_transform(cleaned_train)
term_docs_test = tfidf_vectorizer.transform(cleaned_test)

Setup SVM classifier

In [10]:
from sklearn.svm import SVC
svm = SVC(kernel='linear', C=1.0, random_state=42)

Fit model on training set

In [11]:
svm.fit(term_docs_train, label_train)

SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
    kernel='linear', max_iter=-1, probability=False, random_state=42,
    shrinking=True, tol=0.001, verbose=False)

Predict on test set

In [13]:
accuracy = svm.score(term_docs_test, label_test)
print('The accuracy of binary classification is:', '{0:.1f}%'.format(accuracy*100))

The accuracy of binary classification is: 96.7%
