In [14]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier
from nltk.stem import SnowballStemmer
from platform import python_version
# nltk.download('stopwords')
import nltk
from sklearn.model_selection import GridSearchCV


In [15]:
class StemmedCountVectorizer(CountVectorizer):
    stemmer = SnowballStemmer("english", ignore_stopwords=True)

    def build_analyzer(self):
        analyzer = super(StemmedCountVectorizer, self).build_analyzer()
        return lambda doc: ([self.stemmer.stem(w) for w in analyzer(doc)])

In [16]:
def train_test_pipeline(pipe, pipe_parameters):
    train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'), shuffle=True)
    gs_clf = GridSearchCV(pipe, pipe_parameters, cv=5, n_jobs=-1)
    gs_clf = gs_clf.fit(train.data, train.target)
    print("Best score: %s" % gs_clf.best_score_)
    print("Best param: %s" % gs_clf.best_params_)
    print()


In [19]:
pipe_clf = Pipeline([
    ('vect', CountVectorizer()),
    ('tfidf', TfidfTransformer()),
    ('rf', RandomForestClassifier())
])

In [20]:
parameters = {
    'vect__ngram_range': [(1, 2), (2, 2)],
    # 'rf__max_depth': (None, 20, 40)
}

In [None]:
train_test_pipeline(pipe=pipe_clf, pipe_parameters=parameters)

In [17]:
pipe_rf_2 = Pipeline([
    ('vect', CountVectorizer(ngram_range=(1, 2))),
    ('tfidf', TfidfTransformer()),
    ('rf', RandomForestClassifier(max_depth=None))
])

In [12]:
train_test_pipeline(pipe=pipe_rf_2, pipe_parameters={})

KeyboardInterrupt: 

In [18]:
train = fetch_20newsgroups(subset='train', shuffle=True)
test = fetch_20newsgroups(subset='test', shuffle=True)
pipe_rf_2.fit(train.data, train.target)
predicted_stemmed = pipe_rf_2.predict(test.data)

print('Accuracy after stemming: %s' % np.mean(predicted_stemmed == test.target))

Accuracy after stemming: 0.7623473181093999
