# BERT Spam Detection Demo

This notebook demonstrates how to load the fine-tuned BERT model from our repository for:

> **“Harnessing BERT for Advanced Email Filtering in Cybersecurity”**  
> IEEE Xplore: https://ieeexplore.ieee.org/abstract/document/11058531

and run predictions on custom messages (SMS/email-like text).

## 1. Setup and Imports

Make sure you have run BERT fine-tuning first (e.g., `python -m scripts.run_bert`),
which will save the model under `experiments/bert/`.

If you haven't trained yet, you can still run this notebook by loading a base model
such as `bert-base-uncased`, but the predictions will not match our reported results.

In [None]:
import os
from typing import List

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

## 2. Load the Fine-tuned BERT Model

By default, we first try to load the fine-tuned model from `experiments/bert/`.
If that directory is not found (e.g., you haven't trained yet), we fall back
to the base `bert-base-uncased` model.

In [None]:
MODEL_DIR = "experiments/bert"  # where train_bert.py saves the HF model
BASE_MODEL_NAME = "bert-base-uncased"

if os.path.isdir(MODEL_DIR) and any(f.endswith(".bin") or f.endswith(".safetensors") for f in os.listdir(MODEL_DIR)):
    print(f"Loading fine-tuned model from: {MODEL_DIR}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
else:
    print(f"Fine-tuned model not found at '{MODEL_DIR}'. Loading base model: {BASE_MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL_NAME, num_labels=2)

model.to(DEVICE)
model.eval()

## 3. Helper Function for Inference

We define a small helper that:

1. Tokenizes input texts.
2. Runs them through BERT.
3. Returns labels (`"ham"` / `"spam"`) and confidence scores.

We assume label index mapping:

- 0 → `ham`
- 1 → `spam`

In [None]:
import torch.nn.functional as F

ID2LABEL = {0: "ham", 1: "spam"}

def predict_messages(texts: List[str]):
    """Run spam/ham prediction on a list of messages.

    Returns a list of dicts: {"text", "pred_label", "pred_index", "score"}.
    """
    if isinstance(texts, str):
        texts = [texts]

    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )

    encodings = {k: v.to(DEVICE) for k, v in encodings.items()}

    with torch.no_grad():
        outputs = model(**encodings)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1)
        scores, preds = torch.max(probs, dim=-1)

    results = []
    for text, idx, score in zip(texts, preds.cpu().tolist(), scores.cpu().tolist()):
        label = ID2LABEL.get(idx, str(idx))
        results.append(
            {
                "text": text,
                "pred_label": label,
                "pred_index": idx,
                "score": float(score),
            }
        )
    return results

def pretty_print_predictions(results):
    for r in results:
        print("------------------------------")
        print(f"Text: {r['text']}")
        print(f"Prediction: {r['pred_label']} (score={r['score']:.4f})")

print("Helper functions defined.")

## 4. Try Some Example Messages

Below we test the model with a small batch of messages, mixing benign and spammy content.

If you have fine-tuned the model as in our experiments, the predictions should be aligned
with our reported performance. With a base, non-fine-tuned model, the predictions will
be mostly random.

In [None]:
sample_texts = [
    "Hey, are we still meeting for lunch tomorrow?",
    "Congratulations! You have won a $500 gift card. Click here to claim now!",
    "Reminder: Your verification code is 392018. Do not share this code with anyone.",
    "URGENT!! Your bank account has been suspended. Visit http://fake-bank-login.com to reactivate.",
    "Can you send me the project report by tonight?",
]

results = predict_messages(sample_texts)
pretty_print_predictions(results)

## 5. Using the Model in Your Own Code

To reuse the fine-tuned model outside this notebook, you can follow the same pattern
in any Python script:

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("experiments/bert")
model = AutoModelForSequenceClassification.from_pretrained("experiments/bert")
model.eval()

def predict_one(text: str):
    enc = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        logits = model(**enc).logits
        probs = F.softmax(logits, dim=-1)
        score, pred = torch.max(probs, dim=-1)
    return int(pred.item()), float(score.item())
```

You can embed this in a web service, API, or batch-scoring pipeline as needed.