<a href="https://colab.research.google.com/github/WalkerSue/colab/blob/torch/pytorch_mnist_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
import torchvision
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义第一层卷积层，1个输入通道，6个输出通道，5*5的filter,28+2+2=32 padding 填充
        # 左右，上下填充padding=2
        # MNIST图像大小是28，LeNet大小是32
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        # 定义第二层卷积层
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 定义三个全连接层
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    # 向前传播    
    def forward(self, x):
        # 先卷积，再用relu激活函数，然后再最大值池化
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        # num_flat_features = 16*5*5
        # 摊平
        x = x.view(-1, self.num_flat_features(x))
        
        #第一个全连接
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features = num_features * s
        return num_features

import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import  torchvision.transforms as transforms

#超参数定义
EPOCH = 10
BATCH_SIZE = 64
LR = 0.001

train_data = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor(), download=True)

test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels.numpy()[:2000]

import matplotlib.pyplot as plt
%matplotlib inline
print ("Train Data Size:", train_data.train_data.size())
print ("Train Label Size:", train_data.train_labels.size())

plt.imshow(train_data.train_data[0].numpy(), cmap="gray")
plt.show()

# 使用DataLoader 进行分批
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)

# 创建model
model = LeNet5()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=LR)

# device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 训练
total_step = len(train_loader)
for epoch in range(EPOCH):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        # 可能使用GPU
        inputs, labels = inputs.to(device), labels.to(device)
        # forward
        outputs = model(inputs)
        # 计算损失函数
        loss = criterion(outputs, labels)
        # 清空上一轮梯度
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 参数更新
        optimizer.step()
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, EPOCH, i+1, total_step, loss.item()))
            

# 保存模型
# torch.save(model, "mnist_lenet.pt")
torch.save({'state_dict': model.state_dict()}, "mnist_lenet.pt")
# torch.save(model.state_dict(), "mnist_lenet.pt")
# 模型加载
# model = torch.load("mnist_lenet.pt")
model = LeNet5() # 实例化
checkpoint = torch.load('mnist_lenet.pt')
model.load_state_dict(checkpoint['state_dict']) # 加载权重
# model.load_state_dict(torch.load('mnist_lenet.pt')) 

# 测试
model.to(device)
model.eval()
correct = 0
total = 0

for data in test_loader:
    images, labels = data
    images, labels = images.to(device), labels.to(device)
    # forward
    out = model(images)
    _, predicted = torch.max(out.data, 1)
    total = total + labels.size(0)
    correct = correct + (predicted==labels).sum().item()
    
# 输出测试的准确率
print ("Test Data Accurary:",100*correct/total)

In [None]:
# 测试
model.to(device)
model.eval()
correct = 0
total = 0

for data in test_loader:
    images, labels = data
    images, labels = images.to(device), labels.to(device)
    # forward
    out = model(images)
    _, predicted = torch.max(out.data, 1)
    total = total + labels.size(0)
    correct = correct + (predicted==labels).sum().item()
    
# 输出测试的准确率
print ("Test Data Accurary:",100*correct/total)