### Test the model using 5 Ham messages and 5 smish messages

In [12]:
# ================================
# Test custom messages on SAVED DistilBERT model
# Labels: ham=0, smish=1
# ================================

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ---- Load saved model + tokenizer ----
MODEL_DIR = "smish_detection_model/final_model"
MAX_LEN = 95 

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# ---- Test messages ----
samples = [
    ("The total time for each presentation should not exceed 15 minutes.", "ham"),
    ("Your university portal has a new announcement. Log in at https://portal.squ.edu.om", "ham"),
    ("Welcome to Summer Training in Computer Science, Summer 25 https://www.squ.edu.om/Student/E-learning", "ham"),
    ("our SQU E-learning System account was just signed in to from a new device. account: s129299@student.squ.edu.om", "ham"),
    ("HW2 grading table is available on Moodle (under week 12).", "ham"),

    ("Aramex: Your package has arrived at the warehouse but cannot be delivered due to an incomplete address. Please update your shipping details to facilitate delivery. Redelivery costs will be borne by you. Please click on the link below to make change: https://aramexexa.com/ii", "smish"),
    ("Your delivery has been stopped at our DHL depot. Trk#: R690382803147 Please resolve the issue here: https://warning-pages.entermypassword.com/join.html?id=", "smish"),
    ("Your tax refund of 349.14 is ready to be claimed. To access your refund, follow the steps required. https://warning-page.authsecurelogin.com/join.html?id=", "smish"),
    ("ALERT: Your bank account has been locked. Call now to avoid closure.", "smish"),
    ("We tried to deliver your package but failed. Reply with your full name and ID number to reschedule.", "smish"),
]

label_map = {"ham": 0, "smish": 1}
id2label = {0: "ham", 1: "smish"}

# ---- Prediction function ----
def predict_one(text):
    enc = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=MAX_LEN,
        return_tensors="pt"
    )
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        outputs = model(**enc)
        logits = outputs.logits.cpu().numpy()[0]
        probs = torch.softmax(torch.tensor(logits), dim=0).numpy()

    pred_id = int(np.argmax(logits))
    return pred_id, probs

# ---- Run predictions ----
correct = 0
print("=== DistilBERT Saved Model: Manual Message Test ===\n")

for text, true_label in samples:
    true_id = label_map[true_label]
    pred_id, probs = predict_one(text)

    pred_label = id2label[pred_id]
    correct += int(pred_id == true_id)

    print("Message:")
    print(text)
    print("True Label:", true_label)
    print("Predicted Label:", pred_label)
    print(f"[ham, smish] probs: [{probs[0]:.3f}, {probs[1]:.3f}]")
    print("-" * 43)

accuracy = correct / len(samples)
print(f"\nOverall accuracy on these {len(samples)} messages: {accuracy:.3f}")


=== DistilBERT Saved Model: Manual Message Test ===

Message:
The total time for each presentation should not exceed 15 minutes.
True Label: ham
Predicted Label: ham
[ham, smish] probs: [0.658, 0.342]
-------------------------------------------
Message:
Your university portal has a new announcement. Log in at https://portal.squ.edu.om
True Label: ham
Predicted Label: ham
[ham, smish] probs: [0.994, 0.006]
-------------------------------------------
Message:
Welcome to Summer Training in Computer Science, Summer 25 https://www.squ.edu.om/Student/E-learning
True Label: ham
Predicted Label: ham
[ham, smish] probs: [0.981, 0.019]
-------------------------------------------
Message:
our SQU E-learning System account was just signed in to from a new device. account: s129299@student.squ.edu.om
True Label: ham
Predicted Label: smish
[ham, smish] probs: [0.016, 0.984]
-------------------------------------------
Message:
HW2 grading table is available on Moodle (under week 12).
True Label: ham
P