### **Robust Hybrid Phishing Detection Trainer (Live Datasets) v7**

This notebook implements the complete training pipeline using dynamic, real-world data sources. It downloads the latest phishing URLs from the `Phishing.Database` project and uses the Tranco Top 1 Million list for benign examples.

**Key Features & Changes in this Version:**
1.  **Fully Unified Feature Extraction:** The feature extraction logic now perfectly mirrors the live backend. It cleans tracking parameters (e.g., `fbclid`), resolves URL shorteners, and then extracts features, eliminating data mismatch.
2.  **Full Data Utilization:** All undersampling has been removed. The model now trains on all available public and user-reported data for maximum learning.
3.  **Live Data Sources:** Ingests data directly from the Phishing.Database project and the Tranco list, ensuring the model is trained on current threats.
4.  **Full Hybrid System:** Implements the complete AE + GCN training and evaluation pipeline using a robust cross-validation framework.

## 1. Setup and Installations ⚙️
This first code block prepares the environment. It runs installation commands to get all the special tools (software libraries) needed for the project. This includes `requests` for resolving URL shorteners.

In [None]:
# @title
# Install core data science, deep learning, and cloud libraries
!pip -q install pandas scikit-learn tensorflow tldextract google-cloud-firestore matplotlib seaborn requests

# Install PyTorch and PyTorch Geometric for the GNN
!pip -q install torch torch_geometric

## 2. Imports and Initial Configuration 📚
This cell unpacks the project's toolbox. It **imports** all the specific functions and classes from the installed libraries so they are ready to use, including `urllib.parse` for cleaning URLs. It also sets a **random seed**, a number that ensures if the notebook is run again, the results will be identical.

In [None]:
# @title
import os
import re
import json
import random
import pickle
import requests
import numpy as np
import pandas as pd
import tldextract
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode

# TensorFlow for Autoencoder
import tensorflow as tf
from tensorflow import keras

# PyTorch for GCN
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

# Scikit-learn for preprocessing, cross-validation, and metrics
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score, precision_recall_curve, auc
)

# Google Cloud for data fetching
from google.colab import files
from google.cloud import firestore
# Corrected import for Query
from google.cloud.firestore_v1 import Query

# --- Reproducibility ---
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
tf.random.set_seed(SEED)
torch.manual_seed(SEED)

print("Libraries imported and seed set.")

## 3. Fetch Live Phishing and Benign URL Datasets ☁️
This cell downloads the latest data files. For phishing examples, it gets a list of currently active phishing URLs from `Phishing.Database`. For safe examples, it downloads the Tranco Top 1 Million list, which contains the most popular (and generally safe) websites on the internet.

In [None]:
# @title
# Fetch Phishing URLs by direct download
print("Downloading latest active phishing links...")
!wget https://phish.co.za/latest/phishing-links-ACTIVE.txt -O phishing-links-ACTIVE.txt

# Fetch Benign URLs (Tranco Top 1 Million)
print("\nDownloading Tranco Top 1 Million list...")
!wget https://tranco-list.eu/top-1m.csv.zip -O top-1m.csv.zip
!unzip -o top-1m.csv.zip

print("\n✅ Datasets downloaded.")

## 4. Parse and Combine Datasets 📋
This code reads the raw data from the downloaded files and organizes it into a single dataset. This version uses all available data without undersampling to provide the model with the maximum amount of information for pre-training.

In [None]:
# @title
# Parse Phishing URLs from the downloaded text file
try:
    with open('phishing-links-ACTIVE.txt', 'r', encoding='utf-8') as f:
        phishing_urls = [line.strip() for line in f if line.strip()]
    df_phish = pd.DataFrame(phishing_urls, columns=['url'])
    df_phish['label'] = 1
    df_phish.drop_duplicates(inplace=True)
    print(f"Parsed {len(df_phish)} unique active phishing URLs.")
except FileNotFoundError:
    print("phishing-links-ACTIVE.txt not found. Creating an empty phishing dataframe.")
    df_phish = pd.DataFrame(columns=['url', 'label'])

# Parse Benign URLs from Tranco list
df_benign_full = pd.read_csv('top-1m.csv', names=['rank', 'domain'])
df_benign_full['url'] = 'http://' + df_benign_full['domain']
df_benign_full['label'] = 0
print(f"Parsed {len(df_benign_full)} benign domains from Tranco list.")

# CORRECTED: Combine all available public data without undersampling
df_public_raw = pd.concat([df_phish, df_benign_full[['url', 'label']]], ignore_index=True)
df_public_raw = df_public_raw.sample(frac=1, random_state=SEED).reset_index(drop=True)

print(f"\nCreated a combined dataset with {len(df_public_raw)} total URLs.")
print(df_public_raw['label'].value_counts())

## 5. Feature Engineering and Scaling 📏
This is a computationally intensive step that **turns each URL into a list of numbers** a machine can understand. This function is now fully unified with the live backend's logic: it first **cleans tracking parameters** (like `fbclid`), then **resolves URL shorteners**, and finally extracts the numerical features from the final, processed URL.

