In [None]:
import torch
import torchvision.datasets as dsets#导入数据集模块，方便获取常用数据集（如 MNIST）
import torchvision.transforms as transforms#导入图像预处理模块，用于对图像数据进行转换（如转为张量）

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
#Initiation & Shuffle
torch.manual_seed(1)
if device=='cuda':
    torch.cuda.manual_seed_all(1)

In [None]:
learning_rate=0.5
batch_size=10

In [None]:
mnist_train=dsets.MNIST(root='MNIST_data/',
                        train=True,
                        transform=transforms.ToTensor(),#将图像数据从 PIL 格式转换为张量，并归一化到 [0,1] 范围。
                        download=True)
mnist_test=dsets.MNIST(root='MNIST_data/',
                       train=False,
                       transform=transforms.ToTensor(),
                       download=True)

In [None]:
data_loader=torch.utils.data.DataLoader(dataset=mnist_train,
                                        batch_size=batch_size,#每个批次包含 10 张图片
                                        shuffle=True,
                                        drop_last=True)

In [None]:
w1=torch.nn.Parameter(torch.Tensor(784,30)).to(device)
b1=torch.nn.Parameter(torch.Tensor(30)).to(device)
w2=torch.nn.Parameter(torch.Tensor(30,10)).to(device)
b2=torch.nn.Parameter(torch.Tensor(10)).to(device)

In [None]:
#按正态分布对所有参数进行初始化
torch.nn.init.normal_(w1)
torch.nn.init.normal_(w2)
torch.nn.init.normal_(b1)
torch.nn.init.normal_(b2)

In [None]:
def sigmoid(x):
    return 1.0/(1.0+torch.exp(-x))
    #return torch.div(torch.tensor(1),torch.add(torch.tensor(1),torch.exp(-x)))

In [None]:
def sigmoid_prime(x):
    return sigmoid(x)*(1-sigmoid(x))#计算 Sigmoid 函数的导数，用于反向传播时梯度的计算

In [None]:
X_test=mnist_test.data.view(-1,28*28).float().to(device)[:1000]#只取测试集中的前 1000 个样本，便于快速评估模型性能。
Y_test=mnist_test.targets.to(device)[:1000]

i=0
while not i==10000:
    
    for X,Y in data_loader:
        i+=1
        
        # forward
        X=X.view(-1,28*28).to(device)#将当前批次的图片展平为 (batch_size, 784) 的矩阵，并移动到设备。
        Y=torch.zeros((batch_size,10)).scatter_(1,Y.unsqueeze(1),1).to(device)
        l1=torch.add(torch.matmul(X,w1),b1)
        a1=sigmoid(l1)
        l2=torch.add(torch.matmul(a1,w2),b2)
        a2=sigmoid(l2)
        y_pred=sigmoid(l2)
        
        diff=y_pred-Y
        
        # Back prop (chain rule)
        d_l2=diff*sigmoid_prime(l2)
        d_b2=d_l2
        d_w2=torch.matmul(torch.transpose(a1,0,1),d_l2)
        
        d_a1=torch.matmul(d_l2,torch.transpose(w2,0,1))
        d_l1=d_a1*sigmoid_prime(l1)
        d_b1=d_l1
        d_w1=torch.matmul(torch.transpose(X,0,1),d_l1)
        
        #梯度下降法更新参数
        w1=w1-learning_rate*d_w1
        b1=b1-learning_rate*torch.mean(d_b1,0)
        w2=w2-learning_rate*d_w2
        b2=b2-learning_rate*torch.mean(d_b2,0)
        
        if i%1000==0:#测试评估
            l1_test = torch.add(torch.matmul(X_test, w1), b1)
            a1_test = sigmoid(l1_test)
            l2_test = torch.add(torch.matmul(a1_test, w2), b2)
            y_predict_test = sigmoid(l2_test)
            num_correct = (Y_test == torch.argmax(y_predict_test, 1)).sum()
            print("Test Accuracy: {:.2f}%".format(num_correct.item() / Y_test.size(0) * 100))
        if i==10000:
            break
            