In [14]:
import os
import torch
import wandb
import sys
import torch.optim as optim

from pathlib import Path
from datetime import datetime
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import transforms
from torchinfo import summary

In [15]:
BASE_PATH = str(Path.cwd().resolve().parent.parent)
print(BASE_PATH)

sys.path.append(BASE_PATH)

from _01_code._99_common_utils.utils import get_num_cpu_cores, is_linux, is_windows

C:\Users\ajhaj\git\link_dl


In [16]:
# 데이터셋의 평균(mean)과 표준편차(std)를 계산하는 함수
def calculate_mean_std(dataset):
    data_loader = DataLoader(dataset, batch_size=1024, shuffle=False)  # 데이터 로더 생성 (배치 크기: 1024)

    mean = 0.0  # 평균 초기화
    std = 0.0  # 표준편차 초기화
    total_samples = 0  # 총 샘플 수 초기화

    for data, _ in data_loader:  # 데이터셋에서 배치 단위로 데이터를 가져옴
        batch_samples = data.size(0)  # 현재 배치의 샘플 수
        data = data.view(batch_samples, -1)  # 이미지를 1D 벡터로 변환 (예: [28,28] -> [784])

        mean += data.mean(1).sum().item()  # 현재 배치의 평균을 계산하고 누적
        std += data.std(1).sum().item()  # 현재 배치의 표준편차를 계산하고 누적
        total_samples += batch_samples  # 총 샘플 수 누적

    mean /= total_samples  # 총 샘플 수로 평균 계산
    std /= total_samples  # 총 샘플 수로 표준편차 계산

    return mean, std  # 최종 계산된 평균과 표준편차 반환

# FashionMNIST 데이터셋 로드
data_path = "./j_fashion_mnist"  # 데이터셋 경로
f_mnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transforms.ToTensor())

# 평균(mean)과 표준편차(std) 계산
mean, std = calculate_mean_std(f_mnist_train)

# 결과 출력
print(f"Mean: {mean}")
print(f"Std: {std}")


Mean: 0.28604060643513995
Std: 0.3204533660888672


In [17]:
def get_fashion_mnist_data():
    data_path = os.path.join(BASE_PATH, "_00_data", "j_fashion_mnist")
    transform = transforms.Compose([
        transforms.Resize((64, 64)),  # ResNet 구조에 맞게 이미지 크기 조정
        transforms.ToTensor(),
        transforms.Normalize(mean=0.286, std=0.320),
    ])
    f_mnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
    f_mnist_train, f_mnist_validation = random_split(f_mnist_train, [55_000, 5_000])

    print("Num Train Samples: ", len(f_mnist_train))
    print("Num Validation Samples: ", len(f_mnist_validation))
    print("Sample Shape: ", f_mnist_train[0][0].shape)

    num_data_loading_workers = get_num_cpu_cores() if is_linux() or is_windows() else 0
    print("Number of Data Loading Workers:", num_data_loading_workers)

    train_data_loader = DataLoader(
        dataset=f_mnist_train, batch_size=wandb.config.batch_size, shuffle=True,
        pin_memory=True, num_workers=num_data_loading_workers
    )

    validation_data_loader = DataLoader(
        dataset=f_mnist_validation, batch_size=wandb.config.batch_size,
        pin_memory=True, num_workers=num_data_loading_workers
    )

    return train_data_loader, validation_data_loader

In [18]:
def get_fashion_mnist_test_data():
    data_path = os.path.join(BASE_PATH, "_00_data", "j_fashion_mnist")
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=0.286, std=0.320),
    ])

    f_mnist_test_images = datasets.FashionMNIST(data_path, train=False, download=True)
    f_mnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)

    print("Num Test Samples: ", len(f_mnist_test))
    print("Sample Shape: ", f_mnist_test[0][0].shape)

    test_data_loader = DataLoader(dataset=f_mnist_test, batch_size=wandb.config.batch_size)

    return f_mnist_test_images, test_data_loader

In [19]:
class_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

In [20]:
def get_resnet_model():
    class Residual(nn.Module):
        """The Residual block of ResNet models."""
        def __init__(self, num_channels, use_1x1conv=False, strides=1):
            super().__init__()
            self.conv1 = nn.LazyConv2d(out_channels=num_channels, kernel_size=3, padding=1, stride=strides)
            self.conv2 = nn.LazyConv2d(out_channels=num_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.LazyConv2d(out_channels=num_channels, kernel_size=1, stride=strides)
            else:
                self.conv3 = None
            self.bn1 = nn.LazyBatchNorm2d()
            self.bn2 = nn.LazyBatchNorm2d()

        def forward(self, X):
            Y = torch.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if self.conv3:
                X = self.conv3(X)
            Y += X
            return torch.relu(Y)

    class ResNet(nn.Module):
        def __init__(self, arch, n_outputs=10):
            super(ResNet, self).__init__()
            self.model = nn.Sequential(
                nn.Sequential(
                    nn.LazyConv2d(out_channels=64, kernel_size=7, stride=2, padding=3),
                    nn.LazyBatchNorm2d(),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                )
            )

            for i, (num_residuals, num_channels) in enumerate(arch):
                self.model.add_module(
                    name=f'block_{i}', module=self.block(num_residuals, num_channels, first_block=(i == 0))
                )

            self.model.add_module(
                name='last',
                module=nn.Sequential(
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.LazyLinear(n_outputs)
                )
            )

        def block(self, num_residuals, num_channels, first_block=False):
            blk = []
            for i in range(num_residuals):
                if i == 0 and not first_block:
                    blk.append(Residual(num_channels=num_channels, use_1x1conv=True, strides=2))
                else:
                    blk.append(Residual(num_channels=num_channels))
            return nn.Sequential(*blk)

        def forward(self, x):
            x = self.model(x)
            return x

    my_model = ResNet(arch=((2, 64), (2, 128), (2, 256), (2, 512)), n_outputs=10)

    return my_model

In [21]:
def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0

        # Training
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

        train_accuracy = 100 * correct_train / total_train

        # Validation
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)

        val_accuracy = 100 * correct_val / total_val

        # Logging to wandb
        wandb.log({
            "Epoch": epoch + 1,
            "Training Loss": train_loss / len(train_loader),
            "Validation Loss": val_loss / len(val_loader),
            "Training Accuracy": train_accuracy,
            "Validation Accuracy": val_accuracy
        })

        print(
            f"Epoch {epoch + 1}/{epochs}, "
            f"Train Loss: {train_loss / len(train_loader):.4f}, "
            f"Val Loss: {val_loss / len(val_loader):.4f}, "
            f"Train Acc: {train_accuracy:.2f}%, "
            f"Val Acc: {val_accuracy:.2f}%"
        )

