# URL Security Classifier — Training Pipeline

**XGBoost with 95 hand-crafted URL features + Platt calibration + SHAP explainability**

### Notebook overview
- Trains a URL security classifier using hand-crafted lexical, structural, homograph, and n-gram features.
- Uses a calibrated XGBoost model for probability outputs suitable for risk scoring.
- Produces SHAP-based explanations for both global and per-sample interpretability.

### Pipeline
1. Load and merge datasets, normalize URLs, and generate format-diverse safe URL augmentation
2. Extract 95 URL features (including homograph and n-gram signals)
3. Tune and train XGBoost with Optuna, then apply Platt calibration
4. Run SHAP feature-attribution analysis
5. Run OOD sanity checks on real-world URL formats
6. Export artifacts (`xgb_model.pkl` + `feature_names.json`) for server inference

In [1]:
# Cell 1: Install dependencies
!pip install -q kagglehub xgboost optuna scikit-learn tldextract shap
import warnings; warnings.filterwarnings("ignore")

---
## Section 1: Data Loading, Normalization, Augmentation & Splitting

The original dataset's benign URLs are too simple (e.g., `google.com`, `facebook.com`) and have inconsistent formats (some with schemes, some without, some with `www.`).

**Fixes applied:**
1. **Normalize** all URLs to have a scheme (`https://`), matching real QR code scanner input.
2. **Augment benign class** with ~50K realistic complex safe URLs — but with **format diversity**: randomly varying scheme (`http`/`https`), `www.` prefix (45%/55%), and path presence (bare/simple/complex). This prevents the model from learning spurious formatting rules instead of actual phishing patterns.

In [2]:
import os, glob, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

# ── 1.1 Download Dataset ─────────────────────────────────────
import kagglehub

print("Downloading dataset from Kaggle...")
path = kagglehub.dataset_download("sid321axn/malicious-urls-dataset")
print(f"   Downloaded to: {path}")

csv_files = glob.glob(os.path.join(path, "**", "*.csv"), recursive=True)
csv_file = csv_files[0] if csv_files else os.path.join(path, "malicious_phish.csv")

df = pd.read_csv(csv_file)
print(f"\nLoaded {df.shape[0]:,} URLs  |  Columns: {list(df.columns)}")

# ── 1.2 Explore ──────────────────────────────────────────────
print("\nOriginal class distribution:")
print(df['type'].value_counts().to_string())

# ── 1.3 Convert to Binary ────────────────────────────────────
df['label'] = (df['type'] != 'benign').astype(int)

# ── 1.3b Normalize Original URLs ─────────────────────────────
# The original dataset has inconsistent formats: bare domains ("google.com"),
# without schemes ("www.google.com"), with http://, etc.
# QR codes always produce full URLs.  We normalize everything to have a scheme
# so the models learn CONTENT patterns, not formatting artifacts.

from urllib.parse import urlparse

def normalize_url(url):
    """Ensure every URL has a scheme. Preserve existing schemes."""
    url = str(url).strip()
    if not url:
        return url
    # Already has a scheme
    if '://' in url[:12]:
        return url
    # Add https:// for bare domains / www. prefixes
    return 'https://' + url

df['url'] = df['url'].apply(normalize_url)
print(f"\nNormalized all URLs to have schemes")
print(f"   Sample: {df['url'].iloc[0]}")

# ── 1.4 Augment Benign URLs with Complex Realistic Examples ──
# The original benign set is mostly simple domains (google.com, facebook.com).
# Real QR codes point to complex URLs. We augment with synthetic safe URLs.
#
# CRITICAL: Augmented URLs must be FORMAT-DIVERSE to prevent the model from
# learning spurious formatting rules instead of phishing patterns.
# Each URL randomly varies: scheme, www prefix, path presence, and structure.

SAFE_DOMAINS = [
    "google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org",
    "twitter.com", "instagram.com", "linkedin.com", "reddit.com", "netflix.com",
    "microsoft.com", "apple.com", "github.com", "stackoverflow.com", "medium.com",
    "spotify.com", "twitch.tv", "ebay.com", "cnn.com", "bbc.com",
    "nytimes.com", "theguardian.com", "reuters.com", "walmart.com", "target.com",
    "bestbuy.com", "homedepot.com", "lowes.com", "costco.com", "macys.com",
    "airbnb.com", "booking.com", "expedia.com", "tripadvisor.com", "yelp.com",
    "zillow.com", "realtor.com", "indeed.com", "glassdoor.com", "monster.com",
    "coursera.org", "udemy.com", "edx.org", "khanacademy.org", "duolingo.com",
    "kaggle.com", "notion.so", "figma.com", "canva.com", "trello.com",
    "slack.com", "zoom.us", "dropbox.com", "drive.google.com", "docs.google.com",
    "outlook.com", "mail.google.com", "icloud.com", "proton.me", "adobe.com",
    "salesforce.com", "hubspot.com", "mailchimp.com", "stripe.com", "shopify.com",
    "squarespace.com", "wix.com", "wordpress.com", "blogger.com", "tumblr.com",
    "pinterest.com", "tiktok.com", "snapchat.com", "discord.com", "telegram.org",
    "whatsapp.com", "signal.org", "paypal.com", "venmo.com", "robinhood.com",
    "coinbase.com", "binance.com", "chase.com", "bankofamerica.com", "wellsfargo.com",
    "capitalone.com", "fidelity.com", "vanguard.com", "schwab.com", "etrade.com",
    "hulu.com", "disneyplus.com", "hbomax.com", "peacocktv.com", "crunchyroll.com",
    "imdb.com", "rottentomatoes.com", "goodreads.com", "archive.org", "quora.com",
    "pubmed.ncbi.nlm.nih.gov", "scholar.google.com", "researchgate.net", "jstor.org",
    "mit.edu", "stanford.edu", "harvard.edu", "ox.ac.uk", "cam.ac.uk",
]

