In [1]:
from google.colab import drive
drive.mount('/content/gdrive/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


# 1. CNN 구성

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image, UnidentifiedImageError
from collections import defaultdict
import random

class ImprovedCNN(nn.Module):
    def __init__(self, num_classes):
        super(ImprovedCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Dropout(0.3)
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x


# 2. 레이블 인코더

In [9]:
import json

class_names = [
    "코리안 숏헤어",      # 0
    "러시안블루",        # 1
    "페르시안",          # 2
    "스코티쉬 스트레이드", # 3
    "스코티쉬 폴드",      # 4
    "시암",              # 5
    "터키시 앙고라",      # 6
    "먼치킨",            # 7
    "브리티쉬 숏헤어",    # 8
    "래그돌"             # 9
]

inverse_label_encoder = {idx: name for idx, name in enumerate(class_names)}


with open("/content/gdrive/MyDrive/03. KOREA UNI/2. Deep Learning/프로젝트/inverse_label_encoder.json", "w", encoding="utf-8") as f:
  json.dump(inverse_label_encoder, f, ensure_ascii=False, indent=4)

with open("/content/gdrive/MyDrive/03. KOREA UNI/2. Deep Learning/프로젝트/inverse_label_encoder.json", "r", encoding="utf-8") as f:
  inverse_label_encoder = json.load(f)


# 3. 추론

In [7]:
from torchvision import transforms
def predict(image_path, model, label_encoder, transform, device):
    image = Image.open(image_path).convert('RGB')
    img_t = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_t)
        _, pred = torch.max(outputs, 1)
        breed = label_encoder.get(pred.cpu().numpy()[0].item())
    return breed

# 이미지 변환 정의
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])


test_data_list =[
'KakaoTalk_Photo_2025-06-07-17-46-36 004.jpeg'
,'KakaoTalk_Photo_2025-06-07-17-46-36 003.jpeg'
,'KakaoTalk_Photo_2025-06-07-17-46-35 002.jpeg'
,'KakaoTalk_Photo_2025-06-07-17-46-35 001.jpeg']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedCNN(num_classes=10).to(device)
model.load_state_dict(torch.load( "/content/gdrive/MyDrive/03. KOREA UNI/2. Deep Learning/프로젝트/MODEL/cnn_model_top10_v3.pth", map_location=device))

model.eval()

with torch.no_grad():
    for img in test_data_list:
        test_image_path = '/content/gdrive/MyDrive/03. KOREA UNI/2. Deep Learning/프로젝트/TEST_DATA/' + img
        predicted_breed = predict(test_image_path, model, inverse_label_encoder, transform, device)
        print(f"{img} → 예측 품종: {predicted_breed}")



KakaoTalk_Photo_2025-06-07-17-46-36 004.jpeg → 예측 품종: 먼치킨
KakaoTalk_Photo_2025-06-07-17-46-36 003.jpeg → 예측 품종: 페르시안
KakaoTalk_Photo_2025-06-07-17-46-35 002.jpeg → 예측 품종: 코리안 숏헤어
KakaoTalk_Photo_2025-06-07-17-46-35 001.jpeg → 예측 품종: 코리안 숏헤어
