In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 데이터셋 불러오기 및 전처리
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

# CNN 모델 정의
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

# 손실 함수와 옵티마이저 선택
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 모델 학습
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

for epoch in range(5):  # 데이터셋을 여러번 반복하여 학습합니다.

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # 100 미니배치마다 손실 출력
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0

# 모델 평가
correct = 0
total = 0


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:01<00:00, 8709355.55it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 29574144.00it/s]

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 5725777.16it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 4602688.76it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

[1,   100] loss: 1.081
[1,   200] loss: 0.315
[1,   300] loss: 0.194
[1,   400] loss: 0.145
[1,   500] loss: 0.138
[1,   600] loss: 0.119
[1,   700] loss: 0.098
[1,   800] loss: 0.088
[1,   900] loss: 0.080
[2,   100] loss: 0.072
[2,   200] loss: 0.072
[2,   300] loss: 0.061
[2,   400] loss: 0.068
[2,   500] loss: 0.068
[2,   600] loss: 0.071
[2,   700] loss: 0.051
[2,   800] loss: 0.056
[2,   900] loss: 0.062
[3,   100] loss: 0.047
[3,   200] loss: 0.046
[3,   300] loss: 0.041
[3,   400] loss: 0.048
[3,   500] loss: 0.044
[3,   600] loss: 0.035
[3,   700] loss: 0.050
[3,   800] loss: 0.043
[3,   900] loss: 0.050
[4,   100] loss: 0.033
[4,   200] loss: 0.029
[4,   300] loss: 0.039
[4,   400] loss: 0.032
[4,   500] loss: 0.039
[4,   600] loss: 0.031
[4,   700] loss: 0.034
[4,   800] loss: 0.040
[4,   900] loss: 0.036
[5,   100] loss: 0.021
[5,   200] loss: 0.024
[5,   300] loss: 0.032
[5,   400] loss: 0.029
[5,  