SAFE_PATH_TEMPLATES = [
    "",  # ← bare domain (no path) — CRITICAL to include!
    "/", # ← root path
    "/home", "/about", "/contact", "/help", "/faq", "/terms", "/privacy",
    "/settings", "/profile", "/dashboard", "/account/settings",
    "/products/{id}", "/items/{id}/details", "/search?q={word}&page={n}",
    "/blog/{year}/{month}/{slug}", "/article/{id}", "/news/{slug}",
    "/user/{username}/posts", "/user/{username}/profile",
    "/docs/getting-started", "/docs/api/v2/reference", "/docs/{section}/{page}",
    "/en/help/article/{id}", "/support/ticket/{id}",
    "/category/{cat}/subcategory/{sub}", "/shop/{cat}?sort=price&order=asc",
    "/code/{username}/{project}/edit", "/code/{username}/{project}/blob/main/src/{file}.py",
    "/dp/{id}?ref=sr_1_{n}&tag={word}", "/gp/product/{id}/ref=ox_sc_act_title_1",
    "/watch?v={id}&list={id2}&index={n}", "/playlist?list={id}",
    "/r/{subreddit}/comments/{id}/{slug}", "/r/{subreddit}/wiki/{page}",
    "/status/{id}", "/i/web/status/{id}",
    "/maps/place/{place}/@{lat},{lon},{zoom}z",
    "/flights/results?from={code}&to={code2}&date={date}",
    "/checkout/cart?item={id}&qty={n}", "/order/confirmation/{id}",
    "/index.php?page={word}&ActiveViewID=tab_{word}",
    "/app/{hex32}", "/document/d/{hex32}/edit",
    "/spreadsheets/d/{hex32}/edit#gid=0",
    "/forms/d/{hex32}/viewform",
    "/meeting/join?meetingId={hex32}",
    "/{lang}/download/{product}/{version}",
    "/releases/tag/v{version}", "/issues/{n}", "/pull/{n}/files",
    "/datasets/{username}/{dataset}?resource=download",
    "/competitions/{slug}/leaderboard", "/notebooks/{username}/{slug}",
    "/events/{year}/{slug}/register",
    "/courses/{slug}/learn/lecture/{n}",
    "/recipe/{id}/{slug}",
    "/book/{isbn}", "/author/{slug}",
    "/jobs/{id}/{slug}?utm_source=search&utm_medium=web",
    "/compare/{productA}-vs-{productB}",
    "/api/v3/users/{username}/repos?per_page=100&sort=updated",
]

def _gen_val(key):
    """Generate a realistic value for a template placeholder."""
    if key == "id":    return ''.join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=random.randint(6, 12)))
    if key == "id2":   return ''.join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-", k=random.randint(20, 34)))
    if key == "hex32": return ''.join(random.choices("0123456789abcdef", k=32))
    if key == "n":     return str(random.randint(1, 500))
    if key in ("word", "slug", "cat", "sub", "section", "page", "subreddit",
               "place", "product", "productA", "productB", "dataset", "lang"):
        words = ["intro", "security", "machine-learning", "recipes", "travel",
                 "analysis", "weather", "python", "photography", "health",
                 "science", "technology", "sports", "music", "gaming",
                 "finance", "education", "cooking", "fashion", "design"]
        return random.choice(words)
    if key == "username": return ''.join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789_", k=random.randint(5, 15)))
    if key == "file":   return random.choice(["main", "utils", "config", "app", "index", "helpers"])
    if key == "year":   return str(random.randint(2020, 2026))
    if key == "month":  return f"{random.randint(1,12):02d}"
    if key == "date":   return f"2025-{random.randint(1,12):02d}-{random.randint(1,28):02d}"
    if key == "lat":    return f"{random.uniform(-90, 90):.4f}"
    if key == "lon":    return f"{random.uniform(-180, 180):.4f}"
    if key == "zoom":   return str(random.randint(5, 18))
    if key in ("code", "code2"): return ''.join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=3))
    if key == "isbn":   return ''.join(random.choices("0123456789", k=13))
    if key == "version": return f"{random.randint(1,5)}.{random.randint(0,20)}.{random.randint(0,10)}"
    return key

def generate_safe_urls(n=50000, seed=42):
    """
    Generate n realistic complex safe URLs with FORMAT DIVERSITY.

    Each URL randomly varies across four orthogonal dimensions:
      1. Scheme: https:// (85%) vs http:// (15%)
      2. www prefix: with www. (45%) vs without (55%)
      3. Path: complex path (60%), simple path (25%), bare domain (15%)
      4. Query params & fragments (same as before)

    This prevents the model from learning spurious formatting rules
    (e.g., "www. = malicious" or "no path = malicious").
    """
    random.seed(seed)
    import re as re_mod
    urls = []

    for _ in range(n):
        domain = random.choice(SAFE_DOMAINS)

        # ── Dimension 1: Scheme ──
        scheme = "https" if random.random() < 0.85 else "http"

        # ── Dimension 2: www. prefix ──
        # Skip www. for domains that already have a subdomain (e.g., drive.google.com)
        has_subdomain = domain.count('.') >= 2
        if has_subdomain:
            host = domain
        elif random.random() < 0.45:
            host = f"www.{domain}"
        else:
            host = domain

        # ── Dimension 3: Path ──
        r = random.random()
        if r < 0.15:
            # 15% — bare domain (no path at all, or just /)
            path = random.choice(["", "/"])
        elif r < 0.40:
            # 25% — simple short paths
            simple_paths = ["/home", "/about", "/contact", "/help", "/faq",
                          "/terms", "/privacy", "/settings", "/profile",
                          "/dashboard", "/", ""]
            path = random.choice(simple_paths)
        else:
            # 60% — complex path from templates
            template = random.choice(SAFE_PATH_TEMPLATES)
            path = re_mod.sub(r'\{(\w+)\}', lambda m: _gen_val(m.group(1)), template)

        url = f"{scheme}://{host}{path}"

        # Randomly add extra query params (25% chance, only if no query yet)
        if random.random() < 0.25 and "?" not in url:
            params = ["utm_source=qr", "ref=homepage", "lang=en", "page=1",
                      "sort=newest", "filter=all", "view=grid", "tab=overview"]
            url += "?" + "&".join(random.sample(params, random.randint(1, 3)))

        # Randomly add fragment (15% chance)
        if random.random() < 0.15:
            fragments = ["top", "section-2", "overview", "comments", "reviews",
                         "pricing", "features", f"id-{random.randint(1,999)}"]
            url += "#" + random.choice(fragments)

        urls.append(url)
    return urls

print("Generating synthetic safe URLs (format-diverse)...")
synthetic_safe = generate_safe_urls(50000)
synthetic_df = pd.DataFrame({"url": synthetic_safe, "type": "benign", "label": 0})