In [22]:
def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0

        # Training
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

        train_accuracy = 100 * correct_train / total_train

        # Validation
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)

        val_accuracy = 100 * correct_val / total_val

        # Logging to wandb
        wandb.log({
            "Epoch": epoch + 1,
            "Training Loss": train_loss / len(train_loader),
            "Validation Loss": val_loss / len(val_loader),
            "Training Accuracy": train_accuracy,
            "Validation Accuracy": val_accuracy
        })

        print(
            f"Epoch {epoch + 1}/{epochs}, "
            f"Train Loss: {train_loss / len(train_loader):.4f}, "
            f"Val Loss: {val_loss / len(val_loader):.4f}, "
            f"Train Acc: {train_accuracy:.2f}%, "
            f"Val Acc: {val_accuracy:.2f}%"
        )

In [23]:
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    wandb.log({"Test Accuracy": accuracy})

    return all_labels, all_preds

In [24]:
def visualize_predictions(model, test_dataset, device):
    model.eval()
    indices = random.sample(range(len(test_dataset)), 10)
    images = []
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for idx in indices:
            image, label = test_dataset[idx]
            images.append(image.squeeze(0).cpu().numpy())
            true_labels.append(label)

            image = image.unsqueeze(0).to(device)
            output = model(image)
            _, predicted = torch.max(output, 1)
            predicted_labels.append(predicted.item())

    # 이미지 출력 및 결과 비교
    for i in range(10):
        plt.figure(figsize=(2, 2))
        plt.imshow(images[i], cmap='gray')
        plt.title(f"Label: {class_labels[true_labels[i]]}\nPredicted: {class_labels[predicted_labels[i]]}")
        plt.axis('off')
        plt.show()

        if true_labels[i] == predicted_labels[i]:
            print(f"Correct Prediction for Image {i+1}")
        else:
            print(f"Incorrect Prediction for Image {i+1}")

In [25]:
if __name__ == "__main__":
    # 현재 시간 기록
    current_time_str = datetime.now().astimezone().strftime('%Y-%m-%d_%H-%M-%S')

    # 설정
    config = {'batch_size': 64, 'epochs': 10, 'learning_rate': 0.001, 'weight_decay': 1e-4}
    wandb.init(
        mode="online",
        project="fashion-mnist-resnet",
        notes="Fashion MNIST with Custom ResNet and Regularization",
        tags=["ResNet", "FashionMNIST", "Dropout", "Weight Decay"],
        name=current_time_str,
        config=config
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 데이터 로드
    train_loader, val_loader = get_fashion_mnist_data()
    test_dataset, test_loader = get_fashion_mnist_test_data()

    # 모델 초기화
    model = get_resnet_model()

    # 모델 구조 출력
    print("\nModel Summary:")
    summary(model, input_size=(wandb.config.batch_size, 1, 64, 64), device=device.type)

    # 손실 함수 및 옵티마이저 정의
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

    # 학습
    train_model(model, train_loader, val_loader, config['epochs'], criterion, optimizer, device)

    # 테스트
    all_labels, all_preds = test_model(model, test_loader, device)

    # 예측 결과 시각화
    visualize_predictions(model, test_dataset, device)

    wandb.finish()

0,1
Epoch,▁█
Training Accuracy,▁█
Training Loss,█▁
Validation Accuracy,▁█
Validation Loss,█▁

0,1
Epoch,2.0
Training Accuracy,90.41818
Training Loss,0.26392
Validation Accuracy,91.56
Validation Loss,0.23454


Num Train Samples:  55000
Num Validation Samples:  5000
Sample Shape:  torch.Size([1, 64, 64])
Number of Data Loading Workers: 16
Num Test Samples:  10000
Sample Shape:  torch.Size([1, 64, 64])

Model Summary:
Epoch 1/10, Train Loss: 0.3945, Val Loss: 0.2871, Train Acc: 85.67%, Val Acc: 89.48%


KeyboardInterrupt: 

Exception in thread ChkStopThr:
Traceback (most recent call last):
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\site-packages\ipykernel\ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\site-packages\wandb\sdk\wandb_run.py", line 305, in check_stop_status
    self._loop_check_status(
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\site-packages\wandb\sdk\wandb_run.py", line 235, in _loop_check_status
    local_handle = request()
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\site-packages\wandb\sdk\interface\interface.py", line 896, in deliver_stop_status
    return self._deliver_stop_status(status)
  File "c:\Users\ajhaj\anaconda3\envs\link_dl\lib\site-packages\wandb\sdk