# 멀티모달 스팸 탐지 모델 (ViT + BERT) 🔍

이 노트북은 Vision Transformer (ViT)와 BERT를 결합하여 이미지와 텍스트 기반의 멀티모달 스팸 탐지 모델을 구현합니다.

## 데이터셋
- SpamAssassin Public Corpus (텍스트)
- Kaggle Spam Image Dataset (이미지)

## 1. 환경 설정 및 라이브러리 설치

In [None]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mwoochang4862[0m ([33mwoochang4862-university-of-suwon[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
!pip install torch torchvision transformers scikit-learn pandas numpy pillow tqdm kagglehub email
!nvidia-smi  # GPU 확인

Collecting email
  Using cached email-4.0.2.tar.gz (1.2 MB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
Tue May 27 13:22:31 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTModel
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import pandas as pd
import numpy as np
from PIL import Image
import os
import email
from tqdm import tqdm
import kagglehub
import glob
import random
import wandb
import seaborn as sns
import matplotlib.pyplot as plt

## 2. 데이터셋 다운로드

In [None]:
# SpamAssassin 데이터셋 다운로드
text_dataset_path = kagglehub.dataset_download("beatoa/spamassassin-public-corpus")
print("Text dataset path:", text_dataset_path)

# 이미지 데이터셋 다운로드
image_dataset_path = kagglehub.dataset_download("asifjamal123/spam-image-dataset")
print("Image dataset path:", image_dataset_path)

Text dataset path: /kaggle/input/spamassassin-public-corpus
Image dataset path: /kaggle/input/spam-image-dataset


## 3. 데이터셋 클래스 정의

In [None]:
def parse_email(file_path):
    """이메일 파일을 파싱하여 본문 텍스트를 추출"""
    with open(file_path, 'r', encoding='latin-1') as f:
        msg = email.message_from_file(f)

    body = ""
    if msg.is_multipart():
        for part in msg.walk():
            if part.get_content_type() == "text/plain":
                body += part.get_payload(decode=True).decode('latin-1', errors='ignore')
    else:
        body = msg.get_payload(decode=True).decode('latin-1', errors='ignore')

    return body.strip()

class MultimodalSpamDataset(Dataset):
    def __init__(self, text_dir, image_dir, transform=None, max_length=512):
        self.transform = transform
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # 이미지 데이터 로드
        self.images = []
        self.image_labels = []
        self.valid_extensions = ('.jpg', '.jpeg', '.png', '.gif')

        # 스팸 이미지
        spam_img_dir = os.path.join(image_dir, 'SPAM IMAGE dataset/SpamImages/SpamImages')
        for img_name in os.listdir(spam_img_dir):
            if img_name.lower().endswith(self.valid_extensions):
                img_path = os.path.join(spam_img_dir, img_name)
                try:
                    with Image.open(img_path) as img:
                        img.verify()
                    self.images.append(img_path)
                    self.image_labels.append(1)
                except Exception as e:
                    print(f"Warning: Skipping corrupted image {img_path}: {str(e)}")

        # 정상 이미지
        ham_img_dir = os.path.join(image_dir, 'SPAM IMAGE dataset/NaturalImages/NaturalImages')
        for img_name in os.listdir(ham_img_dir):
            if img_name.lower().endswith(self.valid_extensions):
                img_path = os.path.join(ham_img_dir, img_name)
                try:
                    with Image.open(img_path) as img:
                        img.verify()
                    self.images.append(img_path)
                    self.image_labels.append(0)
                except Exception as e:
                    print(f"Warning: Skipping corrupted image {img_path}: {str(e)}")

        # 텍스트 데이터 로드
        self.texts = []
        self.text_labels = []

        # 스팸 텍스트
        spam_text_dir = os.path.join(text_dir, 'spam_2/spam_2')
        for text_file in os.listdir(spam_text_dir):
            text_path = os.path.join(spam_text_dir, text_file)
            try:
                text = parse_email(text_path)
                self.texts.append(text)
                self.text_labels.append(1)
            except Exception as e:
                print(f"Warning: Error parsing email {text_path}: {str(e)}")

        # 정상 텍스트 (easy_ham)
        ham_text_dir = os.path.join(text_dir, 'easy_ham/easy_ham')
        for text_file in os.listdir(ham_text_dir):
            text_path = os.path.join(ham_text_dir, text_file)
            try:
                text = parse_email(text_path)
                self.texts.append(text)
                self.text_labels.append(0)
            except Exception as e:
                print(f"Warning: Error parsing email {text_path}: {str(e)}")

        # 정상 텍스트 (hard_ham)
        ham_text_dir = os.path.join(text_dir, 'hard_ham/hard_ham')
        for text_file in os.listdir(ham_text_dir):
            text_path = os.path.join(ham_text_dir, text_file)
            try:
                text = parse_email(text_path)
                self.texts.append(text)
                self.text_labels.append(0)
            except Exception as e:
                print(f"Warning: Error parsing email {text_path}: {str(e)}")

        # 랜덤 조합 생성
        self.combined_data = []

        # 각 이미지에 대해 랜덤한 텍스트 할당
        for i in range(len(self.images)):
            # 랜덤한 텍스트 선택
            text_idx = random.randint(0, len(self.texts) - 1)

            # 이미지나 텍스트 중 하나라도 스팸이면 스팸으로 레이블링
            is_spam = 1 if (self.image_labels[i] == 1 or self.text_labels[text_idx] == 1) else 0

            self.combined_data.append({
                'image_path': self.images[i],
                'text': self.texts[text_idx],
                'label': is_spam,
                'image_label': self.image_labels[i],
                'text_label': self.text_labels[text_idx]
            })

        # 데이터셋 통계 출력
        spam_count = sum(1 for item in self.combined_data if item['label'] == 1)
        print(f"\nDataset Statistics:")
        print(f"Total samples: {len(self.combined_data)}")
        print(f"Spam samples: {spam_count} ({spam_count/len(self.combined_data)*100:.2f}%)")
        print(f"Ham samples: {len(self.combined_data) - spam_count} ({(len(self.combined_data)-spam_count)/len(self.combined_data)*100:.2f}%)")

        # 조합 유형별 통계
        spam_spam = sum(1 for item in self.combined_data if item['image_label'] == 1 and item['text_label'] == 1)
        spam_ham = sum(1 for item in self.combined_data if item['image_label'] == 1 and item['text_label'] == 0)
        ham_spam = sum(1 for item in self.combined_data if item['image_label'] == 0 and item['text_label'] == 1)
        ham_ham = sum(1 for item in self.combined_data if item['image_label'] == 0 and item['text_label'] == 0)

        print("\nCombination Statistics:")
        print(f"Spam image + Spam text: {spam_spam}")
        print(f"Spam image + Ham text: {spam_ham}")
        print(f"Ham image + Spam text: {ham_spam}")
        print(f"Ham image + Ham text: {ham_ham}")

    def __len__(self):
        return len(self.combined_data)

    def __getitem__(self, idx):
        item = self.combined_data[idx]

        # 이미지 로드
        try:
            image = Image.open(item['image_path']).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {item['image_path']}: {str(e)}")
            image = torch.zeros((3, 224, 224)) if self.transform else Image.new('RGB', (224, 224), 'black')

        # 텍스트 토큰화
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'image': image,
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': item['label']
        }

## 4. 멀티모달 모델 정의

In [None]:
class MultimodalSpamClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        # ViT 모델
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.vit_dropout = nn.Dropout(0.1)

        # BERT 모델
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.bert_dropout = nn.Dropout(0.1)

        # 특징 결합 및 분류
        hidden_size = self.vit.config.hidden_size + self.bert.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, image, input_ids, attention_mask):
        # 이미지 특징 추출
        vit_outputs = self.vit(image)
        image_features = self.vit_dropout(vit_outputs.last_hidden_state[:, 0])

        # 텍스트 특징 추출
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = self.bert_dropout(bert_outputs.last_hidden_state[:, 0])

        # 특징 결합
        combined_features = torch.cat([image_features, text_features], dim=1)

        # 분류
        logits = self.classifier(combined_features)
        return logits

## 5. 학습 및 평가 함수 정의

In [None]:
# train_epoch 함수 수정
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    for batch_idx, batch in enumerate(tqdm(dataloader)):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # 배치별 로깅
        if batch_idx % 10 == 0:  # 10배치마다 로깅
            wandb.log({
                "batch_loss": loss.item(),
                "batch": batch_idx + epoch * len(dataloader)
            })

    # 에폭별 메트릭 계산
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # wandb에 혼동 행렬 로깅
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - Epoch {epoch+1}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    plt.close()

    return total_loss / len(dataloader), accuracy, precision, recall, f1

# validate 함수 수정
def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 검증 메트릭 계산
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # wandb에 혼동 행렬 로깅
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Validation Confusion Matrix - Epoch {epoch+1}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    wandb.log({"val_confusion_matrix": wandb.Image(plt)})
    plt.close()

    return total_loss / len(dataloader), accuracy, precision, recall, f1

## 6. 모델 학습

In [None]:
# 하이퍼파라미터 설정
config = {
    "architecture": "ViT-BERT-Multimodal",
    "dataset": "SpamAssassin + Spam Image Dataset",
    "batch_size": 8,
    "epochs": 10,
    "learning_rate": 1e-5,
    "optimizer": "AdamW",
    "scheduler": "linear",
    "weight_decay": 0.01,
    "dropout": 0.1,
    "image_size": 224,
    "max_text_length": 512
}

# wandb 초기화
wandb.init(
    project="multimodal-spam-detection",
    config=config,
    name=f"experiment_{wandb.util.generate_id()}"
)

BATCH_SIZE = config["batch_size"]
EPOCHS = config["epochs"]
LEARNING_RATE = config["learning_rate"]

# 데이터 변환
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터셋 및 데이터로더 생성
dataset = MultimodalSpamDataset(text_dataset_path, image_dataset_path, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# 모델 초기화
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultimodalSpamClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=config["weight_decay"])
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=EPOCHS)

# wandb에 모델 구조 로깅
wandb.watch(model, criterion, log="all", log_freq=10)

best_val_f1 = 0
for epoch in range(EPOCHS):
    print(f'\nEpoch {epoch+1}/{EPOCHS}')

    # 학습
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )
    print(f'Train Loss: {train_loss:.4f}')
    print(f'Train Metrics - Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Rec: {train_rec:.4f}, F1: {train_f1:.4f}')

    # 검증
    val_loss, val_acc, val_prec, val_rec, val_f1 = validate(
        model, val_loader, criterion, device, epoch
    )
    print(f'Val Loss: {val_loss:.4f}')
    print(f'Val Metrics - Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Rec: {val_rec:.4f}, F1: {val_f1:.4f}')

    # 학습률 조정
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]

    # wandb에 메트릭 로깅
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "train_precision": train_prec,
        "train_recall": train_rec,
        "train_f1": train_f1,
        "val_loss": val_loss,
        "val_accuracy": val_acc,
        "val_precision": val_prec,
        "val_recall": val_rec,
        "val_f1": val_f1,
        "learning_rate": current_lr
    })

    # 최고 성능 모델 저장
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_multimodal_spam_classifier.pth')
        wandb.save('best_multimodal_spam_classifier.pth')
        print('Model saved!')