In [None]:
# @title
# Define the list of lexical features we will extract.
LEXICAL_FEATURE_COLUMNS = [
    'qty_dot_url', 'qty_hyphen_url', 'qty_underline_url', 'qty_slash_url',
    'qty_questionmark_url', 'qty_equal_url', 'qty_at_url', 'qty_and_url',
    'qty_exclamation_url', 'qty_space_url', 'qty_tilde_url', 'qty_comma_url',
    'qty_plus_url', 'qty_asterisk_url', 'qty_hashtag_url', 'qty_dollar_url',
    'qty_percent_url', 'qty_dot_domain', 'qty_hyphen_domain', 'qty_underline_domain',
    'qty_at_domain', 'qty_vowels_domain', 'domain_length', 'domain_in_ip',
    'server_client_domain', 'url_shortened'
]

# Using a session object for connection pooling is more efficient for many requests
session = requests.Session()

def get_lexical_features(url: str) -> dict:
    """
    Extracts lexical features from a URL, including cleaning tracking parameters 
    and resolving shorteners to match the backend logic.
    """
    features = {}
    raw_url = str(url)
    
    try:
        # 1. NORMALIZE URL: Remove common tracking query params (logic from app.py)
        normalized_url = raw_url
        try:
            p = urlparse(raw_url)
            if p.query:
                q = parse_qsl(p.query, keep_blank_values=True)
                q_filtered = [(k, v) for k, v in q if not (k.lower() == 'fbclid' or k.lower().startswith('utm_'))]
                new_query = urlencode(q_filtered)
                p = p._replace(query=new_query)
                normalized_url = urlunparse(p)
        except Exception:
            normalized_url = raw_url # Fallback to original if parsing fails

        # 2. RESOLVE SHORTENERS: Use the normalized URL for this step
        final_url = normalized_url
        shorteners = {"bit.ly","tinyurl.com","t.co","goo.gl","ow.ly","is.gd","cutt.ly","lnkd.in","buff.ly"}
        ext_original = tldextract.extract(normalized_url)
        original_domain_suffix = f"{ext_original.domain}.{ext_original.suffix}"
        is_shortened = 1 if original_domain_suffix in shorteners else 0
        features['url_shortened'] = is_shortened

        if is_shortened:
            try:
                res = session.head(normalized_url, allow_redirects=True, timeout=2)
                final_url = res.url
            except requests.exceptions.RequestException:
                final_url = normalized_url # Fallback if resolution fails

        # 3. EXTRACT FEATURES: Use the final, processed URL
        ext = tldextract.extract(final_url)
        domain = f"{ext.domain}.{ext.suffix}"

        features['qty_dot_url'] = final_url.count('.')
        features['qty_hyphen_url'] = final_url.count('-')
        features['qty_underline_url'] = final_url.count('_')
        features['qty_slash_url'] = final_url.count('/')
        features['qty_questionmark_url'] = final_url.count('?')
        features['qty_equal_url'] = final_url.count('=')
        features['qty_at_url'] = final_url.count('@')
        features['qty_and_url'] = final_url.count('&')
        features['qty_exclamation_url'] = final_url.count('!')
        features['qty_space_url'] = final_url.count(' ')
        features['qty_tilde_url'] = final_url.count('~')
        features['qty_comma_url'] = final_url.count(',')
        features['qty_plus_url'] = final_url.count('+')
        features['qty_asterisk_url'] = final_url.count('*')
        features['qty_hashtag_url'] = final_url.count('#')
        features['qty_dollar_url'] = final_url.count('$')
        features['qty_percent_url'] = final_url.count('%')
        features['qty_dot_domain'] = domain.count('.')
        features['qty_hyphen_domain'] = domain.count('-')
        features['qty_underline_domain'] = domain.count('_')
        features['qty_at_domain'] = domain.count('@')
        features['qty_vowels_domain'] = sum(1 for char in domain if char in 'aeiouAEIOU')
        features['domain_length'] = len(domain)
        features['domain_in_ip'] = 1 if re.fullmatch(r"\d+\.\d+\.\d+\.\d+", domain) else 0
        features['server_client_domain'] = 1 if 'server' in domain or 'client' in domain else 0

    except Exception:
        # Return a dictionary of zeros if any error occurs
        return {col: 0 for col in LEXICAL_FEATURE_COLUMNS}
    return features

def create_feature_vector(url: str, all_feature_columns: list) -> np.ndarray:
    """Creates a full feature vector for a URL, filling missing values with 0."""
    lexical_feats = get_lexical_features(url)
    feature_vector = np.array([lexical_feats.get(col, 0.0) for col in all_feature_columns], dtype=np.float32)
    return feature_vector

print(f"Extracting {len(LEXICAL_FEATURE_COLUMNS)} lexical features from {len(df_public_raw)} URLs...")
tqdm.pandas(desc="Feature Extraction")
feature_vectors = df_public_raw['url'].progress_apply(lambda url: create_feature_vector(url, LEXICAL_FEATURE_COLUMNS))

