In [1]:
import numpy as np


In [2]:
# 认识20newsgroups数据集
from sklearn.datasets import fetch_20newsgroups
# dataset=fetch_20newsgroups(subset='all') 
# 自动下载第二个版本20news-bydate.tar.gz
# print(len(dataset.data)) # dataset_X 的样本数
# print(dataset.target_names) # dataset_y的名称，标签名称

# train_set=fetch_20newsgroups(subset='train') # 仅仅提取中间的train set
# test_set=fetch_20newsgroups(subset='test')

# 如果仅仅需要其中的某几个类别，可以用
sample_cate = ['alt.atheism', 'soc.religion.christian',
               'comp.graphics', 'sci.med', 'rec.sport.baseball'] # 只取5个类别
train_set = fetch_20newsgroups(subset='train',categories=sample_cate,
                               shuffle=True, random_state=42,
                               remove = ('headers', 'footers', 'quotes'))
test_set = fetch_20newsgroups(subset='test', categories=sample_cate,
                              shuffle=True, random_state=42,
                              remove = ('headers', 'footers', 'quotes'))
print(len(train_set.data), len(test_set.data)) # 2854 1899
print(train_set.target_names) # 只有五个类别

2854 1899
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'soc.religion.christian']


In [3]:
# 1, 准备数据集
category_map = {'misc.forsale': 'Sales', 'rec.motorcycles': 'Motorcycles', 
        'rec.sport.baseball': 'Baseball', 'sci.crypt': 'Cryptography', 
        'sci.space': 'Space'}
from sklearn.datasets import fetch_20newsgroups
train_set=fetch_20newsgroups(subset='train',categories=category_map.keys(),
                             shuffle=True,random_state=42,
                            remove = ('headers', 'footers', 'quotes'))
test_set=fetch_20newsgroups(subset='test',categories=category_map.keys(),
                             shuffle=True,random_state=42,
                           remove = ('headers', 'footers', 'quotes'))
# 获取到的train_set包含有2968个样本，
print('train sample num: ', len(train_set.data)) # 2968
print(train_set.target_names) # 确保是我们要提取的这五个类别

print('test sample num: ', len(test_set.data)) # 1975

train sample num:  2968
['misc.forsale', 'rec.motorcycles', 'rec.sport.baseball', 'sci.crypt', 'sci.space']
test sample num:  1975


In [4]:
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(stop_words='english',lowercase=True)
train_vector = vectorizer.fit_transform(train_set.data)
print(train_vector.shape) # (2968, 31206)
# 此处相当于有2968个词袋，对这些词袋进行TfidfVectorizer进行特征提取，
# 得到最具典型的一些单词，这些单词的个数有31206个，故而得到(2968, 30206)矩阵
# 矩阵中的元素表示这个单词在该词袋中出现的tf-idf权重，值越大，表示该单词越重要。

(2968, 31206)


In [5]:
# 定义模型，训练特征
from sklearn.naive_bayes import MultinomialNB
classifier=MultinomialNB(alpha=.01, fit_prior = False)
classifier.fit(train_vector,train_set.target)

MultinomialNB(alpha=0.01, class_prior=None, fit_prior=False)

In [6]:
# 查看这个数据集在test_set上的表现
from sklearn import metrics
test_vector=vectorizer.transform(test_set.data)
print(test_vector.shape)
pred=classifier.predict(test_vector)
F1_score=metrics.f1_score(test_set.target, pred, average='micro')
print('test set F1 score: ',F1_score)

(1975, 31206)
test set F1 score:  0.8774683544303797


In [7]:
# 用GridSearchCV优化参数
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report

parameters = {'fit_prior':(True, False), 'alpha':(0.01,0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0)}
clf = GridSearchCV(classifier,parameters,cv=5,scoring='precision_macro',n_jobs=-1)
clf.fit(train_vector, train_set.target)
print("Best param_set found on train set: {}".format(clf.best_params_))

print("Detailed classification report on test set:")
y_true, y_pred = test_set.target, clf.predict(test_vector)
print(classification_report(y_true, y_pred))


Best param_set found on train set: {'alpha': 0.05, 'fit_prior': True}
Detailed classification report on test set:
             precision    recall  f1-score   support

          0       0.92      0.89      0.91       390
          1       0.80      0.91      0.85       398
          2       0.93      0.88      0.91       397
          3       0.90      0.88      0.89       396
          4       0.91      0.88      0.89       394

avg / total       0.89      0.89      0.89      1975

