# Phishing Email Detection Model Training

This notebook demonstrates the process of training a machine learning model to detect phishing emails. The model will be converted to ONNX format for deployment in a web application.

## Environment Setup

First, let's install the required libraries:

In [None]:
# Install required libraries
!pip install numpy pandas scikit-learn matplotlib seaborn nltk skl2onnx onnxruntime

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
import nltk
import string
import email
import email.parser
from email import policy
from pathlib import Path

from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, TransformerMixin

# For ONNX conversion
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnxruntime as rt

# For visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
np.random.seed(42)

## Download NLTK Resources

In [None]:
# Download necessary NLTK resources
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

## Dataset Loading

For this notebook, we'll use the phishing email dataset from Kaggle. You'll need to download it and upload to your Google Colab session.

In [None]:
# Uncomment the following cell to upload the dataset
from google.colab import files
uploaded = files.upload()  # Upload your phishing email dataset

In [None]:
# Load the dataset
# Note: Adjust the filename to match your uploaded file
df = pd.read_csv('phishing_email_dataset.csv')

# Display basic information about the dataset
print("Dataset shape:", df.shape)
df.head()

## Data Exploration

In [None]:
# Check for class distribution
print("Class distribution:")
print(df['label'].value_counts())

# Visualize class distribution
plt.figure(figsize=(6, 4))
sns.countplot(x='label', data=df)
plt.title('Class Distribution')
plt.xlabel('Email Type')
plt.ylabel('Count')
plt.xticks([0, 1], ['Legitimate', 'Phishing'])
plt.show()

## Feature Extraction Functions

We'll create functions to extract features from email content. These should match the feature extraction in the web application.

In [None]:
# Define phishing related keywords
PHISHING_KEYWORDS = [
    'urgent', 'verify', 'account', 'password', 'update', 'bank', 'security', 
    'alert', 'suspend', 'login', 'click', 'confirm', 'validate', 'immediately',
    'paypal', 'credit', 'debit', 'ssn', 'social security', 'limited time',
    'offer', 'prize', 'winner', 'lottery', 'inheritance', 'million', 'dollars',
    'fraud', 'secure', 'unauthorised', 'unauthorized', 'access', 'unusual',
    'activity', 'breach', 'verify', 'verification', 'restricted', 'terminate',
    'expire', 'reset', 'cryptocurrency', 'bitcoin'
]

# Define suspicious URL patterns
SUSPICIOUS_URL_PATTERNS = [
    r'https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^/\s]*)*',  # URLs
    r'(?:https?://)?(?:www\.)?(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}(?:/[^/\s]*)*'  # URLs without protocol
]

# Feature names for readability
FEATURE_NAMES = [
    'contains_urgent_words', 
    'contains_finance_words',
    'contains_security_words',
    'num_urls', 
    'num_urls_mismatched_text',
    'has_html', 
    'has_attachments',
    'has_common_phishing_phrases',
    'email_length',
    'num_misspellings',
    'contains_ip_urls',
    'has_suspicious_sender',
    'request_for_credentials',
    'email_contains_javascript',
    'link_domain_age_fake'  # Placeholder feature
]

