In [1]:
import numpy as np
import matplotlib.pyplot as plt

# 导入鸢尾花数据集
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target

# 归一化数据
X = (X - np.min(X)) / (np.max(X) - np.min(X))

# 添加偏置项
X = np.c_[np.ones(X.shape[0]), X]

# 初始化权重
W = np.random.randn(X.shape[1], 3)

# 定义学习率
lr = 0.01

# 定义损失函数
def cross_entropy(y_pred, y_true):
    return -np.sum(y_true * np.log(y_pred))

# 定义softmax函数
def softmax(x):
    exp = np.exp(x)
    return exp / np.sum(exp, axis=1, keepdims=True)

# 训练模型
for i in range(1000):
    # 前向传播
    y_pred = softmax(np.dot(X, W))

    # 计算损失
    loss = cross_entropy(y_pred, y)

    # 反向传播
    dW = np.dot(X.T, y_pred - y)

    # 更新权重
    W -= lr * dW

    # 每100次迭代打印一次损失
    if i % 100 == 0:
        print(f"Iteration {i}: Loss {loss}")

# 预测
y_pred = np.argmax(softmax(np.dot(X, W)), axis=1)

# 计算准确率
accuracy = np.mean(y_pred == y)
print(f"Accuracy: {accuracy}")

# 绘制决策边界
x_min, x_max = X[:, 1].min() - 1, X[:, 1].max() + 1
y_min, y_max = X[:, 2].min() - 1, X[:, 2].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))
Z = np.argmax(softmax(np.dot(np.c_[np.ones(xx.ravel().shape), xx.ravel(), yy.ravel()], W)), axis=1)
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
plt.scatter(X[:, 1], X[:, 2], c=y, cmap=plt.cm.Spectral)
plt.show()


# 评估模型
y_pred = model.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f'Accuracy: {accuracy}')

# 绘制决策边界
h = 0.02  # 步长
x_min, x_max = X_test[:, 0].min() - 1, X_test[:, 0].max() + 1
y_min, y_max = X_test[:, 1].min() - 1, X_test[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=plt.cm.Spectral)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Softmax Regression Decision Boundary')
plt.show()

ValueError: operands could not be broadcast together with shapes (150,) (150,3) 