# 实验八：三种朴素贝叶斯模型分类对比（Iris 数据集）
- GaussianNB（高斯）
- MultinomialNB（多项式）
- BernoulliNB（伯努利）
- 特征选择：Petal length 和 Sepal width
- 输出预测结果、评价指标和可视化决策边界

In [None]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB
from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

## 加载数据与训练模型

In [None]:
iris = load_iris()
X = iris.data[:, [2, 1]]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

clf1 = GaussianNB()
clf1.fit(X_train, y_train)
y_pred1 = clf1.predict(X_test)
print("用高斯朴素贝叶斯分类模型：", y_pred1)

clf2 = MultinomialNB()
clf2.fit(X_train, y_train)
y_pred2 = clf2.predict(X_test)
print("用多项式朴素贝叶斯分类模型：", y_pred2)

clf3 = BernoulliNB()
clf3.fit(X_train, y_train)
y_pred3 = clf3.predict(X_test)
print("用伯努利朴素贝叶斯分类模型：", y_pred3)

## 输出评价指标

In [None]:
def print_metrics(name, y_true, y_pred):
    m = metrics.confusion_matrix(y_true, y_pred)
    a = metrics.accuracy_score(y_true, y_pred)
    p = metrics.precision_score(y_true, y_pred, average='macro')
    r = metrics.recall_score(y_true, y_pred, average='macro')
    f1 = metrics.f1_score(y_true, y_pred, average='macro')
    print(f"{name}：\n混淆矩阵: {m}\n精度: {a}\n查准率: {p}\n查全率: {r}\nF1: {f1}\n")

print_metrics("高斯朴素贝叶斯", y_test, y_pred1)
print_metrics("多项式朴素贝叶斯", y_test, y_pred2)
print_metrics("伯努利朴素贝叶斯", y_test, y_pred3)

## 可视化分类边界

In [None]:
def plot_nb_boundary(clf, title):
    h = 0.02
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
    plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=plt.cm.coolwarm, edgecolors='k')
    plt.title(title)
    plt.xlabel("Petal length")
    plt.ylabel("Sepal width")
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.show()

plot_nb_boundary(clf1, "高斯朴素贝叶斯分类模型")
plot_nb_boundary(clf2, "多项式朴素贝叶斯分类模型")
plot_nb_boundary(clf3, "伯努利朴素贝叶斯分类模型")