In [2]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

# 간단한 CNN 모델 정의
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(2)(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        return x

# MNIST 데이터셋 로드 및 전처리
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 모델 초기화
model = SimpleCNN()

# 손실 함수 및 최적화 기법 설정
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 모델 학습
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# 학습된 모델 저장
torch.save(model.state_dict(), 'mnist_cnn_model.pt')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:03<00:00, 3098717.39it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 28563002.55it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 1662421.54it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch 1, Batch 100, Loss: 0.504
Epoch 1, Batch 200, Loss: 0.163
Epoch 1, Batch 300, Loss: 0.095
Epoch 1, Batch 400, Loss: 0.093
Epoch 1, Batch 500, Loss: 0.088
Epoch 1, Batch 600, Loss: 0.072
Epoch 1, Batch 700, Loss: 0.069
Epoch 1, Batch 800, Loss: 0.063
Epoch 1, Batch 900, Loss: 0.053
Epoch 2, Batch 100, Loss: 0.042
Epoch 2, Batch 200, Loss: 0.043
Epoch 2, Batch 300, Loss: 0.031
Epoch 2, Batch 400, Loss: 0.057
Epoch 2, Batch 500, Loss: 0.037
Epoch 2, Batch 600, Loss: 0.043
Epoch 2, Batch 700, Loss: 0.036
Epoch 2, Batch 800, Loss: 0.041
Epoch 2, Batch 900, Loss: 0.040
Epoch 3, Batch 100, Loss: 0.025
Epoch 3, Batch 200, Loss: 0.024
Epoch 3, Batch 300, Loss: 0.020
Epoch 3, Batch 400, Loss: 0.023
Epoch 3, Batch 500, Loss: 0.021
Epoch 3, Batch 600, Loss: 0.025
Epoch 3, Batch 700, Loss: 0.032
Epoch 3, Batch 800, Loss: 0.030
Epoch 3, Batch 900, Loss: 0.028
Epoch 4, Batch 100, Loss: 0.013
Epoch 4, Batch 200, Loss: 0.0