In [1]:
import nltk
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LinearRegression, SGDClassifier
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from sklearn.pipeline import Pipeline

In [2]:
def test_validate_train_split(X, y,
                              test_size: float = None,
                              random_state: int = None, **kwargs
                              ) -> tuple:
    '''
        Completes Sci-Kit Learn's test_train_split twice to split data into
        three sections
        
        Paramaters:
        X: The dataset without the target present
        y: The target values of the dataset
        test_size: The proportion of the data set in the test set.
            It is also the proportion of the remainder used for the validation set
        random_state: Controls the shuffling 
        **kwargs: These are passed to both test_train_split functions
        
        Returns a tuple of:
            X_train, X_validate, X_test, y_train, y_validate, y_test
    '''
    # Complete test_train split
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, **kwargs)

    # Complete the test_validation split
    X_train, X_validate, y_train, y_validate = train_test_split(
        X_train_val, y_train_val, test_size=test_size,
        random_state=random_state, **kwargs)
    
    return X_train, X_validate, X_test, y_train, y_validate, y_test

In [3]:
# Set a randomstate to be used throughout for reproducability
RANDOM_STATE = 42

In [4]:
# Load in data
data_filepath = '20_news_groups/20_newsgroups'

newsgroups_data = load_files(
    data_filepath, shuffle=True, random_state=RANDOM_STATE, encoding='ISO-8859-1')
print(f'{len(newsgroups_data.data)} files loaded.')
print('They contain the following classes:')
newsgroups_data.target_names

19997 files loaded.
They contain the following classes:


['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [5]:
# Complete a test_validate_train split
X_train, X_val, X_test, y_train, y_val, y_test = test_validate_train_split(
    newsgroups_data.data, newsgroups_data.target, test_size=0.2, random_state=RANDOM_STATE)

In [6]:
# Read in the stop words
stop_words = nltk.corpus.stopwords.words('english')
# Given out tokenizer there is an issue with contracted words from our stop_words
# you'd is in out stop_words, but is tokenised to ["you", "'d"]
# which would be missed, so run our stop words through the tokenizer to match
nested_tokenized_stop_words = [nltk.word_tokenize(stop_word) for stop_word in stop_words]
# This results in a list of lists which need to be flattened
tokenized_stop_words = [word
                        # for each list in the bigger list we want
                        for list_of_words in nested_tokenized_stop_words
                        # each word in the list
                        for word in list_of_words]

In [7]:
# Create a Kfold cross validoator for grid search
kfold_cv = KFold(5, shuffle=True, random_state=RANDOM_STATE)

In [8]:
# Create a base pipeline for gridsearch
# Pipeline accepts a list of tuples, a name of a step, and a step
pipe = Pipeline([
    ('vectorizer', TfidfVectorizer(
        tokenizer=nltk.word_tokenize,
        min_df=2,
        ngram_range=(1,2),
        stop_words=tokenized_stop_words,
        token_pattern=None,
        norm='l2'
    )),
    ('clf', MultinomialNB()),
])

In [9]:
# Define the grid search paramaters
vectorizer_params = {
    'vectorizer__ngram_range':[(1, 1), (1, 2), (1, 3), (2, 3)],
    'vectorizer__norm': ['l1', 'l2', None]
}
lin_reg_params = {**vectorizer_params,
                  'clf': [LinearRegression()], 
}
sgd_params = {**vectorizer_params,
              'clf': [SGDClassifier()],
              'clf__penalty': ['l1', 'l2', 'elasticnet'],
              'clf__alpha': [1e-4, 1e-3, 1e-2, 1e-1],
}
multi_nb_params = {**vectorizer_params,
                   'clf': [MultinomialNB()],
                   'clf__alpha': [1e-4, 1e-3, 1e-2, 1e-1],
}
bern_nb_params = {**vectorizer_params,
                  'clf':[BernoulliNB()],
                  'clf__alpha': [1e-4, 1e-3, 1e-2, 1e-1],
}

# Combine params
grid_search_params = [lin_reg_params, sgd_params,
                      multi_nb_params, bern_nb_params]

In [10]:
grid_search = GridSearchCV(pipe, grid_search_params,
                           cv=kfold_cv, n_jobs=-1)

In [11]:
grid_search.fit(X_train, y_train)

In [12]:
grid_search.cv_results_

{'mean_fit_time': array([  35.11835742,   36.5081316 ,   38.619245  ,   61.68014174,
          82.56189351,   95.95716972,  132.35044179,  127.64264026,
         145.823633  ,  119.62892447,  112.53214836,  108.58910198,
          38.44853902,   37.50822973,  347.30908442,   49.29804573,
          51.10500746, 1139.3836791 ,   79.64142418,   77.98696809,
        1372.16644263,   64.82094178,   78.63352051,  968.90432916,
          42.89587479,   38.25786762,  125.50247798,  176.54694695,
          56.89464951,   63.50891261,   71.58387475,   74.2974277 ,
          86.44925179,   62.11880393,   59.24965811,   68.51940875,
          43.13951731,   42.67154741,   58.73323445,   61.61384993,
          61.08956575,   85.60191779,   82.33117361,   84.17193937,
         137.81996303,   67.10919652,   66.24609718,   97.00377016,
          37.76658883,   42.11158023,  589.55387354,   53.6471035 ,
          57.70177765, 1559.65388021,   77.79464974,   82.48590117,
        2260.57853785,   72.688

In [13]:
grid_search.best_score_

0.9675711948026574

In [14]:
grid_search.best_params_

{'clf': SGDClassifier(),
 'clf__alpha': 0.0001,
 'clf__penalty': 'l1',
 'vectorizer__ngram_range': (1, 2),
 'vectorizer__norm': 'l2'}