# 03. Traffic Sign Classification - 결과 시각화

학습된 모델을 사용하여 테스트 이미지의 예측 결과를 시각화합니다.

In [None]:
!pip install transformers pillow torch matplotlib -q

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset as TorchDataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from google.colab import drive

drive.mount('/content/drive')

WORK_DIR = '/content/drive/MyDrive/2026_AI_Advanced_Study-main/3차시/02_Traffic_Sign_Classification'
os.chdir(WORK_DIR)

MODEL_PATH = 'runs/classification/final_model'

model = AutoModelForImageClassification.from_pretrained(MODEL_PATH)
processor = AutoImageProcessor.from_pretrained(MODEL_PATH)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(f"✅ 모델 로드 완료 (Device: {device})")

In [None]:
# 커스텀 데이터셋 클래스 정의
class TrafficSignDataset(TorchDataset):
    """로컬 이미지 파일을 로드하는 커스텀 데이터셋"""
    
    def __init__(self, data_dir, processor=None):
        self.data_dir = Path(data_dir)
        self.processor = processor
        self.samples = []
        
        # 클래스별 이미지 수집
        for class_idx in range(5):
            class_dir = self.data_dir / f'class_{class_idx}'
            if class_dir.exists():
                for img_path in sorted(class_dir.glob('*.jpg')):
                    self.samples.append((img_path, class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        return {'image': image, 'label': label}

# 테스트 데이터 로드
test_dataset = TrafficSignDataset('data/images/test', processor=None)
print(f"✅ 테스트 데이터: {len(test_dataset)}장")

In [None]:
# 예측 및 시각화
os.makedirs('runs/classification_pred', exist_ok=True)

num_display = min(12, len(test_dataset))
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()

for idx in range(num_display):
    item = test_dataset[idx]
    image = item['image']
    true_label = item['label']
    
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        pred_label = torch.argmax(outputs.logits, dim=-1).cpu().item()
        confidence = torch.softmax(outputs.logits, dim=-1)[0][pred_label].cpu().item()
    
    axes[idx].imshow(image)
    axes[idx].axis('off')
    
    color = 'green' if pred_label == true_label else 'red'
    title = f"True: Class {true_label}\nPred: Class {pred_label}\nConf: {confidence:.2f}"
    axes[idx].set_title(title, fontsize=10, color=color, fontweight='bold')

plt.tight_layout()
plt.savefig('runs/classification_pred/predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print("✅ 예측 결과 시각화 완료: runs/classification_pred/predictions.png")