def extract_features(email_content):
    """
    Extract features from email content for phishing detection
    
    Args:
        email_content (str): Raw email content
        
    Returns:
        list: Numerical features for model input
    """
    try:
        # Try to parse as email
        email_message = email.parser.Parser().parsestr(email_content)
        
        # If can't parse headers properly, treat as just the body
        if not email_message['From'] and not email_message['To'] and not email_message['Subject']:
            email_body = email_content
            email_headers = {}
        else:
            # Extract headers
            email_headers = {k: v for k, v in email_message.items()}
            
            # Extract body
            if email_message.is_multipart():
                email_body = ""
                for part in email_message.walk():
                    content_type = part.get_content_type()
                    if content_type == "text/plain" or content_type == "text/html":
                        try:
                            payload = part.get_payload(decode=True).decode('utf-8', errors='ignore')
                            email_body += payload
                        except:
                            pass
            else:
                try:
                    email_body = email_message.get_payload(decode=True).decode('utf-8', errors='ignore')
                except:
                    email_body = email_message.get_payload()
    except:
        # If parsing fails, assume the input is just the email body
        email_body = email_content
        email_headers = {}
    
    # Feature extraction (as defined in the web application)
    lowercase_body = email_body.lower()
    
    # 1. Check for urgent words
    urgent_words = ['urgent', 'immediately', 'alert', 'attention', 'important', 'critical']
    contains_urgent_words = any(word in lowercase_body for word in urgent_words)
    
    # 2. Check for finance-related words
    finance_words = ['bank', 'account', 'credit', 'debit', 'payment', 'money', 'transfer', 'financial']
    contains_finance_words = any(word in lowercase_body for word in finance_words)
    
    # 3. Check for security-related words
    security_words = ['password', 'login', 'verify', 'secure', 'security', 'update', 'confirm']
    contains_security_words = any(word in lowercase_body for word in security_words)
    
    # 4. Count URLs in the email
    urls = []
    for pattern in SUSPICIOUS_URL_PATTERNS:
        urls.extend(re.findall(pattern, email_body))
    num_urls = len(urls)
    
    # 5. Check for URL text vs. href mismatches
    href_pattern = r'<a\s+(?:[^>]*?\s+)?href=(["\'])(.*?)\1'
    href_urls = re.findall(href_pattern, email_body)
    num_urls_mismatched_text = 0
    
    link_text_pattern = r'<a\s+(?:[^>]*?\s+)?href=["\'](.+?)["\'](?:[^>]*?)>(.+?)<\/a>'
    for match in re.finditer(link_text_pattern, email_body, re.IGNORECASE | re.DOTALL):
        href = match.group(1)
        text = match.group(2)
        
        # Remove HTML tags from link text
        clean_text = re.sub(r'<[^>]+>', '', text)
        
        # Check if text looks like a URL but doesn't match href
        if (re.search(r'https?://\S+', clean_text) or 
            re.search(r'www\.\S+', clean_text) or 
            re.search(r'\S+\.(com|org|net|edu|gov|co|io)\S*', clean_text)):
            
            # Simple domain comparison
            if href not in clean_text and clean_text not in href:
                num_urls_mismatched_text += 1
    
    # 6. Check if email contains HTML
    has_html = 1 if re.search(r'<html|<body|<div|<span|<table|<a\s+href', email_body, re.IGNORECASE) else 0
    
    # 7. Check for attachments
    has_attachments = 0
    if hasattr(email_message, 'is_multipart') and email_message.is_multipart():
        for part in email_message.walk():
            if part.get_content_disposition() == 'attachment':
                has_attachments = 1
                break
    
    # 8. Check for common phishing phrases
    common_phishing_phrases = [
        'verify your account', 
        'update your information',
        'confirm your details', 
        'unusual activity',
        'suspicious activity',
        'click here to',
        'your account will be suspended',
        'won a prize',
        'claim your reward',
        'access will be disabled'
    ]
    has_common_phishing_phrases = any(phrase in lowercase_body for phrase in common_phishing_phrases)
    
    # 9. Email length (normalized)
    email_length = min(len(email_body) / 5000, 1.0)  # Normalize to 0-1 range
    
    # 10. Check for potential misspellings
    words = re.findall(r'\b[a-zA-Z]{3,15}\b', email_body)
    misspelling_patterns = [
        r'[a-z]{2,}[0-9]+[a-z]*',         # Words with numbers mixed in
        r'([a-z])\1{2,}',                 # Characters repeated more than twice
        r'[aeiou]{4,}',                   # Too many consecutive vowels
        r'[bcdfghjklmnpqrstvwxyz]{5,}'    # Too many consecutive consonants
    ]
    num_misspellings = 0
    for word in words:
        if any(re.search(pattern, word.lower()) for pattern in misspelling_patterns):
            num_misspellings += 1
    num_misspellings = min(num_misspellings / 10, 1.0)  # Normalize to 0-1 range
    
    # 11. Check for IP-based URLs
    ip_url_pattern = r'https?://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}'
    contains_ip_urls = 1 if re.search(ip_url_pattern, email_body) else 0
    
    # 12. Check for suspicious sender
    has_suspicious_sender = 0
    if 'From' in email_headers:
        sender = email_headers['From'].lower()
        suspicious_patterns = [
            r'@.*\..*\.[a-z]{2,}',        # Multiple subdomain levels
            r'@.*[0-9]{4,}',              # Numbers in domain
            r'@(?!gmail|yahoo|hotmail|outlook|aol|icloud|protonmail|mail)',  # Uncommon mail providers
            r'@.*\.(ru|cn|top|xyz|tk|ml|ga|cf)',  # Suspicious TLDs
        ]
        if any(re.search(pattern, sender) for pattern in suspicious_patterns):
            has_suspicious_sender = 1
    
    # 13. Check for requests for credentials
    credential_patterns = [
        r'enter.*password',
        r'update.*credentials',
        r'confirm.*account details',
        r'verify.*identity',
        r'login.*details',
        r'your.*pin',
        r'security.*code'
    ]
    request_for_credentials = any(re.search(pattern, lowercase_body) for pattern in credential_patterns)
    
    # 14. Email contains JavaScript
    email_contains_javascript = 1 if re.search(r'<script|javascript:', email_body, re.IGNORECASE) else 0
    
    # 15. Link domain age (placeholder feature)
    link_domain_age_fake = 1  # Default to suspicious (short age)
    
    # Combine features into a vector
    features = [
        int(contains_urgent_words),
        int(contains_finance_words),
        int(contains_security_words),
        min(num_urls / 10, 1.0),  # Normalize number of URLs
        min(num_urls_mismatched_text / 5, 1.0),  # Normalize mismatched URLs
        has_html,
        has_attachments,
        int(has_common_phishing_phrases),
        email_length,
        num_misspellings,
        contains_ip_urls,
        has_suspicious_sender,
        int(request_for_credentials),
        email_contains_javascript,
        link_domain_age_fake
    ]
    
    return features

