# 멀티모달 스팸 탐지 모델 프로토타입 🔍

이 노트북은 ViT + BERT + SVM을 사용한 멀티모달 스팸 탐지 모델의 프로토타입 구현입니다.

## 주요 구성 요소
1. 이미지 처리: Vision Transformer (ViT)
2. 텍스트 처리: BERT
3. 분류기: SVM

## 데이터셋
- Dredze Email Dataset (이미지 스팸)
- SpamAssassin Dataset (텍스트 스팸)

## 1. 환경 설정 및 라이브러리 설치

In [None]:
!pip install transformers torch torchvision scikit-learn pandas numpy pillow tqdm
!nvidia-smi  # GPU 확인

In [None]:
import torch
import torch.nn as nn
from transformers import ViTFeatureExtractor, ViTModel
from transformers import BertTokenizer, BertModel
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import pandas as pd
import numpy as np
from PIL import Image
import os
from tqdm import tqdm

## 2. Google Drive 연동 및 데이터 준비

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 프로젝트 디렉토리 생성
!mkdir -p '/content/drive/MyDrive/spam_detection_project/{data,models,checkpoints}'

## 3. 모델 정의

In [None]:
class MultimodalSpamDetector:
    def __init__(self):
        # ViT 초기화
        self.vit_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        
        # BERT 초기화
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # SVM 분류기
        self.svm = SVC(kernel='rbf', probability=True)
        
        # GPU 사용 설정
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.vit.to(self.device)
        self.bert.to(self.device)
    
    def extract_image_features(self, image_path):
        image = Image.open(image_path)
        inputs = self.vit_extractor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.vit(**inputs)
            image_features = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # [CLS] 토큰 사용
        
        return image_features
    
    def extract_text_features(self, text):
        inputs = self.bert_tokenizer(text, padding=True, truncation=True, max_length=512,
                                    return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.bert(**inputs)
            text_features = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # [CLS] 토큰 사용
        
        return text_features
    
    def combine_features(self, image_features, text_features):
        # 단순 concatenation 방식으로 특징 결합
        return np.concatenate([image_features, text_features], axis=1)
    
    def train(self, image_paths, texts, labels):
        combined_features = []
        
        for img_path, text in tqdm(zip(image_paths, texts), total=len(image_paths)):
            image_feat = self.extract_image_features(img_path)
            text_feat = self.extract_text_features(text)
            combined = self.combine_features(image_feat, text_feat)
            combined_features.append(combined.squeeze())
        
        X = np.array(combined_features)
        self.svm.fit(X, labels)
    
    def predict(self, image_path, text):
        image_feat = self.extract_image_features(image_path)
        text_feat = self.extract_text_features(text)
        combined = self.combine_features(image_feat, text_feat)
        
        return self.svm.predict(combined), self.svm.predict_proba(combined)

## 4. 데이터 로딩 및 전처리

In [None]:
def load_data(data_dir):
    """데이터 로딩 함수 - 실제 데이터셋에 맞게 수정 필요"""
    image_paths = []
    texts = []
    labels = []
    
    # 여기에 실제 데이터 로딩 로직 구현
    # Dredze Dataset과 SpamAssassin Dataset 로딩
    
    return image_paths, texts, labels

## 5. 모델 학습 및 평가

In [None]:
# 데이터 로딩
data_dir = '/content/drive/MyDrive/spam_detection_project/data'
image_paths, texts, labels = load_data(data_dir)

# 모델 초기화 및 학습
model = MultimodalSpamDetector()
model.train(image_paths, texts, labels)

# 평가
predictions = []
for img_path, text in zip(image_paths[:10], texts[:10]):  # 테스트용 샘플
    pred, prob = model.predict(img_path, text)
    predictions.append(pred)

# 성능 지표 계산
accuracy = accuracy_score(labels[:10], predictions)
precision, recall, f1, _ = precision_recall_fscore_support(labels[:10], predictions, average='binary')

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1-Score: {f1:.4f}')

## 6. 모델 저장

In [None]:
import joblib

# SVM 모델 저장
save_path = '/content/drive/MyDrive/spam_detection_project/models/svm_model.joblib'
joblib.dump(model.svm, save_path)