# Create the final feature matrix and label vector
X_public = np.vstack(feature_vectors.values)
y_public = df_public_raw['label'].values

print(f"\nFeature extraction complete. X_public shape: {X_public.shape}")

# --- Feature Scaling ---
print("\nFitting RobustScaler on benign public training data...")
X_public_train_df = pd.DataFrame(X_public, columns=LEXICAL_FEATURE_COLUMNS)
X_public_train, X_public_test, y_public_train, y_public_test = train_test_split(
    X_public_train_df, y_public, test_size=0.3, random_state=SEED, stratify=y_public
)
scaler = RobustScaler().fit(X_public_train[y_public_train == 0])

print("✅ Scaler fitted and saved.")
with open("scaler.pkl", "wb") as f:
    pickle.dump(scaler, f)

## 6. Define and Pre-train the Autoencoder Model 🧠
This cell builds and trains the first of the two main models: the **Autoencoder (AE)**. Its job is to learn what a "normal," safe URL looks like. This initial training is done using **only the benign (safe) links** from the live dataset. This way, it becomes an expert at spotting any URL that looks weird or anomalous.

In [None]:
# @title
# Autoencoder hyperparameters
AE_LAYER1 = 16 # Adjusted for smaller feature set
AE_LAYER2 = 8
AE_BOTTLENECK = 4
AE_DROPOUT = 0.1
LR_PRETRAIN = 1e-3
BATCH_PRETRAIN = 512
EPOCHS_PRETRAIN = 30

def build_autoencoder(input_shape):
    """Builds the Keras Autoencoder model."""
    inp = keras.Input(shape=(input_shape,))
    x = keras.layers.Dense(AE_LAYER1, activation='relu')(inp)
    x = keras.layers.Dropout(AE_DROPOUT)(x)
    x = keras.layers.Dense(AE_LAYER2, activation='relu')(x)
    z = keras.layers.Dense(AE_BOTTLENECK, activation='relu', name='bottleneck')(x)
    x = keras.layers.Dense(AE_LAYER2, activation='relu')(z)
    x = keras.layers.Dense(AE_LAYER1, activation='relu')(x)
    out = keras.layers.Dense(input_shape)(x)
    model = keras.Model(inp, out)
    return model

# Build and compile the model for pre-training
pretrain_ae_model = build_autoencoder(X_public.shape[1])
pretrain_ae_model.compile(optimizer=keras.optimizers.Adam(LR_PRETRAIN), loss='mse')

print('--- Pre-training Autoencoder on New Public Dataset ---')
X_public_train_scaled = scaler.transform(X_public_train)
X_benign_public_train = X_public_train_scaled[y_public_train == 0]

history = pretrain_ae_model.fit(
    X_benign_public_train, X_benign_public_train,
    epochs=EPOCHS_PRETRAIN,
    batch_size=BATCH_PRETRAIN,
    shuffle=True,
    verbose=1,
    validation_split=0.2,
    callbacks=[keras.callbacks.EarlyStopping(patience=5, monitor='val_loss', restore_best_weights=True)]
)

print('\n✅ Pre-training complete.')
pretrain_ae_model.save_weights('pretrained_ae_weights.weights.h5')

## 7. Fetch and Process User-Reported Data 📥
This cell connects to the Firebase database and **pulls down all fresh data** the browser extension has collected. This includes user feedback (links marked as "phishing" or "safe") and the graph data that shows relationships between users, posts, and domains. It then cleans this raw data, removes duplicates, and organizes it into a neat table for the next steps.

In [None]:
# @title
def fetch_user_data(db_client):
    """Fetches and processes user reports and graph data from Firestore."""
    if not db_client:
        print("Firestore client not initialized. Skipping data fetch.")
        return pd.DataFrame(), [], []

    APP_ID = "ads-phishing-link"
    REPORTS_PATH = f"artifacts/{APP_ID}/private_user_reports"
    NODES_PATH = f"artifacts/{APP_ID}/private/graph/nodes"
    EDGES_PATH = f"artifacts/{APP_ID}/private/graph/edges"

    # Fetch Reports
    print(f"Fetching user reports from: {REPORTS_PATH}...")
    try:
        # FIXED: Order by timestamp to process in a predictable, chronological order
        reports_query = db_client.collection(REPORTS_PATH).order_by("timestamp", direction=Query.DESCENDING)
        report_docs = list(reports_query.stream())
        print(f"Found {len(report_docs)} total user reports.")
    except Exception as e:
        print(f"❌ Error fetching reports: {e}")
        report_docs = []

    # Fetch Graph Data
    print(f"Fetching graph data...")
    try:
        node_docs = list(db_client.collection(NODES_PATH).stream())
        edge_docs = list(db_client.collection(EDGES_PATH).stream())
        print(f"Found {len(node_docs)} nodes and {len(edge_docs)} edges.")
    except Exception as e:
        print(f"❌ Error fetching graph data: {e}")
        node_docs, edge_docs = [], []

    # Process reports into a DataFrame
    processed_reports = []
    for doc in report_docs:
        d = doc.to_dict()
        payload = d.get('payload', {})
        # FIXED: Handle both 'url' and 'links' in payload
        url = payload.get('url')
        if not url:
            links = payload.get('links')
            if isinstance(links, list) and len(links) > 0:
                url = links[0]

        report = {
            'url': url,
            'postId': payload.get('postId'),
            'type': d.get('type')
        }
        if not all(report.values()):
            continue

        report['label'] = 1 if report['type'] in ('true_positive', 'false_negative') else 0
        processed_reports.append(report)

    if not processed_reports:
        print("No valid reports found.")
        return pd.DataFrame(), [], []

    df_user = pd.DataFrame(processed_reports).drop_duplicates(subset=['url', 'postId']).reset_index(drop=True)
    print(f"Created a dataset of {len(df_user)} unique user-reported URL-post pairs.")
    print("User data label distribution:")
    print(df_user['label'].value_counts())

    return df_user, node_docs, edge_docs

