In [1]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

# 加载MNIST数据集
mnist = fetch_openml('mnist_784', version=1, cache=True)
X, y = mnist["data"], mnist["target"].astype(int)

# 二值化
X = (X > 127.5).astype(np.uint8)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=40)

  warn(


In [2]:
class MultinomialNaiveBayes:
    def __init__(self, alpha=1.0, V=None):
        self.class_log_prior_ = None
        self.feature_log_prob_ = None
        self.alpha = alpha  # 平滑参数
        self.V = V 

    def fit(self, X, y):
        num_classes = len(np.unique(y))
        self.class_log_prior_ = np.log(np.bincount(y, minlength=num_classes) / len(y))
        
        self.feature_log_prob_ = np.zeros((num_classes, X.shape[1]))
        
        if self.V is None:  # 处理V为None的情况
            self.V = X.shape[1]

        for c in np.unique(y):
            X_c = X[y == c]
            self.feature_log_prob_[c, :] = np.log((X_c.sum(axis=0) + self.alpha) / (np.sum(X_c.sum(axis=0)) + self.V))

    def predict(self, X):
        log_prob = X @ self.feature_log_prob_.T + self.class_log_prior_
        return np.argmax(log_prob, axis=1)

# 创建并训练模型
model = MultinomialNaiveBayes(0.01, 5)

model.fit(X_train, y_train)

# 进行预测
y_pred = model.predict(X_test)

# 计算准确率
print('Accuracy : ', np.mean(y_pred == y_test))

Accuracy :  0.8288571428571428


In [3]:
# 使用网格搜索（grid search）来自动查找更优的参数组合
# 定义参数范围
alpha_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.5, 1, 2, 5, 10, 50, 80, 100,1000]
V_values = [None, -10, -5, -1, -0.1, 0.1, 0.5, 1, 2, 5, 10, 100,1000, 5000, 10000]

best_accuracy = 0
best_alpha = None
best_V = None

# 对每种参数组合进行测试
for alpha in alpha_values:
    for V in V_values:
        model = MultinomialNaiveBayes(alpha, V)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        accuracy = np.mean(y_pred == y_test)
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_alpha = alpha
            best_V = V

print(f"Best alpha is : {best_alpha}")
print(f"Best V is : {best_V}")
print(f"Best Accuray is : {best_accuracy}")