# 使用 `NN` 构建模型

In [19]:
from typing import override

from pathlib import Path
from pickle import load, dump
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.sgd import SGD

## 上一小节实现的数据预处理

In [20]:
class MnistDataset(Dataset):
    @override
    def __init__(self, x:torch.Tensor, y:torch.Tensor) -> None:
        super().__init__()
        if len(x) != len(y):
            raise IndexError('len(x) != len(y)')
        self.x = x
        self.y = y
        return
    
    @override
    def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
        return self.x[index], self.y[index]
    
    @override
    def __len__(self) -> int:
        return len(self.x)
    
    
def readDataset(name:str) -> MnistDataset:
    fullImgPath = f'./mnistData/{name}-images.idx3-ubyte'
    fullLabelPath = f'./mnistData/{name}-labels.idx1-ubyte'

    with open(fullImgPath, 'rb') as f:
        magic = int.from_bytes(f.read(4))
        if magic != 2051:
            raise ValueError(f'head magic number not equal to 2051: gotten={magic}')
        
        dataNumber = int.from_bytes(f.read(4))
        dataRow = int.from_bytes(f.read(4))
        dataCol = int.from_bytes(f.read(4))
        
        img = np.fromfile(
            file=f, 
            dtype=np.uint8, 
            count=dataNumber*dataRow*dataCol,
            offset=0,
            )
    img.shape = (dataNumber,dataRow * dataCol)
    imgTensor = torch.from_numpy(img).type(torch.float32) / 255
        
    with open(fullLabelPath, 'rb') as f:
        magic = int.from_bytes(f.read(4))
        if magic != 2049:
            raise ValueError(f'head magic number not equal to 2049: gotten={magic}')
        
        dataNumber = int.from_bytes(f.read(4))
        
        label = np.fromfile(
            file=f, 
            dtype=np.uint8, 
            count=dataNumber,
            offset=0,
            )

    labelOnehotTensor = torch.zeros((len(label), 10), dtype=torch.float32)
    for i,j in zip(label, labelOnehotTensor):
        j[i] = 1
    
    return MnistDataset(imgTensor, labelOnehotTensor)

## 开始构建

### 网络构建

构建一个3层的网络，每一每一层使用`ReLU`作为激活函数

输出没有归一化，因为后面的损失函数输入不需要归一化

In [21]:

model = nn.Sequential(
    nn.Linear(28 * 28, 16, True, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(16, 16, True, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(16, 16, True, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(16, 10, True, dtype=torch.float32),
)

def predict(input):
    return torch.softmax(model(input), dim=1)

criterion = nn.CrossEntropyLoss()
optimisor = SGD(model.parameters(), lr=0.001, weight_decay=0.0001)

### 加载训练集

加载一次训练集之后，打包成`pickle`，下一次使用就不需要重新解析

加快加载速度

In [22]:

trainSetPath = Path("./train.pickle")

if trainSetPath.exists():
    with open(trainSetPath, "rb") as f:
        trainSet = load(f)
else:
    trainSet = readDataset("train")
    with open(trainSetPath, "wb") as f:
        dump(trainSet, f)

trainDataLoader = DataLoader(trainSet, 20, True)


### 加载已经训练的模块

每次训练完成后保存

如果删除文件则重新训练

In [23]:

modelFile = Path("./model.pickle")

if modelFile.exists():
    with open(modelFile, "rb") as f:
        model = load(f)
    print(f"已加载已有模型：{modelFile}")


已加载已有模型：model.pickle


### 开始训练

可以选择将新的模型覆盖到旧的模型上，也可以不保存

In [24]:

epochs = 0
for epoch in range(epochs):
    for i, data in enumerate(trainDataLoader):
        output = model(data[0])
        loss: torch.Tensor = criterion(output, data[1])
        optimisor.zero_grad()
        loss.backward()
        optimisor.step()
        if (i + 1) % 100 == 0:
            print(f"loss:{float(loss)}")
    for i in optimisor.param_groups:
        i["lr"] *= 0.3

print("训练完成")

# with open(modelFile, "wb") as f:
#     dump(model, f)

训练完成


### 测试模型的准确度

In [26]:
testSet = readDataset("t10k")
testDataLoader = DataLoader(testSet, batch_size=100)

correctorchount = 0
for data in testDataLoader:
    output: torch.Tensor = model(data[0])
    correction: torch.Tensor = output.argmax(1) == data[1].argmax(1)
    correctorchount += int(correction.sum(dtype=torch.int32))

print(f"正确率{correctorchount/100}%")


正确率87.17%
