# 텍스트 기반 스팸 탐지 모델 (BERT + SVM) with W&B 모니터링 🔍

이 노트북은 BERT와 SVM을 결합하여 텍스트 기반의 스팸 탐지 모델을 구현하고, Weights & Biases를 통해 실험을 모니터링합니다.

## 실험 설계
1. BERT를 특징 추출기로 사용
2. BERT의 [CLS] 토큰 임베딩을 SVM의 입력으로 사용
3. SVM으로 최종 분류 수행
4. W&B로 실험 결과 추적 및 시각화

## 데이터셋
- SpamAssassin Public Corpus

## 1. 환경 설정 및 라이브러리 설치

In [None]:
!pip install torch transformers scikit-learn pandas numpy tqdm kagglehub wandb email
!nvidia-smi  # GPU 확인

In [None]:
import torch
from transformers import BertTokenizer, BertModel
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_curve, auc
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import os
import email
from tqdm import tqdm
import kagglehub
import wandb
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime

## 2. W&B 설정

In [None]:
# wandb 로그인
!wandb login

# 실험 설정
config = {
    "architecture": "BERT-SVM",
    "dataset": "SpamAssassin",
    "bert_model": "bert-base-uncased",
    "max_length": 512,
    "svm_kernel": "rbf",
    "svm_C": 1.0,
    "batch_size": 32,
    "random_state": 42,
    "test_size": 0.2,
    "experiment_timestamp": datetime.now().strftime("%Y%m%d_%H%M%S")
}

# wandb 실험 초기화
run = wandb.init(
    project="text-spam-detection",
    name=f"BERT-SVM_{config['experiment_timestamp']}",
    config=config
)

## 3. 데이터셋 다운로드 및 전처리

In [None]:
def parse_email(file_path):
    """이메일 파일을 파싱하여 본문 텍스트를 추출"""
    with open(file_path, 'r', encoding='latin-1') as f:
        msg = email.message_from_file(f)

    body = ""
    if msg.is_multipart():
        for part in msg.walk():
            if part.get_content_type() == "text/plain":
                body += part.get_payload(decode=True).decode('latin-1', errors='ignore')
    else:
        body = msg.get_payload(decode=True).decode('latin-1', errors='ignore')

    return body.strip()

def load_dataset(dataset_path):
    """데이터셋 로드 및 전처리"""
    texts = []
    labels = []
    stats = {"spam": 0, "easy_ham": 0, "hard_ham": 0}

    # 스팸 이메일 로드
    spam_dir = os.path.join(dataset_path, 'spam_2/spam_2')
    for file_name in tqdm(os.listdir(spam_dir), desc='Loading spam emails'):
        file_path = os.path.join(spam_dir, file_name)
        try:
            text = parse_email(file_path)
            texts.append(text)
            labels.append(1)
            stats["spam"] += 1
        except Exception as e:
            print(f"Error parsing {file_path}: {str(e)}")

    # 정상 이메일 로드 (easy_ham)
    ham_dir = os.path.join(dataset_path, 'easy_ham/easy_ham')
    for file_name in tqdm(os.listdir(ham_dir), desc='Loading easy ham emails'):
        file_path = os.path.join(ham_dir, file_name)
        try:
            text = parse_email(file_path)
            texts.append(text)
            labels.append(0)
            stats["easy_ham"] += 1
        except Exception as e:
            print(f"Error parsing {file_path}: {str(e)}")

    # 정상 이메일 로드 (hard_ham)
    ham_dir = os.path.join(dataset_path, 'hard_ham/hard_ham')
    for file_name in tqdm(os.listdir(ham_dir), desc='Loading hard ham emails'):
        file_path = os.path.join(ham_dir, file_name)
        try:
            text = parse_email(file_path)
            texts.append(text)
            labels.append(0)
            stats["hard_ham"] += 1
        except Exception as e:
            print(f"Error parsing {file_path}: {str(e)}")

    # 데이터셋 통계 로깅
    wandb.log({
        "dataset_stats/spam_count": stats["spam"],
        "dataset_stats/easy_ham_count": stats["easy_ham"],
        "dataset_stats/hard_ham_count": stats["hard_ham"],
        "dataset_stats/total_samples": len(texts),
        "dataset_stats/spam_ratio": stats["spam"] / len(texts)
    })

    # 데이터셋 분포 시각화
    plt.figure(figsize=(10, 6))
    plt.bar(['Spam', 'Easy Ham', 'Hard Ham'], 
            [stats['spam'], stats['easy_ham'], stats['hard_ham']])
    plt.title('Dataset Distribution')
    plt.ylabel('Number of Samples')
    wandb.log({"dataset_stats/distribution": wandb.Image(plt)})
    plt.close()

    return texts, labels

