In [3]:
import os
from pathlib import Path
import torch
import wandb
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import transforms
import multiprocessing
import platform

In [4]:
# 필요한 함수 정의
def get_num_cpu_cores():
    """시스템의 CPU 코어 수 반환"""
    return multiprocessing.cpu_count()

def is_linux():
    """현재 OS가 리눅스인지 확인"""
    return platform.system().lower() == "linux"

def is_windows():
    """현재 OS가 윈도우인지 확인"""
    return platform.system().lower() == "windows"

# 현재 경로 설정
BASE_PATH = str(Path(os.getcwd()).resolve())  # 현재 작업 디렉토리를 BASE_PATH로 설정
print("BASE_PATH:", BASE_PATH)

BASE_PATH: /content


In [5]:
# Mean과 Std 계산 함수
def calculate_mean_std(data_loader):
    """
    전체 데이터셋의 Mean과 Std를 계산하는 함수
    """
    mean = 0.0
    std = 0.0
    total_samples = 0

    # 데이터셋을 순회하며 Mean과 Std 계산
    for images, _ in data_loader:
        # 배치 크기 확인
        batch_samples = images.size(0)
        total_samples += batch_samples

        # (B, C, H, W) 형태에서 (B, C*H*W)로 펼치기
        images = images.view(batch_samples, images.size(1), -1)

        # 각 채널에 대해 평균과 표준편차 계산
        mean += images.mean([0, 2]).sum(0)
        std += images.std([0, 2]).sum(0)

    # 총 샘플 수로 나누어 최종 평균 및 표준편차 계산
    mean /= total_samples
    std /= total_samples

    return mean, std

In [6]:
# Mean과 Std 계산 함수
def calculate_mean_std(data_loader):
    """
    전체 데이터셋의 Mean과 Std를 계산하는 함수
    """
    mean = 0.0
    std = 0.0
    total_samples = 0

    # 데이터셋을 순회하며 Mean과 Std 계산
    for images, _ in data_loader:
        # 배치 크기 확인
        batch_samples = images.size(0)
        total_samples += batch_samples

        # (B, C, H, W) 형태에서 (B, C*H*W)로 펼치기
        images = images.view(batch_samples, images.size(1), -1)

        # 각 채널에 대해 평균과 표준편차 계산
        mean += images.mean([0, 2]).sum(0)
        std += images.std([0, 2]).sum(0)

    # 총 샘플 수로 나누어 최종 평균 및 표준편차 계산
    mean /= total_samples
    std /= total_samples

    return mean, std

In [7]:
# Fashion MNIST 데이터셋 로드 (학습 및 검증)
def get_fashion_mnist_data():
    data_path = os.path.join(BASE_PATH, "_00_data", "j_fashion_mnist")

    # Fashion MNIST 데이터 로드
    f_mnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transforms.ToTensor())
    f_mnist_train, f_mnist_validation = random_split(f_mnist_train, [55_000, 5_000])

    print("Num Train Samples: ", len(f_mnist_train))
    print("Num Validation Samples: ", len(f_mnist_validation))
    print("Sample Shape: ", f_mnist_train[0][0].shape)  # torch.Size([1, 28, 28])

    # 데이터 로더 생성
    train_data_loader = DataLoader(
        dataset=f_mnist_train, batch_size=wandb.config.batch_size, shuffle=True, pin_memory=True, num_workers=get_num_cpu_cores()
    )

    validation_data_loader = DataLoader(
        dataset=f_mnist_validation, batch_size=wandb.config.batch_size, pin_memory=True, num_workers=get_num_cpu_cores()
    )

    # Mean과 Std 계산
    mean, std = calculate_mean_std(train_data_loader)
    print(f"Dataset Mean: {mean}, Std: {std}")

    # 데이터 정규화를 위한 변환 정의
    f_mnist_transforms = nn.Sequential(
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=mean.tolist(), std=std.tolist()),  # 계산된 Mean과 Std 사용
    )

    return train_data_loader, validation_data_loader, f_mnist_transforms, mean, std