# ── Verify format diversity ──
n_www = sum(1 for u in synthetic_safe if "://www." in u)
n_https = sum(1 for u in synthetic_safe if u.startswith("https://"))
n_bare = sum(1 for u in synthetic_safe if u.rstrip('/').count('/') <= 2)
print(f"   Format diversity check:")
print(f"     www. prefix:  {n_www:>6,} ({n_www/len(synthetic_safe):.0%})")
print(f"     https scheme: {n_https:>6,} ({n_https/len(synthetic_safe):.0%})")
print(f"     bare/short:   {n_bare:>6,} ({n_bare/len(synthetic_safe):.0%})")
print(f"   Samples:")
random.seed(99)
for u in random.sample(synthetic_safe, 8):
    print(f"     {u}")

# Merge
df = pd.concat([df, synthetic_df], ignore_index=True)
print(f"\n   Added {len(synthetic_df):,} synthetic safe URLs")

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
df['type'].value_counts().plot.bar(ax=axes[0], color=['#34C759','#FF3B30','#FF9500','#007AFF'])
axes[0].set_title('Original 4-Class + Augmented Distribution'); axes[0].set_ylabel('Count')
df['label'].value_counts().rename({0:'Safe', 1:'Malicious'}).plot.bar(ax=axes[1], color=['#34C759','#FF3B30'])
axes[1].set_title('Binary (Safe vs Malicious)'); axes[1].set_ylabel('Count')
plt.tight_layout(); plt.show()

# ── 1.5 Clean ────────────────────────────────────────────────
before = len(df)
df = df.dropna(subset=['url','label']).drop_duplicates(subset=['url'])
print(f"\nRemoved {before - len(df):,} rows -> {len(df):,} URLs remaining")

# ── 1.6 Split: 70 / 15 / 15 ─────────────────────────────────
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df['label'])
val_df, test_df   = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label'])

print(f"\nSplit:")
print(f"   Train:      {len(train_df):>8,}")
print(f"   Validation: {len(val_df):>8,}")
print(f"   Test:       {len(test_df):>8,}")
print(f"\n   Train label distribution:")
print(f"     Safe:      {(train_df['label']==0).sum():>8,}")
print(f"     Malicious: {(train_df['label']==1).sum():>8,}")

---
## Section 2: XGBoost with Feature Engineering (95 features)

Hand-crafted **95 features** covering:
- **Length** (11) — URL, domain, path, query, subdomain, TLD lengths + averages
- **Counts** (24) — dots, hyphens, digits, special chars, subdomains, path depth
- **Ratios** (9) — digit/letter/special proportions, domain-to-URL ratio
- **Entropy** (5) — Shannon entropy of domain, path, query, subdomain
- **Boolean** (11) — IP address, port, HTTPS, hex encoding, punycode, @ symbol
- **TLD** (6) — suspicious/trusted TLD classification
- **Character** (4) — consecutive runs, vowel ratio
- **Keywords** (8) — phishing, malware, brand impersonation
- **Structure** (9) — deep paths, embedded URLs, base64, shorteners
- **Homograph** (5) — mixed scripts, confusable chars, Levenshtein brand distance, char substitution
- **N-gram** (3) — bigram frequency scores for domain/subdomain/path randomness

In [3]:
import ipaddress
import re, math
from urllib.parse import urlparse, parse_qs, unquote
from collections import Counter
from functools import lru_cache
import tldextract

# ── Keyword / Pattern Dictionaries ────────────────────────────

SUSPICIOUS_TLDS = frozenset({
    "tk","ml","ga","cf","gq","pw","top","xyz","club","work","click","link",
    "surf","buzz","fun","monster","quest","cam","icu","cc","ws","info","biz",
    "su","ru","cn","online","site","website","space","tech","store","stream",
    "download","win","review","racing","cricket","science","party","gdn",
    "loan","men","country","kim","date","faith","accountant","bid","trade","webcam",
})
TRUSTED_TLDS = frozenset({
    "edu","gov","mil","int","ac.uk","gov.uk","edu.au","gov.au",
})
BRAND_KEYWORDS = frozenset({
    "paypal","apple","google","microsoft","amazon","facebook","netflix",
    "instagram","whatsapp","twitter","linkedin","ebay","dropbox","icloud",
    "outlook","office365","yahoo","chase","wellsfargo","bankofamerica",
    "citibank","capitalone","steam","spotify","adobe","coinbase","binance","metamask",
})
PHISHING_KEYWORDS = frozenset({
    "login","signin","sign-in","logon","password","verify","verification",
    "confirm","update","secure","security","account","banking","wallet",
    "suspend","suspended","urgent","expire","unlock","restore","recover",
    "validate","authenticate","webscr","customer","support","helpdesk",
})
MALWARE_KEYWORDS = frozenset({
    "download","free","crack","keygen","patch","serial","warez","torrent",
    "nulled","hack","cheat","generator","install","setup","update","flash",
    "player","codec","driver",
})
URL_SHORTENERS = frozenset({
    "bit.ly","goo.gl","tinyurl.com","ow.ly","t.co","is.gd",
    "buff.ly","adf.ly","j.mp","rb.gy","cutt.ly","tiny.cc",
})
DANGEROUS_EXTS = frozenset({
    ".exe",".dll",".bat",".cmd",".msi",".scr",".pif",".vbs",
    ".js",".jar",".apk",".dmg",".zip",".rar",".7z",".iso",
})

# Common bigrams for domain randomness scoring.
# Includes standard English prose bigrams PLUS patterns common in
# legitimate domain names (e.g., "go", "oo", "ok", "bo", "ap", "eb").
# MUST match the server's url_features.py exactly.
_COMMON_BIGRAMS = frozenset({
    # Core English prose bigrams
    "th","he","in","er","an","re","on","at","en","nd",
    "ti","es","or","te","of","ed","is","it","al","ar",
    "st","to","nt","ng","se","ha","as","ou","io","le",
    "ve","co","me","de","hi","ri","ro","ic","ne","ea",
    "ra","ce","li","ch","ll","be","ma","si","om","ur",
    # Domain-typical bigrams (cover common brand names & tech words)
    "go", "oo", "og", "gl", "ok", "bo", "fa", "ac", "eb",
    "am", "az", "ap", "pl", "pp", "tw", "et", "fl", "ix",
    "pa", "sc", "ca", "op", "ub", "dr", "sp", "ot", "if",
    "so", "ft", "ab", "ad", "ob", "do", "ag", "gi", "ig",
    "po", "pi", "cr", "ct", "di", "mi", "mo", "no", "ov",
    "sh", "sk", "sl", "sn", "sw", "ta", "tr", "tu", "up",
    "ut", "wa", "wi", "wo", "zo",
})

