# Spam Email Classification: Data Visualization & Live Inference

本 notebook 提供 spam email 資料視覺化、模型訓練、效能評估（含 threshold sweep）、以及即時推論互動介面。

In [None]:
# 1. Import Required Libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, roc_curve, precision_recall_curve, auc, f1_score, precision_score, recall_score
import ipywidgets as widgets
from IPython.display import display, clear_output
import random
import warnings
warnings.filterwarnings('ignore')

In [None]:
# 2. Load and Preprocess Data
url = 'https://raw.githubusercontent.com/PacktPublishing/Hands-On-Artificial-Intelligence-for-Cybersecurity/refs/heads/master/Chapter03/datasets/sms_spam_no_header.csv'
df = pd.read_csv(url, encoding='latin-1', header=None, names=['label', 'text'])
df['label_num'] = df['label'].map({'ham': 0, 'spam': 1})

# Train/test split
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label_num'])

# Vectorize text
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(df_train['text'])
X_test = vectorizer.transform(df_test['text'])
y_train = df_train['label_num']
y_test = df_test['label_num']

## 3. Visualize Data Overview
- 類別分布
- 訊息長度分布

In [None]:
# 類別分布
plt.figure(figsize=(5,3))
sns.countplot(data=df, x='label')
plt.title('Class Distribution')
plt.show()

# 訊息長度分布
plt.figure(figsize=(8,4))
df['length'] = df['text'].apply(len)
sns.histplot(data=df, x='length', hue='label', bins=40, kde=True, element='step')
plt.title('Message Length Distribution by Class')
plt.show()

## 4. Visualize Top Tokens by Class (Adjustable N)
- 可互動調整 N，顯示每類別 top N token

In [None]:
# Top Tokens by Class (互動式)
def plot_top_tokens_by_class(N=20):
    from collections import Counter
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    for idx, label in enumerate(['ham', 'spam']):
        texts = df[df['label'] == label]['text']
        tokens = ' '.join(texts).lower().split()
        counter = Counter(tokens)
        most_common = counter.most_common(N)
        tokens_, counts_ = zip(*most_common)
        sns.barplot(x=list(counts_), y=list(tokens_), ax=axes[idx], orient='h')
        axes[idx].set_title(f'Top {N} Tokens: {label}')
    plt.tight_layout()
    plt.show()

N_slider = widgets.IntSlider(value=20, min=5, max=50, step=1, description='Top N:')
widgets.interact(plot_top_tokens_by_class, N=N_slider);

## 5. Train Classification Model
- 使用 SVM 訓練分類器

In [None]:
# 訓練 SVM 分類器
clf = LinearSVC()
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
y_pred_proba = clf.decision_function(X_test)  # 用於 threshold sweep

## 6. Visualize Model Performance (Test)
- Confusion Matrix
- ROC Curve
- Threshold Sweep (precision/recall/f1)

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['ham', 'spam'], yticklabels=['ham', 'spam'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6,4))
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0,1], [0,1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

# Threshold Sweep (precision/recall/f1)
def plot_threshold_sweep():
    thresholds = np.linspace(min(y_pred_proba), max(y_pred_proba), 100)
    precisions, recalls, f1s = [], [], []
    for t in thresholds:
        preds = (y_pred_proba > t).astype(int)
        precisions.append(precision_score(y_test, preds))
        recalls.append(recall_score(y_test, preds))
        f1s.append(f1_score(y_test, preds))
    plt.figure(figsize=(8,5))
    plt.plot(thresholds, precisions, label='Precision')
    plt.plot(thresholds, recalls, label='Recall')
    plt.plot(thresholds, f1s, label='F1-score')
    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.title('Threshold Sweep: Precision/Recall/F1')
    plt.legend()
    plt.show()

plot_threshold_sweep()

## 7. Live Inference: Manual Input & Sample Generation
- 可手動輸入訊息或按鈕產生 spam/ham 範例，並即時顯示預測與機率

In [None]:
# Live Inference Widget
input_box = widgets.Text(
    value='',
    placeholder='請輸入訊息',
    description='訊息:',
    layout=widgets.Layout(width='80%')
)

output_box = widgets.Output()

spam_examples = df[df['label']=='spam']['text'].sample(5, random_state=42).tolist()
ham_examples = df[df['label']=='ham']['text'].sample(5, random_state=42).tolist()

spam_btn = widgets.Button(description='產生 spam 範例')
ham_btn = widgets.Button(description='產生 ham 範例')
predict_btn = widgets.Button(description='預測')

example_idx = {'spam': 0, 'ham': 0}

def on_spam_click(b):
    input_box.value = spam_examples[example_idx['spam'] % len(spam_examples)]
    example_idx['spam'] += 1

def on_ham_click(b):
    input_box.value = ham_examples[example_idx['ham'] % len(ham_examples)]
    example_idx['ham'] += 1

def on_predict_click(b=None):
    with output_box:
        clear_output()
        text = input_box.value
        if not text.strip():
            print('請輸入訊息')
            return
        X_vec = vectorizer.transform([text])
        pred = clf.predict(X_vec)[0]
        proba = clf.decision_function(X_vec)[0]
        proba_sigmoid = 1 / (1 + np.exp(-proba))
        print(f'預測結果: {"spam" if pred==1 else "ham"}')
        print(f'預測機率 (sigmoid): {proba_sigmoid:.3f}')

spam_btn.on_click(on_spam_click)
ham_btn.on_click(on_ham_click)
predict_btn.on_click(on_predict_click)

ui = widgets.HBox([input_box, predict_btn, spam_btn, ham_btn])
display(ui, output_box)