In [1]:
import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import cv2
import numpy as np

In [2]:
# data
train_data = datasets.MNIST(
    "./data", train=True, download=True, transform=transforms.ToTensor()
)
test_data = datasets.MNIST(
    "./data", train=False, download=True, transform=transforms.ToTensor()
)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)
print(len(train_data), len(test_data))
print(train_data[0][0].shape)
print(train_data[0][1])
print(len(train_loader))
# print(train_loader[0][0].shape)
# print(train_loader[0])

60000 10000
torch.Size([1, 28, 28])
5
938


In [3]:
# net
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
        )
        self.fc1 = torch.nn.Linear(32 * 14 * 14, 10)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        return out


model = CNN()

In [4]:
# loss
loss_fc = torch.nn.CrossEntropyLoss()
# optim
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [8]:
# train
for epoch in range(10):
    correct = 0
    total = 0
    loss_total = 0
    for i, (images, labels) in enumerate(train_loader):
        outputs = model(images)
        loss = loss_fc(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += images.shape[0]

        _, pred = torch.max(outputs.data, 1)
        # print(outputs.shape)
        # print(outputs.data.shape)
        # print(pred)
        correct += (pred == labels).sum().item()
        loss_total += loss.item()

    print(
        f"Epoch : {epoch}, loss : {(loss_total / len(train_loader)) :.4f}, accrancy : {(correct / total) :.4f}"
    )
torch.save(model.state_dict(), "./model/mnist_cnn.pth")

Epoch : 0, loss : 0.1063, accrancy : 0.9729
Epoch : 1, loss : 0.0987, accrancy : 0.9743
Epoch : 2, loss : 0.0927, accrancy : 0.9759
Epoch : 3, loss : 0.0873, accrancy : 0.9774
Epoch : 4, loss : 0.0829, accrancy : 0.9785
Epoch : 5, loss : 0.0789, accrancy : 0.9796
Epoch : 6, loss : 0.0753, accrancy : 0.9807
Epoch : 7, loss : 0.0723, accrancy : 0.9812
Epoch : 8, loss : 0.0696, accrancy : 0.9821
Epoch : 9, loss : 0.0668, accrancy : 0.9825


In [6]:
model.eval()
loss_test = 0
correct_test = 0
accurancy_test = 0
total = 0

with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader):
        outputs = model(images)
        loss_test += loss_fc(outputs, labels).item()
        _, pred = torch.max(outputs.data, 1)
        total += outputs.shape[0]
        correct_test += (pred == labels).sum().item()

        # print(outputs.data)
        # print(pred)
        # print(labels)

        for i in range(outputs.shape[0]):
            print(f"pred : {pred[i]}, label : {labels[i]}")
            # img = images[i].numpy()
            # img = np.transpose(img, (1, 2, 0))
            # cv2.imshow("img", img)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
    print(
        f"Test result: {(loss_test / len(test_data)):.4f}, acc : {correct_test / total}"
    )

pred : 6, label : 6
pred : 6, label : 6
pred : 3, label : 3
pred : 3, label : 3
pred : 1, label : 1
pred : 6, label : 6
pred : 0, label : 0
pred : 7, label : 7
pred : 1, label : 1
pred : 2, label : 2
pred : 9, label : 9
pred : 9, label : 9
pred : 1, label : 1
pred : 8, label : 8
pred : 6, label : 6
pred : 8, label : 8
pred : 1, label : 1
pred : 4, label : 4
pred : 2, label : 7
pred : 9, label : 9
pred : 6, label : 6
pred : 9, label : 9
pred : 2, label : 2
pred : 8, label : 8
pred : 7, label : 7
pred : 3, label : 3
pred : 7, label : 7
pred : 8, label : 8
pred : 0, label : 6
pred : 8, label : 8
pred : 7, label : 7
pred : 0, label : 0
pred : 0, label : 0
pred : 0, label : 0
pred : 5, label : 5
pred : 0, label : 0
pred : 6, label : 6
pred : 5, label : 5
pred : 0, label : 0
pred : 4, label : 4
pred : 8, label : 8
pred : 3, label : 3
pred : 9, label : 9
pred : 9, label : 9
pred : 0, label : 0
pred : 1, label : 1
pred : 8, label : 8
pred : 7, label : 7
pred : 4, label : 4
pred : 7, label : 7