# Homograph / confusable character mappings
# MUST match the server's homograph_detector.py
CONFUSABLES = {
    # Cyrillic -> Latin
    '\u0430': 'a', '\u0435': 'e', '\u043e': 'o', '\u0440': 'p',
    '\u0441': 'c', '\u0443': 'y', '\u0445': 'x', '\u044a': 'b',
    '\u0456': 'i', '\u0458': 'j', '\u04bb': 'h', '\u0501': 'd',
    # Greek -> Latin
    '\u03b1': 'a', '\u03b5': 'e', '\u03bf': 'o', '\u03c1': 'p',
    '\u03ba': 'k', '\u03bd': 'v', '\u03c4': 't', '\u03b9': 'i',
    # Common number/letter substitutions
    '0': 'o', '1': 'l', '!': 'i', '$': 's',
    '@': 'a', '3': 'e', '5': 's', '7': 't', '8': 'b',
}
BRAND_DOMAINS = {
    "google": "google.com", "facebook": "facebook.com", "amazon": "amazon.com",
    "apple": "apple.com", "microsoft": "microsoft.com", "paypal": "paypal.com",
    "netflix": "netflix.com", "instagram": "instagram.com", "twitter": "twitter.com",
    "linkedin": "linkedin.com", "ebay": "ebay.com", "dropbox": "dropbox.com",
    "spotify": "spotify.com", "adobe": "adobe.com", "yahoo": "yahoo.com",
    "chase": "chase.com", "wellsfargo": "wellsfargo.com", "coinbase": "coinbase.com",
    "binance": "binance.com", "steam": "steampowered.com", "outlook": "outlook.com",
    "icloud": "icloud.com", "whatsapp": "whatsapp.com", "capitalone": "capitalone.com",
    "bankofamerica": "bankofamerica.com", "citibank": "citibank.com",
    "metamask": "metamask.io", "slack": "slack.com", "zoom": "zoom.us",
    "github": "github.com",
}

# ── Helper Functions ──────────────────────────────────────────

def calc_entropy(text):
    if not text: return 0.0
    freq = Counter(text.lower())
    n = len(text)
    return -sum((c/n) * math.log2(c/n) for c in freq.values() if c > 0)

def max_run(text, cond):
    best = cur = 0
    for ch in text:
        if cond(ch): cur += 1; best = max(best, cur)
        else: cur = 0
    return best

def bigram_score(text):
    """Fraction of bigrams in common English bigrams."""
    text = text.lower()
    letters = "".join(c for c in text if c.isalpha())
    if len(letters) < 2: return 0.0
    bigrams = [letters[i:i+2] for i in range(len(letters) - 1)]
    if not bigrams: return 0.0
    return sum(1 for b in bigrams if b in _COMMON_BIGRAMS) / len(bigrams)

@lru_cache(maxsize=4096)
def levenshtein_distance(s1, s2):
    if len(s1) < len(s2): return levenshtein_distance(s2, s1)
    if len(s2) == 0: return len(s1)
    prev = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        curr = [i + 1]
        for j, c2 in enumerate(s2):
            curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1 != c2)))
        prev = curr
    return prev[-1]

def normalize_confusables(text):
    return "".join(CONFUSABLES.get(c, c) for c in text)

def has_mixed_scripts(text):
    import unicodedata
    scripts = set()
    for ch in text:
        if ch in ".-_0123456789":
            continue
        cat = unicodedata.category(ch)
        if cat.startswith("L"):
            name = unicodedata.name(ch, "").upper()
            if "CYRILLIC" in name:
                scripts.add("cyrillic")
            elif "GREEK" in name:
                scripts.add("greek")
            elif "LATIN" in name or ch.isascii():
                scripts.add("latin")
            else:
                scripts.add("other")
    return int(len(scripts) > 1)

def count_confusable_chars(text):
    return sum(1 for c in text.lower() if c in CONFUSABLES and not c.isascii())

def min_brand_distance(domain):
    """Minimum Levenshtein distance from domain to any known brand."""
    clean = domain.lower().lstrip("www.")
    normalized = normalize_confusables(clean)

    # Use tldextract for accurate domain name extraction
    # (handles multi-part TLDs like .co.uk correctly)
    ext = tldextract.extract(clean)
    domain_name = ext.domain or clean
    norm_domain_name = normalize_confusables(domain_name)

    min_dist = 999
    for brand_key, brand_domain in BRAND_DOMAINS.items():
        d1 = levenshtein_distance(domain_name, brand_key)
        d2 = levenshtein_distance(norm_domain_name, brand_key)
        d3 = levenshtein_distance(clean, brand_domain)
        d4 = levenshtein_distance(normalized, brand_domain)
        dist = min(d1, d2, d3, d4)
        min_dist = min(min_dist, dist)
    return min_dist

def detect_char_substitution(domain):
    """Detect leet-speak / character substitution targeting a brand."""
    # Strip TLD and www using tldextract
    ext = tldextract.extract(domain.lower())
    name = ext.domain or domain.lower()

    # Check if normalizing confusables changes the string AND matches a brand
    normalized = normalize_confusables(name)
    if normalized != name:
        for brand_key in BRAND_DOMAINS:
            if brand_key in normalized and brand_key not in name:
                return 1
    return 0

def extract_homograph_features(domain):
    min_dist = min_brand_distance(domain)
    clean_domain = domain.lower().rstrip(".")
    # Exempt official brand domains across all TLDs using tldextract:
    # mail.google.co.il -> ext.domain='google' -> in BRAND_DOMAINS -> official
    ext = tldextract.extract(clean_domain)
    is_official_domain = ext.domain in BRAND_DOMAINS
    normalized = normalize_confusables(domain.lower())
    is_exact_match = (
        any(b in normalized for b in BRAND_DOMAINS)
        and not is_official_domain
    )
    return {
        "homograph_has_mixed_scripts": has_mixed_scripts(domain),
        "homograph_confusable_chars": count_confusable_chars(domain),
        "homograph_min_brand_distance": min_dist,
        "homograph_has_char_sub": detect_char_substitution(domain),
        "homograph_is_exact_brand": int(is_exact_match and min_dist <= 2),
    }

