## 准备数据

In [11]:
import os
import numpy as np
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

## 建立模型

In [23]:
class myModel(nn.Module):
    def __init__(self):
        ####################
        '''声明模型对应的参数'''
        ####################
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28,100),
            nn.ReLU(),
            nn.Linear(100,10),
        )
    def forward(self,x):
        return self.net(x)
        
model = myModel()
model.train()
optimizer = Adam(model.parameters(),lr=5e-5)

## 计算 loss

In [16]:
def compute_loss(logits, labels):
    return CrossEntropyLoss()(logits,labels)

def compute_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=1)
    return (predictions == labels).to(float).mean()

def train_one_step(model, optimizer, x, y):
    logits = model(x)
    loss = compute_loss(logits, y)
    loss.backward()
    optimizer.step()
    accuracy = compute_accuracy(logits, y)
    # loss and accuracy is scalar tensor
    return loss, accuracy

@torch.no_grad()
def test(model, x, y):
    logits = model(x)
    loss = compute_loss(logits, y)
    accuracy = compute_accuracy(logits, y)
    return loss, accuracy

## 实际训练

In [14]:
def mnist_dataset(train_ratio=0.8):
    mnist = MNIST(root=".",download=True)
    data,targets = mnist.data/255,mnist.targets
    data = data.reshape(data.shape[0],-1)
    data = (data-data.mean(axis=1,keepdims=True))/data.std(axis=1,keepdims=True)
    idx = np.arange(data.shape[0])
    train_idx = np.zeros_like(idx)
    train_idx[np.random.choice(idx,int(train_ratio*data.shape[0]),replace=False)] = 1
    train_x,train_y = data[train_idx.astype(bool)],targets[train_idx.astype(bool)]
    test_x,test_y = data[train_idx.astype(bool)],targets[train_idx.astype(bool)]
    return train_x,train_y,test_x,test_y

In [24]:
train_x,train_y,test_x,test_y= mnist_dataset()
for epoch in range(50):
    loss, accuracy = train_one_step(model, optimizer,train_x,train_y)
    print('epoch', epoch, ': loss', loss.item(), '; accuracy', accuracy.item())
loss, accuracy = test(model, 
                      test_x, 
                      test_y)

print('test loss', loss.item(), '; accuracy', accuracy.item())

epoch 0 : loss 2.3192737102508545 ; accuracy 0.10285416666666666
epoch 1 : loss 2.306709051132202 ; accuracy 0.1091875
epoch 2 : loss 2.2947072982788086 ; accuracy 0.11625
epoch 3 : loss 2.2829203605651855 ; accuracy 0.12464583333333333
epoch 4 : loss 2.271228313446045 ; accuracy 0.13447916666666668
epoch 5 : loss 2.2595834732055664 ; accuracy 0.14595833333333333
epoch 6 : loss 2.247954845428467 ; accuracy 0.15985416666666666
epoch 7 : loss 2.2363288402557373 ; accuracy 0.17522916666666666
epoch 8 : loss 2.224700927734375 ; accuracy 0.19083333333333333
epoch 9 : loss 2.213066816329956 ; accuracy 0.2075
epoch 10 : loss 2.2014267444610596 ; accuracy 0.22489583333333332
epoch 11 : loss 2.1897835731506348 ; accuracy 0.24116666666666667
epoch 12 : loss 2.1781351566314697 ; accuracy 0.25854166666666667
epoch 13 : loss 2.1664791107177734 ; accuracy 0.27589583333333334
epoch 14 : loss 2.154816150665283 ; accuracy 0.2921875
epoch 15 : loss 2.1431570053100586 ; accuracy 0.30954166666666666
epoch