In [None]:
# Feature extraction class for the pipeline
class EmailFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        # Extract features from each email
        features = [extract_features(email) for email in X]
        return np.array(features)

## Data Preprocessing

In [None]:
# Extract features for each email in the dataset
# Note: This process can take some time depending on the size of your dataset
print("Extracting features from emails...")
X_features = []

for email_content in df['email_content']:
    features = extract_features(email_content)
    X_features.append(features)

X = np.array(X_features)
y = df['label'].values

print(f"Features extracted. Feature matrix shape: {X.shape}")

In [None]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Testing set: {X_test.shape[0]} samples")

## Model Training and Evaluation

In [None]:
# Train a Gradient Boosting Classifier
print("Training Gradient Boosting Classifier...")
gb_clf = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gb_clf.fit(X_train, y_train)

# Make predictions on the test set
y_pred = gb_clf.predict(X_test)
y_pred_proba = gb_clf.predict_proba(X_test)

# Evaluate the model
print("\nModel Evaluation:")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(f"Precision: {precision_score(y_test, y_pred):.4f}")
print(f"Recall: {recall_score(y_test, y_pred):.4f}")
print(f"F1 Score: {f1_score(y_test, y_pred):.4f}")

# Display confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Legitimate', 'Phishing'],
            yticklabels=['Legitimate', 'Phishing'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# Display classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Legitimate', 'Phishing']))

## Feature Importance Analysis

In [None]:
# Analyze feature importance
feature_importance = gb_clf.feature_importances_

# Sort features by importance
sorted_idx = np.argsort(feature_importance)
plt.figure(figsize=(10, 8))
plt.barh(np.array(FEATURE_NAMES)[sorted_idx], feature_importance[sorted_idx])
plt.xlabel('Feature Importance')
plt.title('Feature Importance in Phishing Detection')
plt.tight_layout()
plt.show()

## Hyperparameter Tuning

In [None]:
# Perform hyperparameter tuning using GridSearchCV
print("Performing hyperparameter tuning...")

# Define parameter grid
param_grid = {
    'n_estimators': [50, 100, 150],
    'learning_rate': [0.05, 0.1, 0.2],
    'max_depth': [2, 3, 4],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

# Create GridSearchCV object
grid_search = GridSearchCV(GradientBoostingClassifier(random_state=42),
                           param_grid=param_grid,
                           cv=5,
                           scoring='f1',
                           n_jobs=-1,
                           verbose=1)

# Fit the grid search to the data
grid_search.fit(X_train, y_train)

# Print the best parameters and best score
print("\nBest parameters:")
print(grid_search.best_params_)
print(f"Best cross-validation score: {grid_search.best_score_:.4f}")

# Get the best model
best_model = grid_search.best_estimator_

## Final Model Evaluation

In [None]:
# Evaluate the best model on the test set
y_pred_best = best_model.predict(X_test)
y_pred_proba_best = best_model.predict_proba(X_test)

# Print evaluation metrics
print("\nFinal Model Evaluation:")
print(f"Accuracy: {accuracy_score(y_test, y_pred_best):.4f}")
print(f"Precision: {precision_score(y_test, y_pred_best):.4f}")
print(f"Recall: {recall_score(y_test, y_pred_best):.4f}")
print(f"F1 Score: {f1_score(y_test, y_pred_best):.4f}")

# Display confusion matrix for the best model
cm_best = confusion_matrix(y_test, y_pred_best)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_best, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Legitimate', 'Phishing'],
            yticklabels=['Legitimate', 'Phishing'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix (Best Model)')
plt.show()

# Display classification report for the best model
print("\nClassification Report (Best Model):")
print(classification_report(y_test, y_pred_best, target_names=['Legitimate', 'Phishing']))

## Convert Model to ONNX Format

In [None]:
# Convert the best model to ONNX format
print("Converting model to ONNX format...")

# Define the input features type
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]

# Convert the model
onnx_model = convert_sklearn(best_model, initial_types=initial_type, options={"zipmap": True})

# Save the ONNX model
onnx_model_path = 'phishing_model.onnx'
with open(onnx_model_path, 'wb') as f:
    f.write(onnx_model.SerializeToString())

print(f"ONNX model saved to {onnx_model_path}")

## Test ONNX Model

In [None]:
# Test the ONNX model with a sample input
print("Testing ONNX model...")

# Create an ONNX inference session
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
onnx_session = rt.InferenceSession(onnx_model_path, sess_options=sess_options)

# Get input name
input_name = onnx_session.get_inputs()[0].name

# Prepare a sample input
sample_input = X_test[:5].astype(np.float32)

# Run inference with ONNX Runtime
onnx_pred = onnx_session.run(None, {input_name: sample_input})

# Compare original model predictions with ONNX model predictions
original_pred = best_model.predict(X_test[:5])
onnx_pred_labels = onnx_pred[0]

print("\nComparison of predictions:")
print("Original model predictions:", original_pred)
print("ONNX model predictions:", onnx_pred_labels)

# Check if the predictions match
is_match = np.array_equal(original_pred, onnx_pred_labels)
print(f"Predictions match: {is_match}")

## Download the ONNX Model

In [None]:
# Download the ONNX model file
from google.colab import files
files.download(onnx_model_path)

## Conclusion

In this notebook, we've:

1. Loaded and preprocessed a phishing email dataset
2. Extracted meaningful features from emails
3. Trained a Gradient Boosting classifier
4. Performed hyperparameter tuning to optimize the model
5. Evaluated the model's performance
6. Converted the model to ONNX format for deployment
7. Tested the ONNX model to ensure it works correctly

The model is now ready to be integrated into the web application for phishing email detection.