# wandb 실험 종료
wandb.finish()


Dataset Statistics:
Total samples: 1739
Spam samples: 1185 (68.14%)
Ham samples: 554 (31.86%)

Combination Statistics:
Spam image + Spam text: 282
Spam image + Ham text: 647
Ham image + Spam text: 256
Ham image + Ham text: 554


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/10


100%|██████████| 174/174 [04:46<00:00,  1.65s/it]


Train Loss: 0.2812
Train Metrics - Acc: 0.8871, Prec: 0.8928, Rec: 0.9483, F1: 0.9197


100%|██████████| 44/44 [00:37<00:00,  1.17it/s]


Val Loss: 0.1109
Val Metrics - Acc: 0.9741, Prec: 0.9790, Rec: 0.9831, F1: 0.9811
Model saved!

Epoch 2/10


100%|██████████| 174/174 [04:52<00:00,  1.68s/it]


Train Loss: 0.0941
Train Metrics - Acc: 0.9763, Prec: 0.9831, Rec: 0.9821, F1: 0.9826


100%|██████████| 44/44 [00:36<00:00,  1.21it/s]


Val Loss: 0.0885
Val Metrics - Acc: 0.9799, Prec: 0.9873, Rec: 0.9831, F1: 0.9852
Model saved!

Epoch 3/10