# ── Main Feature Extractor (95 features) ─────────────────────

def extract_features(url):
    """
    Extract 95 features from a single URL.
    MUST match the server's url_features.py exactly.
    """
    f = {}
    url = str(url).strip()

    try:
        parsed = urlparse(url if "://" in url else f"http://{url}")
    except Exception:
        return {k: 0 for k in FEATURE_NAMES}

    scheme   = parsed.scheme.lower()
    path     = parsed.path
    query    = parsed.query
    fragment = parsed.fragment

    # Use parsed.hostname to correctly handle userinfo URLs
    # (e.g. http://user:pass@example.com) where netloc.split(":")[0]
    # would incorrectly return "user" instead of "example.com".
    domain   = (parsed.hostname or "").lower()
    try:
        has_port = parsed.port is not None
    except ValueError:
        # Malformed port (non-numeric) — treat as no valid port
        has_port = False

    parts  = domain.split(".")
    path_parts = [p for p in path.split("/") if p]

    # Use tldextract for accurate subdomain / registered-domain / TLD
    # parsing (handles multi-part TLDs like .co.uk, .com.au correctly)
    ext = tldextract.extract(domain)
    subdomain = ext.subdomain
    tld = ext.suffix if ext.suffix else (parts[-1] if parts else "")

    url_lower  = url.lower()
    path_lower = path.lower()

    # ═══ LENGTH ═══
    f['url_length']         = len(url)
    f['domain_length']      = len(domain)
    f['path_length']        = len(path)
    f['query_length']       = len(query)
    f['fragment_length']    = len(fragment)
    f['subdomain_length']   = len(subdomain)
    f['tld_length']         = len(tld)
    f['longest_domain_part']= max((len(p) for p in parts), default=0)
    f['avg_domain_part_len']= float(np.mean([len(p) for p in parts])) if parts else 0.0
    f['longest_path_part']  = max((len(p) for p in path_parts), default=0)
    f['avg_path_part_len']  = float(np.mean([len(p) for p in path_parts])) if path_parts else 0.0

    # ═══ COUNTS ═══
    for ch, name in [(".",  "dot"),  ("-", "hyphen"), ("_", "underscore"),
                     ("/",  "slash"),("?", "question"),("=","equals"),
                     ("&",  "amp"),  ("@", "at"),     ("%", "percent"),
                     ("~",  "tilde"),("#", "hash"),   (":", "colon"),
                     (";",  "semicolon")]:
        f[f'{name}_count'] = url.count(ch)

    f['domain_dot_count']    = domain.count(".")
    f['domain_hyphen_count'] = domain.count("-")
    f['domain_digit_count']  = sum(c.isdigit() for c in domain)
    f['subdomain_count']     = subdomain.count(".") + 1 if subdomain else 0
    f['path_depth']          = len(path_parts)
    f['digit_count']         = sum(c.isdigit() for c in url)
    f['letter_count']        = sum(c.isalpha() for c in url)
    f['uppercase_count']     = sum(c.isupper() for c in url)
    f['special_char_count']  = sum(not c.isalnum() for c in url)

    try:
        qp = parse_qs(query)
        f['query_param_count']     = len(qp)
        f['query_value_total_len'] = sum(len(v) for vals in qp.values() for v in vals)
    except Exception:
        f['query_param_count'] = 0; f['query_value_total_len'] = 0

    # ═══ RATIOS ═══
    ul = max(len(url), 1); dl = max(len(domain), 1)
    f['digit_ratio']         = f['digit_count'] / ul
    f['letter_ratio']        = f['letter_count'] / ul
    f['special_char_ratio']  = f['special_char_count'] / ul
    f['uppercase_ratio']     = f['uppercase_count'] / max(f['letter_count'], 1)
    f['domain_digit_ratio']  = f['domain_digit_count'] / dl
    f['domain_hyphen_ratio'] = f['domain_hyphen_count'] / dl
    f['path_url_ratio']      = f['path_length'] / ul
    f['query_url_ratio']     = f['query_length'] / ul
    f['domain_url_ratio']    = f['domain_length'] / ul

    # ═══ ENTROPY ═══
    f['url_entropy']       = calc_entropy(url)
    f['domain_entropy']    = calc_entropy(domain.replace(".", ""))
    f['path_entropy']      = calc_entropy(path)
    f['query_entropy']     = calc_entropy(query)
    f['subdomain_entropy'] = calc_entropy(subdomain)

    # ═══ BOOLEAN ═══
    f['is_https']                = int(scheme == "https")
    f['is_http']                 = int(scheme == "http")
    f['has_www']                 = int(domain.startswith("www."))
    f['has_port']                = int(has_port)
    f['has_at_symbol']           = int("@" in url)
    f['has_double_slash_in_path']= int("//" in path)
    f['has_hex_encoding']        = int(unquote(url) != url)
    f['has_punycode']            = int("xn--" in domain)
    try:
        ipaddress.IPv4Address(domain)
        f['has_ip_address'] = 1
    except ValueError:
        f['has_ip_address'] = 0
    f['has_hex_ip']              = int(bool(re.match(r"^(0x[0-9a-f]+\.){3}0x[0-9a-f]+$", domain)))
    f['has_ip_like']             = int(domain.replace(".", "").isdigit() and len(domain) > 6)

    # ═══ TLD ═══
    f['is_suspicious_tld'] = int(tld in SUSPICIOUS_TLDS)
    f['is_trusted_tld']    = int(tld in TRUSTED_TLDS)
    f['is_com']            = int(tld == "com")
    f['is_org']            = int(tld == "org")
    f['is_net']            = int(tld == "net")
    f['is_country_tld']    = int(len(tld) == 2 and tld.isalpha())

    # ═══ CHARACTER DISTRIBUTION ═══
    f['max_consec_digits']  = max_run(url, str.isdigit)
    f['max_consec_letters'] = max_run(url, str.isalpha)
    f['max_consec_special'] = max_run(url, lambda c: not c.isalnum())
    vowels = set("aeiou")
    dom_letters = [c for c in domain if c.isalpha()]
    f['domain_vowel_ratio'] = sum(c in vowels for c in dom_letters) / max(len(dom_letters), 1)

    # ═══ KEYWORDS ═══
    f['brand_keyword_count']    = sum(1 for b in BRAND_KEYWORDS if b in url_lower)
    f['has_brand_in_subdomain'] = int(any(b in subdomain.lower() for b in BRAND_KEYWORDS))
    f['phishing_keyword_count'] = sum(1 for k in PHISHING_KEYWORDS if k in url_lower)
    f['malware_keyword_count']  = sum(1 for k in MALWARE_KEYWORDS if k in url_lower)
    f['is_url_shortener']       = int(ext.registered_domain in URL_SHORTENERS)
    f['has_dangerous_ext']      = int(any(path_lower.endswith(e) for e in DANGEROUS_EXTS))
    f['has_exe']                = int(path_lower.endswith(".exe"))
    f['has_php']                = int(".php" in path_lower)

    # ═══ STRUCTURAL PATTERNS ═══
    f['has_double_letters']  = int(bool(re.search(r"(.)\1", domain)))
    f['has_long_subdomain']  = int(len(subdomain) > 20)
    f['has_deep_path']       = int(len(path_parts) > 5)
    f['has_embedded_url']    = int("http" in path_lower or "www" in path_lower)
    f['has_data_uri']        = int(url_lower.startswith("data:"))
    f['has_javascript']      = int("javascript:" in url_lower)
    f['has_base64']          = int(bool(re.search(r"[A-Za-z0-9+/]{20,}={0,2}", url)))
    f['brand_in_domain']     = int(any(b in domain for b in BRAND_KEYWORDS))
    # Use ext.domain (registered name without TLD) to correctly
    # identify official brand domains across all TLDs.
    # e.g. mail.google.co.il -> ext.domain='google' -> not flagged
    f['brand_not_registered']= int(
        f['brand_in_domain'] == 1
        and ext.domain not in BRAND_KEYWORDS
    )
    )

    # ═══ HOMOGRAPH / TYPOSQUATTING ═══
    homo = extract_homograph_features(domain)
    f['homograph_has_mixed_scripts']  = homo['homograph_has_mixed_scripts']
    f['homograph_confusable_chars']   = homo['homograph_confusable_chars']
    f['homograph_min_brand_distance'] = homo['homograph_min_brand_distance']
    f['homograph_has_char_sub']       = homo['homograph_has_char_sub']
    f['homograph_is_exact_brand']     = homo['homograph_is_exact_brand']

    # ═══ N-GRAM FEATURES ═══
    domain_name_only = ext.domain or domain
    f['domain_bigram_score']    = bigram_score(domain_name_only)
    f['subdomain_bigram_score'] = bigram_score(subdomain) if subdomain else 0.0
    f['path_bigram_score']      = bigram_score("".join(path_parts)) if path_parts else 0.0

    return f

