In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

In [2]:
# 加载数据集
categories = ['sci.med', 'comp.graphics', 'rec.sport.baseball', 'talk.politics.misc']
data = fetch_20newsgroups(categories=categories, subset='all', shuffle=True, random_state=42)
# data = fetch_20newsgroups(subset='all', shuffle=True, random_state=42)
# 特征提取
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(data.data)
y = data.target

# 训练分类模型
# classifier = LogisticRegression()
classifier = MultinomialNB()
classifier.fit(X, y)

MultinomialNB()

In [3]:
# 测试分类算法
test_cases = [
    "I have a fever and cough",
    "Which software is best for 3D modeling?",
    "Who won the World Series last year?",
    "What are the symptoms of diabetes?",
    "How to create a bar chart in Python?",
    "Who is the highest-scoring player in baseball history?",
    "Is there a cure for cancer?",
    "What is the resolution of a high-definition monitor?",
    "Who is the current MVP in baseball?",
    "How to prevent heart diseases?",
    "What are the different file formats for storing images?"
]


In [4]:
X_test = vectorizer.transform(test_cases)
predictions = classifier.predict(X_test)

# 显示分类结果和准确度
print("测试用例\t\t\t\t\t\t\t\t\t\t预测分类")
print("-----------------------------------------------")
for test_case, prediction in zip(test_cases, predictions):
    print(f"{test_case[:40]:<40}\t{data.target_names[prediction]}")

accuracy = accuracy_score(data.target[:len(test_cases)], predictions)
print("\nAccuracy:", accuracy)


测试用例										预测分类
-----------------------------------------------
I have a fever and cough                	sci.med
Which software is best for 3D modeling? 	comp.graphics
Who won the World Series last year?     	rec.sport.baseball
What are the symptoms of diabetes?      	sci.med
How to create a bar chart in Python?    	comp.graphics
Who is the highest-scoring player in bas	rec.sport.baseball
Is there a cure for cancer?             	sci.med
What is the resolution of a high-definit	comp.graphics
Who is the current MVP in baseball?     	rec.sport.baseball
How to prevent heart diseases?          	sci.med
What are the different file formats for 	comp.graphics

Accuracy: 0.36363636363636365
