In [5]:
import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils

# 1. Data
train_data = dataset.MNIST(root="mnist", 
                           train=True, 
                           transform=transforms.ToTensor(), 
                           download=True)

test_data = dataset.MNIST(root="mnist", 
                           train=False, 
                           transform=transforms.ToTensor(), 
                           download=False)

# batchsize
train_loader = data_utils.DataLoader(dataset=train_data, 
                                     batch_size=64, 
                                     shuffle=True)



test_loader = data_utils.DataLoader(dataset=test_data, 
                                     batch_size=64, 
                                     shuffle=True)


# 2. Net

class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )

        self.fc = torch.nn.Linear(14 * 14 *32, 10)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size()[0], -1)
        out = self.fc(out)
        return out
    
cnn = CNN()
cnn = cnn.cuda()


# 3. loss
loss_func = torch.nn.CrossEntropyLoss()


# 4. optimizer

optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)


# 5. training

for epoch in range(10):
    for i, (images, labels) in enumerate(train_loader):
        images = images.cuda()
        labels = labels.cuda()

        outputs = cnn(images)
        loss = loss_func(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("epoch is {}, ite is {}/{}, loss is {}".format(epoch+1, 
                                                            i, 
                                                            len(train_data) // 64, 
                                                            loss.item()))


# 6. evaluation
    loss_test = 0
    accuracy = 0
    for i, (images, labels) in enumerate(test_loader):
        images = images.cuda()
        labels = labels.cuda()

        outputs = cnn(images)

        loss_test += loss_func(outputs, labels)
        _, predicted = outputs.max(1)
        accuracy += (predicted == labels).sum().item()
    
    accuracy = accuracy/len(test_data)
    loss_test = loss_test / (len(test_data) // 64)

    print("epoch is {}, loss is {}, accuracy is {}".format(epoch+1,
                                                           loss_test.item(),
                                                           accuracy))

# 7. save

torch.save(cnn, "minst_cnn.pkl")

# 8. load 


# 9. inference

epoch is 1, ite is 937/937, loss is 0.07565319538116455
epoch is 1, loss is 0.06827986985445023, accuracy is 0.9785
epoch is 2, ite is 937/937, loss is 0.025464724749326706
epoch is 2, loss is 0.06298859417438507, accuracy is 0.9787
epoch is 3, ite is 937/937, loss is 0.008512557484209538
epoch is 3, loss is 0.04482750594615936, accuracy is 0.9857
epoch is 4, ite is 937/937, loss is 0.010595975443720818
epoch is 4, loss is 0.05589539557695389, accuracy is 0.9829
epoch is 5, ite is 937/937, loss is 0.001842342084273696
epoch is 5, loss is 0.042814433574676514, accuracy is 0.9872
epoch is 6, ite is 937/937, loss is 0.0834626704454422
epoch is 6, loss is 0.06360015273094177, accuracy is 0.9816
epoch is 7, ite is 937/937, loss is 0.001643711468204856
epoch is 7, loss is 0.05787275359034538, accuracy is 0.9823
epoch is 8, ite is 937/937, loss is 0.020839277654886246
epoch is 8, loss is 0.051272936165332794, accuracy is 0.9858
epoch is 9, ite is 937/937, loss is 0.0017867607530206442
epoch i

In [9]:
## 载入model等

import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import cv2


# 2. Net

class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )

        self.fc = torch.nn.Linear(14 * 14 *32, 10)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size()[0], -1)
        out = self.fc(out)
        return out


test_data = dataset.MNIST(root="mnist", 
                           train=False, 
                           transform=transforms.ToTensor(), 
                           download=False)

# batchsize

test_loader = data_utils.DataLoader(dataset=test_data, 
                                     batch_size=64, 
                                     shuffle=True)



cnn = torch.load("minst_cnn.pkl",weights_only=False)
cnn = cnn.cuda()


# 6. evaluation
loss_test = 0
accuracy = 0
for i, (images, labels) in enumerate(test_loader):
    images = images.cuda()
    labels = labels.cuda()
    outputs = cnn(images)
    _, pred = outputs.max(1)
    accuracy += (predicted == labels).sum().item()
    
    images = images.cpu().numpy()
    labels = labels.cpu().numpy()
    pred = pred.cpu().numpy()

    # batchsize * 1 * 28 * 28

    for idx in range(images.shape[0]):
        im_data = images[idx]
        im_label = labels[idx]
        im_pred = pred[idx]
        im_data = im_data.transpose(1, 2, 0)

        print("label", im_label)
        print("pred", im_pred)
        cv2.imshow("imdata", im_data)
        cv2.waitKey(0) 



accuracy = accuracy/len(test_data)
print(accuracy)


# 8. load 


# 9. inference

label 0
pred 0
label 7
pred 9
label 4
pred 4
label 7
pred 7
label 8
pred 8
label 3
pred 3
label 9
pred 9
label 0
pred 0
label 7
pred 2
label 8
pred 8
label 1
pred 1
label 7
pred 9
label 0
pred 0
label 1
pred 1
label 3
pred 3
label 6
pred 6
label 4
pred 4
label 3
pred 3
label 3
pred 3
label 9
pred 9
label 3
pred 3
label 6
pred 6
label 6
pred 6
label 4
pred 4
label 7
pred 7
label 5
pred 5
label 1
pred 1
label 3
pred 3
label 9
pred 9
label 8
pred 8
label 8
pred 8
label 1
pred 1
label 5
pred 5
label 0
pred 0
label 7
pred 7
label 7
pred 7
label 7
pred 7
label 7
pred 7
label 4
pred 4
label 6
pred 6
label 8
pred 8
label 6
pred 6
label 2
pred 2
label 0
pred 0
label 1
pred 1
label 9
pred 9
label 2
pred 2
label 1
pred 1
label 7
pred 7
label 2
pred 2
label 1
pred 1
label 7
pred 7
label 6
pred 6
label 6
pred 6
label 6
pred 6
label 9
pred 9
label 4
pred 4
label 8
pred 8
label 7
pred 7
label 3
pred 3
label 3
pred 3
label 8
pred 8
label 7
pred 7
label 0
pred 0
label 8
pred 8
label 4
pred 4
label 4
pr

KeyboardInterrupt: 