# Build canonical feature name list
FEATURE_NAMES = list(extract_features("https://www.example.com/path?q=1").keys())
print(f"Feature engineering ready — {len(FEATURE_NAMES)} features per URL")

In [4]:
from tqdm import tqdm

# ── 2.1 Extract Features from All URLs ───────────────────────
print("Extracting 95 features per URL (this takes a few minutes)...\n")

def extract_batch(urls, desc="Extracting"):
    return pd.DataFrame(
        [extract_features(str(u)) for u in tqdm(urls, desc=desc)],
        columns=FEATURE_NAMES
    ).fillna(0).astype(np.float32)

X_train_feat = extract_batch(train_df['url'].tolist(), "  Train")
X_val_feat   = extract_batch(val_df['url'].tolist(),   "  Val")
X_test_feat  = extract_batch(test_df['url'].tolist(),  "  Test")

y_train = train_df['label'].values
y_val   = val_df['label'].values
y_test  = test_df['label'].values

# Class imbalance ratio
scale_pos = float((y_train == 0).sum() / (y_train == 1).sum())

print(f"\nFeature matrices ready:")
print(f"   Train: {X_train_feat.shape}")
print(f"   Val:   {X_val_feat.shape}")
print(f"   Test:  {X_test_feat.shape}")
print(f"   Class ratio (safe/malicious): {scale_pos:.2f}")

In [5]:
import xgboost as xgb
import optuna
from sklearn.metrics import f1_score
from sklearn.calibration import CalibratedClassifierCV

# ── 2.2 Hyperparameter Tuning with Optuna (50 trials) ────────
print("Running Optuna hyperparameter search...\n")

# Detect GPU via xgboost — no torch dependency needed
try:
    _test = xgb.XGBClassifier(device='cuda', n_estimators=1)
    _test.fit(np.zeros((2, 1)), np.array([0, 1]))
    USE_GPU = True
except Exception:
    USE_GPU = False
print(f"   GPU available: {USE_GPU}")

def objective(trial):
    params = {
        'objective': 'binary:logistic',
        'eval_metric': 'logloss',
        'device': 'cuda' if USE_GPU else 'cpu',
        'tree_method': 'hist',
        'max_depth': trial.suggest_int('max_depth', 4, 10),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'n_estimators': 1500,
        'early_stopping_rounds': 50,
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-8, 10.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 1e-8, 10.0, log=True),
        'gamma': trial.suggest_float('gamma', 0, 5.0),
        'scale_pos_weight': scale_pos,
        'random_state': 42,
        'verbosity': 0,
    }
    m = xgb.XGBClassifier(**params)
    m.fit(X_train_feat.values, y_train,
          eval_set=[(X_val_feat.values, y_val)], verbose=False)
    return f1_score(y_val, m.predict(X_val_feat.values))

optuna.logging.set_verbosity(optuna.logging.WARNING)
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50, show_progress_bar=True)

print(f"\nBest trial F1: {study.best_value:.4f}")
print(f"   Best params:")
for k, v in study.best_params.items():
    print(f"     {k}: {v}")

# ── 2.3 Train Final XGBoost with Best Params ─────────────────
bp = study.best_params.copy()
bp.update({
    'objective': 'binary:logistic',
    'eval_metric': 'logloss',
    'device': 'cuda' if USE_GPU else 'cpu',
    'tree_method': 'hist',
    'n_estimators': 3000,
    'early_stopping_rounds': 100,
    'scale_pos_weight': scale_pos,
    'random_state': 42,
    'verbosity': 1,
})

xgb_model = xgb.XGBClassifier(**bp)
xgb_model.fit(X_train_feat.values, y_train,
              eval_set=[(X_val_feat.values, y_val)], verbose=True)

