# Naive Bayes Model for Newsgroups Data

For an explanation of the Naive Bayes model, see [our course notes](https://jennselby.github.io/MachineLearningCourseNotes/#naive-bayes).

This notebook uses code from http://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html.

In [1]:
from sklearn.datasets import fetch_20newsgroups # the 20 newgroups set is included in scikit-learn
from sklearn.naive_bayes import MultinomialNB # we need this for our Naive Bayes model

# These next two are about processing the data. We'll look into this more later in the semester.
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer

## Newgroups Data

Back in the day, [Usenet](https://en.wikipedia.org/wiki/Usenet_newsgroup) was a popular discussion system where people could discuss topics in relevant newsgroups (think Slack channel or subreddit). At some point, someone pulled together messages sent to 20 different newsgroups, to use as [a dataset for doing text processing](http://qwone.com/~jason/20Newsgroups/).

We are going to pull out messages from just a few different groups to try out a Naive Bayes model.

Examine the newsgroups dictionary, to make sure you understand the dataset.

In [2]:
# which newsgroups we want to download
newsgroup_names = ['comp.graphics', 'rec.sport.hockey', 'sci.electronics', 'sci.space']

# get the newsgroup data (organized much like the iris data)
newsgroups = fetch_20newsgroups(categories=newsgroup_names, shuffle=True, random_state=265)
newsgroups.keys()

Downloading 20news dataset. This may take a few minutes.
Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1045)>

This next part does some processing of the data, because the scikit-learn Naive Bayes module is expecting numerical data rather than text data. We will talk more about what this code is doing later in the semester. For now, you can ignore it.

In [None]:
# Convert the text into numbers that represent each word (bag of words method)
word_vector = CountVectorizer()
word_vector_counts = word_vector.fit_transform(newsgroups.data)

# Account for the length of the documents:
#   get the frequency with which the word occurs instead of the raw number of times
term_freq_transformer = TfidfTransformer()
term_freq = term_freq_transformer.fit_transform(word_vector_counts)

## Model Training

Now we fit the Naive Bayes model to the subset of the 20 newsgroups data that we've pulled out.

In [None]:
# Train the Naive Bayes model
model = MultinomialNB().fit(term_freq, newsgroups.target)

## Prediction

Let's see how the model does on some (very short) documents that we made up to fit into the specific categories our model is trained on.

In [None]:
# Predict some new fake documents
fake_docs = [
    'That GPU has amazing performance with a lot of shaders',
    'The player had a wicked slap shot',
    'I spent all day yesterday soldering banks of capacitors',
    'Today I have to solder a bank of capacitors',
    'NASA has rovers on Mars',
    'The programmer had to flip the bits in the transistors to trigger the']
fake_counts = word_vector.transform(fake_docs)
fake_term_freq = term_freq_transformer.transform(fake_counts)

predicted = model.predict(fake_term_freq)
print('Predictions:')
for doc, group in zip(fake_docs, predicted):
    print('\t{0} => {1}'.format(doc, newsgroups.target_names[group]))

probabilities = model.predict_proba(fake_term_freq)
print('Probabilities:')
print(''.join(['{:17}'.format(name) for name in newsgroups.target_names]))
for probs in probabilities:
    print(''.join(['{:<17.8}'.format(prob) for prob in probs]))

# Exercise Option (Standard Difficulty)

Modify the fake documents and add some of your own. Find words that have particularly large effects on the model probabilities.

# Exercise Option (Advanced)

Write some code to count up how often the words you found in the exercise above appear in each category in the training dataset. Does this match up with your intuition?