In [8]:

# Fashion MNIST 데이터셋 로드 (테스트)
def get_fashion_mnist_test_data(mean, std):
    data_path = os.path.join(BASE_PATH, "_00_data", "j_fashion_mnist")

    f_mnist_test_images = datasets.FashionMNIST(data_path, train=False, download=True)
    f_mnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transforms.ToTensor())

    print("Num Test Samples: ", len(f_mnist_test))
    print("Sample Shape: ", f_mnist_test[0][0].shape)  # torch.Size([1, 28, 28])

    # 데이터 로더 생성
    test_data_loader = DataLoader(dataset=f_mnist_test, batch_size=64, pin_memory=True, num_workers=get_num_cpu_cores())

    # 테스트 데이터 정규화를 위한 변환 정의
    f_mnist_transforms = nn.Sequential(
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=mean.tolist(), std=std.tolist()),  # 학습 데이터의 Mean과 Std 사용
    )

    return f_mnist_test_images, test_data_loader, f_mnist_transforms

In [9]:
import os
import wandb

# API Key 수동 설정 (선택 사항)
os.environ["WANDB_API_KEY"] = "e4b00d9b8180a4d284aef5d6cf9ac7be13501e1e"

# WandB 로그인
wandb.login(relogin=True)

# WandB 초기화
wandb.init(
    project="fashion_mnist_project",
    config={
        "batch_size": 2048,
        "learning_rate": 0.001,
        "epochs": 10
    }
)

train_data_loader, validation_data_loader, f_mnist_transforms, mean, std = get_fashion_mnist_data()
print()
f_mnist_test_images, test_data_loader, f_mnist_transforms = get_fashion_mnist_test_data(mean, std)





Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 13.4MB/s]


Extracting /content/_00_data/j_fashion_mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 201kB/s]


Extracting /content/_00_data/j_fashion_mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 3.71MB/s]


Extracting /content/_00_data/j_fashion_mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 5.73MB/s]


Extracting /content/_00_data/j_fashion_mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /content/_00_data/j_fashion_mnist/FashionMNIST/raw

Num Train Samples:  55000
Num Validation Samples:  5000
Sample Shape:  torch.Size([1, 28, 28])
Dataset Mean: 0.00014048324374016374, Std: 0.00017339707119390368

Num Test Samples:  10000
Sample Shape:  torch.Size([1, 28, 28])


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [11]:

# CNN 모델 정의
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 장치 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델, 손실 함수 및 옵티마이저 초기화
model = CNNModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

In [12]:

# 모델 학습 코드 (WandB 설정값 사용)
for epoch in range(wandb.config.epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_data_loader:
        images, labels = images.to(device), labels.to(device)

        # 옵티마이저 초기화
        optimizer.zero_grad()

        # 순전파 및 손실 계산
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 통계 업데이트
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    # 에포크별 결과 WandB 로그 기록
    train_loss /= len(train_data_loader.dataset)
    train_accuracy = 100.0 * correct / total
    wandb.log({"epoch": epoch + 1, "train_loss": train_loss, "train_accuracy": train_accuracy})
    print(f"Epoch {epoch+1}/{wandb.config.epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")

Epoch 1/10, Loss: 1.1898, Accuracy: 61.96%
Epoch 2/10, Loss: 0.6092, Accuracy: 77.54%
Epoch 3/10, Loss: 0.5117, Accuracy: 81.28%
Epoch 4/10, Loss: 0.4562, Accuracy: 83.69%
Epoch 5/10, Loss: 0.4223, Accuracy: 84.87%
Epoch 6/10, Loss: 0.3962, Accuracy: 85.80%
Epoch 7/10, Loss: 0.3744, Accuracy: 86.69%
Epoch 8/10, Loss: 0.3594, Accuracy: 87.23%
Epoch 9/10, Loss: 0.3446, Accuracy: 87.70%
Epoch 10/10, Loss: 0.3326, Accuracy: 88.22%
