In [5]:
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset
import os



# 查看是否有cuda如果没有，则用cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 将输入数据标准化,ToTensor将数据转换为张量，Normalize将数据标准化，其中0.13047是均值，0.3081是方差，这两个数据是经验值
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='../data/MNIST/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64)

test_dataset = datasets.MNIST(root='../data/MNIST/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=64)


# 定义网络结构
class FC(torch.nn.Module):
    def __init__(self):
        super(FC, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.l1(x)
        x = self.relu(x)

        x = self.l2(x)
        x = self.relu(x)

        x = self.l3(x)
        x = self.relu(x)

        x = self.l4(x)
        x = self.relu(x)

        x = self.l5(x)

        return x


# 生成全连接网络模型实例
model = FC()
model.to(device)

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)


# 训练函数
def train(epoch):
    model.train()
    train_bar = tqdm(train_loader)
    for data in train_bar:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        # 梯度清零
        optimizer.zero_grad()
        # 正向传播
        outputs = model(images)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        # 权重更新
        optimizer.step()
        # 进度条描述训练进度
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)

     

# 验证函数
def validate(epoch):
    correct = 0
    total = 0
    with torch.no_grad():
        test_bar = tqdm(test_loader)
        for data in test_bar:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            # 得到预测值
            _, predicted = torch.max(outputs.data, dim=1)
            # 判断是否预测正确
            correct += (predicted == labels).sum().item()

            total += labels.size(0)

            # 进度条描述训练进度
            test_bar.desc = "validate epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        print('accuracy on validate set:%d %%\n' % (100 * correct / total))


if __name__ == '__main__':
    # 训练周期
    epochs = 20

    for i in range(epochs):
        train(i)

        validate(i)

    torch.save(model.state_dict(), "fc_trained_model.pth")

train epoch[1/10] loss:2.295:   1%|▍                                                   | 9/938 [00:00<00:11, 83.08it/s]

using cpu device.


train epoch[1/10] loss:2.298: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 92.81it/s]
validate epoch[1/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 128.56it/s]
train epoch[2/10] loss:2.294:   1%|▍                                                   | 9/938 [00:00<00:10, 85.48it/s]

accuracy on validate set:14 %



train epoch[2/10] loss:2.281: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 90.70it/s]
validate epoch[2/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 123.29it/s]
train epoch[3/10] loss:2.277:   1%|▍                                                   | 9/938 [00:00<00:10, 84.68it/s]

accuracy on validate set:19 %



train epoch[3/10] loss:2.274: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 91.88it/s]
validate epoch[3/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 124.57it/s]
train epoch[4/10] loss:2.257:   1%|▍                                                   | 9/938 [00:00<00:10, 85.49it/s]

accuracy on validate set:24 %



train epoch[4/10] loss:2.209: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 91.83it/s]
validate epoch[4/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 124.67it/s]
train epoch[5/10] loss:2.194:   1%|▍                                                   | 9/938 [00:00<00:10, 87.57it/s]

accuracy on validate set:44 %



train epoch[5/10] loss:2.045: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 92.21it/s]
validate epoch[5/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 124.37it/s]
train epoch[6/10] loss:2.102:   1%|▍                                                   | 9/938 [00:00<00:10, 88.87it/s]

accuracy on validate set:50 %



train epoch[6/10] loss:1.913: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 91.97it/s]
validate epoch[6/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 124.22it/s]
train epoch[7/10] loss:1.924:   1%|▍                                                   | 9/938 [00:00<00:11, 84.28it/s]

accuracy on validate set:56 %



train epoch[7/10] loss:1.474: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 92.57it/s]
validate epoch[7/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 123.78it/s]
train epoch[8/10] loss:1.226:   1%|▍                                                   | 9/938 [00:00<00:10, 85.08it/s]

accuracy on validate set:65 %



train epoch[8/10] loss:0.894: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 92.84it/s]
validate epoch[8/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 123.29it/s]
train epoch[9/10] loss:0.987:   1%|▍                                                   | 9/938 [00:00<00:11, 81.60it/s]

accuracy on validate set:74 %



train epoch[9/10] loss:0.760: 100%|██████████████████████████████████████████████████| 938/938 [00:10<00:00, 93.27it/s]
validate epoch[9/10]: 100%|█████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 124.27it/s]
train epoch[10/10] loss:0.813:   1%|▍                                                  | 9/938 [00:00<00:10, 87.57it/s]

accuracy on validate set:79 %



train epoch[10/10] loss:0.752: 100%|█████████████████████████████████████████████████| 938/938 [00:10<00:00, 92.98it/s]
validate epoch[10/10]: 100%|████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 121.01it/s]

accuracy on validate set:83 %






In [9]:
from PIL import Image
from torch.utils.data import Dataset
import os

class MyMnistDataset(Dataset):
    def __init__(self, root, transform):

        self.myMnistPath = root
        self.imagesData = []
        self.labelsData = []
        self.labelsDict = {}
        self.trans = transform

        self.loadLabelsDate()
        self.loadImageData()

    # 读取标签txt文件，并生成字典
    def loadLabelsDate(self):
        labelsPath = os.path.join(self.myMnistPath, "labels", "labels.txt")
        f = open(labelsPath)
        lines = f.readlines()
        for line in lines:
            name = line.split(' ')[0]
            label = line.split(' ')[1]
            self.labelsDict[name] = int(label)

    # 读取手写图片数据，并将图片数据和对应的标签组合在一起
    def loadImageData(self):
        imagesFolderPath = os.path.join(self.myMnistPath, 'images')
        imageFiles = os.listdir(imagesFolderPath)

        for imageName in imageFiles:
            imagePath = os.path.join(imagesFolderPath, imageName)
            image = Image.open(imagePath)
            grayImage = image.convert("L")

            imageTensor = self.trans(grayImage)
            self.imagesData.append(imageTensor)

            self.labelsData.append(self.labelsDict[imageName])

        self.labelsData = torch.Tensor(self.labelsData)

    # 重写魔法函数
    def __getitem__(self, index):
        return self.imagesData[index], self.labelsData[index]

    # 重写魔法函数
    def __len__(self):
        return len(self.labelsData)
    
transform = transforms.Compose([
    transforms.Resize([28, 28]),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

# 载入自己的数据集
dataset = MyMnistDataset(root='../data/my_mnist_dateset', transform=transform)
test_loader = DataLoader(dataset=dataset, shuffle=False)

# 生成全连接神经网络并载入训练好的模型
model = FC()
model.load_state_dict(torch.load("fc_trained_model.pth"))


def test():
    correct = 0
    total = 0
    print("label       predicted")
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            print("{}          {}".format(int(labels.item()), predicted.data.item()))

        print('FC trained model: accuracy on mymnist set:%d %%' % (100 * correct / total))


if __name__ == '__main__':
    test()

label       predicted
6          5
6          6
0          0
0          2
3          3
3          9
9          7
9          1
8          8
8          2
5          3
5          8
7          1
7          1
2          2
2          2
4          1
4          9
1          1
1          1
FC trained model: accuracy on mymnist set:40 %