100%|██████████| 174/174 [04:52<00:00,  1.68s/it]


Train Loss: 0.0422
Train Metrics - Acc: 0.9899, Prec: 0.9895, Rec: 0.9958, F1: 0.9926


100%|██████████| 44/44 [00:37<00:00,  1.19it/s]


Val Loss: 0.0752
Val Metrics - Acc: 0.9741, Prec: 0.9790, Rec: 0.9831, F1: 0.9811

Epoch 4/10


100%|██████████| 174/174 [04:46<00:00,  1.65s/it]


Train Loss: 0.0219
Train Metrics - Acc: 0.9942, Prec: 0.9927, Rec: 0.9989, F1: 0.9958


100%|██████████| 44/44 [00:40<00:00,  1.08it/s]


Val Loss: 0.0596
Val Metrics - Acc: 0.9799, Prec: 0.9792, Rec: 0.9916, F1: 0.9853
Model saved!

Epoch 5/10


100%|██████████| 174/174 [04:48<00:00,  1.66s/it]


Train Loss: 0.0118
Train Metrics - Acc: 0.9993, Prec: 1.0000, Rec: 0.9989, F1: 0.9995


100%|██████████| 44/44 [00:41<00:00,  1.05it/s]


Val Loss: 0.0578
Val Metrics - Acc: 0.9770, Prec: 0.9791, Rec: 0.9873, F1: 0.9832

