In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import struct

In [14]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=(1,1))
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=(1,1))
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=(1,1))
        self.globalMaxPool = nn.MaxPool2d(7, 1)
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))        #卷积
        x = self.pool1(x)                #池化
        x = F.relu(self.conv2(x))        #卷积
        x = self.pool2(x)                #池化
        x = F.relu(self.conv3(x))        #卷积
        x = self.globalMaxPool(x)        #池化
        x = x.view(-1, 64*1*1)            #将输出reshape为1维
        x = self.fc(x)                    #全连接
        return x

In [20]:
train_images_idx3_ubyte_file = 'train-images.idx3-ubyte' # 训练图片数据
train_labels_idx1_ubyte_file = 'train-labels.idx1-ubyte' # 训练标签数据
test_images_idx3_ubyte_file = 't10k-images.idx3-ubyte'   # 测试图片数据
test_labels_idx1_ubyte_file = 't10k-labels.idx1-ubyte'   # 测试标签数据

def decode_idx3_ubyte(idx3_ubyte_file):
    bin_data = open(idx3_ubyte_file, 'rb').read()
    offset = 0
    fmt_header = '>iiii'
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)
    fmt_image = '>' + str(image_size) + 'B'
    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
        offset += struct.calcsize(fmt_image)
    return images

def decode_idx1_ubyte(idx1_ubyte_file):
    bin_data = open(idx1_ubyte_file, 'rb').read()
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels

def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file)
def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file)
def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file)
def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file)

train_images = load_train_images()
train_labels = load_train_labels()
test_images = load_test_images()
test_labels = load_test_labels()

train_images = np.expand_dims(train_images, axis=1) / 255.0    #数据预处理
train_labels = train_labels.astype('int')
test_images = np.expand_dims(test_images, axis=1) / 255.0
test_labels = test_labels.astype('int')

print(test_images.size)

7840000


In [16]:
net = Net()        #创建Net()对象
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")    #设置计算设备，cpu或gpu
net.to(device)
criterion = nn.CrossEntropyLoss()    #损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)    #优化器

In [17]:
epoch_num = 10
batch_size = 100
for epoch in range(epoch_num):  # loop over the dataset multiple times
    running_loss = 0.0
    for i in range(int(60000/batch_size)):
        start_index = i*batch_size
        inputs = torch.from_numpy(train_images[start_index:start_index+batch_size])
        labels = torch.from_numpy(train_labels[start_index:start_index+batch_size])

        inputs, labels = inputs.to(device), labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs.float())

        labels = labels.to(torch.int64)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 50 == 0:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

[1,     1] loss: 0.001
[1,    51] loss: 0.058
[1,   101] loss: 0.058
[1,   151] loss: 0.057
[1,   201] loss: 0.057
[1,   251] loss: 0.057
[1,   301] loss: 0.057
[1,   351] loss: 0.057
[1,   401] loss: 0.057
[1,   451] loss: 0.057
[1,   501] loss: 0.057
[1,   551] loss: 0.057
[2,     1] loss: 0.001
[2,    51] loss: 0.056
[2,   101] loss: 0.056
[2,   151] loss: 0.056
[2,   201] loss: 0.055
[2,   251] loss: 0.055
[2,   301] loss: 0.054
[2,   351] loss: 0.053
[2,   401] loss: 0.051
[2,   451] loss: 0.048
[2,   501] loss: 0.045
[2,   551] loss: 0.040
[3,     1] loss: 0.001
[3,    51] loss: 0.027
[3,   101] loss: 0.021
[3,   151] loss: 0.018
[3,   201] loss: 0.014
[3,   251] loss: 0.012
[3,   301] loss: 0.011
[3,   351] loss: 0.010
[3,   401] loss: 0.009
[3,   451] loss: 0.009
[3,   501] loss: 0.008
[3,   551] loss: 0.008
[4,     1] loss: 0.000
[4,    51] loss: 0.007
[4,   101] loss: 0.007
[4,   151] loss: 0.008
[4,   201] loss: 0.006
[4,   251] loss: 0.006
[4,   301] loss: 0.006
[4,   351] 

In [29]:
test_num = 10000
correct = 0
for i in range(int(test_num/batch_size)):
    start_index = i*batch_size
    test_inputs = torch.from_numpy(test_images[start_index:start_index+batch_size]).float()
    test_inputs = test_inputs.to(device)
    outputs = net(test_inputs)
    _, predicted = torch.max(outputs.data, 1)
    for j in range(batch_size):
        if(predicted[j] == test_labels[start_index+j]):
            correct += 1
print(f"Model accuracy on test set: {correct/test_num * 100:.2f}%")

save_path = './net.pth'
torch.save(net.state_dict(), save_path)

Model accuracy on test set: 93.35%


In [27]:
import torch.onnx

# 设已训练好的网络模型为 net
torch_model_path = 'net.pth'
onnx_model_path = "mnist_model.onnx"

# torch 模型参数保存再加载到新的 model 中，避免报错
torch.save(net.state_dict(), torch_model_path)
model = Net()
model.load_state_dict(torch.load("net.pth"))
model.eval()

# 使用TorchScript来跟踪并导出模型
input_tensor = torch.randn(1, 1, 28, 28)
traced_model = torch.jit.trace(model, input_tensor)
traced_model.save("mnist_model.pt")

# 创建一个样本输入，大小需要与模型的输入层相匹配
sample_input = torch.randn(1, 1, 28, 28)

# 导出 onnx 模型
torch.onnx.export(model, sample_input, onnx_model_path)
