## Logistic Regression

In [1]:
import nltk
import numpy as np
import pandas as pd
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

In [2]:
with open('20ng-train-all-terms.txt') as f:
    train = f.read()
with open('20ng-test-all-terms.txt') as f:
    test = f.read()

train = train[:-1]
test = test[:-1]
topic_train = []
topic_test = []
news_train = []
news_test = []
for new in train.split('\n'):
    topic_train.append(new.split('\t')[0])
    news_train.append(new.split('\t')[1])
print('number of training news: ' + str(len(news_train)))
for new in test.split('\n'):
    topic_test.append(new.split('\t')[0])
    news_test.append(new.split('\t')[1])
print('number of test news: ' + str(len(news_test)))

number of training news: 11293
number of test news: 7528


In [3]:
# topic_map = {'alt.atheism' : 0,
#     'comp.graphics': 1,
#     'comp.os.ms-windows.misc': 2,
#     'comp.sys.ibm.pc.hardware': 3,
#     'comp.sys.mac.hardware': 4,
#     'comp.windows.x': 5,
#     'misc.forsale': 6,
#     'rec.autos': 7,
#     'rec.motorcycles': 8,
#     'rec.sport.baseball': 9,
#     'rec.sport.hockey': 10,
#     'sci.crypt': 11,
#     'sci.electronics': 12,
#     'sci.med': 13,
#     'sci.space': 14,
#     'soc.religion.christian': 15,
#     'talk.politics.guns': 16,
#     'talk.politics.mideast': 17,
#     'talk.politics.misc': 18,
#     'talk.religion.misc': 19}

In [None]:
topic_map = {'alt.atheism' : 0,
    'comp.graphics': 1,
    'comp.os.ms-windows.misc': 2,
    'comp.sys.ibm.pc.hardware': 3,
    'comp.sys.mac.hardware': 4,
    'comp.windows.x': 5,
    'misc.forsale': 6,
    'rec.autos': 7,
    'rec.motorcycles': 8,
    'rec.sport.baseball': 9,
    'rec.sport.hockey': 10,
    'sci.crypt': 11,
    'sci.electronics': 12,
    'sci.med': 13,
    'sci.space': 14,
    'soc.religion.christian': 15,
    'talk.politics.guns': 16,
    'talk.politics.mideast': 17,
    'talk.politics.misc': 18,
    'talk.religion.misc': 19}

In [4]:
df_train = pd.DataFrame(np.array(news_train), columns = ['news'])
df_train['target'] = pd.DataFrame(np.array(topic_train), columns = ['topic'])
df_train = df_train.reindex(np.random.permutation(df_train.index)).reset_index(drop = True)
df_train['target'] = df_train['target'].replace(topic_map)
df_train.head()

Unnamed: 0,news,target
0,re how to get there was re comet in temporary ...,14
1,re tiff philosophical significance of in artic...,1
2,jewish committee on the middle east jcome i fl...,17
3,baseball card for sale i have a roberto clemen...,6
4,re doing the work of god desiree bradley desir...,15


In [5]:
df_test = pd.DataFrame(np.array(news_test), columns = ['news'])
df_test['target'] = pd.DataFrame(np.array(topic_test), columns = ['topic'])
df_test = df_test.reindex(np.random.permutation(df_test.index)).reset_index(drop = True)
df_test['target'] = df_test['target'].replace(topic_map)
df_test.head()

Unnamed: 0,news,target
0,re happy birthday israel in article c da me bo...,17
1,traveling fatima was re consecration and anniv...,15
2,re zionist leaders frank statements dear pete ...,17
3,re pens fans reactions richard j coyle gomer p...,10
4,sale calculator sharp brand pocket computer mo...,6


In [6]:
X_train, y_train = df_train['news'], df_train['target']
X_test, y_test = df_test['news'], df_test['target']

In [7]:
%%time
# use CountVectorizer to find three letter tokens, remove stop_words,
# remove tokens that don't appear in at least 20 documents,
# remove tokens that appear in more than 20% of the documents
vect = CountVectorizer(min_df = 20, max_df = 0.2, stop_words = 'english', token_pattern = '(?u)\\b\\w\\w\\w+\\b')
vect.fit(X_train)
X_train_vectorized = vect.transform(X_train)
X_test_vectorized = vect.transform(X_test)
clfNB = MultinomialNB(alpha = 1.0).fit(X_train_vectorized, y_train)


Wall time: 4.07 s


In [8]:
y_predict = clfNB.predict(X_test_vectorized)
confusion_matrix(y_predict, y_test)

array([[226,   2,   0,   0,   1,   0,   0,   2,   0,   3,   2,   0,   1,
         11,   3,  16,   0,  21,   6,  37],
       [  2, 283,  31,   8,  10,  59,   2,   1,   0,   2,   0,   7,  21,
          7,  15,   1,   1,   1,   1,   4],
       [  0,  25, 289,  33,  21,  31,   2,   2,   0,   0,   0,   6,  19,
          1,   0,   3,   1,   0,   0,   0],
       [  0,  12,  31, 261,  33,   2,  27,   0,   1,   0,   0,   0,  24,
          3,   0,   0,   0,   0,   0,   0],
       [  0,  18,  10,  49, 289,   6,  11,   1,   0,   1,   0,   5,  23,
          4,   1,   0,   0,   1,   0,   0],
       [  0,  21,  11,   2,   0, 271,   0,   1,   0,   1,   0,   1,   0,
          1,   2,   1,   1,   0,   0,   0],
       [  1,   4,   2,   8,   6,   7, 321,   9,   3,   6,   0,   1,  14,
          2,   1,   3,   1,   1,   1,   0],
       [  0,   0,   0,   3,   3,   1,  10, 351,  15,   1,   0,   2,   6,
         10,   4,   0,   2,   0,   1,   1],
       [  7,   0,   0,   0,   1,   2,   5,  11, 369,   4,   0,  

In [9]:
print('accuracy is ' + str(accuracy_score(y_predict, y_test)))

accuracy is 0.7907810839532412


In [10]:
print(classification_report(y_predict, y_test))

             precision    recall  f1-score   support

          0       0.71      0.68      0.70       331
          1       0.73      0.62      0.67       456
          2       0.73      0.67      0.70       433
          3       0.67      0.66      0.66       394
          4       0.75      0.69      0.72       419
          5       0.69      0.87      0.77       313
          6       0.82      0.82      0.82       391
          7       0.89      0.86      0.87       410
          8       0.93      0.88      0.91       417
          9       0.90      0.94      0.92       383
         10       0.96      0.94      0.95       408
         11       0.89      0.91      0.90       389
         12       0.65      0.77      0.71       333
         13       0.81      0.89      0.85       361
         14       0.86      0.88      0.87       384
         15       0.86      0.78      0.82       437
         16       0.90      0.69      0.78       473
         17       0.84      0.95      0.89   