Epoch 6/10


100%|██████████| 174/174 [04:45<00:00,  1.64s/it]


Train Loss: 0.0072
Train Metrics - Acc: 1.0000, Prec: 1.0000, Rec: 1.0000, F1: 1.0000


100%|██████████| 44/44 [00:36<00:00,  1.20it/s]


Val Loss: 0.0590
Val Metrics - Acc: 0.9828, Prec: 0.9873, Rec: 0.9873, F1: 0.9873
Model saved!

Epoch 7/10


100%|██████████| 174/174 [04:51<00:00,  1.67s/it]


Train Loss: 0.0058
Train Metrics - Acc: 1.0000, Prec: 1.0000, Rec: 1.0000, F1: 1.0000


100%|██████████| 44/44 [00:36<00:00,  1.22it/s]


Val Loss: 0.0637
Val Metrics - Acc: 0.9770, Prec: 0.9791, Rec: 0.9873, F1: 0.9832

Epoch 8/10


100%|██████████| 174/174 [04:49<00:00,  1.66s/it]


Train Loss: 0.0044
Train Metrics - Acc: 1.0000, Prec: 1.0000, Rec: 1.0000, F1: 1.0000


100%|██████████| 44/44 [00:36<00:00,  1.20it/s]


Val Loss: 0.0599
Val Metrics - Acc: 0.9856, Prec: 0.9874, Rec: 0.9916, F1: 0.9895
Model saved!

Epoch 9/10


100%|██████████| 174/174 [04:45<00:00,  1.64s/it]


Train Loss: 0.0040
Train Metrics - Acc: 1.0000, Prec: 1.0000, Rec: 1.0000, F1: 1.0000


100%|██████████| 44/44 [00:40<00:00,  1.08it/s]


Val Loss: 0.0618
Val Metrics - Acc: 0.9828, Prec: 0.9833, Rec: 0.9916, F1: 0.9874

Epoch 10/10


100%|██████████| 174/174 [04:44<00:00,  1.63s/it]


Train Loss: 0.0033
Train Metrics - Acc: 1.0000, Prec: 1.0000, Rec: 1.0000, F1: 1.0000


100%|██████████| 44/44 [00:40<00:00,  1.09it/s]

Val Loss: 0.0622
Val Metrics - Acc: 0.9828, Prec: 0.9833, Rec: 0.9916, F1: 0.9874





0,1
batch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇████
batch_loss,█▄▂▂▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,█▇▆▆▅▄▃▃▂▁
train_accuracy,▁▇▇███████
train_f1,▁▆▇███████
train_loss,█▃▂▁▁▁▁▁▁▁
train_precision,▁▇▇███████
train_recall,▁▆▇███████
val_accuracy,▁▅▁▅▃▆▃█▆▆

0,1
batch,1736.0
batch_loss,0.00429
epoch,10.0
learning_rate,0.0
train_accuracy,1.0
train_f1,1.0
train_loss,0.00335
train_precision,1.0
train_recall,1.0
val_accuracy,0.98276


## 7. 모델 평가 및 분석

In [None]:
# 최고 성능 모델 로드
best_model = MultimodalSpamClassifier().to(device)
best_model.load_state_dict(torch.load('best_multimodal_spam_classifier.pth'))

# 전체 검증 세트에 대한 상세 평가
wandb.init(project="multimodal-spam-detection", name="final_evaluation")
val_loss, val_acc, val_prec, val_rec, val_f1 = validate(best_model, val_loader, criterion, device, EPOCHS)

