In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet50, googlenet
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

In [None]:
# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# 데이터 전처리 및 증강
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),  # ResNet과 GoogLeNet의 입력 크기에 맞춤
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [4]:
# CIFAR-10 데이터셋 로드
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                      download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)



In [5]:
# 데이터셋 레이블
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

## 모델 정의

In [6]:
# resnet50
resnet_model = resnet50(pretrained=True)

# backbone 고정
for param in resnet_model.parameters():
    param.requires_grad = False

num_ftrs = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_ftrs, 10)



In [7]:
# googlenet
googlenet_model = googlenet(pretrained=True)

# backbone 고정
for param in googlenet_model.parameters():
    param.requires_grad = False

num_ftrs = googlenet_model.fc.in_features
googlenet_model.fc = nn.Linear(num_ftrs, 10)



## 모델 학습

In [8]:
# 학습 함수
def train_model(model, criterion, optimizer, num_epochs=5):
    best_acc = 0.0
    train_losses = []
    test_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        from tqdm.auto import tqdm
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 100 == 99:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                train_losses.append(running_loss / 100)
                running_loss = 0.0

        # 테스트 정확도 계산
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        test_accuracies.append(acc)
        print(f'Accuracy on test images: {acc}%')

    return train_losses, test_accuracies

In [9]:
# ResNet 학습
from tqdm.auto import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_model.fc.parameters(), lr=0.001, momentum=0.9)
resnet_losses, resnet_accuracies = train_model(resnet_model.to(device), criterion, optimizer)

# GoogLeNet 학습
from tqdm.auto import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(googlenet_model.fc.parameters(), lr=0.001, momentum=0.9)
googlenet_losses, googlenet_accuracies = train_model(googlenet_model.to(device), criterion, optimizer)

[1, 100] loss: 1.834
[1, 200] loss: 1.293
[1, 300] loss: 1.098
[1, 400] loss: 0.995
[1, 500] loss: 0.915
[1, 600] loss: 0.889
[1, 700] loss: 0.855


KeyboardInterrupt: 

In [None]:
# 결과 시각화
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(resnet_losses, label='ResNet')
plt.plot(googlenet_losses, label='GoogLeNet')
plt.title('Training Loss')
plt.xlabel('Iterations (x100)')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(resnet_accuracies, label='ResNet')
plt.plot(googlenet_accuracies, label='GoogLeNet')
plt.title('Train Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

## 테스트

In [None]:
def test_model(model):
    model.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    print(f'Accuracy on test images: {100 * correct / total}%')
    for i in range(10):
        print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]}%')

print("\nResNet Final Results:")
test_model(resnet_model)

print("\nGoogLeNet Final Results:")
test_model(googlenet_model)