In [1]:
from sklearn.datasets import fetch_20newsgroups

In [2]:
twenty_train = fetch_20newsgroups(subset='train',shuffle=True,random_state=42)

In [3]:
twenty_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [4]:
len(twenty_train.data)

11314

In [5]:
len(twenty_train.filenames)

11314

In [6]:
print("\n".join(twenty_train.data[0].split("\n")[:3]))

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu


In [7]:
print(twenty_train.target_names[twenty_train.target[0]])

rec.autos


In [8]:
twenty_train.target[:10]

array([ 7,  4,  4,  1, 14, 16, 13,  3,  2,  4])

In [9]:
for t in twenty_train.target[:10]:
    print(twenty_train.target_names[t])

rec.autos
comp.sys.mac.hardware
comp.sys.mac.hardware
comp.graphics
sci.space
talk.politics.guns
sci.med
comp.sys.ibm.pc.hardware
comp.os.ms-windows.misc
comp.sys.mac.hardware


In [10]:
# Tokenizing text

from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(twenty_train.data)
X_train_counts.shape

(11314, 130107)

In [11]:
count_vect.vocabulary_.get(u'newspapers')

86945

In [12]:
#From occurrences to frequencies
from sklearn.feature_extraction.text import TfidfTransformer
tf_transformer = TfidfTransformer(use_idf = False).fit(X_train_counts)
X_train_tf = tf_transformer.transform(X_train_counts)
X_train_tf.shape

(11314, 130107)

In [13]:
print X_train_tf

  (0, 86580)	0.0640184399664
  (0, 128420)	0.0640184399664
  (0, 35983)	0.0640184399664
  (0, 35187)	0.0640184399664
  (0, 66098)	0.0640184399664
  (0, 114428)	0.0640184399664
  (0, 78955)	0.0640184399664
  (0, 94362)	0.0640184399664
  (0, 76722)	0.0640184399664
  (0, 57308)	0.0640184399664
  (0, 62221)	0.0640184399664
  (0, 128402)	0.128036879933
  (0, 67156)	0.0640184399664
  (0, 123989)	0.0640184399664
  (0, 90252)	0.0640184399664
  (0, 63363)	0.0640184399664
  (0, 78784)	0.0640184399664
  (0, 96144)	0.0640184399664
  (0, 128026)	0.0640184399664
  (0, 109271)	0.0640184399664
  (0, 51730)	0.0640184399664
  (0, 86001)	0.0640184399664
  (0, 83256)	0.0640184399664
  (0, 113986)	0.0640184399664
  (0, 37565)	0.0640184399664
  :	:
  (11313, 87626)	0.0990147542977
  (11313, 30044)	0.0990147542977
  (11313, 76377)	0.0990147542977
  (11313, 119714)	0.0990147542977
  (11313, 47982)	0.0990147542977
  (11313, 28146)	0.198029508595
  (11313, 88363)	0.198029508595
  (11313, 56283)	0.0990147542977


In [14]:
#Training the classifier

from sklearn.naive_bayes import MultinomialNB
clf = MultinomialNB()
clf = clf.fit(X_train_tf,twenty_train.target)

In [15]:
docs_new = ['World war 3 will happen','Doctors are in immense need']
X_new_counts = count_vect.transform(docs_new)
X_new_tfidf = tf_transformer.transform(X_new_counts)


In [16]:
predicted = clf.predict(X_new_tfidf)

In [17]:
for doc, category in zip(docs_new,predicted):
    print('%r => %s' % (doc,twenty_train.target_names[category]))

'World war 3 will happen' => talk.politics.mideast
'Doctors are in immense need' => sci.med


In [18]:
from sklearn.pipeline import Pipeline
text_clf = Pipeline([('vect', CountVectorizer()),('tfidf', TfidfTransformer()),('clf', MultinomialNB())])
text_clf = text_clf.fit(twenty_train.data,twenty_train.target)

In [19]:
# Evaluation of the performance
import numpy as np
twenty_test = fetch_20newsgroups(subset='test',shuffle=True,random_state = 42)

docs_test = twenty_test.data
predicted = text_clf.predict(docs_test)

np.mean(predicted == twenty_test.target)

0.7738980350504514

In [20]:
# Detailed performance analysis of result
from sklearn import metrics

In [21]:
print(metrics.classification_report(twenty_test.target,predicted,target_names=twenty_test.target_names))

                          precision    recall  f1-score   support

             alt.atheism       0.80      0.52      0.63       319
           comp.graphics       0.81      0.65      0.72       389
 comp.os.ms-windows.misc       0.82      0.65      0.73       394
comp.sys.ibm.pc.hardware       0.67      0.78      0.72       392
   comp.sys.mac.hardware       0.86      0.77      0.81       385
          comp.windows.x       0.89      0.75      0.82       395
            misc.forsale       0.93      0.69      0.80       390
               rec.autos       0.85      0.92      0.88       396
         rec.motorcycles       0.94      0.93      0.93       398
      rec.sport.baseball       0.92      0.90      0.91       397
        rec.sport.hockey       0.89      0.97      0.93       399
               sci.crypt       0.59      0.97      0.74       396
         sci.electronics       0.84      0.60      0.70       393
                 sci.med       0.92      0.74      0.82       396
         

In [22]:
metrics.confusion_matrix(twenty_test.target,predicted)

array([[166,   0,   0,   1,   0,   1,   0,   0,   1,   1,   1,   3,   0,
          6,   3, 123,   4,   8,   0,   1],
       [  1, 252,  15,  12,   9,  18,   1,   2,   1,   5,   2,  41,   4,
          0,   6,  15,   4,   1,   0,   0],
       [  0,  14, 258,  45,   3,   9,   0,   2,   1,   3,   2,  25,   1,
          0,   6,  23,   2,   0,   0,   0],
       [  0,   5,  11, 305,  17,   1,   3,   6,   1,   0,   2,  19,  13,
          0,   5,   3,   1,   0,   0,   0],
       [  0,   3,   8,  23, 298,   0,   3,   8,   1,   3,   1,  16,   8,
          0,   2,   8,   3,   0,   0,   0],
       [  1,  21,  17,  13,   2, 298,   1,   0,   1,   1,   0,  23,   0,
          1,   4,  10,   2,   0,   0,   0],
       [  0,   1,   3,  31,  12,   1, 271,  19,   4,   4,   6,   5,  12,
          6,   3,   9,   3,   0,   0,   0],
       [  0,   1,   0,   3,   0,   0,   4, 364,   3,   2,   2,   4,   1,
          1,   3,   3,   4,   0,   1,   0],
       [  0,   0,   0,   1,   0,   0,   2,  10, 371,   0,   0,  