In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 입력 정규화 + 데이터 초기화(증강)
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self._initialize_weights()

    def _initialize_weights(self):
        # He 초기화
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
        # Xavier 초기화
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)

model = CNNModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

epochs = 20
for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += loss_fn(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"Epoch {epoch+1}: Test Accuracy: {accuracy:.2f}%")


100%|██████████| 26.4M/26.4M [00:00<00:00, 109MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 4.07MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 60.5MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 11.6MB/s]


Epoch 1: Test Accuracy: 83.84%
Epoch 2: Test Accuracy: 85.90%
Epoch 3: Test Accuracy: 87.05%
Epoch 4: Test Accuracy: 87.86%
Epoch 5: Test Accuracy: 88.16%
Epoch 6: Test Accuracy: 88.86%
Epoch 7: Test Accuracy: 88.89%
Epoch 8: Test Accuracy: 89.01%
Epoch 9: Test Accuracy: 89.35%
Epoch 10: Test Accuracy: 89.92%
Epoch 11: Test Accuracy: 90.18%
Epoch 12: Test Accuracy: 90.48%
Epoch 13: Test Accuracy: 90.11%
Epoch 14: Test Accuracy: 90.22%
Epoch 15: Test Accuracy: 91.02%
Epoch 16: Test Accuracy: 90.62%
Epoch 17: Test Accuracy: 90.91%
Epoch 18: Test Accuracy: 91.13%
Epoch 19: Test Accuracy: 91.16%
Epoch 20: Test Accuracy: 91.29%