print(f"\n   Best iteration: {xgb_model.best_iteration}")

# ── 2.4 Calibrate Probabilities (Platt Scaling) ──────────────
xgb_calibrated = CalibratedClassifierCV(xgb_model, method='sigmoid', cv='prefit')
xgb_calibrated.fit(X_val_feat.values, y_val)

print("XGBoost trained & calibrated!")

In [6]:
from sklearn.metrics import (
    classification_report, confusion_matrix,
    roc_curve, auc, roc_auc_score,
    precision_recall_curve, average_precision_score, accuracy_score
)

# ── 2.5 Evaluate XGBoost on Test Set ─────────────────────────
xgb_probs = xgb_calibrated.predict_proba(X_test_feat.values)[:, 1]
xgb_preds = (xgb_probs >= 0.5).astype(int)

print("=" * 55)
print("XGBoost — Test Set Results")
print("=" * 55)
print(classification_report(y_test, xgb_preds, target_names=['Safe', 'Malicious']))
print(f"ROC-AUC: {roc_auc_score(y_test, xgb_probs):.4f}")

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Confusion Matrix
sns.heatmap(confusion_matrix(y_test, xgb_preds), annot=True, fmt='d', cmap='Blues',
            xticklabels=['Safe','Malicious'], yticklabels=['Safe','Malicious'], ax=axes[0])
axes[0].set_title('XGBoost — Confusion Matrix')
axes[0].set_xlabel('Predicted'); axes[0].set_ylabel('Actual')

# ROC Curve
fpr, tpr, _ = roc_curve(y_test, xgb_probs)
axes[1].plot(fpr, tpr, 'b-', linewidth=2, label=f'AUC = {auc(fpr, tpr):.4f}')
axes[1].plot([0, 1], [0, 1], 'k--', alpha=0.5)
axes[1].set_title('ROC Curve'); axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate'); axes[1].legend()

# Feature Importance (Top 20)
imp = xgb_model.feature_importances_
top20 = np.argsort(imp)[-20:]
axes[2].barh(range(20), imp[top20], color='steelblue')
axes[2].set_yticks(range(20))
axes[2].set_yticklabels([FEATURE_NAMES[i] for i in top20])
axes[2].set_title('Top 20 Feature Importances')

plt.tight_layout(); plt.show()

---
## Section 3: SHAP Explainability Analysis

Use **SHAP (SHapley Additive exPlanations)** to understand which URL features
drive XGBoost predictions. TreeExplainer provides exact Shapley values in
polynomial time for tree-based models (Lundberg & Lee 2017).

Key visualizations:
- **Beeswarm plot** — global feature importance + direction of effect
- **Individual explanations** — per-URL feature attributions

In [None]:
# ══════════════════════════════════════════════════════════════
# 3.1 SHAP Explainability — XGBoost Feature Attribution
# ══════════════════════════════════════════════════════════════
import shap

# Unwrap calibrated model to get the raw XGBClassifier
raw_xgb = xgb_calibrated.calibrated_classifiers_[0].estimator

# Create TreeExplainer (exact Shapley values for tree models)
explainer = shap.TreeExplainer(raw_xgb)

# Compute SHAP values on a sample of the test set (200 samples for speed)
sample_idx = np.random.RandomState(42).choice(len(X_test_feat), size=min(200, len(X_test_feat)), replace=False)
X_sample = X_test_feat.iloc[sample_idx]
shap_values = explainer.shap_values(X_sample.values)

# If shap_values is a list (binary classification), take class 1
if isinstance(shap_values, list):
    shap_values = shap_values[1]

# ── Beeswarm Plot (Global Feature Importance) ────────────────
print("Top Features by Mean |SHAP Value| (Global Importance)")
print("=" * 55)
mean_abs = np.abs(shap_values).mean(axis=0)
top_idx = np.argsort(mean_abs)[::-1][:15]
for i, idx in enumerate(top_idx, 1):
    print(f"  {i:2d}. {FEATURE_NAMES[idx]:<35} |SHAP| = {mean_abs[idx]:.4f}")

fig = plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_sample, feature_names=FEATURE_NAMES, show=False)
plt.title("SHAP Beeswarm — XGBoost Feature Importance")
plt.tight_layout()
plt.show()

# ── Bar Plot (Feature Importance Ranking) ─────────────────────
fig = plt.figure(figsize=(10, 6))
shap.summary_plot(shap_values, X_sample, feature_names=FEATURE_NAMES,
                  plot_type="bar", show=False)
plt.title("SHAP Bar — Mean |SHAP Value| per Feature")
plt.tight_layout()
plt.show()

# ── Individual Explanation (one safe, one malicious) ──────────
print("\n" + "=" * 55)
print("Individual SHAP Explanation — Example URLs")
print("=" * 55)

safe_sample_idx = np.where(y_test[sample_idx] == 0)[0]
mal_sample_idx = np.where(y_test[sample_idx] == 1)[0]

if len(safe_sample_idx) > 0:
    idx = safe_sample_idx[0]
    sv = shap_values[idx]
    top_k = np.argsort(np.abs(sv))[::-1][:8]
    print(f"\nSafe URL (sample #{sample_idx[idx]}):")
    for k in top_k:
        direction = "risk ↑" if sv[k] > 0 else "safe ↓"
        print(f"  {FEATURE_NAMES[k]:<35} SHAP={sv[k]:+.4f}  ({direction})")

if len(mal_sample_idx) > 0:
    idx = mal_sample_idx[0]
    sv = shap_values[idx]
    top_k = np.argsort(np.abs(sv))[::-1][:8]
    print(f"\nMalicious URL (sample #{sample_idx[idx]}):")
    for k in top_k:
        direction = "risk ↑" if sv[k] > 0 else "safe ↓"
        print(f"  {FEATURE_NAMES[k]:<35} SHAP={sv[k]:+.4f}  ({direction})")

print("\nSHAP analysis complete. Feature attributions will be served via API.")

---
## Section 4: OOD Sanity Check

Out-of-distribution sanity check with real-world URLs.
Tests FORMAT DIVERSITY: bare domains, www., http://, complex paths.
Ensures the model learned actual phishing patterns, not formatting artifacts.