# --- Firebase Authentication (if not already done) ---
if 'db' not in locals() or db is None:
    print("\nPlease upload your Firebase service account JSON key file.")
    try:
        uploaded = files.upload()
        sa_path = next(iter(uploaded.keys()))
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = sa_path
        db = firestore.Client()
        print("\n✅ Firebase authentication configured.")
    except Exception as e:
        print(f"\n❌ Firebase authentication failed: {e}")
        db = None

# Fetch all user data
df_user_reports, node_docs, edge_docs = fetch_user_data(db)

## 8. Prepare Data for Cross-Validation and Fine-Tuning 📋
This code takes the cleaned user data from the previous step and **prepares it for the main training phase**. It uses the same feature engineering logic from section 5 to convert all user-reported URLs into numerical lists, ensuring they are in the exact same format as the public data.

In [None]:
# @title
if not df_user_reports.empty:
    # Create feature vectors for all user-reported URLs
    # This step will now resolve shorteners as defined in the updated function
    tqdm.pandas(desc="Feature Extraction (User Data)")
    user_feature_vectors = df_user_reports['url'].progress_apply(lambda url: create_feature_vector(url, LEXICAL_FEATURE_COLUMNS))
    X_user = np.vstack(user_feature_vectors.values)

    y_user = df_user_reports['label'].values
    post_ids_user = df_user_reports['postId'].values

    print(f"User data prepared for cross-validation: X_user shape {X_user.shape}")
else:
    print("User reports DataFrame is empty. Cannot proceed.")
    X_user, y_user, post_ids_user = [np.array([]) for _ in range(3)]
    node_docs, edge_docs = [], []

## 9. The Cross-Validation and Fine-Tuning Loop 🔁
This is the most important part of the notebook. It starts a loop to **train and test the model in a robust way**. 

- It **fine-tunes the Autoencoder (AE)** using all available safe links from the user training data for that round.
- It **trains the Graph (GCN) model** using all user data for that round, but with class weights to handle imbalance.
- It then tests both models on held-back data and finds the **best way to fuse their scores** for the highest accuracy.

This process repeats 5 times, ensuring the final performance score is reliable.

In [None]:
# @title
# --- GCN Model Definition ---
class GCN(nn.Module):
    def __init__(self, num_features, num_classes=2):
        super().__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