print('\nFinal Evaluation Results:')
print(f'Accuracy: {val_acc:.4f}')
print(f'Precision: {val_prec:.4f}')
print(f'Recall: {val_rec:.4f}')
print(f'F1-Score: {val_f1:.4f}')

# 최종 결과를 wandb에 로깅
wandb.log({
    "final_accuracy": val_acc,
    "final_precision": val_prec,
    "final_recall": val_rec,
    "final_f1": val_f1
})
wandb.finish()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


100%|██████████| 44/44 [00:19<00:00,  2.31it/s]



Final Evaluation Results:
Accuracy: 0.9856
Precision: 0.9874
Recall: 0.9916
F1-Score: 0.9895


0,1
final_accuracy,▁
final_f1,▁
final_precision,▁
final_recall,▁

0,1
final_accuracy,0.98563
final_f1,0.98947
final_precision,0.98739
final_recall,0.99156


## 8. 새로운 데이터에 대한 예측

In [None]:
def predict_sample(model, image_path, text, device, transform):
    model.eval()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # 이미지 전처리
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    # 텍스트 전처리
    encoding = tokenizer(
        text,
        truncation=True,
        max_length=512,
        padding='max_length',
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(image, input_ids, attention_mask)
        probs = torch.softmax(outputs, dim=1)
        prediction = torch.argmax(outputs, dim=1).item()

    return {
        'prediction': '스팸' if prediction == 1 else '정상',
        'spam_prob': probs[0][1].item(),
        'normal_prob': probs[0][0].item()
    }

In [None]:
# 예시 데이터로 테스트
test_image_path = "/content/000a01c72c26$bec5c540$00000000@zv65vp7hlihg4p_storeblog_download000.jpg"  # 테스트할 이미지 경로
test_text = """
Subject: 🎄 CHRISTMAS SPECIAL OFFER - MEN'S POWER PACK! 🎁

HELLO!

CHRISTMAS MEN'S POWER CHARGE!
SPECIAL HOLIDAY COMBO PACK - CIALIS + VIAGRA

INCREDIBLE HOLIDAY SAVINGS:
✨ 10 + 10 = $129.95 (BEST VALUE!)
✨ 20 + 20 = $249.95
✨ 30 + 30 = $319.95

🎅 HOLIDAY BONUS: 20% EXTRA DISCOUNT! 🎄

Available Products:
- Cialis Soft
- Viagra Professional
- Viagra Soft
- Generic Viagra
- Valium
- Xanax
- Soma
- Ambien
And many more!

💊 100% GENUINE PRODUCTS
✈️ WORLDWIDE SHIPPING
🔒 SECURE PAYMENT

DON'T MISS OUT ON THIS LIMITED TIME OFFER!

CLICK HERE TO ORDER NOW!
[suspicious_link_removed]

To unsubscribe, reply with "STOP"
"""  # 테스트할 이메일 텍스트

if os.path.exists(test_image_path):
    result = predict_sample(best_model, test_image_path, test_text, device, transform)

    print("\n예측 결과:")
    print(f"판정: {result['prediction']}")
    print(f"스팸 확률: {result['spam_prob']:.4f}")
    print(f"정상 확률: {result['normal_prob']:.4f}")
else:
    print("테스트 이미지 파일을 찾을 수 없습니다.")


예측 결과:
판정: 스팸
스팸 확률: 1.0000
정상 확률: 0.0000


In [None]:
# 예시 데이터로 테스트
test_image_path = "/content/IMG_2757.jpg"  # 테스트할 이미지 경로
test_text = """
Subject: 🌟 Join Late Night Study Group - SWU AI Security Department!

Hello from Suwon University's AI Security Department Late Night Study Group!

We are a dedicated study group that meets every Tuesday night for intensive learning sessions.
(Photo attached: Last week's study session - Our passionate team at 2:06 AM)

📚 Study Focus Areas:
- AI Security Project Labs
- Coding Test Preparation
- Team Project Collaboration
- Career Development Insights

✨ Requirements:
- Current student in SWU AI Security Department
- Passionate and committed mindset
- Available for late-night sessions (2-4 AM)
- Strong team player attitude

💝 Benefits:
- Network with department peers
- Hands-on project experience
- Career opportunity sharing
- Snacks and beverages provided

Interested? Join our open KakaoTalk chat!
[Link]

* This semester's recruitment ends this week!
* Limited to first 3 applicants - Don't miss out!

Best regards,
SWU AI Security Late Night Study Group
"""  # 테스트할 이메일 텍스트

if os.path.exists(test_image_path):
    result = predict_sample(best_model, test_image_path, test_text, device, transform)

    print("\n예측 결과:")
    print(f"판정: {result['prediction']}")
    print(f"스팸 확률: {result['spam_prob']:.4f}")
    print(f"정상 확률: {result['normal_prob']:.4f}")
else:
    print("테스트 이미지 파일을 찾을 수 없습니다.")


예측 결과:
판정: 스팸
스팸 확률: 0.9836
정상 확률: 0.0164


In [None]:
# 예시 데이터로 테스트
test_image_path = "/content/IMG_2757.jpg"  # 테스트할 이미지 경로
test_text = """
Subject: 🎄 CHRISTMAS SPECIAL OFFER - MEN'S POWER PACK! 🎁

HELLO!

CHRISTMAS MEN'S POWER CHARGE!
SPECIAL HOLIDAY COMBO PACK - CIALIS + VIAGRA

INCREDIBLE HOLIDAY SAVINGS:
✨ 10 + 10 = $129.95 (BEST VALUE!)
✨ 20 + 20 = $249.95
✨ 30 + 30 = $319.95

🎅 HOLIDAY BONUS: 20% EXTRA DISCOUNT! 🎄

Available Products:
- Cialis Soft
- Viagra Professional
- Viagra Soft
- Generic Viagra
- Valium
- Xanax
- Soma
- Ambien
And many more!

💊 100% GENUINE PRODUCTS
✈️ WORLDWIDE SHIPPING
🔒 SECURE PAYMENT

DON'T MISS OUT ON THIS LIMITED TIME OFFER!

CLICK HERE TO ORDER NOW!
[suspicious_link_removed]

To unsubscribe, reply with "STOP"
"""  # 테스트할 이메일 텍스트

if os.path.exists(test_image_path):
    result = predict_sample(best_model, test_image_path, test_text, device, transform)

    print("\n예측 결과:")
    print(f"판정: {result['prediction']}")
    print(f"스팸 확률: {result['spam_prob']:.4f}")
    print(f"정상 확률: {result['normal_prob']:.4f}")
else:
    print("테스트 이미지 파일을 찾을 수 없습니다.")


예측 결과:
판정: 스팸
스팸 확률: 0.9922
정상 확률: 0.0078


In [None]:
def predict_text_only(model, text, device):
    model.eval()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # 텍스트 전처리
    encoding = tokenizer(
        text,
        truncation=True,
        max_length=512,
        padding='max_length',
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # 더미 이미지 생성 (검은색 이미지)
    dummy_image = torch.zeros((1, 3, 224, 224)).to(device)

    with torch.no_grad():
        outputs = model(dummy_image, input_ids, attention_mask)
        probs = torch.softmax(outputs, dim=1)
        prediction = torch.argmax(outputs, dim=1).item()

    return {
        'prediction': '스팸' if prediction == 1 else '정상',
        'spam_prob': probs[0][1].item(),
        'normal_prob': probs[0][0].item()
    }

# 텍스트만으로 테스트
test_text = """
From ilug-admin@linux.ie  Tue Aug  6 11:51:02 2002
Return-Path: <ilug-admin@linux.ie>
Delivered-To: yyyy@localhost.netnoteinc.com
Received: from localhost (localhost [127.0.0.1])
	by phobos.labs.netnoteinc.com (Postfix) with ESMTP id 9E1F5441DD
	for <jm@localhost>; Tue,  6 Aug 2002 06:48:09 -0400 (EDT)
Received: from phobos [127.0.0.1]
	by localhost with IMAP (fetchmail-5.9.0)
	for jm@localhost (single-drop); Tue, 06 Aug 2002 11:48:09 +0100 (IST)
Received: from lugh.tuatha.org (root@lugh.tuatha.org [194.125.145.45]) by
    dogma.slashnull.org (8.11.6/8.11.6) with ESMTP id g72LqWv13294 for
    <jm-ilug@jmason.org>; Fri, 2 Aug 2002 22:52:32 +0100
Received: from lugh (root@localhost [127.0.0.1]) by lugh.tuatha.org
    (8.9.3/8.9.3) with ESMTP id WAA31224; Fri, 2 Aug 2002 22:50:17 +0100
Received: from bettyjagessar.com (w142.z064000057.nyc-ny.dsl.cnc.net
    [64.0.57.142]) by lugh.tuatha.org (8.9.3/8.9.3) with ESMTP id WAA31201 for
    <ilug@linux.ie>; Fri, 2 Aug 2002 22:50:11 +0100
X-Authentication-Warning: lugh.tuatha.org: Host w142.z064000057.nyc-ny.dsl.cnc.net
    [64.0.57.142] claimed to be bettyjagessar.com
Received: from 64.0.57.142 [202.63.165.34] by bettyjagessar.com
    (SMTPD32-7.06 EVAL) id A42A7FC01F2; Fri, 02 Aug 2002 02:18:18 -0400
Message-Id: <1028311679.886@0.57.142>
Date: Fri, 02 Aug 2002 23:37:59 0530
To: ilug@linux.ie
From: "Start Now" <startnow2002@hotmail.com>
MIME-Version: 1.0
Content-Type: text/plain; charset="US-ASCII"; format=flowed
Subject: [ILUG] STOP THE MLM INSANITY
Sender: ilug-admin@linux.ie
Errors-To: ilug-admin@linux.ie
X-Mailman-Version: 1.1
Precedence: bulk
List-Id: Irish Linux Users' Group <ilug.linux.ie>
X-Beenthere: ilug@linux.ie

Greetings!

You are receiving this letter because you have expressed an interest in
receiving information about online business opportunities. If this is
erroneous then please accept my most sincere apology. This is a one-time
mailing, so no removal is necessary.

If you've been burned, betrayed, and back-stabbed by multi-level marketing,
MLM, then please read this letter. It could be the most important one that
has ever landed in your Inbox.

MULTI-LEVEL MARKETING IS A HUGE MISTAKE FOR MOST PEOPLE

MLM has failed to deliver on its promises for the past 50 years. The pursuit
of the "MLM Dream" has cost hundreds of thousands of people their friends,
their fortunes and their sacred honor. The fact is that MLM is fatally
flawed, meaning that it CANNOT work for most people.

The companies and the few who earn the big money in MLM are NOT going to
tell you the real story. FINALLY, there is someone who has the courage to
cut through the hype and lies and tell the TRUTH about MLM.

HERE'S GOOD NEWS

There IS an alternative to MLM that WORKS, and works BIG! If you haven't yet
abandoned your dreams, then you need to see this. Earning the kind of income
you've dreamed about is easier than you think!

With your permission, I'd like to send you a brief letter that will tell you
WHY MLM doesn't work for most people and will then introduce you to
something so new and refreshing that you'll wonder why you haven't heard of
this before.

I promise that there will be NO unwanted follow up, NO sales pitch, no one
will call you, and your email address will only be used to send you the
information. Period.

To receive this free, life-changing information, simply click Reply, type
"Send Info" in the Subject box and hit Send. I'll get the information to you
within 24 hours. Just look for the words MLM WALL OF SHAME in your Inbox.

Cordially,

Siddhi

P.S. Someone recently sent the letter to me and it has been the most
eye-opening, financially beneficial information I have ever received. I
honestly believe that you will feel the same way once you've read it. And
it's FREE!


------------------------------------------------------------
This email is NEVER sent unsolicited.  THIS IS NOT "SPAM". You are receiving
this email because you EXPLICITLY signed yourself up to our list with our
online signup form or through use of our FFA Links Page and E-MailDOM
systems, which have EXPLICIT terms of use which state that through its use
you agree to receive our emailings.  You may also be a member of a Altra
Computer Systems list or one of many numerous FREE Marketing Services and as
such you agreed when you signed up for such list that you would also be
receiving this emailing.
Due to the above, this email message cannot be considered unsolicitated, or
spam.
-----------------------------------------------------------




--
Irish Linux Users' Group: ilug@linux.ie
http://www.linux.ie/mailman/listinfo/ilug for (un)subscription information.
List maintainer: listmaster@linux.ie
"""

result = predict_text_only(best_model, test_text, device)

print("\n텍스트 기반 예측 결과:")
print(f"판정: {result['prediction']}")
print(f"스팸 확률: {result['spam_prob']:.4f}")
print(f"정상 확률: {result['normal_prob']:.4f}")


텍스트 기반 예측 결과:
판정: 정상
스팸 확률: 0.0263
정상 확률: 0.9737
