In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [3]:
# ========================================
# 1. CNN 모델 정의 (Flatten Size 자동 계산)
# ========================================

# dummy_input.shape                         # torch.Size([1, 1, 28, 28])
# self.features(dummy_input).shape          # torch.Size([1, 32, 7, 7])
# self.features(dummy_input).view(1, -1)    # torch.Size([1, 1568])
# flatten_dim = 1568

def get_flatten_size(model, input_shape=(1, 28, 28)):
    dummy = torch.zeros(1, *input_shape)  # *input_shape는 unpacking 연산자로 1, 28, 28로 풀림
    output = model(dummy)   # (1, 1, 28, 28) 크기의 0으로 채워진 dummy 이미지
    return output.view(1, -1).shape[1]  # (1, 32x7x7)


class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # (B, 16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # (B, 16, 14, 14)

            nn.Conv2d(16, 32, kernel_size=3, padding=1), # (B, 32, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # (B, 32, 7, 7)
        )

        # Flatten 크기 자동 계산
        flatten_dim = get_flatten_size(self.features, input_shape=(1, 28, 28))
        print(f"Flatten size: {flatten_dim}")

        # Flatten 크기 자동 계산
        # dummy_input = torch.zeros(1, 1, 28, 28)
        # with torch.no_grad():
        #     flatten_dim = self.features(dummy_input).view(1, -1).shape[1]
        #     print('flatten_dim', flatten_dim)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flatten_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [4]:
# ========================================
# 2. 환경 및 데이터 준비
# ========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.ToTensor()

train_data = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
test_data  = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=64, shuffle=False)

len(train_loader), len(test_loader)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 485kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.97MB/s]


(938, 157)

In [5]:
# ========================================
# 3. 모델, 손실 함수, 옵티마이저 정의
# ========================================
model = MNIST_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Flatten size: 1568


In [6]:
# ========================================
# 4. 학습 루프
# ========================================
for epoch in range(1, 10):  # 10 epochs
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() / len(train_loader)

    print(f"Epoch {epoch}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 0.2380
Epoch 2, Loss: 0.0578
Epoch 3, Loss: 0.0411
Epoch 4, Loss: 0.0302
Epoch 5, Loss: 0.0244
Epoch 6, Loss: 0.0190
Epoch 7, Loss: 0.0154
Epoch 8, Loss: 0.0130
Epoch 9, Loss: 0.0105


In [15]:
# ========================================
# 5. 테스트 평가
# ========================================
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        # 예측값 (가장 높은 확률을 갖는 클래스 인덱스)
        _, predicted = torch.max(outputs.data, 1)

        # 누적
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# 정확도 출력
print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 99.09%