if len(X_user) > 0:
    # --- Cross-Validation Setup ---
    N_SPLITS = 5
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

    # --- Hyperparameters ---
    LR_FINETUNE = 1e-5
    BATCH_FINETUNE = 16
    EPOCHS_FINETUNE = 50
    EPOCHS_GCN = 50

    # --- Storage for results across folds ---
    fold_results = []
    best_thresholds = []
    best_fusion_weights = []
    all_confusion_matrices = []

    # --- Build the full graph structure once ---
    node_key_to_idx, idx_to_key, x_rows = {}, [], []
    def add_node(key, doc):
        if key in node_key_to_idx: return
        idx = len(idx_to_key)
        node_key_to_idx[key] = idx
        idx_to_key.append(key)
        node_type = (doc or {}).get("type")
        feats = [1.0, 0.0, 0.0] if node_type == "user" else [0.0, 1.0, 0.0] if node_type == "domain" else [0.0, 0.0, 1.0]
        x_rows.append(feats)

    for d in node_docs: add_node(d.id, d.to_dict())
    edges = []
    for d in edge_docs:
        e = d.to_dict()
        src, dst = e.get("src"), e.get("dst")
        if src and dst and src in node_key_to_idx and dst in node_key_to_idx:
            edges.append([node_key_to_idx[src], node_key_to_idx[dst]])
            edges.append([node_key_to_idx[dst], node_key_to_idx[src]]) # Undirected

    x_graph = torch.tensor(np.array(x_rows, dtype=np.float32)) if x_rows else torch.empty((0, 3))
    edge_index_graph = torch.tensor(np.array(edges, dtype=np.int64)).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)

    # Map post IDs to their corresponding node indices
    post_id_to_node_idx = {pid: node_key_to_idx.get(f"post:{pid}") for pid in np.unique(post_ids_user)}

    # --- The Loop ---
    for fold_idx, (train_indices, val_indices) in enumerate(skf.split(X_user, y_user)):
        print(f"\n--- Starting Fold {fold_idx + 1}/{N_SPLITS} ---")

        # 1. Split data for this fold
        X_train_fold, y_train_fold = X_user[train_indices], y_user[train_indices]
        X_val_fold, y_val_fold = X_user[val_indices], y_user[val_indices]
        posts_train_fold, posts_val_fold = post_ids_user[train_indices], post_ids_user[val_indices]

        # ===================================
        #     AUTOENCODER TRAINING
        # ===================================

        # 2. CORRECTED: Fine-tune AE only on the BENIGN data from the user training set for this fold
        X_benign_finetune = X_train_fold[y_train_fold == 0]
        X_benign_finetune_scaled = scaler.transform(X_benign_finetune)
        X_val_scaled = scaler.transform(X_val_fold)

        # 3. Build AE and load pre-trained weights
        fine_tune_ae = build_autoencoder(X_user.shape[1])
        fine_tune_ae.load_weights('pretrained_ae_weights.weights.h5')
        fine_tune_ae.compile(optimizer=keras.optimizers.Adam(LR_FINETUNE), loss='mse')

        # 4. Fine-tune AE
        if len(X_benign_finetune_scaled) > 0:
            fine_tune_ae.fit(X_benign_finetune_scaled, X_benign_finetune_scaled, epochs=EPOCHS_FINETUNE, batch_size=BATCH_FINETUNE, verbose=0, callbacks=[keras.callbacks.EarlyStopping(patience=5)])

        # 5. Get AE reconstruction errors for the validation set
        val_reconstructed = fine_tune_ae.predict(X_val_scaled, verbose=0)
        val_errors_ae = np.mean(np.square(X_val_scaled - val_reconstructed), axis=1)

        # 6. Find best AE threshold on validation set
        precision, recall, thresholds = precision_recall_curve(y_val_fold, val_errors_ae)
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-9)
        best_ae_threshold = thresholds[np.argmax(f1_scores)]
        best_thresholds.append(best_ae_threshold)
        print(f"Best AE Threshold found: {best_ae_threshold:.4f}")

        # ===================================
        #         GCN TRAINING
        # ===================================

        # 7. Prepare graph data for this fold
        y_nodes = torch.full((len(idx_to_key),), -1, dtype=torch.long)
        for post_id, label in zip(post_ids_user, y_user):
            node_idx = post_id_to_node_idx.get(post_id)
            if node_idx is not None:
                y_nodes[node_idx] = label

        train_node_indices = [post_id_to_node_idx.get(pid) for pid in posts_train_fold if post_id_to_node_idx.get(pid) is not None]

        train_mask = torch.zeros(len(idx_to_key), dtype=torch.bool)
        if train_node_indices:
            train_mask[torch.tensor(train_node_indices)] = True

        graph_data = Data(x=x_graph, edge_index=edge_index_graph, y=y_nodes, train_mask=train_mask)

        # 8. Train GCN
        gcn_model = GCN(num_features=graph_data.x.size(1))
        optimizer = optim.Adam(gcn_model.parameters(), lr=1e-2, weight_decay=5e-4)

        # Calculate class weights for imbalance
        train_labels = graph_data.y[train_mask]
        class_counts = torch.bincount(train_labels[train_labels >= 0], minlength=2)
        if torch.all(class_counts > 0):
            class_weights = 1. / class_counts.float()
            class_weights = class_weights / class_weights.sum()
            print(f"Training GCN with class weights: {class_weights.numpy()}")
        else:
            class_weights = None
            print("Training GCN without class weights (one class missing in this training fold).")

        if train_mask.sum() > 0:
            for epoch in range(EPOCHS_GCN):
                gcn_model.train()
                optimizer.zero_grad()
                out = gcn_model(graph_data)
                loss = F.nll_loss(out[train_mask], graph_data.y[train_mask], weight=class_weights)
                loss.backward()
                optimizer.step()

        # 9. Get GCN probabilities for all nodes
        gcn_model.eval()
        with torch.no_grad():
            all_node_logits = gcn_model(graph_data)
            all_node_probs = all_node_logits.exp()[:, 1].cpu().numpy()

        # ===================================
        #     HYBRID FUSION & EVALUATION
        # ===================================

        # 10. Get AE and GCN scores for the validation set posts
        df_val_fold = pd.DataFrame({'postId': posts_val_fold, 'ae_error': val_errors_ae, 'label': y_val_fold})
        post_level_ae_scores = df_val_fold.groupby('postId')['ae_error'].max()
        val_post_ids_unique = df_val_fold['postId'].unique()
        val_labels_post_level_df = df_val_fold.groupby('postId')['label'].first()

        # Align posts in validation set with posts in the graph
        common_posts = [pid for pid in val_post_ids_unique if post_id_to_node_idx.get(pid) is not None]

        if not common_posts:
            print("Warning: No validation posts found in the graph for this fold. Skipping GCN/Hybrid evaluation.")
            fold_results.append({'f1_ae': f1_score(y_val_fold, (val_errors_ae > best_ae_threshold).astype(int)), 'f1_gcn': 0, 'f1_hybrid': 0, 'precision_hybrid': 0, 'recall_hybrid': 0, 'accuracy_hybrid': 0})
            all_confusion_matrices.append(np.zeros((2,2), dtype=int))
            continue

        ae_scores_val = post_level_ae_scores.loc[common_posts].values
        gcn_scores_val = np.array([all_node_probs[post_id_to_node_idx[pid]] for pid in common_posts])
        val_labels_post_level = val_labels_post_level_df.loc[common_posts].values

        # 11. Optimize Fusion Weight 'w'
        best_w = 0.5
        best_f1_fusion = -1
        for w in np.linspace(0, 1, 21):
            ae_prob = np.minimum(ae_scores_val / (best_ae_threshold * 2 + 1e-9), 1.0)
            fused_scores = w * ae_prob + (1 - w) * gcn_scores_val
            fused_preds = (fused_scores > 0.5).astype(int)
            f1 = f1_score(val_labels_post_level, fused_preds, zero_division=0)
            if f1 > best_f1_fusion:
                best_f1_fusion = f1
                best_w = w
        best_fusion_weights.append(best_w)
        print(f"Best Fusion Weight (w for AE) found: {best_w:.2f}")

        # 12. Calculate and store final metrics for this fold
        ae_preds_post_level = (ae_scores_val > best_ae_threshold).astype(int)
        gcn_preds_post_level = (gcn_scores_val > 0.5).astype(int)
        ae_prob_final = np.minimum(ae_scores_val / (best_ae_threshold * 2 + 1e-9), 1.0)
        fused_scores_final = best_w * ae_prob_final + (1 - best_w) * gcn_scores_val
        hybrid_preds_post_level = (fused_scores_final > 0.5).astype(int)

        fold_metrics = {
            'f1_ae': f1_score(val_labels_post_level, ae_preds_post_level, zero_division=0),
            'f1_gcn': f1_score(val_labels_post_level, gcn_preds_post_level, zero_division=0),
            'f1_hybrid': f1_score(val_labels_post_level, hybrid_preds_post_level, zero_division=0),
            'precision_hybrid': precision_score(val_labels_post_level, hybrid_preds_post_level, zero_division=0),
            'recall_hybrid': recall_score(val_labels_post_level, hybrid_preds_post_level, zero_division=0),
            'accuracy_hybrid': accuracy_score(val_labels_post_level, hybrid_preds_post_level)
        }
        fold_results.append(fold_metrics)
        all_confusion_matrices.append(confusion_matrix(val_labels_post_level, hybrid_preds_post_level))
        print(f"Hybrid Validation Metrics for Fold {fold_idx + 1}: {fold_metrics}")
