# 이미지 기반 스팸 탐지를 위한 ViT Fine-tuning 🔍

이 노트북은 Vision Transformer (ViT)를 사용하여 이미지 기반 스팸을 탐지하는 모델을 구현합니다.

## 데이터셋
- Kaggle Spam Image Dataset

## 1. 환경 설정 및 라이브러리 설치

In [None]:
!pip install torch torchvision transformers scikit-learn pandas numpy pillow tqdm kagglehub
!nvidia-smi  # GPU 확인

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import pandas as pd
import numpy as np
from PIL import Image
import os
from tqdm import tqdm
import kagglehub

## 2. 데이터셋 다운로드 및 준비

In [None]:
# Kaggle 데이터셋 다운로드
dataset_path = kagglehub.dataset_download("asifjamal123/spam-image-dataset")
print("Dataset path:", dataset_path)

In [None]:
class SpamImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.valid_extensions = ('.jpg', '.jpeg', '.png')
        
        # 스팸 이미지 로드 (label=1)
        spam_dir = os.path.join(image_dir, 'spam')
        for img_name in os.listdir(spam_dir):
            if img_name.lower().endswith(self.valid_extensions):
                img_path = os.path.join(spam_dir, img_name)
                try:
                    # 이미지 유효성 검사
                    with Image.open(img_path) as img:
                        img.verify()
                    self.images.append(img_path)
                    self.labels.append(1)
                except Exception as e:
                    print(f"Warning: Skipping corrupted image {img_path}: {str(e)}")
        
        # 정상 이미지 로드 (label=0)
        ham_dir = os.path.join(image_dir, 'ham')
        for img_name in os.listdir(ham_dir):
            if img_name.lower().endswith(self.valid_extensions):
                img_path = os.path.join(ham_dir, img_name)
                try:
                    # 이미지 유효성 검사
                    with Image.open(img_path) as img:
                        img.verify()
                    self.images.append(img_path)
                    self.labels.append(0)
                except Exception as e:
                    print(f"Warning: Skipping corrupted image {img_path}: {str(e)}")
        
        print(f"Successfully loaded {len(self.images)} valid images")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        try:
            image = Image.open(image_path).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception as e:
            print(f"Error loading image {image_path} at runtime: {str(e)}")
            # 에러 발생 시 대체 이미지 반환 (검은색 이미지)
            if self.transform:
                return torch.zeros((3, 224, 224)), self.labels[idx]
            return Image.new('RGB', (224, 224), 'black'), self.labels[idx]

## 3. ViT 모델 설정

In [None]:
class ViTSpamClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.vit = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224',
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
    
    def forward(self, pixel_values):
        outputs = self.vit(pixel_values)
        return outputs.logits

## 4. 학습 함수 정의

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    
    return total_loss / len(dataloader), accuracy, precision, recall, f1

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    
    return total_loss / len(dataloader), accuracy, precision, recall, f1

## 5. 모델 학습

In [None]:
# 하이퍼파라미터 설정
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 2e-5

# 데이터 변환
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터셋 및 데이터로더 생성
dataset = SpamImageDataset(dataset_path, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# 모델 초기화
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTSpamClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# 학습 루프
best_val_f1 = 0
for epoch in range(EPOCHS):
    print(f'\nEpoch {epoch+1}/{EPOCHS}')
    
    # 학습
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, device
    )
    print(f'Train Loss: {train_loss:.4f}')
    print(f'Train Metrics - Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Rec: {train_rec:.4f}, F1: {train_f1:.4f}')
    
    # 검증
    val_loss, val_acc, val_prec, val_rec, val_f1 = validate(
        model, val_loader, criterion, device
    )
    print(f'Val Loss: {val_loss:.4f}')
    print(f'Val Metrics - Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Rec: {val_rec:.4f}, F1: {val_f1:.4f}')
    
    # 최고 성능 모델 저장
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_vit_spam_classifier.pth')
        print('Model saved!')

## 6. 모델 평가 및 분석

In [None]:
# 최고 성능 모델 로드
best_model = ViTSpamClassifier().to(device)
best_model.load_state_dict(torch.load('best_vit_spam_classifier.pth'))

# 전체 검증 세트에 대한 상세 평가
val_loss, val_acc, val_prec, val_rec, val_f1 = validate(best_model, val_loader, criterion, device)

print('\nFinal Evaluation Results:')
print(f'Accuracy: {val_acc:.4f}')
print(f'Precision: {val_prec:.4f}')
print(f'Recall: {val_rec:.4f}')
print(f'F1-Score: {val_f1:.4f}')

In [None]:
# 테스트할 이미지 경로 설정 (예: 특정 디렉토리의 모든 이미지)
test_dir = "/content/images/"  # 테스트할 이미지가 있는 디렉토리 경로로 변경해주세요
image_paths = []
for ext in ['.jpg', '.jpeg', '.png', '.gif']:
    image_paths.extend(glob.glob(os.path.join(test_dir, f'*{ext}')))
    image_paths.extend(glob.glob(os.path.join(test_dir, f'*{ext.upper()}')))

if not image_paths:
    print("테스트할 이미지를 찾을 수 없습니다.")
else:
    print(f"총 {len(image_paths)}개의 이미지를 처리합니다.")
    
    # 이미지 예측 수행
    results = predict_images(best_model, image_paths, device)
    
    # 결과 출력
    print("\n예측 결과:")
    print("-" * 80)
    print(f"{'파일명':<40} {'정상 확률':>10} {'스팸 확률':>10} {'판정':>10}")
    print("-" * 80)
    
    for result in results:
        print(f"{result['filename']:<40} {result['normal_prob']:>10.4f} {result['spam_prob']:>10.4f} {result['prediction']:>10}")
    
    # 통계 출력
    spam_count = sum(1 for r in results if r['prediction'] == '스팸')
    normal_count = len(results) - spam_count
    
    print("\n통계:")
    print(f"전체 이미지: {len(results)}개")
    print(f"정상 이미지: {normal_count}개 ({normal_count/len(results)*100:.1f}%)")
    print(f"스팸 이미지: {spam_count}개 ({spam_count/len(results)*100:.1f}%)")