In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT Spam Detection Demo\n",
    "\n",
    "This notebook demonstrates how to load the fine-tuned BERT model from our repository for:\n",
    "\n",
    "> **“Harnessing BERT for Advanced Email Filtering in Cybersecurity”**  \n",
    "> IEEE Xplore: https://ieeexplore.ieee.org/abstract/document/11058531\n",
    "\n",
    "and run predictions on custom messages (SMS/email-like text)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup and Imports\n",
    "\n",
    "Make sure you have run BERT fine-tuning first (e.g., `python -m scripts.run_bert`),\n",
    "which will save the model under `experiments/bert/`.\n",
    "\n",
    "If you haven't trained yet, you can still run this notebook by loading a base model\n",
    "such as `bert-base-uncased`, but the predictions will not match our reported results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "DEVICE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load the Fine-tuned BERT Model\n",
    "\n",
    "By default, we first try to load the fine-tuned model from `experiments/bert/`.\n",
    "If that directory is not found (e.g., you haven't trained yet), we fall back\n",
    "to the base `bert-base-uncased` model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "MODEL_DIR = \"experiments/bert\"  # where train_bert.py saves the HF model\n",
    "BASE_MODEL_NAME = \"bert-base-uncased\"\n",
    "\n",
    "if os.path.isdir(MODEL_DIR) and any(f.endswith(\".bin\") or f.endswith(\".safetensors\") for f in os.listdir(MODEL_DIR)):\n",
    "    print(f\"Loading fine-tuned model from: {MODEL_DIR}\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)\n",
    "else:\n",
    "    print(f\"Fine-tuned model not found at '{MODEL_DIR}'. Loading base model: {BASE_MODEL_NAME}\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL_NAME, num_labels=2)\n",
    "\n",
    "model.to(DEVICE)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Helper Function for Inference\n",
    "\n",
    "We define a small helper that:\n",
    "\n",
    "1. Tokenizes input texts.\n",
    "2. Runs them through BERT.\n",
    "3. Returns labels (`\"ham\"` / `\"spam\"`) and confidence scores.\n",
    "\n",
    "We assume label index mapping:\n",
    "\n",
    "- 0 → `ham`\n",
    "- 1 → `spam`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "ID2LABEL = {0: \"ham\", 1: \"spam\"}\n",
    "\n",
    "def predict_messages(texts: List[str]):\n",
    "    \"\"\"Run spam/ham prediction on a list of messages.\n",
    "\n",
    "    Returns a list of dicts: {\"text\", \"pred_label\", \"pred_index\", \"score\"}.\n",
    "    \"\"\"\n",
    "    if isinstance(texts, str):\n",
    "        texts = [texts]\n",
    "\n",
    "    encodings = tokenizer(\n",
    "        texts,\n",
    "        padding=True,\n",
    "        truncation=True,\n",
    "        max_length=128,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "\n",
    "    encodings = {k: v.to(DEVICE) for k, v in encodings.items()}\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**encodings)\n",
    "        logits = outputs.logits\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        scores, preds = torch.max(probs, dim=-1)\n",
    "\n",
    "    results = []\n",
    "    for text, idx, score in zip(texts, preds.cpu().tolist(), scores.cpu().tolist()):\n",
    "        label = ID2LABEL.get(idx, str(idx))\n",
    "        results.append(\n",
    "            {\n",
    "                \"text\": text,\n",
    "                \"pred_label\": label,\n",
    "                \"pred_index\": idx,\n",
    "                \"score\": float(score),\n",
    "            }\n",
    "        )\n",
    "    return results\n",
    "\n",
    "def pretty_print_predictions(results):\n",
    "    for r in results:\n",
    "        print(\"------------------------------\")\n",
    "        print(f\"Text: {r['text']}\")\n",
    "        print(f\"Prediction: {r['pred_label']} (score={r['score']:.4f})\")\n",
    "\n",
    "print(\"Helper functions defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Try Some Example Messages\n",
    "\n",
    "Below we test the model with a small batch of messages, mixing benign and spammy content.\n",
    "\n",
    "If you have fine-tuned the model as in our experiments, the predictions should be aligned\n",
    "with our reported performance. With a base, non-fine-tuned model, the predictions will\n",
    "be mostly random."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sample_texts = [\n",
    "    \"Hey, are we still meeting for lunch tomorrow?\",\n",
    "    \"Congratulations! You have won a $500 gift card. Click here to claim now!\",\n",
    "    \"Reminder: Your verification code is 392018. Do not share this code with anyone.\",\n",
    "    \"URGENT!! Your bank account has been suspended. Visit http://fake-bank-login.com to reactivate.\",\n",
    "    \"Can you send me the project report by tonight?\",\n",
    "]\n",
    "\n",
    "results = predict_messages(sample_texts)\n",
    "pretty_print_predictions(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Using the Model in Your Own Code\n",
    "\n",
    "To reuse the fine-tuned model outside this notebook, you can follow the same pattern\n",
    "in any Python script:\n",
    "\n",
    "```python\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"experiments/bert\")\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"experiments/bert\")\n",
    "model.eval()\n",
    "\n",
    "def predict_one(text: str):\n",
    "    enc = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True)\n",
    "    with torch.no_grad():\n",
    "        logits = model(**enc).logits\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        score, pred = torch.max(probs, dim=-1)\n",
    "    return int(pred.item()), float(score.item())\n",
    "```\n",
    "\n",
    "You can embed this in a web service, API, or batch-scoring pipeline as needed."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