else:
    print("Cannot run cross-validation as there is no user data.")

## 10. Aggregate and Display Final Performance 📊
After the main loop is finished, this cell acts as the **reporter**. It gathers the results from all 5 rounds of testing and calculates the **average performance scores** (accuracy, precision, recall, and F1-score). It prints these scores in a table and also displays a **confusion matrix**, which is a simple chart that visually breaks down how many phishing links the model correctly caught, how many it missed, and how many times it made a false alarm.

In [None]:
# @title
if fold_results:
    df_results = pd.DataFrame(fold_results)
    print("\n--- Cross-Validation Results Summary (F1-Scores) ---")
    print(df_results[['f1_ae', 'f1_gcn', 'f1_hybrid']])

    print("\n--- Average Performance Metrics (± Std Dev) --- ")
    # Focus on the hybrid model's performance as per the thesis
    hybrid_metrics = ['accuracy_hybrid', 'precision_hybrid', 'recall_hybrid', 'f1_hybrid']
    mean_metrics = df_results[hybrid_metrics].mean()
    std_metrics = df_results[hybrid_metrics].std()
    summary_df = pd.concat([mean_metrics, std_metrics], axis=1)
    summary_df.columns = ['Mean', 'Std Dev']
    summary_df.index = ['Accuracy', 'Precision', 'Recall', 'F1-Score'] # Rename for clarity
    print(summary_df)

    # Summed Confusion Matrix
    if all_confusion_matrices:
        # Ensure all matrices are 2x2, padding if necessary
        padded_matrices = []
        for cm in all_confusion_matrices:
            if cm.shape == (1, 1):
                # This can happen if a fold only has one class
                padded_cm = np.zeros((2, 2), dtype=int)
                # Assuming the single class is the majority (negative) class
                padded_cm[0, 0] = cm[0, 0]
                padded_matrices.append(padded_cm)
            elif cm.shape == (2,2):
                padded_matrices.append(cm)

        if padded_matrices:
            summed_cm = np.sum(padded_matrices, axis=0)
            print("\n--- Summed Confusion Matrix (Across All Folds) ---")
            print(summed_cm)
            # Optional: Plot the summed confusion matrix
            plt.figure(figsize=(6, 5))
            sns.heatmap(summed_cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=['Predicted Benign', 'Predicted Phishing'],
                        yticklabels=['True Benign', 'True Phishing'])
            plt.title('Summed Confusion Matrix')
            plt.xlabel('Predicted Label')
            plt.ylabel('True Label')
            plt.show()

    # Determine the final hyperparameters
    final_threshold = np.mean(best_thresholds)
    final_fusion_weight = np.mean(best_fusion_weights)
    print(f"\n✅ Final Optimized AE Threshold (Mean): {final_threshold:.6f}")
    print(f"✅ Final Optimized Fusion Weight (Mean): {final_fusion_weight:.2f}")
