In [49]:
import cv2
import torch
import gzip
import pickle
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
# from torch.utils.tensorboard import SummaryWriter

In [50]:
## 读取数据集
PATH = "./data/mnist/mnist.pkl.gz"
((x_train, y_train), (x_valid, y_valid), temp) = pickle.load(gzip.open(PATH, "rb"), encoding="latin-1")

In [51]:
## 数据格式转换
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
y_train

tensor([5, 0, 4,  ..., 8, 4, 8])

In [52]:
## batch
batch_size = 64
train_data = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_data = TensorDataset(x_valid, y_valid)
test_loader = DataLoader(test_data, batch_size, shuffle=True)

In [53]:
# ## 借助tensorboard观察
# writer = SummaryWriter("logs")
# for epoch in range(64):
#     steps = 0
#     for data in train_loader:
#         imgs, targets = data
#         imgs.resize(28,28)
#         print(imgs.shape)
#         writer.add_images("Epoch: {}".format(epoch), imgs, steps)
#         steps += 1
# writer.close()

In [60]:
## 网络定义
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, x):
        x = self.hidden1(x)
        x = nn.ReLU()
        x = self.hidden2(x)
        x = nn.ReLU()
        x = self.out(x)
        return x

In [55]:
## 参数设置
# 模型例化
model = CNN()
# 训练次数
epochs = 50
# 学习率
lr = 0.01
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 损失函数
loss = nn.CrossEntropyLoss()

In [56]:
## 开始训练
for epoch in range(epochs):
    losses = []
    val_loss = []
    model.train()
    for i, data in enumerate(train_loader):
        input_img, targets = data
        # 前向传播
        out = model(input_img.to(torch.float32))
        # 计算损失
        losses = loss(out, targets)
        # 反相传播
        losses.backward()
        # 更新参数
        optimizer.step()
        # 梯度清零
        optimizer.zero_grad()
    model.eval()
    for i, val_data in enumerate(test_loader):
        val_img, targets = val_data
        val_out = model(val_img.to(torch.float32))
        val_loss = loss(val_out, targets)
    print("当前step : {}/{} ,训练集损失 : {:.5f} , 验证集损失 : {:.5f}".format(epoch+1, epochs, losses, val_loss))

当前step : 0/50 ,训练集损失 : 0.30031 , 验证集损失 : 0.05481
当前step : 1/50 ,训练集损失 : 0.16279 , 验证集损失 : 0.00040
当前step : 2/50 ,训练集损失 : 0.00132 , 验证集损失 : 0.00316
当前step : 3/50 ,训练集损失 : 0.23709 , 验证集损失 : 0.01035
当前step : 4/50 ,训练集损失 : 0.05438 , 验证集损失 : 0.16858
当前step : 5/50 ,训练集损失 : 0.04857 , 验证集损失 : 0.00525
当前step : 6/50 ,训练集损失 : 0.00768 , 验证集损失 : 0.00996
当前step : 7/50 ,训练集损失 : 0.00101 , 验证集损失 : 0.03757
当前step : 8/50 ,训练集损失 : 0.19925 , 验证集损失 : 0.00946
当前step : 9/50 ,训练集损失 : 0.16383 , 验证集损失 : 0.00207
当前step : 10/50 ,训练集损失 : 0.15149 , 验证集损失 : 0.50414
当前step : 11/50 ,训练集损失 : 0.02419 , 验证集损失 : 0.78361
当前step : 12/50 ,训练集损失 : 0.00020 , 验证集损失 : 0.00040
当前step : 13/50 ,训练集损失 : 0.00238 , 验证集损失 : 0.08138
当前step : 14/50 ,训练集损失 : 0.00264 , 验证集损失 : 0.00003
当前step : 15/50 ,训练集损失 : 0.15614 , 验证集损失 : 1.55595
当前step : 16/50 ,训练集损失 : 0.00000 , 验证集损失 : 2.81519
当前step : 17/50 ,训练集损失 : 0.00000 , 验证集损失 : 0.00672
当前step : 18/50 ,训练集损失 : 0.00022 , 验证集损失 : 0.00197
当前step : 19/50 ,训练集损失 : 0.00000 , 验证集损失 : 0.21673
当前step : 2

In [59]:
## 结果预测
img = cv2.imread("./data/3.jpg",cv2.THRESH_BINARY)
img = cv2.resize(img,(28,28)).flatten()
img = torch.from_numpy(img)
predicted = model(img)
print(predicted)

RuntimeError: expected scalar type Byte but found Float