In [11]:
# ══════════════════════════════════════════════════════════════
# 4.1 Out-of-Distribution Sanity Check — XGBoost Only
# ══════════════════════════════════════════════════════════════
print("=" * 60)
print("OOD SANITY CHECK — Real-World URLs")
print("=" * 60)
print("Tests FORMAT DIVERSITY: bare domains, www., http://, complex paths\n")

ood_urls = [
    # ── SAFE: bare domains (the exact format QR scanners produce) ──
    ("https://www.google.com", "SAFE"),
    ("https://www.netflix.com", "SAFE"),
    ("https://www.wikipedia.org", "SAFE"),
    ("https://www.amazon.com", "SAFE"),
    ("https://www.youtube.com", "SAFE"),
    ("https://www.linkedin.com", "SAFE"),
    ("http://www.google.com", "SAFE"),
    # ── SAFE: without www ──
    ("https://google.com", "SAFE"),
    ("https://github.com", "SAFE"),
    ("https://reddit.com", "SAFE"),
    # ── SAFE: with complex paths ──
    ("https://www.kaggle.com/code/alexandrucalaras/model-training/edit", "SAFE"),
    ("https://x.com/home", "SAFE"),
    ("https://www.amazon.com/dp/B0D5B7TH89/ref=cm_sw_r_cp_ud_dp_abc123", "SAFE"),
    ("https://docs.google.com/spreadsheets/d/1aBcDeFgHiJkLmNoPqRsTuVwXyZ/edit#gid=0", "SAFE"),
    ("https://github.com/facebook/react/pull/28347/files", "SAFE"),
    ("https://www.youtube.com/watch?v=dQw4w9WgXcQ&list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf", "SAFE"),
    ("https://stackoverflow.com/questions/12345678/how-to-parse-json-in-python", "SAFE"),
    ("https://en.wikipedia.org/wiki/Machine_learning#Supervised_learning", "SAFE"),
    ("https://studenti.pub.ro/index.php?page=Informatii&ActiveViewID=tab_infogen", "SAFE"),
    ("https://mail.google.com/mail/u/0/#inbox", "SAFE"),
    ("https://zoom.us/j/1234567890?pwd=aBcDeFgH", "SAFE"),
    ("https://www.reddit.com/r/MachineLearning/comments/abc123/new_paper_on_transformers/", "SAFE"),
    # ── MALICIOUS: phishing ──
    ("http://192.168.1.1/login.php?user=admin&redirect=http://evil.com", "MALICIOUS"),
    ("http://g00gle-com.tk/secure/login/verify-account", "MALICIOUS"),
    ("http://free-iphone15-winner-claim.xyz/prize.php?id=28374", "MALICIOUS"),
    ("http://paypa1-security.top/update/billing/confirm.php", "MALICIOUS"),
    ("http://secure-bankofamerica.ml/signin?session=expired", "MALICIOUS"),
    ("http://bit.ly/3xY9z2k", "MALICIOUS"),
]

print(f"{'URL':<75} {'Expected':>10} {'XGB':>7} {'Pred':>6}")
print("-" * 105)

ood_correct = 0
ood_failures = []
for url, expected in ood_urls:
    feats = pd.DataFrame([extract_features(url)], columns=FEATURE_NAMES).fillna(0).astype(np.float32)
    xgb_p = xgb_calibrated.predict_proba(feats.values)[:, 1][0]
    pred = "MAL" if xgb_p >= 0.5 else "SAFE"
    correct = (pred == "SAFE" and expected == "SAFE") or (pred == "MAL" and expected == "MALICIOUS")
    ood_correct += correct
    marker = "OK" if correct else "FAIL"
    if not correct:
        ood_failures.append((url, expected, xgb_p))

    display_url = url[:72] + "..." if len(url) > 72 else url
    print(f"{display_url:<75} {expected:>10} {xgb_p:>7.3f} {pred:>5} {marker}")

print(f"\nOOD accuracy: {ood_correct}/{len(ood_urls)} ({ood_correct/len(ood_urls):.0%})")

if ood_failures:
    print(f"\n⚠️  {len(ood_failures)} FAILURES detected — model may have formatting bias!")
    for url, expected, score in ood_failures:
        print(f"  {url[:60]:<60}  expected={expected}  score={score:.3f}")
    print("\nCheck that augmented data includes www., http://, & bare domains.")
else:
    print("\n✅ All OOD tests passed — no formatting bias detected.")

---
## Section 5: Export Model for Server Integration

Saves everything needed to run inference on the FastAPI server:
- `xgb_model.pkl` — Calibrated XGBoost model (CalibratedClassifierCV)
- `feature_names.json` — Ordered list of 95 feature names

Place these files in the server's `models/` directory:
```
qr-security-server/
  models/
    xgb_model.pkl
    feature_names.json
```

In [None]:
import json
import joblib

# ── 5.1 Save Calibrated XGBoost Model ────────────────────────
joblib.dump(xgb_calibrated, "xgb_model.pkl")
print(f"XGBoost saved: xgb_model.pkl")

# ── 5.2 Save Feature Names ───────────────────────────────────
with open("feature_names.json", "w") as f:
    json.dump(FEATURE_NAMES, f)
print(f"Feature names saved: feature_names.json ({len(FEATURE_NAMES)} features)")

# ── 5.3 Summary ──────────────────────────────────────────────
print(f"\n{'='*55}")
print("MODEL EXPORT SUMMARY")
print(f"{'='*55}")
print(f"  Model:           CalibratedClassifierCV(XGBClassifier)")
print(f"  Features:        {len(FEATURE_NAMES)}")
print(f"  Test Accuracy:   {accuracy_score(y_test, xgb_preds):.4f}")
print(f"  Test F1:         {f1_score(y_test, xgb_preds):.4f}")
print(f"  Test ROC-AUC:    {roc_auc_score(y_test, xgb_probs):.4f}")
print(f"{'='*55}")

# ── 5.4 Package & Download ───────────────────────────────────
!zip -r url_classifier_models.zip xgb_model.pkl feature_names.json

print("\nAll models packaged: url_classifier_models.zip")
print("\nPlace these files in: qr-security-server/models/")
print("  models/xgb_model.pkl")
print("  models/feature_names.json")

# Auto-download (Colab)
try:
    from google.colab import files
    files.download("url_classifier_models.zip")
    print("Download started!")
except ImportError:
    print("Download the zip from the Output tab (Kaggle) or file browser.")