else:
    print("\nNo results to aggregate.")
    final_threshold = 0.5
    final_fusion_weight = 0.5

## 11. Retrain Final Models on All User Data 🎓
Now that the best settings and fusion weight have been discovered, this code performs **one final training run**. It uses **all of the available user data** to train the AE and GCN models one last time. This ensures the exported models are as smart and up-to-date as possible, having learned from every single piece of user feedback.

In [None]:
# @title
if len(X_user) > 0:
    # --- Final Autoencoder Training ---
    print("\n--- Training Final Autoencoder on ALL User Data ---")
    X_benign_user_final = X_user[y_user == 0]
    X_benign_user_final_scaled = scaler.transform(X_benign_user_final)

    final_ae_model = build_autoencoder(X_user.shape[1])
    final_ae_model.load_weights('pretrained_ae_weights.weights.h5')
    final_ae_model.compile(optimizer=keras.optimizers.Adam(LR_FINETUNE), loss='mse')

    if len(X_benign_user_final_scaled) > 0:
        final_ae_model.fit(X_benign_user_final_scaled, X_benign_user_final_scaled, epochs=EPOCHS_FINETUNE, batch_size=BATCH_FINETUNE, verbose=1, callbacks=[keras.callbacks.EarlyStopping(patience=5)])
        print("\n✅ Final AE model training complete.")
    else:
        print("No benign user data for final AE training. Using pre-trained model.")
        final_ae_model = pretrain_ae_model

    # --- Final GCN Training ---
    print("\n--- Training Final GCN on ALL User Data ---")
    final_train_mask = torch.zeros(len(idx_to_key), dtype=torch.bool)
    final_train_node_indices = [post_id_to_node_idx.get(pid) for pid in post_ids_user if post_id_to_node_idx.get(pid) is not None]
    if final_train_node_indices:
        final_train_mask[torch.tensor(final_train_node_indices)] = True

    final_graph_data = Data(x=x_graph, edge_index=edge_index_graph, y=y_nodes, train_mask=final_train_mask)

    final_gcn_model = GCN(num_features=final_graph_data.x.size(1))
    optimizer = optim.Adam(final_gcn_model.parameters(), lr=1e-2, weight_decay=5e-4)

    final_train_labels = final_graph_data.y[final_train_mask]
    final_class_counts = torch.bincount(final_train_labels[final_train_labels >= 0], minlength=2)
    if torch.all(final_class_counts > 0):
        final_class_weights = 1. / final_class_counts.float()
        final_class_weights = final_class_weights / final_class_weights.sum()
    else:
        final_class_weights = None

    if final_train_mask.sum() > 0:
        for epoch in range(EPOCHS_GCN):
            final_gcn_model.train()
            optimizer.zero_grad()
            out = final_gcn_model(final_graph_data)
            loss = F.nll_loss(out[final_train_mask], final_graph_data.y[final_train_mask], weight=final_class_weights)
            loss.backward()
            optimizer.step()
        print("\n✅ Final GCN model training complete.")

        # Generate final probabilities for all nodes
        final_gcn_model.eval()
        with torch.no_grad():
            final_all_node_logits = final_gcn_model(final_graph_data)
            final_all_node_probs = final_all_node_logits.exp()[:, 1].cpu().numpy()
        print("Generated final GCN probabilities for all nodes.")
    else:
        print("No user data to train final GCN model.")
        final_gcn_model = None
        final_all_node_probs = np.array([])
else:
    print("\nNo user data. The final models will be the pre-trained public models.")
    final_ae_model = pretrain_ae_model
    final_gcn_model = None
    final_all_node_probs = np.array([])

## 12. Thesis Performance Evaluation 📝
This section gives the models their final exam. It evaluates the performance of the initial **pre-trained model** and the **final fine-tuned hybrid model** on the portion of the live dataset that was held back from training. This generates the specific scores and charts needed to report the results formally.