# SpamAssassin 데이터셋 다운로드
dataset_path = kagglehub.dataset_download("beatoa/spamassassin-public-corpus")
print("Dataset path:", dataset_path)

# 데이터 로드
texts, labels = load_dataset(dataset_path)

## 4. BERT 특징 추출기 정의

In [None]:
class BertFeatureExtractor:
    def __init__(self, model_name='bert-base-uncased', max_length=512):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name).to(self.device)
        self.max_length = max_length
        
        # 모델 정보 로깅
        wandb.log({
            "model_info/device": str(self.device),
            "model_info/bert_model": model_name,
            "model_info/max_length": max_length
        })

    def extract_features(self, texts, batch_size=32):
        self.model.eval()
        all_features = []
        total_tokens = 0
        truncated_count = 0

        for i in tqdm(range(0, len(texts), batch_size), desc='Extracting BERT features'):
            batch_texts = texts[i:i + batch_size]
            
            # 텍스트 토큰화
            inputs = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(self.device)

            # 토큰화 통계 수집
            total_tokens += inputs.input_ids.numel()
            truncated_count += sum(len(self.tokenizer.encode(text)) > self.max_length for text in batch_texts)

            # BERT 특징 추출
            with torch.no_grad():
                outputs = self.model(**inputs)
                features = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                all_features.append(features)

        # 토큰화 통계 로깅
        wandb.log({
            "tokenization/avg_tokens_per_sample": total_tokens / len(texts),
            "tokenization/truncated_samples": truncated_count,
            "tokenization/truncation_ratio": truncated_count / len(texts)
        })

        return np.vstack(all_features)

## 5. 실험 실행

In [None]:
# 데이터셋 분할
X_train, X_test, y_train, y_test = train_test_split(
    texts, labels, 
    test_size=config['test_size'], 
    random_state=config['random_state'], 
    stratify=labels
)

# 분할 정보 로깅
wandb.log({
    "split/train_size": len(X_train),
    "split/test_size": len(X_test),
    "split/train_spam_ratio": sum(y_train) / len(y_train),
    "split/test_spam_ratio": sum(y_test) / len(y_test)
})

# BERT 특징 추출
print("\nBERT 특징 추출 중...")
feature_extractor = BertFeatureExtractor(
    model_name=config['bert_model'],
    max_length=config['max_length']
)

X_train_features = feature_extractor.extract_features(X_train, batch_size=config['batch_size'])
X_test_features = feature_extractor.extract_features(X_test, batch_size=config['batch_size'])

# 특징 추출 결과 로깅
wandb.log({
    "features/train_shape": list(X_train_features.shape),
    "features/test_shape": list(X_test_features.shape),
    "features/mean_value": float(X_train_features.mean()),
    "features/std_value": float(X_train_features.std())
})

# SVM 학습
print("\nSVM 학습 중...")
svm = SVC(kernel=config['svm_kernel'], C=config['svm_C'], probability=True)
svm.fit(X_train_features, y_train)

# 예측 및 평가
y_pred = svm.predict(X_test_features)
y_prob = svm.predict_proba(X_test_features)[:, 1]

# 메트릭 계산
accuracy = accuracy_score(y_test, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='binary')
conf_matrix = confusion_matrix(y_test, y_pred)

