In [None]:
# ==============================================================================
# 1. SETUP AND ENVIRONMENT
# ==============================================================================
# This script trains and evaluates the best-performing model, DistilBERT + Features,
# Email Classification Using Deep and Structural Features".

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
import os
import psutil
from bs4 import BeautifulSoup
from tqdm.auto import tqdm
import pickle

In [None]:
import nltk
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

In [None]:
import tensorflow as tf
from transformers import DistilBertTokenizerFast, TFDistilBertModel
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

In [None]:
# --- Configuration ---
RANDOM_STATE = 42
DATA_PATH = "../data/balanced_dataset.csv"
PRETRAINED_DISTILBERT_PATH = "distilbert-base-uncased" # Using public model for reproducibility
MODEL_SAVE_DIR = "../saved_models/"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

print(f"TensorFlow Version: {tf.__version__}")
from transformers import __version__ as transformers_version
print(f"Transformers Version: {transformers_version}")

In [None]:
# ==============================================================================
# 2. DATA LOADING AND PREPROCESSING
# ==============================================================================
print("\n--- Loading and Preprocessing Data ---")
df = pd.read_csv(DATA_PATH)
df.dropna(subset=['text', 'label'], inplace=True)
df['label'] = df['label'].astype(int)

print(f"Dataset loaded successfully. Total rows: {len(df)}")
print("Label distribution:\n", df['label'].value_counts(normalize=True))

stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

def clean_text_for_transformer(text):
    if not isinstance(text, str):
        return ""
    text = BeautifulSoup(text, "html.parser").get_text()
    text = re.sub(r'[^a-zA-Z\s]', '', text, flags=re.I|re.A)
    text = text.lower()
    tokens = text.split()
    cleaned_tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words and len(word) > 1]
    return " ".join(cleaned_tokens)

tqdm.pandas(desc="Cleaning Text")
print("\nApplying text cleaning for Transformer model...")
df['cleaned_text'] = df['text'].progress_apply(clean_text_for_transformer)
print("Text cleaning complete.")

In [None]:
# ==============================================================================
# 3. FEATURE ENGINEERING
# ==============================================================================
print("\n--- Engineering Features ---")
print("Extracting structural features from raw text...")

df['text_len'] = df['text'].apply(len)
df['word_count'] = df['text'].apply(lambda x: len(x.split()))

def uppercase_ratio(text):
    if not isinstance(text, str) or len(text) == 0: return 0.0
    return sum(1 for char in text if char.isupper()) / len(text)
df['uppercase_ratio'] = df['text'].apply(uppercase_ratio)

def punctuation_count(text):
    if not isinstance(text, str): return 0
    return len(re.findall(r'[!\"#$%&\'()*+,-./:;<=>?@\[\]^_`{|}~]', text))
df['punctuation_count'] = df['text'].apply(punctuation_count)

numerical_feature_cols = ['text_len', 'word_count', 'uppercase_ratio', 'punctuation_count']
numerical_features = df[numerical_feature_cols].values

scaler = StandardScaler()
numerical_features_scaled = scaler.fit_transform(numerical_features)
with open(os.path.join(MODEL_SAVE_DIR, 'scaler.pkl'), 'wb') as f:
    pickle.dump(scaler, f)
print("Structural features extracted and scaled. Scaler saved.")

print("\nGenerating contextual embeddings with DistilBERT...")
tokenizer = DistilBertTokenizerFast.from_pretrained(PRETRAINED_DISTILBERT_PATH)
distilbert_base = TFDistilBertModel.from_pretrained(PRETRAINED_DISTILBERT_PATH)

MAX_LEN = 256
BATCH_SIZE_EMBEDDING = 32

all_embeddings = []
texts = df['cleaned_text'].tolist()
for i in tqdm(range(0, len(texts), BATCH_SIZE_EMBEDDING), desc="Generating Embeddings"):
    batch = texts[i:i+BATCH_SIZE_EMBEDDING]
    inputs = tokenizer(batch, return_tensors='tf', truncation=True, padding='max_length', max_length=MAX_LEN)
    outputs = distilbert_base(inputs)
    cls_embeddings = outputs.last_hidden_state[:, 0, :].numpy()
    all_embeddings.append(cls_embeddings)
text_embeddings = np.vstack(all_embeddings)
print(f"Embeddings generated. Shape: {text_embeddings.shape}")

print("\nCombining features and splitting data...")
labels = df['label'].values
combined_features = np.concatenate([text_embeddings, numerical_features_scaled], axis=1)

X_train_full, X_test, y_train_full, y_test = train_test_split(
    combined_features, labels, test_size=0.20, random_state=RANDOM_STATE, stratify=labels
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_full, y_train_full, test_size=0.10, random_state=RANDOM_STATE, stratify=y_train_full
)
print(f"Data split complete: Train={X_train.shape}, Validation={X_val.shape}, Test={X_test.shape}")

In [None]:
# ==============================================================================
# 4. MODEL TRAINING
# ==============================================================================
print("\n--- Model Training ---")
def build_classifier(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=input_shape),
        tf.keras.layers.Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    return model

classifier = build_classifier((X_train.shape[1],))
classifier.summary()

print("\nStarting model training...")
proc = psutil.Process(os.getpid())
mem_before = proc.memory_info().rss / (1024 ** 2)
start_time = time.time()

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
history = classifier.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=15,
    batch_size=32,
    callbacks=[early_stopping],
    verbose=1
)
end_time = time.time()
mem_after = proc.memory_info().rss / (1024 ** 2)

training_time_seconds = end_time - start_time
peak_cpu_mem_mb = max(mem_before, mem_after)
print(f"\nTraining complete in {training_time_seconds:.2f} seconds.")
print(f"Peak CPU Memory Usage: {peak_cpu_mem_mb:.2f} MB")

classifier.save(os.path.join(MODEL_SAVE_DIR, 'distilbert_features_classifier.h5'))
print("Classifier model saved.")

In [None]:
# ==============================================================================
# 5. EVALUATION
# ==============================================================================
print("\n--- Final Evaluation on Test Set ---")
loss, accuracy = classifier.evaluate(X_test, y_test, verbose=0)
print(f"  Test Loss: {loss:.4f}")
print(f"  Test Accuracy: {accuracy:.4f}")

y_pred_probs = classifier.predict(X_test)
y_pred = (y_pred_probs > 0.5).astype(int).flatten()

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Ham (0)', 'Spam (1)']))

print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Ham', 'Spam'], yticklabels=['Ham', 'Spam'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

print("\nTraining History:")
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='lower right')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()