In [None]:
# @title
def print_thesis_stats(model_name, y_true, y_pred, y_prob=None):
    """Prints evaluation statistics formatted for the thesis."""
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    print(f"\n--- {model_name} Performance ---")
    print(f"| Metric    | Score    |")
    print(f"| :-------- | :------- |")
    print(f"| Accuracy  | {accuracy:.4f}   |")
    print(f"| Precision | {precision:.4f}   |")
    print(f"| Recall    | {recall:.4f}   |")
    print(f"| F1-Score  | {f1:.4f}   |")
    if y_prob is not None and len(np.unique(y_true)) > 1:
        roc_auc = roc_auc_score(y_true, y_prob)
        print(f"| ROC-AUC   | {roc_auc:.4f}   |")

    print(f"\nConfusion Matrix for {model_name}:")
    print(cm)
    # Plotting the confusion matrix
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Predicted Benign', 'Predicted Phishing'],
                yticklabels=['True Benign', 'True Phishing'])
    plt.title(f'Confusion Matrix: {model_name}')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

# --- 1. Evaluate Pre-trained Autoencoder (Baseline) ---
X_public_test_scaled = scaler.transform(X_public_test)
baseline_reconstructed = pretrain_ae_model.predict(X_public_test_scaled, verbose=0)
baseline_errors = np.mean(np.square(X_public_test_scaled - baseline_reconstructed), axis=1)

# Use a simple percentile threshold for the baseline model as per the original thesis approach
X_benign_public_train_scaled = scaler.transform(X_public_train)
baseline_train_reconstructed = pretrain_ae_model.predict(X_benign_public_train_scaled, verbose=0)
baseline_train_errors = np.mean(np.square(X_benign_public_train_scaled - baseline_train_reconstructed), axis=1)
baseline_threshold = np.percentile(baseline_train_errors, 99) # High precision threshold
baseline_preds = (baseline_errors > baseline_threshold).astype(int)

print_thesis_stats("Pre-trained Autoencoder (Baseline)", y_public_test, baseline_preds)

# --- 2. Evaluate Final Fine-Tuned Hybrid Model ---
if len(X_user) > 0:
    final_reconstructed = final_ae_model.predict(X_public_test_scaled, verbose=0)
    final_ae_errors = np.mean(np.square(X_public_test_scaled - final_reconstructed), axis=1)
    final_ae_probs = final_ae_errors / (final_threshold + 1e-9)

    # For public data, GCN score is neutral as these posts are not in the user graph
    final_gcn_probs = np.full_like(final_ae_probs, 0.5)

    final_fused_scores = final_fusion_weight * final_ae_probs + (1 - final_fusion_weight) * final_gcn_probs
    final_hybrid_preds = (final_fused_scores > 0.5).astype(int)

    print_thesis_stats("Final Hybrid Model", y_public_test, final_hybrid_preds, y_prob=final_fused_scores)
else:
    print("\nSkipping final hybrid model evaluation as no user data was available for fine-tuning.")

## 13. Export Artifacts for Application 📦
This cell takes the final, fully trained models and all their necessary settings (like the scaler and the optimal threshold) and **saves them into a collection of files**. These files are the "brains" of the operation. They are the artifacts that the backend server will load to start making live predictions.

In [None]:
# @title
print("\n--- Exporting Final Artifacts ---")

# 1. Save AE model, scaler, and threshold
final_ae_model.save("phishing_autoencoder_model.keras")
print("Saved phishing_autoencoder_model.keras")

with open("scaler.pkl", "wb") as f:
    pickle.dump(scaler, f)
print("Saved scaler.pkl")

with open("autoencoder_threshold.txt", "w") as f:
    f.write(str(final_threshold))
print(f"Saved autoencoder_threshold.txt")

# 2. Save GCN model and post-to-node mapping
if final_gcn_model:
    torch.save(final_gcn_model.state_dict(), "gnn_model.pth")
    print("Saved gnn_model.pth")
    with open("post_node_map.json", "w") as f:
        json.dump(post_id_to_node_idx, f)
    print("Saved post_node_map.json")
    # Save the final GCN probabilities
    if 'final_all_node_probs' in locals() and final_all_node_probs.size > 0:
        np.save("gnn_probs.npy", final_all_node_probs)
        print("Saved gnn_probs.npy")

# 3. Save fusion configuration
fusion_config = {
    'ae_threshold': float(final_threshold),
    'fusion_weight_ae': float(final_fusion_weight)
}
with open("fusion_config.json", "w") as f:
    json.dump(fusion_config, f)
print("Saved fusion_config.json")

print("\n✅ All artifacts exported successfully.")

## 14. Download Artifacts ⬇️
This final cell is a simple convenience script. It **triggers a download prompt in the browser** for every artifact file created in the previous step. This makes it easy to get the finished model files from the training environment to a local computer so they can be uploaded to the server.

In [None]:
# @title
print("Preparing to download artifacts...")
artifacts_to_download = [
    "phishing_autoencoder_model.keras",
    "scaler.pkl",
    "autoencoder_threshold.txt",
    "gnn_model.pth",
    "post_node_map.json",
    "fusion_config.json",
    "gnn_probs.npy"
]

for artifact in artifacts_to_download:
    if os.path.exists(artifact):
        print(f"Downloading {artifact}...")
        files.download(artifact)
    else:
        print(f"Skipping {artifact} as it was not found.")