# ROC 곡선 계산
fpr, tpr, _ = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)

# 결과 출력
print("\n실험 결과:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"ROC AUC: {roc_auc:.4f}")

# 혼동 행렬 시각화
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
wandb.log({"evaluation/confusion_matrix": wandb.Image(plt)})
plt.close()

# ROC 곡선 시각화
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
wandb.log({"evaluation/roc_curve": wandb.Image(plt)})
plt.close()

# 예측 확률 분포 시각화
plt.figure(figsize=(10, 6))
plt.hist(y_prob[y_test == 0], bins=50, alpha=0.5, label='Normal', density=True)
plt.hist(y_prob[y_test == 1], bins=50, alpha=0.5, label='Spam', density=True)
plt.xlabel('Predicted Spam Probability')
plt.ylabel('Density')
plt.title('Distribution of Predicted Probabilities')
plt.legend()
wandb.log({"evaluation/probability_distribution": wandb.Image(plt)})
plt.close()

# 메트릭 로깅
wandb.log({
    "metrics/accuracy": accuracy,
    "metrics/precision": precision,
    "metrics/recall": recall,
    "metrics/f1": f1,
    "metrics/roc_auc": roc_auc,
    "metrics/true_positives": int(conf_matrix[1][1]),
    "metrics/false_positives": int(conf_matrix[0][1]),
    "metrics/true_negatives": int(conf_matrix[0][0]),
    "metrics/false_negatives": int(conf_matrix[1][0])
})

# 실험 종료
wandb.finish()

## 6. 새로운 이메일에 대한 예측

In [None]:
def predict_email(text, feature_extractor, classifier):
    # BERT 특징 추출
    features = feature_extractor.extract_features([text])
    
    # SVM 예측
    prediction = classifier.predict(features)[0]
    probability = classifier.predict_proba(features)[0]
    
    return {
        'prediction': '스팸' if prediction == 1 else '정상',
        'spam_prob': probability[1],
        'normal_prob': probability[0]
    }

# wandb 실험 초기화
run = wandb.init(
    project="text-spam-detection",
    name=f"BERT-SVM_{config['experiment_timestamp']}",
    config=config
)

# 테스트 이메일
test_email = """
From: john.doe@company.com
Return-Path: <john.doe@company.com>
Delivered-To: recipient@localhost.com
Received: from localhost (localhost [127.0.0.1])
	by server.company.com (Postfix) with ESMTP id ABC123XYZ
	for <recipient@localhost.com>; Tue, 15 Feb 2024 10:30:15 +0900 (KST)
Received: from mail.company.com (mail.company.com [192.168.1.100])
	by server.company.com (Postfix) with ESMTP id DEF456UVW
	for <recipient@localhost.com>; Tue, 15 Feb 2024 10:30:14 +0900 (KST)
Message-Id: <20240215103014.DEF456UVW@mail.company.com>
Date: Tue, 15 Feb 2024 10:30:14 +0900 (KST)
To: team@company.com
From: "John Doe" <john.doe@company.com>
MIME-Version: 1.0
Content-Type: text/plain; charset="UTF-8"
Subject: Project Progress Report

Dear Team Members,

I would like to share this week's project progress update.

1. AI Model Development
- Completed implementation of BERT-based text classification model
- Achieved 95% test accuracy
- Hyperparameter optimization in progress

2. Data Preprocessing
- Data cleaning completed
- Labeling work 80% complete
- Additional data collection planned

3. Next Week's Plan
- Model performance improvement
- Start API development
- Documentation work

Please reply if you have any questions or comments.

Best regards,

John Doe
AI Development Team
Company Inc.
Tel: 02-123-4567
"""

result = predict_email(test_email, feature_extractor, svm)

print("\n예측 결과:")
print(f"판정: {result['prediction']}")
print(f"스팸 확률: {result['spam_prob']:.4f}")
print(f"정상 확률: {result['normal_prob']:.4f}")