In [12]:
model_train = CNN()
model_train.load_state_dict(torch.load("./model/mnist_cnn.pth"))
print(model_train)
model_train.eval()
loss_test = 0
correct = 0
correct_total = 0
acc = 0
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader):
        outputs = model_train(images)
        loss_test += loss_fc(outputs, labels).item()
        correct_total += outputs.shape[0]
        _, pred = torch.max(outputs, 1)
        correct += (pred == labels).sum().item()
        for i in range(outputs.shape[0]):
            print(f"pred : {pred[i]}, label : {labels[i]}")
    print(
        f"Test result: {(loss_test / len(test_data)):.4f}, acc : {correct / correct_total}"
    )

  model_train.load_state_dict(torch.load("./model/mnist_cnn.pth"))


CNN(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=6272, out_features=10, bias=True)
)
pred : 2, label : 2
pred : 1, label : 1
pred : 9, label : 9
pred : 1, label : 1
pred : 7, label : 7
pred : 2, label : 2
pred : 7, label : 7
pred : 3, label : 3
pred : 3, label : 3
pred : 7, label : 7
pred : 3, label : 3
pred : 5, label : 5
pred : 8, label : 8
pred : 5, label : 5
pred : 0, label : 0
pred : 7, label : 7
pred : 4, label : 4
pred : 2, label : 2
pred : 5, label : 5
pred : 8, label : 8
pred : 1, label : 1
pred : 0, label : 0
pred : 5, label : 5
pred : 4, label : 4
pred : 3, label : 3
pred : 1, label : 1
pred : 2, label : 2
pred : 5, label : 5
pred : 7, label : 7
pred : 4, label : 4
pred : 2, label : 2
pred : 8, label : 8
pr

In [15]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(device)


model_train = CNN()

model_train.to(device)
model_train.load_state_dict(torch.load("./model/mnist_cnn.pth"))
print(model_train)
model_train.eval()
loss_test = 0
correct = 0
correct_total = 0
acc = 0
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model_train(images)
        loss_test += loss_fc(outputs, labels).item()
        correct_total += outputs.shape[0]
        _, pred = torch.max(outputs, 1)
        correct += (pred == labels).sum().item()
        for i in range(outputs.shape[0]):
            print(f"pred : {pred[i]}, label : {labels[i]}")
    print(
        f"Test result: {(loss_test / len(test_data)):.4f}, acc : {correct / correct_total}"
    )

cuda
CNN(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=6272, out_features=10, bias=True)
)


  model_train.load_state_dict(torch.load("./model/mnist_cnn.pth"))


pred : 4, label : 4
pred : 1, label : 1
pred : 8, label : 8
pred : 1, label : 1
pred : 8, label : 8
pred : 8, label : 8
pred : 8, label : 8
pred : 1, label : 1
pred : 9, label : 9
pred : 9, label : 9
pred : 1, label : 1
pred : 3, label : 3
pred : 1, label : 1
pred : 3, label : 3
pred : 6, label : 6
pred : 4, label : 4
pred : 9, label : 9
pred : 8, label : 8
pred : 4, label : 4
pred : 6, label : 6
pred : 0, label : 0
pred : 1, label : 1
pred : 1, label : 1
pred : 0, label : 0
pred : 4, label : 4
pred : 3, label : 3
pred : 5, label : 5
pred : 8, label : 8
pred : 3, label : 3
pred : 2, label : 2
pred : 8, label : 8
pred : 5, label : 5
pred : 0, label : 6
pred : 9, label : 9
pred : 4, label : 4
pred : 4, label : 4
pred : 5, label : 5
pred : 2, label : 2
pred : 1, label : 1
pred : 6, label : 6
pred : 8, label : 8
pred : 6, label : 6
pred : 1, label : 1
pred : 2, label : 2
pred : 7, label : 7
pred : 9, label : 9
pred : 9, label : 9
pred : 1, label : 1
pred : 3, label : 3
pred : 1, label : 1
