In [1]:
import numpy as np
from torchvision.datasets import MNIST
from scipy.special import expit
import matplotlib.pyplot as plt
import subprocess
from matplotlib.ticker import MultipleLocator, FuncFormatter



In [2]:
class MLP:
    def __init__(self, num_input_node, num_hidden_node, num_output_node, learning_rate):
        """initialize the multi-layer perceptron

        Args:
            num_input_node: number of input nodes
            num_hidden_node: number of hidden nodes
            num_output_node: number of output nodes
        """
        self.num_input_node = num_input_node
        self.num_hidden_node = num_hidden_node
        self.num_output_node = num_output_node

        # Initialize the weights
        self.wih = np.random.normal(
            0.0,
            pow(self.num_hidden_node, -0.5),
            (self.num_hidden_node, self.num_input_node),
        )
        self.who = np.random.normal(
            0,
            pow(self.num_output_node, -0.5),
            (self.num_output_node, self.num_hidden_node),
        )

        self.lr = learning_rate
        self.activation_function = lambda x: expit(x)

    def predict(self, input):
        """predict the output of the neural network

        Args:
            inputs: input data
        """
        out_hidden = self.activation_function(np.dot(self.wih, input))
        out_output = self.activation_function(np.dot(self.who, out_hidden))

        ground_truth = np.argmax(out_output)
        return ground_truth

    def train(self, input, label, iter_nums=100):
        # correct_cnt = 0
        for i in range(iter_nums):
            # step 1: predict the output
            # print("wih shape: ",self.wih.shape)
            # print("input shape: ", input.shape)
            out_hidden = self.activation_function(np.dot(self.wih, input))
            out_output = self.activation_function(np.dot(self.who, out_hidden))
            # ground_truth = np.argmax(out_output)
            # label_val = np.where(label == 1)[0][0]
            
            # step 2: compute the error
            error_output = label - out_output
            error_hidden = np.dot(self.who.T, error_output)

            # step 3: update the weights
            self.who += self.lr * np.dot(
                (error_output * out_output * (1.0 - out_output)), np.transpose(out_hidden)
            )
            self.wih += self.lr * np.dot(
                (error_hidden * out_hidden * (1.0 - out_hidden)), np.transpose(input)
            )
            
            
            
            # if ground_truth == label_val:
                # correct_cnt += 1

            # if i % 49 == 0:
            #     print(
            #         f"iteration {i} label:{label_val} ground truth:{ ground_truth}, Correct: {label_val == ground_truth}"
            #     )
        # print(f"the correct radix is {correct_cnt / 100.0}")
        
    def test(self, test_list):
        correct_cnt = 0
        for _, item in enumerate(test_list):
            input = np.array(item[0])
            # plt.imshow(input, cmap="gray")
            # plt.show()
            input = (input / 255.0).reshape(28 * 28, 1)
            ground_truth = self.predict(input)
            # print("ground truth: ", ground_truth)
            # print("label: ", item[1])
            if ground_truth == item[1]:
                correct_cnt += 1
        # print(f"correct radix: {correct_cnt / len(test_list)}")
        return correct_cnt / len(test_list)
            
            
        

In [3]:
# Load the MNIST dataset
train_data = MNIST(root="data", train=True, download=True)
test_data = MNIST(root="data", train=False, download=True)
train_list = list(train_data)
test_list = list(test_data)

In [4]:
def train_mlp(lr):
    num_input_node = 28 * 28
    num_hidden_node = 99
    num_output_node = 10

    mlp_model = MLP(num_input_node, num_hidden_node, num_output_node, lr)

    train_accuracy_list = []
    test_accuracy_list = []
    iteration_list = []

    for i, item in enumerate(train_list):
        input = np.array(item[0])
        # !remember to normalize the input
        input = (input / 255.0).reshape(28 * 28, 1)
        label = np.zeros((10, 1))
        label[item[1]] = 1
        mlp_model.train(input, label)
        if i % 100 == 0:  # 每训练100个样本就测试一次
            train_accuracy = mlp_model.test(train_list)  # 计算训练集上的正确率
            test_accuracy = mlp_model.test(test_list)  # 计算测试集上的正确率
            train_accuracy_list.append(train_accuracy)
            test_accuracy_list.append(test_accuracy)
            iteration_list.append(i)  # 记录当前迭代次数
            print(f"=== lr: {lr} test {i} ===")
            print(f"Train Accuracy: {train_accuracy}, Test Accuracy: {test_accuracy}")

    # 可视化正确率变化
    plt.figure(figsize=(20, 12))
    plt.plot(iteration_list, train_accuracy_list, label="Training Accuracy")
    plt.plot(iteration_list, test_accuracy_list, label="Testing Accuracy")
    plt.xlabel("Iterations")
    plt.ylabel("Accuracy")
    plt.title(f"Training and Testing Accuracy (Learning Rate: {lr})")
    plt.legend()

    svg_path = f"./images/mnist_train_{lr}_plot.svg"
    emf_path = f"./images/mnist_train_{lr}_plot.emf"
    plt.savefig(svg_path)
    # 使用 Inkscape 将 SVG 转换为 EMF
    subprocess.run(f"inkscape --export-filename={emf_path} {svg_path}", shell=True)
    # 显示图表
    plt.show()

In [5]:
train_mlp(0.001)
train_mlp(0.01)
train_mlp(0.1)


=== lr: 0.001 test 0 ===
Train Accuracy: 0.07926666666666667, Test Accuracy: 0.078
=== lr: 0.001 test 100 ===
Train Accuracy: 0.59335, Test Accuracy: 0.5867
=== lr: 0.001 test 200 ===
Train Accuracy: 0.6940166666666666, Test Accuracy: 0.6834
=== lr: 0.001 test 300 ===
Train Accuracy: 0.6845666666666667, Test Accuracy: 0.6851
=== lr: 0.001 test 400 ===
Train Accuracy: 0.73495, Test Accuracy: 0.737
=== lr: 0.001 test 500 ===
Train Accuracy: 0.7377166666666667, Test Accuracy: 0.7362
=== lr: 0.001 test 600 ===
Train Accuracy: 0.7963166666666667, Test Accuracy: 0.7979
=== lr: 0.001 test 700 ===
Train Accuracy: 0.73555, Test Accuracy: 0.727
=== lr: 0.001 test 800 ===
Train Accuracy: 0.8069333333333333, Test Accuracy: 0.8069
=== lr: 0.001 test 900 ===
Train Accuracy: 0.8182666666666667, Test Accuracy: 0.8179
=== lr: 0.001 test 1000 ===
Train Accuracy: 0.7699166666666667, Test Accuracy: 0.7721
=== lr: 0.001 test 1100 ===
Train Accuracy: 0.82145, Test Accuracy: 0.8252
=== lr: 0.001 test 1200 ==